mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/parent-child-retrieval
This commit is contained in:
commit
c4aa98e609
|
|
@ -1,5 +1,5 @@
|
|||
FROM mcr.microsoft.com/devcontainers/python:3.10
|
||||
FROM mcr.microsoft.com/devcontainers/python:3.12
|
||||
|
||||
# [Optional] Uncomment this section to install additional OS packages.
|
||||
# RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||
# && apt-get -y install --no-install-recommends <your-package-list-here>
|
||||
# && apt-get -y install --no-install-recommends <your-package-list-here>
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
|
||||
// README at: https://github.com/devcontainers/templates/tree/main/src/anaconda
|
||||
{
|
||||
"name": "Python 3.10",
|
||||
"name": "Python 3.12",
|
||||
"build": {
|
||||
"context": "..",
|
||||
"dockerfile": "Dockerfile"
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ inputs:
|
|||
python-version:
|
||||
description: Python version to use and the Poetry installed with
|
||||
required: true
|
||||
default: '3.10'
|
||||
default: '3.11'
|
||||
poetry-version:
|
||||
description: Poetry version to set up
|
||||
required: true
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.10"
|
||||
- "3.11"
|
||||
- "3.12"
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ on:
|
|||
- api/core/rag/datasource/**
|
||||
- docker/**
|
||||
- .github/workflows/vdb-tests.yml
|
||||
- api/poetry.lock
|
||||
- api/pyproject.toml
|
||||
|
||||
concurrency:
|
||||
group: vdb-tests-${{ github.head_ref || github.run_id }}
|
||||
|
|
@ -20,7 +22,6 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.10"
|
||||
- "3.11"
|
||||
- "3.12"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
# CONTRIBUTING
|
||||
|
||||
So you're looking to contribute to Dify - that's awesome, we can't wait to see what you do. As a startup with limited headcount and funding, we have grand ambitions to design the most intuitive workflow for building and managing LLM applications. Any help from the community counts, truly.
|
||||
|
||||
We need to be nimble and ship fast given where we are, but we also want to make sure that contributors like you get as smooth an experience at contributing as possible. We've assembled this contribution guide for that purpose, aiming at getting you familiarized with the codebase & how we work with contributors, so you could quickly jump to the fun part.
|
||||
We need to be nimble and ship fast given where we are, but we also want to make sure that contributors like you get as smooth an experience at contributing as possible. We've assembled this contribution guide for that purpose, aiming at getting you familiarized with the codebase & how we work with contributors, so you could quickly jump to the fun part.
|
||||
|
||||
This guide, like Dify itself, is a constant work in progress. We highly appreciate your understanding if at times it lags behind the actual project, and welcome any feedback for us to improve.
|
||||
|
||||
|
|
@ -10,14 +12,12 @@ In terms of licensing, please take a minute to read our short [License and Contr
|
|||
|
||||
[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:open) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types:
|
||||
|
||||
### Feature requests:
|
||||
### Feature requests
|
||||
|
||||
* If you're opening a new feature request, we'd like you to explain what the proposed feature achieves, and include as much context as possible. [@perzeusss](https://github.com/perzeuss) has made a solid [Feature Request Copilot](https://udify.app/chat/MK2kVSnw1gakVwMX) that helps you draft out your needs. Feel free to give it a try.
|
||||
|
||||
* If you want to pick one up from the existing issues, simply drop a comment below it saying so.
|
||||
|
||||
|
||||
|
||||
A team member working in the related direction will be looped in. If all looks good, they will give the go-ahead for you to start coding. We ask that you hold off working on the feature until then, so none of your work goes to waste should we propose changes.
|
||||
|
||||
Depending on whichever area the proposed feature falls under, you might talk to different team members. Here's rundown of the areas each our team members are working on at the moment:
|
||||
|
|
@ -40,7 +40,7 @@ In terms of licensing, please take a minute to read our short [License and Contr
|
|||
| Non-core features and minor enhancements | Low Priority |
|
||||
| Valuable but not immediate | Future-Feature |
|
||||
|
||||
### Anything else (e.g. bug report, performance optimization, typo correction):
|
||||
### Anything else (e.g. bug report, performance optimization, typo correction)
|
||||
|
||||
* Start coding right away.
|
||||
|
||||
|
|
@ -52,7 +52,6 @@ In terms of licensing, please take a minute to read our short [License and Contr
|
|||
| Non-critical bugs, performance boosts | Medium Priority |
|
||||
| Minor fixes (typos, confusing but working UI) | Low Priority |
|
||||
|
||||
|
||||
## Installing
|
||||
|
||||
Here are the steps to set up Dify for development:
|
||||
|
|
@ -63,7 +62,7 @@ Here are the steps to set up Dify for development:
|
|||
|
||||
Clone the forked repository from your terminal:
|
||||
|
||||
```
|
||||
```shell
|
||||
git clone git@github.com:<github_username>/dify.git
|
||||
```
|
||||
|
||||
|
|
@ -71,11 +70,11 @@ git clone git@github.com:<github_username>/dify.git
|
|||
|
||||
Dify requires the following dependencies to build, make sure they're installed on your system:
|
||||
|
||||
- [Docker](https://www.docker.com/)
|
||||
- [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
- [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
|
||||
- [Python](https://www.python.org/) version 3.10.x
|
||||
* [Docker](https://www.docker.com/)
|
||||
* [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
* [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
* [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
|
||||
* [Python](https://www.python.org/) version 3.11.x or 3.12.x
|
||||
|
||||
### 4. Installations
|
||||
|
||||
|
|
@ -85,7 +84,7 @@ Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/install-faq) fo
|
|||
|
||||
### 5. Visit dify in your browser
|
||||
|
||||
To validate your set up, head over to [http://localhost:3000](http://localhost:3000) (the default, or your self-configured URL and port) in your browser. You should now see Dify up and running.
|
||||
To validate your set up, head over to [http://localhost:3000](http://localhost:3000) (the default, or your self-configured URL and port) in your browser. You should now see Dify up and running.
|
||||
|
||||
## Developing
|
||||
|
||||
|
|
@ -97,9 +96,9 @@ To help you quickly navigate where your contribution fits, a brief, annotated ou
|
|||
|
||||
### Backend
|
||||
|
||||
Dify’s backend is written in Python using [Flask](https://flask.palletsprojects.com/en/3.0.x/). It uses [SQLAlchemy](https://www.sqlalchemy.org/) for ORM and [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) for task queueing. Authorization logic goes via Flask-login.
|
||||
Dify’s backend is written in Python using [Flask](https://flask.palletsprojects.com/en/3.0.x/). It uses [SQLAlchemy](https://www.sqlalchemy.org/) for ORM and [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) for task queueing. Authorization logic goes via Flask-login.
|
||||
|
||||
```
|
||||
```text
|
||||
[api/]
|
||||
├── constants // Constant settings used throughout code base.
|
||||
├── controllers // API route definitions and request handling logic.
|
||||
|
|
@ -121,7 +120,7 @@ Dify’s backend is written in Python using [Flask](https://flask.palletsproject
|
|||
|
||||
The website is bootstrapped on [Next.js](https://nextjs.org/) boilerplate in Typescript and uses [Tailwind CSS](https://tailwindcss.com/) for styling. [React-i18next](https://react.i18next.com/) is used for internationalization.
|
||||
|
||||
```
|
||||
```text
|
||||
[web/]
|
||||
├── app // layouts, pages, and components
|
||||
│ ├── (commonLayout) // common layout used throughout the app
|
||||
|
|
@ -149,10 +148,10 @@ The website is bootstrapped on [Next.js](https://nextjs.org/) boilerplate in Typ
|
|||
|
||||
## Submitting your PR
|
||||
|
||||
At last, time to open a pull request (PR) to our repo. For major features, we first merge them into the `deploy/dev` branch for testing, before they go into the `main` branch. If you run into issues like merge conflicts or don't know how to open a pull request, check out [GitHub's pull request tutorial](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests).
|
||||
At last, time to open a pull request (PR) to our repo. For major features, we first merge them into the `deploy/dev` branch for testing, before they go into the `main` branch. If you run into issues like merge conflicts or don't know how to open a pull request, check out [GitHub's pull request tutorial](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests).
|
||||
|
||||
And that's it! Once your PR is merged, you will be featured as a contributor in our [README](https://github.com/langgenius/dify/blob/main/README.md).
|
||||
|
||||
## Getting Help
|
||||
|
||||
If you ever get stuck or got a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.
|
||||
If you ever get stuck or got a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ Dify 依赖以下工具和库:
|
|||
- [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
- [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
|
||||
- [Python](https://www.python.org/) version 3.10.x
|
||||
- [Python](https://www.python.org/) version 3.11.x or 3.12.x
|
||||
|
||||
### 4. 安装
|
||||
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ Dify を構築するには次の依存関係が必要です。それらがシス
|
|||
- [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
- [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
|
||||
- [Python](https://www.python.org/) version 3.10.x
|
||||
- [Python](https://www.python.org/) version 3.11.x or 3.12.x
|
||||
|
||||
### 4. インストール
|
||||
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ Dify yêu cầu các phụ thuộc sau để build, hãy đảm bảo chúng đ
|
|||
- [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
- [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
- [npm](https://www.npmjs.com/) phiên bản 8.x.x hoặc [Yarn](https://yarnpkg.com/)
|
||||
- [Python](https://www.python.org/) phiên bản 3.10.x
|
||||
- [Python](https://www.python.org/) phiên bản 3.11.x hoặc 3.12.x
|
||||
|
||||
### 4. Cài đặt
|
||||
|
||||
|
|
@ -153,4 +153,4 @@ Và thế là xong! Khi PR của bạn được merge, bạn sẽ được giớ
|
|||
|
||||
## Nhận trợ giúp
|
||||
|
||||
Nếu bạn gặp khó khăn hoặc có câu hỏi cấp bách trong quá trình đóng góp, hãy đặt câu hỏi của bạn trong vấn đề GitHub liên quan, hoặc tham gia [Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi để trò chuyện nhanh chóng.
|
||||
Nếu bạn gặp khó khăn hoặc có câu hỏi cấp bách trong quá trình đóng góp, hãy đặt câu hỏi của bạn trong vấn đề GitHub liên quan, hoặc tham gia [Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi để trò chuyện nhanh chóng.
|
||||
|
|
|
|||
|
|
@ -42,6 +42,11 @@ REDIS_SENTINEL_USERNAME=
|
|||
REDIS_SENTINEL_PASSWORD=
|
||||
REDIS_SENTINEL_SOCKET_TIMEOUT=0.1
|
||||
|
||||
# redis Cluster configuration.
|
||||
REDIS_USE_CLUSTERS=false
|
||||
REDIS_CLUSTERS=
|
||||
REDIS_CLUSTERS_PASSWORD=
|
||||
|
||||
# PostgreSQL database configuration
|
||||
DB_USERNAME=postgres
|
||||
DB_PASSWORD=difyai123456
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# base image
|
||||
FROM python:3.10-slim-bookworm AS base
|
||||
FROM python:3.12-slim-bookworm AS base
|
||||
|
||||
WORKDIR /app/api
|
||||
|
||||
|
|
|
|||
|
|
@ -18,12 +18,17 @@
|
|||
```
|
||||
|
||||
2. Copy `.env.example` to `.env`
|
||||
|
||||
```cli
|
||||
cp .env.example .env
|
||||
```
|
||||
3. Generate a `SECRET_KEY` in the `.env` file.
|
||||
|
||||
bash for Linux
|
||||
```bash for Linux
|
||||
sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env
|
||||
```
|
||||
|
||||
bash for Mac
|
||||
```bash for Mac
|
||||
secret_key=$(openssl rand -base64 42)
|
||||
sed -i '' "/^SECRET_KEY=/c\\
|
||||
|
|
@ -37,18 +42,10 @@
|
|||
5. Install dependencies
|
||||
|
||||
```bash
|
||||
poetry env use 3.10
|
||||
poetry env use 3.12
|
||||
poetry install
|
||||
```
|
||||
|
||||
In case of contributors missing to update dependencies for `pyproject.toml`, you can perform the following shell instead.
|
||||
|
||||
```bash
|
||||
poetry shell # activate current environment
|
||||
poetry add $(cat requirements.txt) # install dependencies of production and update pyproject.toml
|
||||
poetry add $(cat requirements-dev.txt) --group dev # install dependencies of development and update pyproject.toml
|
||||
```
|
||||
|
||||
6. Run migrate
|
||||
|
||||
Before the first launch, migrate the database to the latest version.
|
||||
|
|
@ -84,5 +81,3 @@
|
|||
```bash
|
||||
poetry run -C api bash dev/pytest/pytest_all_tests.sh
|
||||
```
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,11 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
python_version = sys.version_info
|
||||
if not ((3, 11) <= python_version < (3, 13)):
|
||||
print(f"Python 3.11 or 3.12 is required, current version is {python_version.major}.{python_version.minor}")
|
||||
raise SystemExit(1)
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
if not dify_config.DEBUG:
|
||||
|
|
@ -30,9 +35,6 @@ from models import account, dataset, model, source, task, tool, tools, web # no
|
|||
|
||||
# DO NOT REMOVE ABOVE
|
||||
|
||||
if sys.version_info[:2] == (3, 10):
|
||||
print("Warning: Python 3.10 will not be supported in the next version.")
|
||||
|
||||
|
||||
warnings.simplefilter("ignore", ResourceWarning)
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,6 @@ class DifyConfig(
|
|||
# read from dotenv format config file
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
frozen=True,
|
||||
# ignore extra attributes
|
||||
extra="ignore",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -68,3 +68,18 @@ class RedisConfig(BaseSettings):
|
|||
description="Socket timeout in seconds for Redis Sentinel connections",
|
||||
default=0.1,
|
||||
)
|
||||
|
||||
REDIS_USE_CLUSTERS: bool = Field(
|
||||
description="Enable Redis Clusters mode for high availability",
|
||||
default=False,
|
||||
)
|
||||
|
||||
REDIS_CLUSTERS: Optional[str] = Field(
|
||||
description="Comma-separated list of Redis Clusters nodes (host:port)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_CLUSTERS_PASSWORD: Optional[str] = Field(
|
||||
description="Password for Redis Clusters authentication (if required)",
|
||||
default=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
|||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.11.2",
|
||||
default="0.12.0",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from flask import Blueprint
|
|||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
from .app.app_import import AppImportApi, AppImportConfirmApi
|
||||
from .files import FileApi, FilePreviewApi, FileSupportTypeApi
|
||||
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
|
||||
|
||||
|
|
@ -17,6 +18,10 @@ api.add_resource(FileSupportTypeApi, "/files/support-type")
|
|||
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
||||
api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
|
||||
|
||||
# Import App
|
||||
api.add_resource(AppImportApi, "/apps/imports")
|
||||
api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm")
|
||||
|
||||
# Import other controllers
|
||||
from . import admin, apikey, extension, feature, ping, setup, version
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
import uuid
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, abort
|
||||
|
||||
from controllers.console import api
|
||||
|
|
@ -13,13 +16,15 @@ from controllers.console.wraps import (
|
|||
setup_required,
|
||||
)
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import (
|
||||
app_detail_fields,
|
||||
app_detail_fields_with_site,
|
||||
app_pagination_fields,
|
||||
)
|
||||
from libs.login import login_required
|
||||
from services.app_dsl_service import AppDslService
|
||||
from models import Account, App
|
||||
from services.app_dsl_service import AppDslService, ImportMode
|
||||
from services.app_service import AppService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
|
|
@ -92,61 +97,6 @@ class AppListApi(Resource):
|
|||
return app, 201
|
||||
|
||||
|
||||
class AppImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
def post(self):
|
||||
"""Import app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("data", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app = AppDslService.import_and_create_new_app(
|
||||
tenant_id=current_user.current_tenant_id, data=args["data"], args=args, account=current_user
|
||||
)
|
||||
|
||||
return app, 201
|
||||
|
||||
|
||||
class AppImportFromUrlApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
def post(self):
|
||||
"""Import app from url"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("url", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app = AppDslService.import_and_create_new_app_from_url(
|
||||
tenant_id=current_user.current_tenant_id, url=args["url"], args=args, account=current_user
|
||||
)
|
||||
|
||||
return app, 201
|
||||
|
||||
|
||||
class AppApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -224,10 +174,24 @@ class AppCopyApi(Resource):
|
|||
parser.add_argument("icon_background", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
data = AppDslService.export_dsl(app_model=app_model, include_secret=True)
|
||||
app = AppDslService.import_and_create_new_app(
|
||||
tenant_id=current_user.current_tenant_id, data=data, args=args, account=current_user
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
|
||||
account = cast(Account, current_user)
|
||||
result = import_service.import_app(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT.value,
|
||||
yaml_content=yaml_content,
|
||||
name=args.get("name"),
|
||||
description=args.get("description"),
|
||||
icon_type=args.get("icon_type"),
|
||||
icon=args.get("icon"),
|
||||
icon_background=args.get("icon_background"),
|
||||
)
|
||||
session.commit()
|
||||
|
||||
stmt = select(App).where(App.id == result.app_id)
|
||||
app = session.scalar(stmt)
|
||||
|
||||
return app, 201
|
||||
|
||||
|
|
@ -368,8 +332,6 @@ class AppTraceApi(Resource):
|
|||
|
||||
|
||||
api.add_resource(AppListApi, "/apps")
|
||||
api.add_resource(AppImportApi, "/apps/import")
|
||||
api.add_resource(AppImportFromUrlApi, "/apps/import/url")
|
||||
api.add_resource(AppApi, "/apps/<uuid:app_id>")
|
||||
api.add_resource(AppCopyApi, "/apps/<uuid:app_id>/copy")
|
||||
api.add_resource(AppExportApi, "/apps/<uuid:app_id>/export")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,90 @@
|
|||
from typing import cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_import_fields
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.app_dsl_service import AppDslService, ImportStatus
|
||||
|
||||
|
||||
class AppImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_fields)
|
||||
def post(self):
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("mode", type=str, required=True, location="json")
|
||||
parser.add_argument("yaml_content", type=str, location="json")
|
||||
parser.add_argument("yaml_url", type=str, location="json")
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument("app_id", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Import app
|
||||
account = cast(Account, current_user)
|
||||
result = import_service.import_app(
|
||||
account=account,
|
||||
import_mode=args["mode"],
|
||||
yaml_content=args.get("yaml_content"),
|
||||
yaml_url=args.get("yaml_url"),
|
||||
name=args.get("name"),
|
||||
description=args.get("description"),
|
||||
icon_type=args.get("icon_type"),
|
||||
icon=args.get("icon"),
|
||||
icon_background=args.get("icon_background"),
|
||||
app_id=args.get("app_id"),
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED.value:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING.value:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
class AppImportConfirmApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_fields)
|
||||
def post(self, import_id):
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Confirm import
|
||||
account = cast(Account, current_user)
|
||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED.value:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytz
|
||||
from flask_login import current_user
|
||||
|
|
@ -314,7 +314,7 @@ def _get_conversation(app_model, conversation_id):
|
|||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
if not conversation.read_at:
|
||||
conversation.read_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
conversation.read_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
conversation.read_account_id = current_user.id
|
||||
db.session.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
|
|
@ -75,7 +75,7 @@ class AppSite(Resource):
|
|||
setattr(site, attr_name, value)
|
||||
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
site.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return site
|
||||
|
|
@ -99,7 +99,7 @@ class AppSiteAccessTokenReset(Resource):
|
|||
|
||||
site.code = Site.generate_code(16)
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
site.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return site
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ from libs.helper import TimestampField, uuid_value
|
|||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.workflow_service import WorkflowService
|
||||
|
|
@ -126,31 +125,6 @@ class DraftWorkflowApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
class DraftWorkflowImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_fields)
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Import draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("data", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
workflow = AppDslService.import_and_overwrite_workflow(
|
||||
app_model=app_model, data=args["data"], account=current_user
|
||||
)
|
||||
|
||||
return workflow
|
||||
|
||||
|
||||
class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -453,7 +427,6 @@ class ConvertToWorkflowApi(Resource):
|
|||
|
||||
|
||||
api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft")
|
||||
api.add_resource(DraftWorkflowImportApi, "/apps/<uuid:app_id>/workflows/draft/import")
|
||||
api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run")
|
||||
api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run")
|
||||
api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ class ActivateApi(Resource):
|
|||
account.timezone = args["timezone"]
|
||||
account.interface_theme = "light"
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
|
@ -106,7 +106,7 @@ class OAuthCallback(Resource):
|
|||
|
||||
if account.status == AccountStatus.PENDING.value:
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class DataSourceApi(Resource):
|
|||
if action == "enable":
|
||||
if data_source_binding.disabled:
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
|
|
@ -92,7 +92,7 @@ class DataSourceApi(Resource):
|
|||
if action == "disable":
|
||||
if not data_source_binding.disabled:
|
||||
data_source_binding.disabled = True
|
||||
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
from argparse import ArgumentTypeError
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
|
|
@ -665,7 +665,7 @@ class DocumentProcessingApi(DocumentResource):
|
|||
raise InvalidActionError("Document not in indexing state.")
|
||||
|
||||
document.paused_by = current_user.id
|
||||
document.paused_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
document.paused_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
document.is_paused = True
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -745,7 +745,7 @@ class DocumentMetadataApi(DocumentResource):
|
|||
document.doc_metadata[key] = value
|
||||
|
||||
document.doc_type = doc_type
|
||||
document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success", "message": "Document metadata updated."}, 200
|
||||
|
|
@ -787,7 +787,7 @@ class DocumentStatusApi(DocumentResource):
|
|||
document.enabled = True
|
||||
document.disabled_at = None
|
||||
document.disabled_by = None
|
||||
document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
# Set cache to prevent indexing the same document multiple times
|
||||
|
|
@ -804,9 +804,9 @@ class DocumentStatusApi(DocumentResource):
|
|||
raise InvalidActionError("Document already disabled.")
|
||||
|
||||
document.enabled = False
|
||||
document.disabled_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
document.disabled_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
document.disabled_by = current_user.id
|
||||
document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
# Set cache to prevent indexing the same document multiple times
|
||||
|
|
@ -821,9 +821,9 @@ class DocumentStatusApi(DocumentResource):
|
|||
raise InvalidActionError("Document already archived.")
|
||||
|
||||
document.archived = True
|
||||
document.archived_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
document.archived_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
document.archived_by = current_user.id
|
||||
document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
if document.enabled:
|
||||
|
|
@ -840,7 +840,7 @@ class DocumentStatusApi(DocumentResource):
|
|||
document.archived = False
|
||||
document.archived_at = None
|
||||
document.archived_by = None
|
||||
document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
# Set cache to prevent indexing the same document multiple times
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pandas as pd
|
||||
from flask import request
|
||||
|
|
@ -188,7 +188,7 @@ class DatasetDocumentSegmentApi(Resource):
|
|||
raise InvalidActionError("Segment is already disabled.")
|
||||
|
||||
segment.enabled = False
|
||||
segment.disabled_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
segment.disabled_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
segment.disabled_by = current_user.id
|
||||
db.session.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import reqparse
|
||||
|
|
@ -46,7 +46,7 @@ class CompletionApi(InstalledAppResource):
|
|||
streaming = args["response_mode"] == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
|
|
@ -106,7 +106,7 @@ class ChatApi(InstalledAppResource):
|
|||
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, inputs, marshal_with, reqparse
|
||||
|
|
@ -81,7 +81,7 @@ class InstalledAppsListApi(Resource):
|
|||
tenant_id=current_tenant_id,
|
||||
app_owner_tenant_id=app.tenant_id,
|
||||
is_pinned=False,
|
||||
last_used_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
last_used_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
db.session.add(new_installed_app)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ class AccountInitApi(Resource):
|
|||
raise InvalidInvitationCodeError()
|
||||
|
||||
invitation_code.status = "used"
|
||||
invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
invitation_code.used_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
invitation_code.used_by_tenant_id = account.current_tenant_id
|
||||
invitation_code.used_by_account_id = account.id
|
||||
|
||||
|
|
@ -68,7 +68,7 @@ class AccountInitApi(Resource):
|
|||
account.timezone = args["timezone"]
|
||||
account.interface_theme = "light"
|
||||
account.status = "active"
|
||||
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Callable
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
|
@ -198,7 +198,7 @@ def validate_and_get_api_token(scope=None):
|
|||
if not api_token:
|
||||
raise Unauthorized("Access token is invalid")
|
||||
|
||||
api_token.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return api_token
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import json
|
|||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||
|
|
@ -114,16 +114,9 @@ class BaseAgentRunner(AppRunner):
|
|||
# check if model supports stream tool call
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
|
||||
self.stream_tool_call = True
|
||||
else:
|
||||
self.stream_tool_call = False
|
||||
|
||||
# check if model supports vision
|
||||
if model_schema and ModelFeature.VISION in (model_schema.features or []):
|
||||
self.files = application_generate_entity.files
|
||||
else:
|
||||
self.files = []
|
||||
features = model_schema.features if model_schema and model_schema.features else []
|
||||
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
|
||||
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
|
||||
self.query = None
|
||||
self._current_thoughts: list[PromptMessage] = []
|
||||
|
||||
|
|
@ -250,7 +243,7 @@ class BaseAgentRunner(AppRunner):
|
|||
update prompt message tool
|
||||
"""
|
||||
# try to get tool runtime parameters
|
||||
tool_runtime_parameters = tool.get_runtime_parameters() or []
|
||||
tool_runtime_parameters = tool.get_runtime_parameters()
|
||||
|
||||
for parameter in tool_runtime_parameters:
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
|
|
@ -419,7 +412,7 @@ class BaseAgentRunner(AppRunner):
|
|||
.first()
|
||||
)
|
||||
|
||||
db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from core.provider_manager import ProviderManager
|
|||
|
||||
class ModelConfigConverter:
|
||||
@classmethod
|
||||
def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity:
|
||||
def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity:
|
||||
"""
|
||||
Convert app model config dict to entity.
|
||||
:param app_config: app config
|
||||
|
|
@ -38,27 +38,23 @@ class ModelConfigConverter:
|
|||
)
|
||||
|
||||
if model_credentials is None:
|
||||
if not skip_check:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
else:
|
||||
model_credentials = {}
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
|
||||
if not skip_check:
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_config.model, model_type=ModelType.LLM
|
||||
)
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_config.model, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
model_name = model_config.model
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
if provider_model is None:
|
||||
model_name = model_config.model
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||
|
||||
# model config
|
||||
completion_params = model_config.parameters
|
||||
|
|
@ -76,7 +72,7 @@ class ModelConfigConverter:
|
|||
|
||||
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
|
||||
|
||||
if not skip_check and not model_schema:
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
return ModelConfigWithCredentialsEntity(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from core.app.app_config.entities import (
|
||||
AdvancedChatMessageEntity,
|
||||
AdvancedChatPromptTemplateEntity,
|
||||
AdvancedCompletionPromptTemplateEntity,
|
||||
PromptTemplateEntity,
|
||||
|
|
@ -25,7 +26,9 @@ class PromptTemplateConfigManager:
|
|||
chat_prompt_messages = []
|
||||
for message in chat_prompt_config.get("prompt", []):
|
||||
chat_prompt_messages.append(
|
||||
{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
|
||||
AdvancedChatMessageEntity(
|
||||
**{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
|
||||
)
|
||||
)
|
||||
|
||||
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
|
@ -88,7 +88,7 @@ class PromptTemplateEntity(BaseModel):
|
|||
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
|
||||
|
||||
|
||||
class VariableEntityType(str, Enum):
|
||||
class VariableEntityType(StrEnum):
|
||||
TEXT_INPUT = "text-input"
|
||||
SELECT = "select"
|
||||
PARAGRAPH = "paragraph"
|
||||
|
|
|
|||
|
|
@ -127,7 +127,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=conversation.inputs
|
||||
if conversation
|
||||
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
|
||||
else self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
|
|
|
|||
|
|
@ -134,7 +134,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=conversation.inputs
|
||||
if conversation
|
||||
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
|
||||
else self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
|
|
@ -6,7 +6,7 @@ from core.file import File, FileUploadConfig
|
|||
from factories import file_factory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.app.app_config.entities import AppConfig, VariableEntity
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
|
||||
|
||||
class BaseAppGenerator:
|
||||
|
|
@ -14,23 +14,23 @@ class BaseAppGenerator:
|
|||
self,
|
||||
*,
|
||||
user_inputs: Optional[Mapping[str, Any]],
|
||||
app_config: "AppConfig",
|
||||
variables: Sequence["VariableEntity"],
|
||||
tenant_id: str,
|
||||
) -> Mapping[str, Any]:
|
||||
user_inputs = user_inputs or {}
|
||||
# Filter input variables from form configuration, handle required fields, default values, and option values
|
||||
variables = app_config.variables
|
||||
user_inputs = {
|
||||
var.variable: self._validate_inputs(value=user_inputs.get(var.variable), variable_entity=var)
|
||||
for var in variables
|
||||
}
|
||||
user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()}
|
||||
# Convert files in inputs to File
|
||||
entity_dictionary = {item.variable: item for item in app_config.variables}
|
||||
entity_dictionary = {item.variable: item for item in variables}
|
||||
# Convert single file to File
|
||||
files_inputs = {
|
||||
k: file_factory.build_from_mapping(
|
||||
mapping=v,
|
||||
tenant_id=app_config.tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
config=FileUploadConfig(
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
|
|
@ -44,7 +44,7 @@ class BaseAppGenerator:
|
|||
file_list_inputs = {
|
||||
k: file_factory.build_from_mappings(
|
||||
mappings=v,
|
||||
tenant_id=app_config.tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
config=FileUploadConfig(
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
|
|
|
|||
|
|
@ -132,7 +132,9 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=conversation.inputs
|
||||
if conversation
|
||||
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
|
||||
else self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
|
|
|
|||
|
|
@ -113,7 +113,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
file_upload_config=file_extra_config,
|
||||
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
user_id=user.id,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy import and_
|
||||
|
|
@ -200,7 +200,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
else:
|
||||
conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
conversation.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
message = Message(
|
||||
|
|
|
|||
|
|
@ -96,7 +96,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
file_upload_config=file_extra_config,
|
||||
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
files=system_files,
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
|
|
|
|||
|
|
@ -43,7 +43,6 @@ from core.workflow.graph_engine.entities.event import (
|
|||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.iteration import IterationNodeData
|
||||
from core.workflow.nodes.node_mapping import node_type_classes_mapping
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -160,8 +159,6 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
node_type=node_type,
|
||||
node_data=IterationNodeData(**iteration_node_config.get("data", {})),
|
||||
)
|
||||
|
||||
return graph, variable_pool
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
|
@ -11,7 +11,7 @@ from core.workflow.nodes import NodeType
|
|||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class QueueEvent(str, Enum):
|
||||
class QueueEvent(StrEnum):
|
||||
"""
|
||||
QueueEvent enum
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import json
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, Union, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
@ -80,38 +81,38 @@ class WorkflowCycleManage:
|
|||
|
||||
inputs[f"sys.{key.value}"] = value
|
||||
|
||||
inputs = WorkflowEntry.handle_special_values(inputs)
|
||||
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom.DEBUGGING
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
|
||||
else WorkflowRunTriggeredFrom.APP_RUN
|
||||
)
|
||||
|
||||
# init workflow run
|
||||
workflow_run = WorkflowRun()
|
||||
workflow_run_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID]
|
||||
if workflow_run_id:
|
||||
workflow_run.id = workflow_run_id
|
||||
workflow_run.tenant_id = self._workflow.tenant_id
|
||||
workflow_run.app_id = self._workflow.app_id
|
||||
workflow_run.sequence_number = new_sequence_number
|
||||
workflow_run.workflow_id = self._workflow.id
|
||||
workflow_run.type = self._workflow.type
|
||||
workflow_run.triggered_from = triggered_from.value
|
||||
workflow_run.version = self._workflow.version
|
||||
workflow_run.graph = self._workflow.graph
|
||||
workflow_run.inputs = json.dumps(inputs)
|
||||
workflow_run.status = WorkflowRunStatus.RUNNING.value
|
||||
workflow_run.created_by_role = (
|
||||
CreatedByRole.ACCOUNT.value if isinstance(self._user, Account) else CreatedByRole.END_USER.value
|
||||
)
|
||||
workflow_run.created_by = self._user.id
|
||||
# handle special values
|
||||
inputs = WorkflowEntry.handle_special_values(inputs)
|
||||
|
||||
db.session.add(workflow_run)
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_run)
|
||||
db.session.close()
|
||||
# init workflow run
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = WorkflowRun()
|
||||
system_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID]
|
||||
workflow_run.id = system_id or str(uuid4())
|
||||
workflow_run.tenant_id = self._workflow.tenant_id
|
||||
workflow_run.app_id = self._workflow.app_id
|
||||
workflow_run.sequence_number = new_sequence_number
|
||||
workflow_run.workflow_id = self._workflow.id
|
||||
workflow_run.type = self._workflow.type
|
||||
workflow_run.triggered_from = triggered_from.value
|
||||
workflow_run.version = self._workflow.version
|
||||
workflow_run.graph = self._workflow.graph
|
||||
workflow_run.inputs = json.dumps(inputs)
|
||||
workflow_run.status = WorkflowRunStatus.RUNNING
|
||||
workflow_run.created_by_role = (
|
||||
CreatedByRole.ACCOUNT if isinstance(self._user, Account) else CreatedByRole.END_USER
|
||||
)
|
||||
workflow_run.created_by = self._user.id
|
||||
workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
session.add(workflow_run)
|
||||
session.commit()
|
||||
|
||||
return workflow_run
|
||||
|
||||
|
|
@ -144,7 +145,7 @@ class WorkflowCycleManage:
|
|||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_run)
|
||||
|
|
@ -191,7 +192,7 @@ class WorkflowCycleManage:
|
|||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -211,15 +212,18 @@ class WorkflowCycleManage:
|
|||
for workflow_node_execution in running_workflow_node_executions:
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_node_execution.elapsed_time = (
|
||||
workflow_node_execution.finished_at - workflow_node_execution.created_at
|
||||
).total_seconds()
|
||||
db.session.commit()
|
||||
|
||||
db.session.refresh(workflow_run)
|
||||
db.session.close()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
session.add(workflow_run)
|
||||
session.refresh(workflow_run)
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
|
|
@ -259,7 +263,7 @@ class WorkflowCycleManage:
|
|||
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
}
|
||||
)
|
||||
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
session.add(workflow_node_execution)
|
||||
session.commit()
|
||||
|
|
@ -282,7 +286,7 @@ class WorkflowCycleManage:
|
|||
execution_metadata = (
|
||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||
)
|
||||
finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
|
||||
|
|
@ -326,7 +330,7 @@ class WorkflowCycleManage:
|
|||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
execution_metadata = (
|
||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||
|
|
@ -381,7 +385,7 @@ class WorkflowCycleManage:
|
|||
id=workflow_run.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
sequence_number=workflow_run.sequence_number,
|
||||
inputs=workflow_run.inputs_dict or {},
|
||||
inputs=workflow_run.inputs_dict,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
),
|
||||
)
|
||||
|
|
@ -428,7 +432,7 @@ class WorkflowCycleManage:
|
|||
created_by=created_by,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
finished_at=int(workflow_run.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict or {}),
|
||||
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict),
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -654,7 +658,7 @@ class WorkflowCycleManage:
|
|||
if event.error is None
|
||||
else WorkflowNodeExecutionStatus.FAILED,
|
||||
error=None,
|
||||
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
|
||||
elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(),
|
||||
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
|
||||
execution_metadata=event.metadata,
|
||||
finished_at=int(time.time()),
|
||||
|
|
|
|||
|
|
@ -240,7 +240,7 @@ class ProviderConfiguration(BaseModel):
|
|||
if provider_record:
|
||||
provider_record.encrypted_config = json.dumps(credentials)
|
||||
provider_record.is_valid = True
|
||||
provider_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
provider_record = Provider(
|
||||
|
|
@ -394,7 +394,7 @@ class ProviderConfiguration(BaseModel):
|
|||
if provider_model_record:
|
||||
provider_model_record.encrypted_config = json.dumps(credentials)
|
||||
provider_model_record.is_valid = True
|
||||
provider_model_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
provider_model_record = ProviderModel(
|
||||
|
|
@ -468,7 +468,7 @@ class ProviderConfiguration(BaseModel):
|
|||
|
||||
if model_setting:
|
||||
model_setting.enabled = True
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
model_setting = ProviderModelSetting(
|
||||
|
|
@ -503,7 +503,7 @@ class ProviderConfiguration(BaseModel):
|
|||
|
||||
if model_setting:
|
||||
model_setting.enabled = False
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
model_setting = ProviderModelSetting(
|
||||
|
|
@ -570,7 +570,7 @@ class ProviderConfiguration(BaseModel):
|
|||
|
||||
if model_setting:
|
||||
model_setting.load_balancing_enabled = True
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
model_setting = ProviderModelSetting(
|
||||
|
|
@ -605,7 +605,7 @@ class ProviderConfiguration(BaseModel):
|
|||
|
||||
if model_setting:
|
||||
model_setting.load_balancing_enabled = False
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
model_setting = ProviderModelSetting(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class FileType(str, Enum):
|
||||
class FileType(StrEnum):
|
||||
IMAGE = "image"
|
||||
DOCUMENT = "document"
|
||||
AUDIO = "audio"
|
||||
|
|
@ -16,7 +16,7 @@ class FileType(str, Enum):
|
|||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileTransferMethod(str, Enum):
|
||||
class FileTransferMethod(StrEnum):
|
||||
REMOTE_URL = "remote_url"
|
||||
LOCAL_FILE = "local_file"
|
||||
TOOL_FILE = "tool_file"
|
||||
|
|
@ -29,7 +29,7 @@ class FileTransferMethod(str, Enum):
|
|||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileBelongsTo(str, Enum):
|
||||
class FileBelongsTo(StrEnum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
|
|
@ -41,7 +41,7 @@ class FileBelongsTo(str, Enum):
|
|||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileAttribute(str, Enum):
|
||||
class FileAttribute(StrEnum):
|
||||
TYPE = "type"
|
||||
SIZE = "size"
|
||||
NAME = "name"
|
||||
|
|
@ -51,5 +51,5 @@ class FileAttribute(str, Enum):
|
|||
EXTENSION = "extension"
|
||||
|
||||
|
||||
class ArrayFileAttribute(str, Enum):
|
||||
class ArrayFileAttribute(StrEnum):
|
||||
LENGTH = "length"
|
||||
|
|
|
|||
|
|
@ -3,7 +3,12 @@ import base64
|
|||
from configs import dify_config
|
||||
from core.file import file_repository
|
||||
from core.helper import ssrf_proxy
|
||||
from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent
|
||||
from core.model_runtime.entities import (
|
||||
AudioPromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
|
@ -29,35 +34,17 @@ def get_attr(*, file: File, attr: FileAttribute):
|
|||
return file.remote_url
|
||||
case FileAttribute.EXTENSION:
|
||||
return file.extension
|
||||
case _:
|
||||
raise ValueError(f"Invalid file attribute: {attr}")
|
||||
|
||||
|
||||
def to_prompt_message_content(
|
||||
f: File,
|
||||
/,
|
||||
*,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
):
|
||||
"""
|
||||
Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object.
|
||||
|
||||
This function takes a File object and converts it to an appropriate PromptMessageContent
|
||||
object, which can be used as a prompt for image or audio-based AI models.
|
||||
|
||||
Args:
|
||||
f (File): The File object to convert.
|
||||
detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts.
|
||||
If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW.
|
||||
|
||||
Returns:
|
||||
Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level
|
||||
|
||||
Raises:
|
||||
ValueError: If the file type is not supported or if required data is missing.
|
||||
"""
|
||||
match f.type:
|
||||
case FileType.IMAGE:
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
|
||||
data = _to_url(f)
|
||||
else:
|
||||
|
|
@ -65,7 +52,7 @@ def to_prompt_message_content(
|
|||
|
||||
return ImagePromptMessageContent(data=data, detail=image_detail_config)
|
||||
case FileType.AUDIO:
|
||||
encoded_string = _file_to_encoded_string(f)
|
||||
encoded_string = _get_encoded_string(f)
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
|
||||
|
|
@ -74,9 +61,20 @@ def to_prompt_message_content(
|
|||
data = _to_url(f)
|
||||
else:
|
||||
data = _to_base64_data_string(f)
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
|
||||
case FileType.DOCUMENT:
|
||||
data = _get_encoded_string(f)
|
||||
if f.mime_type is None:
|
||||
raise ValueError("Missing file mime_type")
|
||||
return DocumentPromptMessageContent(
|
||||
encode_format="base64",
|
||||
mime_type=f.mime_type,
|
||||
data=data,
|
||||
)
|
||||
case _:
|
||||
raise ValueError("file type f.type is not supported")
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
|
||||
|
||||
def download(f: File, /):
|
||||
|
|
@ -118,21 +116,16 @@ def _get_encoded_string(f: File, /):
|
|||
case FileTransferMethod.REMOTE_URL:
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
encoded_string = base64.b64encode(content).decode("utf-8")
|
||||
return encoded_string
|
||||
data = response.content
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||
data = _download_file_content(upload_file.key)
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return encoded_string
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
|
||||
data = _download_file_content(tool_file.file_key)
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return encoded_string
|
||||
case _:
|
||||
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return encoded_string
|
||||
|
||||
|
||||
def _to_base64_data_string(f: File, /):
|
||||
|
|
@ -140,18 +133,6 @@ def _to_base64_data_string(f: File, /):
|
|||
return f"data:{f.mime_type};base64,{encoded_string}"
|
||||
|
||||
|
||||
def _file_to_encoded_string(f: File, /):
|
||||
match f.type:
|
||||
case FileType.IMAGE:
|
||||
return _to_base64_data_string(f)
|
||||
case FileType.VIDEO:
|
||||
return _to_base64_data_string(f)
|
||||
case FileType.AUDIO:
|
||||
return _get_encoded_string(f)
|
||||
case _:
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
|
||||
|
||||
def _to_url(f: File, /):
|
||||
if f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
if f.remote_url is None:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from threading import Lock
|
||||
from typing import Any, Optional
|
||||
|
||||
|
|
@ -31,7 +31,7 @@ class CodeExecutionResponse(BaseModel):
|
|||
data: Data
|
||||
|
||||
|
||||
class CodeLanguage(str, Enum):
|
||||
class CodeLanguage(StrEnum):
|
||||
PYTHON3 = "python3"
|
||||
JINJA2 = "jinja2"
|
||||
JAVASCRIPT = "javascript"
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from core.rag.splitter.fixed_text_splitter import (
|
|||
)
|
||||
from core.rag.splitter.text_splitter import TextSplitter
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
|
|
@ -85,7 +86,7 @@ class IndexingRunner:
|
|||
except ProviderTokenNotInitError as e:
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e.description)
|
||||
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
except ObjectDeletedError:
|
||||
logging.warning("Document deleted, document id: {}".format(dataset_document.id))
|
||||
|
|
@ -93,7 +94,7 @@ class IndexingRunner:
|
|||
logging.exception("consume document failed")
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e)
|
||||
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
def run_in_splitting_status(self, dataset_document: DatasetDocument):
|
||||
|
|
@ -141,13 +142,13 @@ class IndexingRunner:
|
|||
except ProviderTokenNotInitError as e:
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e.description)
|
||||
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
logging.exception("consume document failed")
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e)
|
||||
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
def run_in_indexing_status(self, dataset_document: DatasetDocument):
|
||||
|
|
@ -199,13 +200,13 @@ class IndexingRunner:
|
|||
except ProviderTokenNotInitError as e:
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e.description)
|
||||
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
logging.exception("consume document failed")
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e)
|
||||
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
def indexing_estimate(
|
||||
|
|
@ -279,6 +280,19 @@ class IndexingRunner:
|
|||
if len(preview_texts) < 5:
|
||||
preview_texts.append(document.page_content)
|
||||
|
||||
# delete image files and related db records
|
||||
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
|
||||
for upload_file_id in image_upload_file_ids:
|
||||
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
|
||||
try:
|
||||
storage.delete(image_file.key)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"Delete image_files failed while indexing_estimate, \
|
||||
image_upload_file_is: {}".format(upload_file_id)
|
||||
)
|
||||
db.session.delete(image_file)
|
||||
|
||||
if doc_form and doc_form == "qa_model":
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
|
|
@ -358,7 +372,7 @@ class IndexingRunner:
|
|||
after_indexing_status="splitting",
|
||||
extra_update_params={
|
||||
DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs),
|
||||
DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -450,7 +464,7 @@ class IndexingRunner:
|
|||
doc_store.add_documents(documents)
|
||||
|
||||
# update document status to indexing
|
||||
cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
self._update_document_index_status(
|
||||
document_id=dataset_document.id,
|
||||
after_indexing_status="indexing",
|
||||
|
|
@ -465,7 +479,7 @@ class IndexingRunner:
|
|||
dataset_document_id=dataset_document.id,
|
||||
update_params={
|
||||
DocumentSegment.status: "indexing",
|
||||
DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -666,7 +680,7 @@ class IndexingRunner:
|
|||
after_indexing_status="completed",
|
||||
extra_update_params={
|
||||
DatasetDocument.tokens: tokens,
|
||||
DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
DatasetDocument.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
|
||||
DatasetDocument.error: None,
|
||||
},
|
||||
|
|
@ -691,7 +705,7 @@ class IndexingRunner:
|
|||
{
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -724,7 +738,7 @@ class IndexingRunner:
|
|||
{
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -835,7 +849,7 @@ class IndexingRunner:
|
|||
doc_store.add_documents(documents)
|
||||
|
||||
# update document status to indexing
|
||||
cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
self._update_document_index_status(
|
||||
document_id=dataset_document.id,
|
||||
after_indexing_status="indexing",
|
||||
|
|
@ -850,7 +864,7 @@ class IndexingRunner:
|
|||
dataset_document_id=dataset_document.id,
|
||||
update_params={
|
||||
DocumentSegment.status: "indexing",
|
||||
DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
},
|
||||
)
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file import file_manager
|
||||
from core.file.models import FileType
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
|
|
@ -27,7 +27,7 @@ class TokenBufferMemory:
|
|||
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
||||
) -> list[PromptMessage]:
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
:param max_token_limit: max token limit
|
||||
|
|
@ -102,12 +102,11 @@ class TokenBufferMemory:
|
|||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||
for file in file_objs:
|
||||
if file.type in {FileType.IMAGE, FileType.AUDIO}:
|
||||
prompt_message = file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=detail,
|
||||
)
|
||||
prompt_message_contents.append(prompt_message)
|
||||
prompt_message = file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=detail,
|
||||
)
|
||||
prompt_message_contents.append(prompt_message)
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
|
||||
|
|
|
|||
|
|
@ -100,10 +100,10 @@ class ModelInstance:
|
|||
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
|
|
@ -31,7 +32,7 @@ class Callback(ABC):
|
|||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
|
|
@ -60,7 +61,7 @@ class Callback(ABC):
|
|||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
):
|
||||
|
|
@ -90,7 +91,7 @@ class Callback(ABC):
|
|||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
|
|
@ -120,7 +121,7 @@ class Callback(ABC):
|
|||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsa
|
|||
from .message_entities import (
|
||||
AssistantPromptMessage,
|
||||
AudioPromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
|
|
@ -37,4 +38,5 @@ __all__ = [
|
|||
"LLMResultChunk",
|
||||
"LLMResultChunkDelta",
|
||||
"AudioPromptMessageContent",
|
||||
"DocumentPromptMessageContent",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from abc import ABC
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
|
@ -48,7 +49,7 @@ class PromptMessageFunction(BaseModel):
|
|||
function: PromptMessageTool
|
||||
|
||||
|
||||
class PromptMessageContentType(Enum):
|
||||
class PromptMessageContentType(StrEnum):
|
||||
"""
|
||||
Enum class for prompt message content type.
|
||||
"""
|
||||
|
|
@ -57,6 +58,7 @@ class PromptMessageContentType(Enum):
|
|||
IMAGE = "image"
|
||||
AUDIO = "audio"
|
||||
VIDEO = "video"
|
||||
DOCUMENT = "document"
|
||||
|
||||
|
||||
class PromptMessageContent(BaseModel):
|
||||
|
|
@ -93,7 +95,7 @@ class ImagePromptMessageContent(PromptMessageContent):
|
|||
Model class for image prompt message content.
|
||||
"""
|
||||
|
||||
class DETAIL(str, Enum):
|
||||
class DETAIL(StrEnum):
|
||||
LOW = "low"
|
||||
HIGH = "high"
|
||||
|
||||
|
|
@ -101,13 +103,20 @@ class ImagePromptMessageContent(PromptMessageContent):
|
|||
detail: DETAIL = DETAIL.LOW
|
||||
|
||||
|
||||
class DocumentPromptMessageContent(PromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
|
||||
encode_format: Literal["base64"]
|
||||
mime_type: str
|
||||
data: str
|
||||
|
||||
|
||||
class PromptMessage(ABC, BaseModel):
|
||||
"""
|
||||
Model class for prompt message.
|
||||
"""
|
||||
|
||||
role: PromptMessageRole
|
||||
content: Optional[str | list[PromptMessageContent]] = None
|
||||
content: Optional[str | Sequence[PromptMessageContent]] = None
|
||||
name: Optional[str] = None
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
|
@ -87,9 +87,12 @@ class ModelFeature(Enum):
|
|||
AGENT_THOUGHT = "agent-thought"
|
||||
VISION = "vision"
|
||||
STREAM_TOOL_CALL = "stream-tool-call"
|
||||
DOCUMENT = "document"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
|
||||
|
||||
class DefaultParameterName(str, Enum):
|
||||
class DefaultParameterName(StrEnum):
|
||||
"""
|
||||
Enum class for parameter template variable.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import logging
|
|||
import re
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
|
@ -48,7 +48,7 @@ class LargeLanguageModel(AIModel):
|
|||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
|
|
@ -169,7 +169,7 @@ class LargeLanguageModel(AIModel):
|
|||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
|
|
@ -212,7 +212,7 @@ if you are not sure about the structure.
|
|||
)
|
||||
|
||||
model_parameters.pop("response_format")
|
||||
stop = stop or []
|
||||
stop = list(stop) if stop is not None else []
|
||||
stop.extend(["\n```", "```\n"])
|
||||
block_prompts = block_prompts.replace("{{block}}", code_block)
|
||||
|
||||
|
|
@ -408,7 +408,7 @@ if you are not sure about the structure.
|
|||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
|
|
@ -479,7 +479,7 @@ if you are not sure about the structure.
|
|||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
|
|
@ -601,7 +601,7 @@ if you are not sure about the structure.
|
|||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
|
|
@ -647,7 +647,7 @@ if you are not sure about the structure.
|
|||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
|
|
@ -694,7 +694,7 @@ if you are not sure about the structure.
|
|||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
|
|
@ -742,7 +742,7 @@ if you are not sure about the structure.
|
|||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ features:
|
|||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ features:
|
|||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import base64
|
||||
import io
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import anthropic
|
||||
|
|
@ -21,9 +21,9 @@ from httpx import Timeout
|
|||
from PIL import Image
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
|
|
@ -33,6 +33,7 @@ from core.model_runtime.entities.message_entities import (
|
|||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
|
|
@ -86,10 +87,10 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
|
|
@ -130,9 +131,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||
# Add the new header for claude-3-5-sonnet-20240620 model
|
||||
extra_headers = {}
|
||||
if model == "claude-3-5-sonnet-20240620":
|
||||
if model_parameters.get("max_tokens") > 4096:
|
||||
if model_parameters.get("max_tokens", 0) > 4096:
|
||||
extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15"
|
||||
|
||||
if any(
|
||||
isinstance(content, DocumentPromptMessageContent)
|
||||
for prompt_message in prompt_messages
|
||||
if isinstance(prompt_message.content, list)
|
||||
for content in prompt_message.content
|
||||
):
|
||||
extra_headers["anthropic-beta"] = "pdfs-2024-09-25"
|
||||
|
||||
if tools:
|
||||
extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools]
|
||||
response = client.beta.tools.messages.create(
|
||||
|
|
@ -444,7 +453,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||
|
||||
return credentials_kwargs
|
||||
|
||||
def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
|
||||
def _convert_prompt_messages(self, prompt_messages: Sequence[PromptMessage]) -> tuple[str, list[dict]]:
|
||||
"""
|
||||
Convert prompt messages to dict list and system
|
||||
"""
|
||||
|
|
@ -452,7 +461,15 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||
first_loop = True
|
||||
for message in prompt_messages:
|
||||
if isinstance(message, SystemPromptMessage):
|
||||
message.content = message.content.strip()
|
||||
if isinstance(message.content, str):
|
||||
message.content = message.content.strip()
|
||||
elif isinstance(message.content, list):
|
||||
# System prompt only support text
|
||||
message.content = "".join(
|
||||
c.data.strip() for c in message.content if isinstance(c, TextPromptMessageContent)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown system prompt message content type {type(message.content)}")
|
||||
if first_loop:
|
||||
system = message.content
|
||||
first_loop = False
|
||||
|
|
@ -504,6 +521,21 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||
"source": {"type": "base64", "media_type": mime_type, "data": base64_data},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif isinstance(message_content, DocumentPromptMessageContent):
|
||||
if message_content.mime_type != "application/pdf":
|
||||
raise ValueError(
|
||||
f"Unsupported document type {message_content.mime_type}, "
|
||||
"only support application/pdf"
|
||||
)
|
||||
sub_message_dict = {
|
||||
"type": "document",
|
||||
"source": {
|
||||
"type": message_content.encode_format,
|
||||
"media_type": message_content.mime_type,
|
||||
"data": message_content.data,
|
||||
},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
prompt_message_dicts.append({"role": "user", "content": sub_messages})
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
|
|
|
|||
|
|
@ -779,7 +779,7 @@ LLM_BASE_MODELS = [
|
|||
name="frequency_penalty",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=4096),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=16384),
|
||||
ParameterRule(
|
||||
name="seed",
|
||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||
|
|
|
|||
|
|
@ -2,13 +2,11 @@
|
|||
import base64
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
# 3rd import
|
||||
import boto3
|
||||
import requests
|
||||
from botocore.config import Config
|
||||
from botocore.exceptions import (
|
||||
ClientError,
|
||||
|
|
@ -439,22 +437,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
if not message_content.data.startswith("data:"):
|
||||
# fetch image data from url
|
||||
try:
|
||||
url = message_content.data
|
||||
image_content = requests.get(url).content
|
||||
if "?" in url:
|
||||
url = url.split("?")[0]
|
||||
mime_type, _ = mimetypes.guess_type(url)
|
||||
base64_data = base64.b64encode(image_content).decode("utf-8")
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||
else:
|
||||
data_split = message_content.data.split(";base64,")
|
||||
mime_type = data_split[0].replace("data:", "")
|
||||
base64_data = data_split[1]
|
||||
image_content = base64.b64decode(base64_data)
|
||||
data_split = message_content.data.split(";base64,")
|
||||
mime_type = data_split[0].replace("data:", "")
|
||||
base64_data = data_split[1]
|
||||
image_content = base64.b64decode(base64_data)
|
||||
|
||||
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
|
||||
raise ValueError(
|
||||
|
|
|
|||
|
|
@ -15,9 +15,9 @@ parameter_rules:
|
|||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 4096
|
||||
max: 8192
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
|
|
|
|||
|
|
@ -16,9 +16,9 @@ parameter_rules:
|
|||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 4096
|
||||
max: 8192
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
|
|
|
|||
|
|
@ -691,8 +691,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||
base_model_schema = cast(AIModelEntity, base_model_schema)
|
||||
|
||||
base_model_schema_features = base_model_schema.features or []
|
||||
base_model_schema_model_properties = base_model_schema.model_properties or {}
|
||||
base_model_schema_parameters_rules = base_model_schema.parameter_rules or []
|
||||
base_model_schema_model_properties = base_model_schema.model_properties
|
||||
base_model_schema_parameters_rules = base_model_schema.parameter_rules
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ label:
|
|||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
|
|
@ -72,7 +73,7 @@ parameter_rules:
|
|||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '1'
|
||||
output: '2'
|
||||
unit: '0.000001'
|
||||
input: "1"
|
||||
output: "2"
|
||||
unit: "0.000001"
|
||||
currency: RMB
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ label:
|
|||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
|
|
|
|||
|
|
@ -1,18 +1,17 @@
|
|||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import tiktoken
|
||||
from yarl import URL
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguageModel
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
|
||||
class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel):
|
||||
class DeepseekLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -25,92 +24,15 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel):
|
|||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
# refactored from openai model runtime, use cl100k_base for calculate token number
|
||||
def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Calculate num tokens for text completion model with tiktoken package.
|
||||
|
||||
:param model: model name
|
||||
:param text: prompt text
|
||||
:param tools: tools for tool calling
|
||||
:return: number of tokens
|
||||
"""
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
num_tokens = len(encoding.encode(text))
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(encoding, tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
# refactored from openai model runtime, use cl100k_base for calculate token number
|
||||
def _num_tokens_from_messages(
|
||||
self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
|
||||
) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
|
||||
num_tokens = 0
|
||||
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
# Cast str(value) in case the message value is not a string
|
||||
# This occurs with function messages
|
||||
# TODO: The current token calculation method for the image type is not implemented,
|
||||
# which need to download the image and then get the resolution for calculation,
|
||||
# and will increase the request delay
|
||||
if isinstance(value, list):
|
||||
text = ""
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item["type"] == "text":
|
||||
text += item["text"]
|
||||
|
||||
value = text
|
||||
|
||||
if key == "tool_calls":
|
||||
for tool_call in value:
|
||||
for t_key, t_value in tool_call.items():
|
||||
num_tokens += len(encoding.encode(t_key))
|
||||
if t_key == "function":
|
||||
for f_key, f_value in t_value.items():
|
||||
num_tokens += len(encoding.encode(f_key))
|
||||
num_tokens += len(encoding.encode(f_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(t_key))
|
||||
num_tokens += len(encoding.encode(t_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(str(value)))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(encoding, tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict) -> None:
|
||||
credentials["mode"] = "chat"
|
||||
credentials["openai_api_key"] = credentials["api_key"]
|
||||
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
|
||||
credentials["openai_api_base"] = "https://api.deepseek.com"
|
||||
else:
|
||||
parsed_url = urlparse(credentials["endpoint_url"])
|
||||
credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
def _add_custom_parameters(credentials) -> None:
|
||||
credentials["endpoint_url"] = str(URL(credentials.get("endpoint_url", "https://api.deepseek.com")))
|
||||
credentials["mode"] = LLMMode.CHAT.value
|
||||
credentials["function_calling_type"] = "tool_call"
|
||||
credentials["stream_function_calling"] = "support"
|
||||
|
|
|
|||
|
|
@ -18,7 +18,8 @@ class FishAudioProvider(ModelProvider):
|
|||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.TTS)
|
||||
model_instance.validate_credentials(credentials=credentials)
|
||||
# FIXME fish tts do not have model for now, so set it to empty string instead
|
||||
model_instance.validate_credentials(model="", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ class FishAudioText2SpeechModel(TTSModel):
|
|||
voice=voice,
|
||||
)
|
||||
|
||||
def validate_credentials(self, credentials: dict, user: Optional[str] = None) -> None:
|
||||
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Validate credentials for text2speech model
|
||||
|
||||
|
|
@ -76,7 +76,7 @@ class FishAudioText2SpeechModel(TTSModel):
|
|||
|
||||
try:
|
||||
self.get_tts_model_voices(
|
||||
None,
|
||||
"",
|
||||
credentials={
|
||||
"api_key": credentials["api_key"],
|
||||
"api_base": credentials["api_base"],
|
||||
|
|
|
|||
|
|
@ -122,7 +122,7 @@ class GiteeAIRerankModel(RerankModel):
|
|||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.RERANK,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512))},
|
||||
)
|
||||
|
||||
return entity
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ features:
|
|||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ features:
|
|||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ features:
|
|||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ features:
|
|||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ features:
|
|||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ features:
|
|||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ features:
|
|||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ features:
|
|||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ features:
|
|||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ features:
|
|||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ features:
|
|||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ features:
|
|||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ features:
|
|||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
|
|
|||
|
|
@ -7,9 +7,10 @@ features:
|
|||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
|
|
|||
|
|
@ -0,0 +1,38 @@
|
|||
model: gemini-exp-1121
|
||||
label:
|
||||
en_US: Gemini exp 1121
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
model: learnlm-1.5-pro-experimental
|
||||
label:
|
||||
en_US: LearnLM 1.5 Pro Experimental
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
@ -16,6 +16,7 @@ from PIL import Image
|
|||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
|
|
@ -35,6 +36,21 @@ from core.model_runtime.errors.invoke import (
|
|||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
GOOGLE_AVAILABLE_MIMETYPE = [
|
||||
"application/pdf",
|
||||
"application/x-javascript",
|
||||
"text/javascript",
|
||||
"application/x-python",
|
||||
"text/x-python",
|
||||
"text/plain",
|
||||
"text/html",
|
||||
"text/css",
|
||||
"text/md",
|
||||
"text/csv",
|
||||
"text/xml",
|
||||
"text/rtf",
|
||||
]
|
||||
|
||||
|
||||
class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
def _invoke(
|
||||
|
|
@ -370,6 +386,12 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||
blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}}
|
||||
glm_content["parts"].append(blob)
|
||||
elif c.type == PromptMessageContentType.DOCUMENT:
|
||||
message_content = cast(DocumentPromptMessageContent, c)
|
||||
if message_content.mime_type not in GOOGLE_AVAILABLE_MIMETYPE:
|
||||
raise ValueError(f"Unsupported mime type {message_content.mime_type}")
|
||||
blob = {"inline_data": {"mime_type": message_content.mime_type, "data": message_content.data}}
|
||||
glm_content["parts"].append(blob)
|
||||
|
||||
return glm_content
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
|
|
|
|||
|
|
@ -140,7 +140,7 @@ class GPUStackRerankModel(RerankModel):
|
|||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.RERANK,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512))},
|
||||
)
|
||||
|
||||
return entity
|
||||
|
|
|
|||
|
|
@ -34,3 +34,11 @@ model_credential_schema:
|
|||
placeholder:
|
||||
zh_Hans: 在此输入Text Embedding Inference的服务器地址,如 http://192.168.1.100:8080
|
||||
en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
|
|
|
|||
|
|
@ -51,8 +51,13 @@ class HuggingfaceTeiRerankModel(RerankModel):
|
|||
|
||||
server_url = server_url.removesuffix("/")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
api_key = credentials.get("api_key")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
try:
|
||||
results = TeiHelper.invoke_rerank(server_url, query, docs)
|
||||
results = TeiHelper.invoke_rerank(server_url, query, docs, headers)
|
||||
|
||||
rerank_documents = []
|
||||
for result in results:
|
||||
|
|
@ -80,7 +85,11 @@ class HuggingfaceTeiRerankModel(RerankModel):
|
|||
"""
|
||||
try:
|
||||
server_url = credentials["server_url"]
|
||||
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
api_key = credentials.get("api_key")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers)
|
||||
if extra_args.model_type != "reranker":
|
||||
raise CredentialsValidateFailedError("Current model is not a rerank model")
|
||||
|
||||
|
|
|
|||
|
|
@ -26,13 +26,15 @@ cache_lock = Lock()
|
|||
|
||||
class TeiHelper:
|
||||
@staticmethod
|
||||
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
|
||||
def get_tei_extra_parameter(
|
||||
server_url: str, model_name: str, headers: Optional[dict] = None
|
||||
) -> TeiModelExtraParameter:
|
||||
TeiHelper._clean_cache()
|
||||
with cache_lock:
|
||||
if model_name not in cache:
|
||||
cache[model_name] = {
|
||||
"expires": time() + 300,
|
||||
"value": TeiHelper._get_tei_extra_parameter(server_url),
|
||||
"value": TeiHelper._get_tei_extra_parameter(server_url, headers),
|
||||
}
|
||||
return cache[model_name]["value"]
|
||||
|
||||
|
|
@ -47,7 +49,7 @@ class TeiHelper:
|
|||
pass
|
||||
|
||||
@staticmethod
|
||||
def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter:
|
||||
def _get_tei_extra_parameter(server_url: str, headers: Optional[dict] = None) -> TeiModelExtraParameter:
|
||||
"""
|
||||
get tei model extra parameter like model_type, max_input_length, max_batch_requests
|
||||
"""
|
||||
|
|
@ -61,7 +63,7 @@ class TeiHelper:
|
|||
session.mount("https://", HTTPAdapter(max_retries=3))
|
||||
|
||||
try:
|
||||
response = session.get(url, timeout=10)
|
||||
response = session.get(url, headers=headers, timeout=10)
|
||||
except (MissingSchema, ConnectionError, Timeout) as e:
|
||||
raise RuntimeError(f"get tei model extra parameter failed, url: {url}, error: {e}")
|
||||
if response.status_code != 200:
|
||||
|
|
@ -86,7 +88,7 @@ class TeiHelper:
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
|
||||
def invoke_tokenize(server_url: str, texts: list[str], headers: Optional[dict] = None) -> list[list[dict]]:
|
||||
"""
|
||||
Invoke tokenize endpoint
|
||||
|
||||
|
|
@ -114,15 +116,15 @@ class TeiHelper:
|
|||
:param server_url: server url
|
||||
:param texts: texts to tokenize
|
||||
"""
|
||||
resp = httpx.post(
|
||||
f"{server_url}/tokenize",
|
||||
json={"inputs": texts},
|
||||
)
|
||||
url = f"{server_url}/tokenize"
|
||||
json_data = {"inputs": texts}
|
||||
resp = httpx.post(url, json=json_data, headers=headers)
|
||||
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
@staticmethod
|
||||
def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
|
||||
def invoke_embeddings(server_url: str, texts: list[str], headers: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Invoke embeddings endpoint
|
||||
|
||||
|
|
@ -147,15 +149,14 @@ class TeiHelper:
|
|||
:param texts: texts to embed
|
||||
"""
|
||||
# Use OpenAI compatible API here, which has usage tracking
|
||||
resp = httpx.post(
|
||||
f"{server_url}/v1/embeddings",
|
||||
json={"input": texts},
|
||||
)
|
||||
url = f"{server_url}/v1/embeddings"
|
||||
json_data = {"input": texts}
|
||||
resp = httpx.post(url, json=json_data, headers=headers)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
@staticmethod
|
||||
def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]:
|
||||
def invoke_rerank(server_url: str, query: str, docs: list[str], headers: Optional[dict] = None) -> list[dict]:
|
||||
"""
|
||||
Invoke rerank endpoint
|
||||
|
||||
|
|
@ -173,10 +174,7 @@ class TeiHelper:
|
|||
:param candidates: candidates to rerank
|
||||
"""
|
||||
params = {"query": query, "texts": docs, "return_text": True}
|
||||
|
||||
response = httpx.post(
|
||||
server_url + "/rerank",
|
||||
json=params,
|
||||
)
|
||||
url = f"{server_url}/rerank"
|
||||
response = httpx.post(url, json=params, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
|
|
|||
|
|
@ -51,6 +51,10 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
|||
|
||||
server_url = server_url.removesuffix("/")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
api_key = credentials["api_key"]
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
# get model properties
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
|
|
@ -60,7 +64,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
|||
used_tokens = 0
|
||||
|
||||
# get tokenized results from TEI
|
||||
batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts)
|
||||
batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts, headers)
|
||||
|
||||
for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)):
|
||||
# Check if the number of tokens is larger than the context size
|
||||
|
|
@ -97,7 +101,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
|||
used_tokens = 0
|
||||
for i in _iter:
|
||||
iter_texts = inputs[i : i + max_chunks]
|
||||
results = TeiHelper.invoke_embeddings(server_url, iter_texts)
|
||||
results = TeiHelper.invoke_embeddings(server_url, iter_texts, headers)
|
||||
embeddings = results["data"]
|
||||
embeddings = [embedding["embedding"] for embedding in embeddings]
|
||||
batched_embeddings.extend(embeddings)
|
||||
|
|
@ -127,7 +131,11 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
|||
|
||||
server_url = server_url.removesuffix("/")
|
||||
|
||||
batch_tokens = TeiHelper.invoke_tokenize(server_url, texts)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {credentials.get('api_key')}",
|
||||
}
|
||||
|
||||
batch_tokens = TeiHelper.invoke_tokenize(server_url, texts, headers)
|
||||
num_tokens = sum(len(tokens) for tokens in batch_tokens)
|
||||
return num_tokens
|
||||
|
||||
|
|
@ -141,7 +149,14 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
|||
"""
|
||||
try:
|
||||
server_url = credentials["server_url"]
|
||||
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
api_key = credentials.get("api_key")
|
||||
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers)
|
||||
print(extra_args)
|
||||
if extra_args.model_type != "embedding":
|
||||
raise CredentialsValidateFailedError("Current model is not a embedding model")
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ class JinaRerankModel(RerankModel):
|
|||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.RERANK,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000))},
|
||||
)
|
||||
|
||||
return entity
|
||||
|
|
|
|||
|
|
@ -193,7 +193,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
|||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000))},
|
||||
)
|
||||
|
||||
return entity
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from core.model_runtime.entities.message_entities import (
|
|||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
|
|
@ -86,6 +87,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
|
|
@ -153,6 +155,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
|
|
@ -196,6 +199,8 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
if completion_type is LLMMode.CHAT:
|
||||
endpoint_url = urljoin(endpoint_url, "api/chat")
|
||||
data["messages"] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
||||
if tools:
|
||||
data["tools"] = [self._convert_prompt_message_tool_to_dict(tool) for tool in tools]
|
||||
else:
|
||||
endpoint_url = urljoin(endpoint_url, "api/generate")
|
||||
first_prompt_message = prompt_messages[0]
|
||||
|
|
@ -232,7 +237,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
if stream:
|
||||
return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages)
|
||||
|
||||
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages)
|
||||
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages, tools)
|
||||
|
||||
def _handle_generate_response(
|
||||
self,
|
||||
|
|
@ -241,6 +246,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
completion_type: LLMMode,
|
||||
response: requests.Response,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]],
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm completion response
|
||||
|
|
@ -253,14 +259,16 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
:return: llm result
|
||||
"""
|
||||
response_json = response.json()
|
||||
|
||||
tool_calls = []
|
||||
if completion_type is LLMMode.CHAT:
|
||||
message = response_json.get("message", {})
|
||||
response_content = message.get("content", "")
|
||||
response_tool_calls = message.get("tool_calls", [])
|
||||
tool_calls = [self._extract_response_tool_call(tool_call) for tool_call in response_tool_calls]
|
||||
else:
|
||||
response_content = response_json["response"]
|
||||
|
||||
assistant_message = AssistantPromptMessage(content=response_content)
|
||||
assistant_message = AssistantPromptMessage(content=response_content, tool_calls=tool_calls)
|
||||
|
||||
if "prompt_eval_count" in response_json and "eval_count" in response_json:
|
||||
# transform usage
|
||||
|
|
@ -405,9 +413,28 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
|
||||
chunk_index += 1
|
||||
|
||||
def _convert_prompt_message_tool_to_dict(self, tool: PromptMessageTool) -> dict:
|
||||
"""
|
||||
Convert PromptMessageTool to dict for Ollama API
|
||||
|
||||
:param tool: tool
|
||||
:return: tool dict
|
||||
"""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
},
|
||||
}
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict for Ollama API
|
||||
|
||||
:param message: prompt message
|
||||
:return: message dict
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
|
|
@ -432,6 +459,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {"role": "tool", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
|
|
@ -452,6 +482,29 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
|
||||
return num_tokens
|
||||
|
||||
def _extract_response_tool_call(self, response_tool_call: dict) -> AssistantPromptMessage.ToolCall:
|
||||
"""
|
||||
Extract response tool call
|
||||
"""
|
||||
tool_call = None
|
||||
if response_tool_call and "function" in response_tool_call:
|
||||
# Convert arguments to JSON string if it's a dict
|
||||
arguments = response_tool_call.get("function").get("arguments")
|
||||
if isinstance(arguments, dict):
|
||||
arguments = json.dumps(arguments)
|
||||
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call.get("function").get("name"),
|
||||
arguments=arguments,
|
||||
)
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_tool_call.get("function").get("name"),
|
||||
type="function",
|
||||
function=function,
|
||||
)
|
||||
|
||||
return tool_call
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
Get customizable model schema.
|
||||
|
|
@ -461,10 +514,15 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||
|
||||
:return: model schema
|
||||
"""
|
||||
extras = {}
|
||||
extras = {
|
||||
"features": [],
|
||||
}
|
||||
|
||||
if "vision_support" in credentials and credentials["vision_support"] == "true":
|
||||
extras["features"] = [ModelFeature.VISION]
|
||||
extras["features"].append(ModelFeature.VISION)
|
||||
if "function_call_support" in credentials and credentials["function_call_support"] == "true":
|
||||
extras["features"].append(ModelFeature.TOOL_CALL)
|
||||
extras["features"].append(ModelFeature.MULTI_TOOL_CALL)
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
|
|
|
|||
|
|
@ -96,3 +96,22 @@ model_credential_schema:
|
|||
label:
|
||||
en_US: 'No'
|
||||
zh_Hans: 否
|
||||
- variable: function_call_support
|
||||
label:
|
||||
zh_Hans: 是否支持函数调用
|
||||
en_US: Function call support
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
default: 'false'
|
||||
type: radio
|
||||
required: false
|
||||
options:
|
||||
- value: 'true'
|
||||
label:
|
||||
en_US: 'Yes'
|
||||
zh_Hans: 是
|
||||
- value: 'false'
|
||||
label:
|
||||
en_US: 'No'
|
||||
zh_Hans: 否
|
||||
|
|
|
|||
|
|
@ -139,7 +139,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
|
|||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")),
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)),
|
||||
ModelPropertyKey.MAX_CHUNKS: 1,
|
||||
},
|
||||
parameter_rules=[],
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
- gpt-4o
|
||||
- gpt-4o-2024-05-13
|
||||
- gpt-4o-2024-08-06
|
||||
- gpt-4o-2024-11-20
|
||||
- chatgpt-4o-latest
|
||||
- gpt-4o-mini
|
||||
- gpt-4o-mini-2024-07-18
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue