diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index d681dc6627..719e6cfe90 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -20,7 +20,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v44 + uses: tj-actions/changed-files@v45 with: files: api/** @@ -66,7 +66,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v44 + uses: tj-actions/changed-files@v45 with: files: web/** @@ -97,7 +97,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v44 + uses: tj-actions/changed-files@v45 with: files: | **.sh @@ -107,7 +107,7 @@ jobs: dev/** - name: Super-linter - uses: super-linter/super-linter/slim@v6 + uses: super-linter/super-linter/slim@v7 if: steps.changed-files.outputs.any_changed == 'true' env: BASH_SEVERITY: warning diff --git a/.github/workflows/translate-i18n-base-on-english.yml b/.github/workflows/translate-i18n-base-on-english.yml new file mode 100644 index 0000000000..3f51b3b2c7 --- /dev/null +++ b/.github/workflows/translate-i18n-base-on-english.yml @@ -0,0 +1,54 @@ +name: Check i18n Files and Create PR + +on: + pull_request: + types: [closed] + branches: [main] + +jobs: + check-and-update: + if: github.event.pull_request.merged == true + runs-on: ubuntu-latest + defaults: + run: + working-directory: web + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 2 # last 2 commits + + - name: Check for file changes in i18n/en-US + id: check_files + run: | + recent_commit_sha=$(git rev-parse HEAD) + second_recent_commit_sha=$(git rev-parse HEAD~1) + changed_files=$(git diff --name-only $recent_commit_sha $second_recent_commit_sha -- 'i18n/en-US/*.ts') + echo "Changed files: $changed_files" + if [ -n "$changed_files" ]; then + echo "FILES_CHANGED=true" >> $GITHUB_ENV + else + echo "FILES_CHANGED=false" >> $GITHUB_ENV + fi + + - name: Set up Node.js + if: env.FILES_CHANGED == 'true' + uses: actions/setup-node@v2 + with: + node-version: 'lts/*' + + - name: Install dependencies + if: env.FILES_CHANGED == 'true' + run: yarn install --frozen-lockfile + + - name: Run npm script + if: env.FILES_CHANGED == 'true' + run: npm run auto-gen-i18n + + - name: Create Pull Request + if: env.FILES_CHANGED == 'true' + uses: peter-evans/create-pull-request@v6 + with: + commit-message: Update i18n files based on en-US changes + title: 'chore: translate i18n files' + body: This PR was automatically created to update i18n files based on changes in en-US locale. + branch: chore/automated-i18n-updates diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f810584f24..8f57cd545e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -8,7 +8,7 @@ In terms of licensing, please take a minute to read our short [License and Contr ## Before you jump in -[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types: +[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:open) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types: ### Feature requests: diff --git a/CONTRIBUTING_CN.md b/CONTRIBUTING_CN.md index 303c2513f5..7cd2bb60eb 100644 --- a/CONTRIBUTING_CN.md +++ b/CONTRIBUTING_CN.md @@ -8,7 +8,7 @@ ## 在开始之前 -[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:closed)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类: +[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:open)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类: ### 功能请求: diff --git a/CONTRIBUTING_JA.md b/CONTRIBUTING_JA.md index 6d5bfb205c..a68bdeddbc 100644 --- a/CONTRIBUTING_JA.md +++ b/CONTRIBUTING_JA.md @@ -10,7 +10,7 @@ Dify にコントリビュートしたいとお考えなのですね。それは ## 飛び込む前に -[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。 +[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:open) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。 ### 機能リクエスト diff --git a/CONTRIBUTING_VI.md b/CONTRIBUTING_VI.md index 2521256d24..80e68a046e 100644 --- a/CONTRIBUTING_VI.md +++ b/CONTRIBUTING_VI.md @@ -8,7 +8,7 @@ Về vấn đề cấp phép, xin vui lòng dành chút thời gian đọc qua [ ## Trước khi bắt đầu -[Tìm kiếm](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) một vấn đề hiện có, hoặc [tạo mới](https://github.com/langgenius/dify/issues/new/choose) một vấn đề. Chúng tôi phân loại các vấn đề thành 2 loại: +[Tìm kiếm](https://github.com/langgenius/dify/issues?q=is:issue+is:open) một vấn đề hiện có, hoặc [tạo mới](https://github.com/langgenius/dify/issues/new/choose) một vấn đề. Chúng tôi phân loại các vấn đề thành 2 loại: ### Yêu cầu tính năng: diff --git a/LICENSE b/LICENSE index 071ef42bda..06b0fa1d12 100644 --- a/LICENSE +++ b/LICENSE @@ -4,7 +4,7 @@ Dify is licensed under the Apache License 2.0, with the following additional con 1. Dify may be utilized commercially, including as a backend service for other applications or as an application development platform for enterprises. Should the conditions below be met, a commercial license must be obtained from the producer: -a. Multi-tenant SaaS service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment. +a. Multi-tenant service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment. - Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations. b. LOGO and copyright information: In the process of using Dify's frontend components, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend components. diff --git a/api/.env.example b/api/.env.example index 502f641d0c..2de37af1ca 100644 --- a/api/.env.example +++ b/api/.env.example @@ -39,7 +39,7 @@ DB_DATABASE=dify # Storage configuration # use for store upload files, private keys... -# storage type: local, s3, azure-blob, google-storage +# storage type: local, s3, azure-blob, google-storage, tencent-cos, huawei-obs, volcengine-tos STORAGE_TYPE=local STORAGE_LOCAL_PATH=storage S3_USE_AWS_MANAGED_IAM=false @@ -60,7 +60,8 @@ ALIYUN_OSS_SECRET_KEY=your-secret-key ALIYUN_OSS_ENDPOINT=your-endpoint ALIYUN_OSS_AUTH_VERSION=v1 ALIYUN_OSS_REGION=your-region - +# Don't start with '/'. OSS doesn't support leading slash in object names. +ALIYUN_OSS_PATH=your-path # Google Storage configuration GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string @@ -72,6 +73,12 @@ TENCENT_COS_SECRET_ID=your-secret-id TENCENT_COS_REGION=your-region TENCENT_COS_SCHEME=your-scheme +# Huawei OBS Storage Configuration +HUAWEI_OBS_BUCKET_NAME=your-bucket-name +HUAWEI_OBS_SECRET_KEY=your-secret-key +HUAWEI_OBS_ACCESS_KEY=your-access-key +HUAWEI_OBS_SERVER=your-server-url + # OCI Storage configuration OCI_ENDPOINT=your-endpoint OCI_BUCKET_NAME=your-bucket-name @@ -79,6 +86,13 @@ OCI_ACCESS_KEY=your-access-key OCI_SECRET_KEY=your-secret-key OCI_REGION=your-region +# Volcengine tos Storage configuration +VOLCENGINE_TOS_ENDPOINT=your-endpoint +VOLCENGINE_TOS_BUCKET_NAME=your-bucket-name +VOLCENGINE_TOS_ACCESS_KEY=your-access-key +VOLCENGINE_TOS_SECRET_KEY=your-secret-key +VOLCENGINE_TOS_REGION=your-region + # CORS configuration WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* @@ -100,11 +114,10 @@ QDRANT_GRPC_ENABLED=false QDRANT_GRPC_PORT=6334 # Milvus configuration -MILVUS_HOST=127.0.0.1 -MILVUS_PORT=19530 +MILVUS_URI=http://127.0.0.1:19530 +MILVUS_TOKEN= MILVUS_USER=root MILVUS_PASSWORD=Milvus -MILVUS_SECURE=false # MyScale configuration MYSCALE_HOST=127.0.0.1 diff --git a/api/Dockerfile b/api/Dockerfile index cca6488679..6483f8281b 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -55,7 +55,7 @@ RUN apt-get update \ && echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \ && apt-get update \ # For Security - && apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-1 libldap-2.5-0=2.5.18+dfsg-2 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \ + && apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \ && apt-get autoremove -y \ && rm -rf /var/lib/apt/lists/* diff --git a/api/commands.py b/api/commands.py index 41f1a6444c..3bf8bc0ecc 100644 --- a/api/commands.py +++ b/api/commands.py @@ -559,8 +559,9 @@ def add_qdrant_doc_id_index(field: str): @click.command("create-tenant", help="Create account and tenant.") @click.option("--email", prompt=True, help="The email address of the tenant account.") +@click.option("--name", prompt=True, help="The workspace name of the tenant account.") @click.option("--language", prompt=True, help="Account language, default: en-US.") -def create_tenant(email: str, language: Optional[str] = None): +def create_tenant(email: str, language: Optional[str] = None, name: Optional[str] = None): """ Create tenant account """ @@ -580,13 +581,15 @@ def create_tenant(email: str, language: Optional[str] = None): if language not in languages: language = "en-US" + name = name.strip() + # generate random password new_password = secrets.token_urlsafe(16) # register account account = RegisterService.register(email=email, name=account_name, password=new_password, language=language) - TenantService.create_owner_tenant_if_not_exist(account) + TenantService.create_owner_tenant_if_not_exist(account, name) click.echo( click.style( diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index d9e7038091..5d5411555a 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Annotated, Optional from pydantic import ( AliasChoices, @@ -55,7 +55,7 @@ class CodeExecutionSandboxConfig(BaseSettings): """ CODE_EXECUTION_ENDPOINT: HttpUrl = Field( - description="endpoint URL of code execution servcie", + description="endpoint URL of code execution service", default="http://sandbox:8194", ) @@ -226,20 +226,17 @@ class HttpConfig(BaseSettings): def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]: return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",") - HTTP_REQUEST_MAX_CONNECT_TIMEOUT: NonNegativeInt = Field( - description="", - default=300, - ) + HTTP_REQUEST_MAX_CONNECT_TIMEOUT: Annotated[ + PositiveInt, Field(ge=10, description="connect timeout in seconds for HTTP request") + ] = 10 - HTTP_REQUEST_MAX_READ_TIMEOUT: NonNegativeInt = Field( - description="", - default=600, - ) + HTTP_REQUEST_MAX_READ_TIMEOUT: Annotated[ + PositiveInt, Field(ge=60, description="read timeout in seconds for HTTP request") + ] = 60 - HTTP_REQUEST_MAX_WRITE_TIMEOUT: NonNegativeInt = Field( - description="", - default=600, - ) + HTTP_REQUEST_MAX_WRITE_TIMEOUT: Annotated[ + PositiveInt, Field(ge=10, description="read timeout in seconds for HTTP request") + ] = 20 HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field( description="", @@ -427,7 +424,7 @@ class MailConfig(BaseSettings): """ MAIL_TYPE: Optional[str] = Field( - description="Mail provider type name, default to None, availabile values are `smtp` and `resend`.", + description="Mail provider type name, default to None, available values are `smtp` and `resend`.", default=None, ) diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index f25979e5d8..e017c2c5b8 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -1,7 +1,7 @@ from typing import Any, Optional from urllib.parse import quote_plus -from pydantic import Field, NonNegativeInt, PositiveInt, computed_field +from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field from pydantic_settings import BaseSettings from configs.middleware.cache.redis_config import RedisConfig @@ -9,8 +9,10 @@ from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorag from configs.middleware.storage.amazon_s3_storage_config import S3StorageConfig from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorageConfig from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig +from configs.middleware.storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig from configs.middleware.storage.oci_storage_config import OCIStorageConfig from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig +from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig from configs.middleware.vdb.chroma_config import ChromaConfig from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig @@ -157,6 +159,21 @@ class CeleryConfig(DatabaseConfig): default=None, ) + CELERY_USE_SENTINEL: Optional[bool] = Field( + description="Whether to use Redis Sentinel mode", + default=False, + ) + + CELERY_SENTINEL_MASTER_NAME: Optional[str] = Field( + description="Redis Sentinel master name", + default=None, + ) + + CELERY_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field( + description="Redis Sentinel socket timeout", + default=0.1, + ) + @computed_field @property def CELERY_RESULT_BACKEND(self) -> str | None: @@ -184,6 +201,8 @@ class MiddlewareConfig( AzureBlobStorageConfig, GoogleCloudStorageConfig, TencentCloudCOSStorageConfig, + HuaweiCloudOBSStorageConfig, + VolcengineTOSStorageConfig, S3StorageConfig, OCIStorageConfig, # configs of vdb and vdb providers diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index cacdaf6fb6..4fcd52ddc9 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -1,6 +1,6 @@ from typing import Optional -from pydantic import Field, NonNegativeInt, PositiveInt +from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt from pydantic_settings import BaseSettings @@ -38,3 +38,33 @@ class RedisConfig(BaseSettings): description="whether to use SSL for Redis connection", default=False, ) + + REDIS_USE_SENTINEL: Optional[bool] = Field( + description="Whether to use Redis Sentinel mode", + default=False, + ) + + REDIS_SENTINELS: Optional[str] = Field( + description="Redis Sentinel nodes", + default=None, + ) + + REDIS_SENTINEL_SERVICE_NAME: Optional[str] = Field( + description="Redis Sentinel service name", + default=None, + ) + + REDIS_SENTINEL_USERNAME: Optional[str] = Field( + description="Redis Sentinel username", + default=None, + ) + + REDIS_SENTINEL_PASSWORD: Optional[str] = Field( + description="Redis Sentinel password", + default=None, + ) + + REDIS_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field( + description="Redis Sentinel socket timeout", + default=0.1, + ) diff --git a/api/configs/middleware/storage/aliyun_oss_storage_config.py b/api/configs/middleware/storage/aliyun_oss_storage_config.py index 78f70b7ad3..c1843dc26c 100644 --- a/api/configs/middleware/storage/aliyun_oss_storage_config.py +++ b/api/configs/middleware/storage/aliyun_oss_storage_config.py @@ -38,3 +38,8 @@ class AliyunOSSStorageConfig(BaseSettings): description="Aliyun OSS authentication version", default=None, ) + + ALIYUN_OSS_PATH: Optional[str] = Field( + description="Aliyun OSS path", + default=None, + ) diff --git a/api/configs/middleware/storage/huawei_obs_storage_config.py b/api/configs/middleware/storage/huawei_obs_storage_config.py new file mode 100644 index 0000000000..c5cb379cae --- /dev/null +++ b/api/configs/middleware/storage/huawei_obs_storage_config.py @@ -0,0 +1,29 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class HuaweiCloudOBSStorageConfig(BaseModel): + """ + Huawei Cloud OBS storage configs + """ + + HUAWEI_OBS_BUCKET_NAME: Optional[str] = Field( + description="Huawei Cloud OBS bucket name", + default=None, + ) + + HUAWEI_OBS_ACCESS_KEY: Optional[str] = Field( + description="Huawei Cloud OBS Access key", + default=None, + ) + + HUAWEI_OBS_SECRET_KEY: Optional[str] = Field( + description="Huawei Cloud OBS Secret key", + default=None, + ) + + HUAWEI_OBS_SERVER: Optional[str] = Field( + description="Huawei Cloud OBS server URL", + default=None, + ) diff --git a/api/configs/middleware/storage/volcengine_tos_storage_config.py b/api/configs/middleware/storage/volcengine_tos_storage_config.py new file mode 100644 index 0000000000..a0e09a3cc7 --- /dev/null +++ b/api/configs/middleware/storage/volcengine_tos_storage_config.py @@ -0,0 +1,34 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class VolcengineTOSStorageConfig(BaseModel): + """ + Volcengine tos storage configs + """ + + VOLCENGINE_TOS_BUCKET_NAME: Optional[str] = Field( + description="Volcengine TOS Bucket Name", + default=None, + ) + + VOLCENGINE_TOS_ACCESS_KEY: Optional[str] = Field( + description="Volcengine TOS Access Key", + default=None, + ) + + VOLCENGINE_TOS_SECRET_KEY: Optional[str] = Field( + description="Volcengine TOS Secret Key", + default=None, + ) + + VOLCENGINE_TOS_ENDPOINT: Optional[str] = Field( + description="Volcengine TOS Endpoint URL", + default=None, + ) + + VOLCENGINE_TOS_REGION: Optional[str] = Field( + description="Volcengine TOS Region", + default=None, + ) diff --git a/api/configs/middleware/vdb/milvus_config.py b/api/configs/middleware/vdb/milvus_config.py index 85466cd5cc..98d375966a 100644 --- a/api/configs/middleware/vdb/milvus_config.py +++ b/api/configs/middleware/vdb/milvus_config.py @@ -1,6 +1,6 @@ from typing import Optional -from pydantic import Field, PositiveInt +from pydantic import Field from pydantic_settings import BaseSettings @@ -9,14 +9,14 @@ class MilvusConfig(BaseSettings): Milvus configs """ - MILVUS_HOST: Optional[str] = Field( - description="Milvus host", - default=None, + MILVUS_URI: Optional[str] = Field( + description="Milvus uri", + default="http://127.0.0.1:19530", ) - MILVUS_PORT: PositiveInt = Field( - description="Milvus RestFul API port", - default=9091, + MILVUS_TOKEN: Optional[str] = Field( + description="Milvus token", + default=None, ) MILVUS_USER: Optional[str] = Field( @@ -29,11 +29,6 @@ class MilvusConfig(BaseSettings): default=None, ) - MILVUS_SECURE: bool = Field( - description="whether to use SSL connection for Milvus", - default=False, - ) - MILVUS_DATABASE: str = Field( description="Milvus database, default to `default`", default="default", diff --git a/api/configs/packaging/__init__.py b/api/configs/packaging/__init__.py index dd09671612..e03dfeb27c 100644 --- a/api/configs/packaging/__init__.py +++ b/api/configs/packaging/__init__.py @@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings): CURRENT_VERSION: str = Field( description="Dify version", - default="0.7.2", + default="0.8.0", ) COMMIT_SHA: str = Field( diff --git a/api/constants/recommended_apps.json b/api/constants/recommended_apps.json index df4adc4a1f..3779fb0180 100644 --- a/api/constants/recommended_apps.json +++ b/api/constants/recommended_apps.json @@ -320,7 +320,7 @@ "icon_background": "#FFEAD5", "id": "e9870913-dd01-4710-9f06-15d4180ca1ce", "mode": "advanced-chat", - "name": "Knowledge Retreival + Chatbot " + "name": "Knowledge Retrieval + Chatbot " }, "app_id": "e9870913-dd01-4710-9f06-15d4180ca1ce", "category": "Workflow", @@ -423,7 +423,7 @@ "name": "Website Generator" }, "a23b57fa-85da-49c0-a571-3aff375976c1": { - "export_data": "app:\n icon: \"\\U0001F911\"\n icon_background: '#E4FBCC'\n mode: agent-chat\n name: Investment Analysis Report Copilot\nmodel_config:\n agent_mode:\n enabled: true\n max_iteration: 5\n strategy: function_call\n tools:\n - enabled: true\n isDeleted: false\n notAuthor: false\n provider_id: yahoo\n provider_name: yahoo\n provider_type: builtin\n tool_label: Analytics\n tool_name: yahoo_finance_analytics\n tool_parameters:\n end_date: ''\n start_date: ''\n symbol: ''\n - enabled: true\n isDeleted: false\n notAuthor: false\n provider_id: yahoo\n provider_name: yahoo\n provider_type: builtin\n tool_label: News\n tool_name: yahoo_finance_news\n tool_parameters:\n symbol: ''\n - enabled: true\n isDeleted: false\n notAuthor: false\n provider_id: yahoo\n provider_name: yahoo\n provider_type: builtin\n tool_label: Ticker\n tool_name: yahoo_finance_ticker\n tool_parameters:\n symbol: ''\n annotation_reply:\n enabled: false\n chat_prompt_config: {}\n completion_prompt_config: {}\n dataset_configs:\n datasets:\n datasets: []\n retrieval_model: single\n dataset_query_variable: ''\n external_data_tools: []\n file_upload:\n image:\n detail: high\n enabled: false\n number_limits: 3\n transfer_methods:\n - remote_url\n - local_file\n model:\n completion_params:\n frequency_penalty: 0.5\n max_tokens: 4096\n presence_penalty: 0.5\n stop: []\n temperature: 0.2\n top_p: 0.75\n mode: chat\n name: gpt-4-1106-preview\n provider: openai\n more_like_this:\n enabled: false\n opening_statement: 'Welcome to your personalized Investment Analysis Copilot service,\n where we delve into the depths of stock analysis to provide you with comprehensive\n insights. To begin our journey into the financial world, try to ask:\n\n '\n pre_prompt: \"# Job Description: Data Analysis Copilot\\n## Character\\nMy primary\\\n \\ goal is to provide user with expert data analysis advice. Using extensive and\\\n \\ detailed data. Tell me the stock (with ticket symbol) you want to analyze. I\\\n \\ will do all fundemental, technical, market sentiment, and Marcoeconomical analysis\\\n \\ for the stock as an expert. \\n\\n## Skills \\n### Skill 1: Search for stock information\\\n \\ using 'Ticker' from Yahoo Finance \\n### Skill 2: Search for recent news using\\\n \\ 'News' for the target company. \\n### Skill 3: Search for financial figures and\\\n \\ analytics using 'Analytics' for the target company\\n\\n## Workflow\\nAsks the\\\n \\ user which stocks with ticker name need to be analyzed and then performs the\\\n \\ following analysis in sequence. \\n**Part I: Fundamental analysis: financial\\\n \\ reporting analysis\\n*Objective 1: In-depth analysis of the financial situation\\\n \\ of the target company.\\n*Steps:\\n1. Identify the object of analysis:\\n\\n\\n\\n2. Access to financial\\\n \\ reports \\n\\n- Obtain the key data\\\n \\ of the latest financial report of the target company {{company}} organized by\\\n \\ Yahoo Finance. \\n\\n\\n\\n3. Vertical Analysis:\\n- Get the insight of the company's\\\n \\ balance sheet Income Statement and cash flow. \\n- Analyze Income Statement:\\\n \\ Analyze the proportion of each type of income and expense to total income. /Analyze\\\n \\ Balance Sheet: Analyze the proportion of each asset and liability to total assets\\\n \\ or total liabilities./ Analyze Cash Flow \\n-\\n4. Ratio Analysis:\\n\\\n - analyze the Profitability Ratios Solvency Ratios Operational Efficiency Ratios\\\n \\ and Market Performance Ratios of the company. \\n(Profitability Ratios: Such\\\n \\ as net profit margin gross profit margin operating profit margin to assess the\\\n \\ company's profitability.)\\n(Solvency Ratios: Such as debt-to-asset ratio interest\\\n \\ coverage ratio to assess the company's ability to pay its debts.)\\n(Operational\\\n \\ Efficiency Ratios: Such as inventory turnover accounts receivable turnover to\\\n \\ assess the company's operational efficiency.)\\n(Market Performance Ratios: Such\\\n \\ as price-to-earnings ratio price-to-book ratio to assess the company's market\\\n \\ performance.)>\\n-\\n5. Comprehensive Analysis and Conclusion:\\n- Combine the above analyses to\\\n \\ evaluate the company's financial health profitability solvency and operational\\\n \\ efficiency comprehensively. Identify the main financial risks and potential\\\n \\ opportunities facing the company.\\n-\\nOrganize and output [Record 1.1] [Record 1.2] [Record\\\n \\ 1.3] [Record 1.4] [Record 1.5] \\nPart II: Foundamental Analysis: Industry\\n\\\n *Objective 2: To analyze the position and competitiveness of the target company\\\n \\ {{company}} in the industry. \\n\\n\\n* Steps:\\n1. Determine the industry classification:\\n\\\n - Define the industry to which the target company belongs.\\n- Search for company\\\n \\ information to determine its main business and industry.\\n-\\n2. Market Positioning and Segmentation\\\n \\ analysis:\\n- To assess the company's market positioning and segmentation. \\n\\\n - Understand the company's market share growth rate and competitors in the industry\\\n \\ to analyze them. \\n-\\n3. Analysis \\n- Analyze the development\\\n \\ trend of the industry. \\n- \\n4. Competitors\\n- Analyze the competition around the target company \\n-\\\n \\ \\nOrganize\\\n \\ and output [Record 2.1] [Record 2.2] [Record 2.3] [Record 2.4]\\nCombine the\\\n \\ above Record and output all the analysis in the form of a investment analysis\\\n \\ report. Use markdown syntax for a structured output. \\n\\n## Constraints\\n- Your\\\n \\ responses should be strictly on analysis tasks. Use a structured language and\\\n \\ think step by step. \\n- The language you use should be identical to the user's\\\n \\ language.\\n- Avoid addressing questions regarding work tools and regulations.\\n\\\n - Give a structured response using bullet points and markdown syntax. Give an\\\n \\ introduction to the situation first then analyse the main trend in the graph.\\\n \\ \\n\"\n prompt_type: simple\n retriever_resource:\n enabled: true\n sensitive_word_avoidance:\n configs: []\n enabled: false\n type: ''\n speech_to_text:\n enabled: false\n suggested_questions:\n - 'Analyze the stock of Tesla. '\n - What are some recent development on Nvidia?\n - 'Do a fundamental analysis for Amazon. '\n suggested_questions_after_answer:\n enabled: true\n text_to_speech:\n enabled: false\n user_input_form:\n - text-input:\n default: ''\n label: company\n required: false\n variable: company\n", + "export_data": "app:\n icon: \"\\U0001F911\"\n icon_background: '#E4FBCC'\n mode: agent-chat\n name: Investment Analysis Report Copilot\nmodel_config:\n agent_mode:\n enabled: true\n max_iteration: 5\n strategy: function_call\n tools:\n - enabled: true\n isDeleted: false\n notAuthor: false\n provider_id: yahoo\n provider_name: yahoo\n provider_type: builtin\n tool_label: Analytics\n tool_name: yahoo_finance_analytics\n tool_parameters:\n end_date: ''\n start_date: ''\n symbol: ''\n - enabled: true\n isDeleted: false\n notAuthor: false\n provider_id: yahoo\n provider_name: yahoo\n provider_type: builtin\n tool_label: News\n tool_name: yahoo_finance_news\n tool_parameters:\n symbol: ''\n - enabled: true\n isDeleted: false\n notAuthor: false\n provider_id: yahoo\n provider_name: yahoo\n provider_type: builtin\n tool_label: Ticker\n tool_name: yahoo_finance_ticker\n tool_parameters:\n symbol: ''\n annotation_reply:\n enabled: false\n chat_prompt_config: {}\n completion_prompt_config: {}\n dataset_configs:\n datasets:\n datasets: []\n retrieval_model: single\n dataset_query_variable: ''\n external_data_tools: []\n file_upload:\n image:\n detail: high\n enabled: false\n number_limits: 3\n transfer_methods:\n - remote_url\n - local_file\n model:\n completion_params:\n frequency_penalty: 0.5\n max_tokens: 4096\n presence_penalty: 0.5\n stop: []\n temperature: 0.2\n top_p: 0.75\n mode: chat\n name: gpt-4-1106-preview\n provider: openai\n more_like_this:\n enabled: false\n opening_statement: 'Welcome to your personalized Investment Analysis Copilot service,\n where we delve into the depths of stock analysis to provide you with comprehensive\n insights. To begin our journey into the financial world, try to ask:\n\n '\n pre_prompt: \"# Job Description: Data Analysis Copilot\\n## Character\\nMy primary\\\n \\ goal is to provide user with expert data analysis advice. Using extensive and\\\n \\ detailed data. Tell me the stock (with ticket symbol) you want to analyze. I\\\n \\ will do all fundamental, technical, market sentiment, and Marco economical analysis\\\n \\ for the stock as an expert. \\n\\n## Skills \\n### Skill 1: Search for stock information\\\n \\ using 'Ticker' from Yahoo Finance \\n### Skill 2: Search for recent news using\\\n \\ 'News' for the target company. \\n### Skill 3: Search for financial figures and\\\n \\ analytics using 'Analytics' for the target company\\n\\n## Workflow\\nAsks the\\\n \\ user which stocks with ticker name need to be analyzed and then performs the\\\n \\ following analysis in sequence. \\n**Part I: Fundamental analysis: financial\\\n \\ reporting analysis\\n*Objective 1: In-depth analysis of the financial situation\\\n \\ of the target company.\\n*Steps:\\n1. Identify the object of analysis:\\n\\n\\n\\n2. Access to financial\\\n \\ reports \\n\\n- Obtain the key data\\\n \\ of the latest financial report of the target company {{company}} organized by\\\n \\ Yahoo Finance. \\n\\n\\n\\n3. Vertical Analysis:\\n- Get the insight of the company's\\\n \\ balance sheet Income Statement and cash flow. \\n- Analyze Income Statement:\\\n \\ Analyze the proportion of each type of income and expense to total income. /Analyze\\\n \\ Balance Sheet: Analyze the proportion of each asset and liability to total assets\\\n \\ or total liabilities./ Analyze Cash Flow \\n-\\n4. Ratio Analysis:\\n\\\n - analyze the Profitability Ratios Solvency Ratios Operational Efficiency Ratios\\\n \\ and Market Performance Ratios of the company. \\n(Profitability Ratios: Such\\\n \\ as net profit margin gross profit margin operating profit margin to assess the\\\n \\ company's profitability.)\\n(Solvency Ratios: Such as debt-to-asset ratio interest\\\n \\ coverage ratio to assess the company's ability to pay its debts.)\\n(Operational\\\n \\ Efficiency Ratios: Such as inventory turnover accounts receivable turnover to\\\n \\ assess the company's operational efficiency.)\\n(Market Performance Ratios: Such\\\n \\ as price-to-earnings ratio price-to-book ratio to assess the company's market\\\n \\ performance.)>\\n-\\n5. Comprehensive Analysis and Conclusion:\\n- Combine the above analyses to\\\n \\ evaluate the company's financial health profitability solvency and operational\\\n \\ efficiency comprehensively. Identify the main financial risks and potential\\\n \\ opportunities facing the company.\\n-\\nOrganize and output [Record 1.1] [Record 1.2] [Record\\\n \\ 1.3] [Record 1.4] [Record 1.5] \\nPart II: Fundamental Analysis: Industry\\n\\\n *Objective 2: To analyze the position and competitiveness of the target company\\\n \\ {{company}} in the industry. \\n\\n\\n* Steps:\\n1. Determine the industry classification:\\n\\\n - Define the industry to which the target company belongs.\\n- Search for company\\\n \\ information to determine its main business and industry.\\n-\\n2. Market Positioning and Segmentation\\\n \\ analysis:\\n- To assess the company's market positioning and segmentation. \\n\\\n - Understand the company's market share growth rate and competitors in the industry\\\n \\ to analyze them. \\n-\\n3. Analysis \\n- Analyze the development\\\n \\ trend of the industry. \\n- \\n4. Competitors\\n- Analyze the competition around the target company \\n-\\\n \\ \\nOrganize\\\n \\ and output [Record 2.1] [Record 2.2] [Record 2.3] [Record 2.4]\\nCombine the\\\n \\ above Record and output all the analysis in the form of a investment analysis\\\n \\ report. Use markdown syntax for a structured output. \\n\\n## Constraints\\n- Your\\\n \\ responses should be strictly on analysis tasks. Use a structured language and\\\n \\ think step by step. \\n- The language you use should be identical to the user's\\\n \\ language.\\n- Avoid addressing questions regarding work tools and regulations.\\n\\\n - Give a structured response using bullet points and markdown syntax. Give an\\\n \\ introduction to the situation first then analyse the main trend in the graph.\\\n \\ \\n\"\n prompt_type: simple\n retriever_resource:\n enabled: true\n sensitive_word_avoidance:\n configs: []\n enabled: false\n type: ''\n speech_to_text:\n enabled: false\n suggested_questions:\n - 'Analyze the stock of Tesla. '\n - What are some recent development on Nvidia?\n - 'Do a fundamental analysis for Amazon. '\n suggested_questions_after_answer:\n enabled: true\n text_to_speech:\n enabled: false\n user_input_form:\n - text-input:\n default: ''\n label: company\n required: false\n variable: company\n", "icon": "🤑", "icon_background": "#E4FBCC", "id": "a23b57fa-85da-49c0-a571-3aff375976c1", @@ -438,8 +438,8 @@ "mode": "advanced-chat", "name": "Workflow Planning Assistant " }, - "e9d92058-7d20-4904-892f-75d90bef7587":{"export_data":"app:\n icon: \"\\U0001F916\"\n icon_background: '#FFEAD5'\n mode: advanced-chat\n name: 'Automated Email Reply '\nworkflow:\n features:\n file_upload:\n image:\n enabled: false\n opening_statement: ''\n retriever_resource:\n enabled: false\n sensitive_word_avoidance:\n enabled: false\n speech_to_text:\n enabled: false\n suggested_questions: []\n suggested_questions_after_answer:\n enabled: false\n text_to_speech:\n enabled: false\n language: ''\n voice: ''\n graph:\n edges:\n - data:\n isInIteration: false\n sourceType: code\n targetType: iteration\n id: 1716909112104-source-1716909114582-target\n source: '1716909112104'\n sourceHandle: source\n target: '1716909114582'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: false\n sourceType: iteration\n targetType: template-transform\n id: 1716909114582-source-1716913435742-target\n source: '1716909114582'\n sourceHandle: source\n target: '1716913435742'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: false\n sourceType: template-transform\n targetType: answer\n id: 1716913435742-source-1716806267180-target\n source: '1716913435742'\n sourceHandle: source\n target: '1716806267180'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: false\n sourceType: start\n targetType: tool\n id: 1716800588219-source-1716946869294-target\n source: '1716800588219'\n sourceHandle: source\n target: '1716946869294'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: false\n sourceType: tool\n targetType: code\n id: 1716946869294-source-1716909112104-target\n source: '1716946869294'\n sourceHandle: source\n target: '1716909112104'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: tool\n targetType: code\n id: 1716946889408-source-1716909122343-target\n source: '1716946889408'\n sourceHandle: source\n target: '1716909122343'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: code\n targetType: code\n id: 1716909122343-source-1716951357236-target\n source: '1716909122343'\n sourceHandle: source\n target: '1716951357236'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: code\n targetType: llm\n id: 1716951357236-source-1716913272656-target\n source: '1716951357236'\n sourceHandle: source\n target: '1716913272656'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: template-transform\n targetType: llm\n id: 1716951236700-source-1716951159073-target\n source: '1716951236700'\n sourceHandle: source\n target: '1716951159073'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: llm\n targetType: template-transform\n id: 1716951159073-source-1716952228079-target\n source: '1716951159073'\n sourceHandle: source\n target: '1716952228079'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: template-transform\n targetType: tool\n id: 1716952228079-source-1716952912103-target\n source: '1716952228079'\n sourceHandle: source\n target: '1716952912103'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: llm\n targetType: question-classifier\n id: 1716913272656-source-1716960721611-target\n source: '1716913272656'\n sourceHandle: source\n target: '1716960721611'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: question-classifier\n targetType: llm\n id: 1716960721611-1-1716909125498-target\n source: '1716960721611'\n sourceHandle: '1'\n target: '1716909125498'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: question-classifier\n targetType: llm\n id: 1716960721611-2-1716960728136-target\n source: '1716960721611'\n sourceHandle: '2'\n target: '1716960728136'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: llm\n targetType: variable-aggregator\n id: 1716909125498-source-1716960791399-target\n source: '1716909125498'\n sourceHandle: source\n target: '1716960791399'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: variable-aggregator\n targetType: template-transform\n id: 1716960791399-source-1716951236700-target\n source: '1716960791399'\n sourceHandle: source\n target: '1716951236700'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: question-classifier\n targetType: template-transform\n id: 1716960721611-1716960736883-1716960834468-target\n source: '1716960721611'\n sourceHandle: '1716960736883'\n target: '1716960834468'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: llm\n targetType: variable-aggregator\n id: 1716960728136-source-1716960791399-target\n source: '1716960728136'\n sourceHandle: source\n target: '1716960791399'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: template-transform\n targetType: variable-aggregator\n id: 1716960834468-source-1716960791399-target\n source: '1716960834468'\n sourceHandle: source\n target: '1716960791399'\n targetHandle: target\n type: custom\n zIndex: 1002\n nodes:\n - data:\n desc: ''\n selected: false\n title: Start\n type: start\n variables:\n - label: Your Email\n max_length: 256\n options: []\n required: true\n type: text-input\n variable: email\n - label: Maximum Number of Email you want to retrieve\n max_length: 256\n options: []\n required: true\n type: number\n variable: maxResults\n height: 115\n id: '1716800588219'\n position:\n x: 30\n y: 445\n positionAbsolute:\n x: 30\n y: 445\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n answer: '{{#1716913435742.output#}}'\n desc: ''\n selected: false\n title: Direct Reply\n type: answer\n variables: []\n height: 106\n id: '1716806267180'\n position:\n x: 4700\n y: 445\n positionAbsolute:\n x: 4700\n y: 445\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n code: \"def main(message: str) -> dict:\\n import json\\n \\n # Parse\\\n \\ the JSON string\\n parsed_data = json.loads(message)\\n \\n # Extract\\\n \\ all the \\\"id\\\" values\\n ids = [msg['id'] for msg in parsed_data['messages']]\\n\\\n \\ \\n return {\\n \\\"result\\\": ids\\n }\"\n code_language: python3\n desc: ''\n outputs:\n result:\n children: null\n type: array[string]\n selected: false\n title: 'Code: Extract Email ID'\n type: code\n variables:\n - value_selector:\n - '1716946869294'\n - text\n variable: message\n height: 53\n id: '1716909112104'\n position:\n x: 638\n y: 445\n positionAbsolute:\n x: 638\n y: 445\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n desc: ''\n height: 490\n iterator_selector:\n - '1716909112104'\n - result\n output_selector:\n - '1716909125498'\n - text\n output_type: array[string]\n selected: false\n startNodeType: tool\n start_node_id: '1716946889408'\n title: 'Iteraction '\n type: iteration\n width: 3393.7520359289056\n height: 490\n id: '1716909114582'\n position:\n x: 942\n y: 445\n positionAbsolute:\n x: 942\n y: 445\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 3394\n zIndex: 1\n - data:\n desc: ''\n isInIteration: true\n isIterationStart: true\n iteration_id: '1716909114582'\n provider_id: e64b4c7f-2795-499c-8d11-a971a7d57fc9\n provider_name: List and Get Gmail\n provider_type: api\n selected: false\n title: getMessage\n tool_configurations: {}\n tool_label: getMessage\n tool_name: getMessage\n tool_parameters:\n format:\n type: mixed\n value: full\n id:\n type: mixed\n value: '{{#1716909114582.item#}}'\n userId:\n type: mixed\n value: '{{#1716800588219.email#}}'\n type: tool\n extent: parent\n height: 53\n id: '1716946889408'\n parentId: '1716909114582'\n position:\n x: 117\n y: 85\n positionAbsolute:\n x: 1059\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1001\n - data:\n code: \"\\ndef main(email_json: dict) -> dict:\\n import json \\n email_dict\\\n \\ = json.loads(email_json)\\n base64_data = email_dict['payload']['parts'][0]['body']['data']\\n\\\n \\n return {\\n \\\"result\\\": base64_data, \\n }\\n\"\n code_language: python3\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n outputs:\n result:\n children: null\n type: string\n selected: false\n title: 'Code: Extract Email Body'\n type: code\n variables:\n - value_selector:\n - '1716946889408'\n - text\n variable: email_json\n extent: parent\n height: 53\n id: '1716909122343'\n parentId: '1716909114582'\n position:\n x: 421\n y: 85\n positionAbsolute:\n x: 1363\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: 'Generate reply. '\n isInIteration: true\n iteration_id: '1716909114582'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: 982014aa-702b-4d7c-ae1f-08dbceb6e930\n role: system\n text: \" \\nRespond to the emails. \\n\\n{{#1716913272656.text#}}\\n\\\n \"\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: true\n extent: parent\n height: 127\n id: '1716909125498'\n parentId: '1716909114582'\n position:\n x: 1625\n y: 85\n positionAbsolute:\n x: 2567\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: fd8de569-c099-4320-955b-61aa4b054789\n role: system\n text: \"\\nYou need to transform the input data (in base64 encoding)\\\n \\ to text. Input base64. Output text. \\n\\n{{#1716909122343.result#}}\\n\\\n \"\n selected: false\n title: 'Base64 Decoder '\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: false\n extent: parent\n height: 97\n id: '1716913272656'\n parentId: '1716909114582'\n position:\n x: 1025\n y: 85\n positionAbsolute:\n x: 1967\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 | join(\"\\n\\n -------------------------\\n\\n\") }}'\n title: 'Template '\n type: template-transform\n variables:\n - value_selector:\n - '1716909114582'\n - output\n variable: arg1\n height: 53\n id: '1716913435742'\n position:\n x: 4396\n y: 445\n positionAbsolute:\n x: 4396\n y: 445\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n desc: ''\n provider_id: e64b4c7f-2795-499c-8d11-a971a7d57fc9\n provider_name: List and Get Gmail\n provider_type: api\n selected: false\n title: listMessages\n tool_configurations: {}\n tool_label: listMessages\n tool_name: listMessages\n tool_parameters:\n maxResults:\n type: variable\n value:\n - '1716800588219'\n - maxResults\n userId:\n type: mixed\n value: '{{#1716800588219.email#}}'\n type: tool\n height: 53\n id: '1716946869294'\n position:\n x: 334\n y: 445\n positionAbsolute:\n x: 334\n y: 445\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: b7fd0ec5-864a-42c6-9d04-a1958bd4fc0d\n role: system\n text: \"\\nYou need to encode the input data from text to base64. Input\\\n \\ text. Output base64 encoding. Output nothing other than base64 encoding.\\\n \\ \\n\\n{{#1716951236700.output#}}\\n \"\n selected: false\n title: Base64 Encoder\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: true\n extent: parent\n height: 97\n id: '1716951159073'\n parentId: '1716909114582'\n position:\n x: 2525.7520359289056\n y: 85\n positionAbsolute:\n x: 3467.7520359289056\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: Generaate MIME email template\n isInIteration: true\n iteration_id: '1716909114582'\n selected: false\n template: \"Content-Type: text/plain; charset=\\\"utf-8\\\"\\r\\nContent-Transfer-Encoding:\\\n \\ 7bit\\r\\nMIME-Version: 1.0\\r\\nTo: {{ emailMetadata.recipientEmail }} #\\\n \\ xiaoyi@dify.ai\\r\\nFrom: {{ emailMetadata.senderEmail }} # sxy.hj156@gmail.com\\r\\\n \\nSubject: Re: {{ emailMetadata.subject }} \\r\\n\\r\\n{{ text }}\\r\\n\"\n title: 'Template: Reply Email'\n type: template-transform\n variables:\n - value_selector:\n - '1716951357236'\n - result\n variable: emailMetadata\n - value_selector:\n - '1716960791399'\n - output\n variable: text\n extent: parent\n height: 83\n id: '1716951236700'\n parentId: '1716909114582'\n position:\n x: 2231.269960149744\n y: 85\n positionAbsolute:\n x: 3173.269960149744\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n code: \"def main(email_json: dict) -> dict:\\n import json\\n if isinstance(email_json,\\\n \\ str): \\n email_json = json.loads(email_json)\\n\\n subject = None\\n\\\n \\ recipient_email = None \\n sender_email = None\\n \\n headers\\\n \\ = email_json['payload']['headers']\\n for header in headers:\\n \\\n \\ if header['name'] == 'Subject':\\n subject = header['value']\\n\\\n \\ elif header['name'] == 'To':\\n recipient_email = header['value']\\n\\\n \\ elif header['name'] == 'From':\\n sender_email = header['value']\\n\\\n \\n return {\\n \\\"result\\\": [subject, recipient_email, sender_email]\\n\\\n \\ }\\n\"\n code_language: python3\n desc: \"Recipient, Sender, Subject\\uFF0COutput Array[String]\"\n isInIteration: true\n iteration_id: '1716909114582'\n outputs:\n result:\n children: null\n type: array[string]\n selected: false\n title: Extract Email Metadata\n type: code\n variables:\n - value_selector:\n - '1716946889408'\n - text\n variable: email_json\n extent: parent\n height: 101\n id: '1716951357236'\n parentId: '1716909114582'\n position:\n x: 725\n y: 85\n positionAbsolute:\n x: 1667\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n selected: false\n template: '{\"raw\": \"{{ encoded_message }}\"}'\n title: \"Template\\uFF1AEmail Request Body\"\n type: template-transform\n variables:\n - value_selector:\n - '1716951159073'\n - text\n variable: encoded_message\n extent: parent\n height: 53\n id: '1716952228079'\n parentId: '1716909114582'\n position:\n x: 2828.4325280181324\n y: 86.31950791077293\n positionAbsolute:\n x: 3770.4325280181324\n y: 531.3195079107729\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n provider_id: 038963aa-43c8-47fc-be4b-0255c19959c1\n provider_name: Draft Gmail\n provider_type: api\n selected: false\n title: createDraft\n tool_configurations: {}\n tool_label: createDraft\n tool_name: createDraft\n tool_parameters:\n message:\n type: mixed\n value: '{{#1716952228079.output#}}'\n userId:\n type: mixed\n value: '{{#1716800588219.email#}}'\n type: tool\n extent: parent\n height: 53\n id: '1716952912103'\n parentId: '1716909114582'\n position:\n x: 3133.7520359289056\n y: 85\n positionAbsolute:\n x: 4075.7520359289056\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n classes:\n - id: '1'\n name: 'Technical questions, related to product '\n - id: '2'\n name: Unrelated to technicals, non technical\n - id: '1716960736883'\n name: Other questions\n desc: ''\n instructions: ''\n isInIteration: true\n iteration_id: '1716909114582'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n query_variable_selector:\n - '1716800588219'\n - sys.query\n selected: false\n title: Question Classifier\n topics: []\n type: question-classifier\n extent: parent\n height: 255\n id: '1716960721611'\n parentId: '1716909114582'\n position:\n x: 1325\n y: 85\n positionAbsolute:\n x: 2267\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - id: a639bbf8-bc58-42a2-b477-6748e80ecda2\n role: system\n text: \" \\nRespond to the emails. \\n\\n{{#1716913272656.text#}}\\n\\\n \"\n selected: false\n title: 'LLM - Non technical '\n type: llm\n variables: []\n vision:\n enabled: false\n extent: parent\n height: 97\n id: '1716960728136'\n parentId: '1716909114582'\n position:\n x: 1625\n y: 251\n positionAbsolute:\n x: 2567\n y: 696\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n output_type: string\n selected: false\n title: Variable Aggregator\n type: variable-aggregator\n variables:\n - - '1716909125498'\n - text\n - - '1716960728136'\n - text\n - - '1716960834468'\n - output\n extent: parent\n height: 164\n id: '1716960791399'\n parentId: '1716909114582'\n position:\n x: 1931.2699601497438\n y: 85\n positionAbsolute:\n x: 2873.269960149744\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: Other questions\n isInIteration: true\n iteration_id: '1716909114582'\n selected: false\n template: 'Sorry, I cannot answer that. This is outside my capabilities. '\n title: 'Direct Reply '\n type: template-transform\n variables: []\n extent: parent\n height: 83\n id: '1716960834468'\n parentId: '1716909114582'\n position:\n x: 1625\n y: 385.57142857142856\n positionAbsolute:\n x: 2567\n y: 830.5714285714286\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n author: Dify\n desc: ''\n height: 153\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":3,\"mode\":\"normal\",\"style\":\"font-size:\n 14px;\",\"text\":\"OpenAPI-Swagger for all custom tools: \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":3},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"openapi:\n 3.0.0\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"info:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" title:\n Gmail API\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n OpenAPI schema for Gmail API methods `users.messages.get`, `users.messages.list`,\n and `users.drafts.create`.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" version:\n 1.0.0\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"servers:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n url: https://gmail.googleapis.com\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Gmail API Server\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"paths:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" /gmail/v1/users/{userId}/messages/{id}:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" get:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" summary:\n Get a message.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Retrieves a specific message by ID.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" operationId:\n getMessage\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" parameters:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n name: userId\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" in:\n path\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" required:\n true\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The user''s email address. The special value `me` can be used to indicate\n the authenticated user.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n name: id\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" in:\n path\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" required:\n true\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The ID of the message to retrieve.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n name: format\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" in:\n query\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" required:\n false\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" enum:\n [full, metadata, minimal, raw]\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" default:\n full\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The format to return the message in.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" responses:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''200'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Successful response\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" content:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" application/json:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" id:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" threadId:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" labelIds:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n array\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" items:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" snippet:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" historyId:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" internalDate:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" payload:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" sizeEstimate:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n integer\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" raw:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''401'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Unauthorized\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''403'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Forbidden\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''404'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Not Found\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" /gmail/v1/users/{userId}/messages:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" get:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" summary:\n List messages.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Lists the messages in the user''s mailbox.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" operationId:\n listMessages\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" parameters:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n name: userId\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" in:\n path\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" required:\n true\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The user''s email address. The special value `me` can be used to indicate\n the authenticated user.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n name: maxResults\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" in:\n query\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n integer\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" format:\n int32\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" default:\n 100\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Maximum number of messages to return.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" responses:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''200'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Successful response\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" content:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" application/json:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" messages:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n array\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" items:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" id:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" threadId:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" nextPageToken:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" resultSizeEstimate:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n integer\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''401'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Unauthorized\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" /gmail/v1/users/{userId}/drafts:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" post:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" summary:\n Creates a new draft.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" operationId:\n createDraft\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" tags:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n Drafts\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" parameters:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n name: userId\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" in:\n path\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" required:\n true\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The user''s email address. The special value \\\"me\\\" can be used to indicate\n the authenticated user.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" requestBody:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" required:\n true\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" content:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" application/json:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" message:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" raw:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The entire email message in an RFC 2822 formatted and base64url encoded\n string.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" responses:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''200'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Successful response with the created draft.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" content:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" application/json:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" id:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The immutable ID of the draft.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" message:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" id:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The immutable ID of the message.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" threadId:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The ID of the thread the message belongs to.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" labelIds:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n array\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" items:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" snippet:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n A short part of the message text.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" historyId:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The ID of the last history record that modified this message.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''400'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Bad Request - The request is invalid.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''401'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Unauthorized - Authentication is required.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''403'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Forbidden - The user does not have permission to create drafts.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''404'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Not Found - The specified user does not exist.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''500'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Internal Server Error - An error occurred on the server.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"components:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" securitySchemes:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" OAuth2:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n oauth2\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" flows:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" authorizationCode:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" authorizationUrl:\n https://accounts.google.com/o/oauth2/auth\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" tokenUrl:\n https://oauth2.googleapis.com/token\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" scopes:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" https://mail.google.com/:\n All access to Gmail.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" https://www.googleapis.com/auth/gmail.compose:\n Send email on your behalf.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" https://www.googleapis.com/auth/gmail.modify:\n Modify your email.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"security:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n OAuth2:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n https://mail.google.com/\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n https://www.googleapis.com/auth/gmail.compose\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n https://www.googleapis.com/auth/gmail.modify\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: yellow\n title: ''\n type: ''\n width: 367\n height: 153\n id: '1718992681576'\n position:\n x: 321.9646831030669\n y: 538.1642616264143\n positionAbsolute:\n x: 321.9646831030669\n y: 538.1642616264143\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 367\n - data:\n author: Dify\n desc: ''\n height: 158\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Replace\n custom tools after added this template to your own workspace. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Fill\n in \",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"\",\"text\":\"your\n email \",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"and\n the \",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"\",\"text\":\"maximum\n number of results you want to retrieve from your inbox \",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"to\n get started. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 287\n height: 158\n id: '1718992805687'\n position:\n x: 18.571428571428356\n y: 237.80887395992687\n positionAbsolute:\n x: 18.571428571428356\n y: 237.80887395992687\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 287\n - data:\n author: Dify\n desc: ''\n height: 375\n selected: true\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"font-size:\n 16px;\",\"text\":\"Steps within Iteraction node: \",\"type\":\"text\",\"version\":1},{\"type\":\"linebreak\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"1.\n getMessage: This step retrieves the incoming email message.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":1},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"2.\n Code: Extract Email Body: Custom code is executed to extract the body of\n the email from the retrieved message.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"3.\n Extract Email Metadata: Extracts metadata from the email, such as the recipient,\n sender, subject, and other relevant information.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"4.\n Base64 Decoder: Decodes the email content from Base64 encoding.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"5.\n Question Classifier (gpt-3.5-turbo): Uses a GPT-3.5-turbo model to classify\n the email content into different categories. For each classified question,\n the workflow uses a GPT-4.0 model to generate an appropriate reply:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"6.\n Template: Reply Email: Uses a template to generate a MIME email format for\n the reply.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"6.\n Base64 Encoder: Encodes the generated reply email content back to Base64.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"7.\n Template: Email Request: Prepares the email request using a template.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"8.\n createDraft: Creates a draft of the email reply.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"This\n workflow automates the process of reading, classifying, responding to, and\n drafting replies to incoming emails, leveraging advanced language models\n to generate contextually appropriate responses.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 640\n height: 375\n id: '1718993366836'\n position:\n x: 966.7525290975368\n y: 971.80362905854\n positionAbsolute:\n x: 966.7525290975368\n y: 971.80362905854\n selected: true\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 640\n - data:\n author: Dify\n desc: ''\n height: 400\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":3,\"mode\":\"normal\",\"style\":\"font-size:\n 16px;\",\"text\":\"Preparation\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":3},{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Enable\n Gmail API in Google Cloud Console\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"listitem\",\"version\":1,\"value\":1},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Configure\n OAuth Client ID, OAuth Client Secrets, and OAuth Consent Screen for the\n Web Application in Google Cloud Console\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"listitem\",\"version\":1,\"value\":2},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Use\n Postman to authorize and obtain the OAuth Access Token (Google''s Access\n Token will expire after 1 hour and cannot be used for a long time)\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"listitem\",\"version\":1,\"value\":3}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"list\",\"version\":1,\"listType\":\"bullet\",\"start\":1,\"tag\":\"ul\"},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Users\n who want to try building an AI auto-reply email can refer to this document\n to use Postman (Postman.com) to obtain all the above keys: https://blog.postman.com/how-to-access-google-apis-using-oauth-in-postman/.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Developers\n who want to use Google OAuth to call the Gmail API to develop corresponding\n plugins can refer to this official document: \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"https://developers.google.com/identity/protocols/oauth2/web-server.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"At\n this stage, it is still a bit difficult to reproduce this example within\n the Dify platform. If you have development capabilities, developing the\n corresponding plugin externally and using an external database to automatically\n read and write the user''s Access Token and write the Refresh Token would\n be a better choice.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 608\n height: 400\n id: '1718993557447'\n position:\n x: 354.0157230378119\n y: -1.2732157979666\n positionAbsolute:\n x: 354.0157230378119\n y: -1.2732157979666\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 608\n viewport:\n x: 147.09446825757777\n y: 101.03530130020579\n zoom: 0.9548416039104178\n","icon":"\ud83e\udd16","icon_background":"#FFEAD5","id":"e9d92058-7d20-4904-892f-75d90bef7587","mode":"advanced-chat","name":"Automated Email Reply "}, - "98b87f88-bd22-4d86-8b74-86beba5e0ed4":{"export_data":"app:\n icon: \"\\U0001F916\"\n icon_background: '#FFEAD5'\n mode: workflow\n name: 'Book Translation '\nworkflow:\n features:\n file_upload:\n image:\n enabled: false\n number_limits: 3\n transfer_methods:\n - local_file\n - remote_url\n opening_statement: ''\n retriever_resource:\n enabled: false\n sensitive_word_avoidance:\n enabled: false\n speech_to_text:\n enabled: false\n suggested_questions: []\n suggested_questions_after_answer:\n enabled: false\n text_to_speech:\n enabled: false\n language: ''\n voice: ''\n graph:\n edges:\n - data:\n isInIteration: false\n sourceType: start\n targetType: code\n id: 1711067409646-source-1717916867969-target\n source: '1711067409646'\n sourceHandle: source\n target: '1717916867969'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: false\n sourceType: code\n targetType: iteration\n id: 1717916867969-source-1717916955547-target\n source: '1717916867969'\n sourceHandle: source\n target: '1717916955547'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: true\n iteration_id: '1717916955547'\n sourceType: llm\n targetType: llm\n id: 1717916961837-source-1717916977413-target\n source: '1717916961837'\n sourceHandle: source\n target: '1717916977413'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1717916955547'\n sourceType: llm\n targetType: llm\n id: 1717916977413-source-1717916984996-target\n source: '1717916977413'\n sourceHandle: source\n target: '1717916984996'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1717916955547'\n sourceType: llm\n targetType: llm\n id: 1717916984996-source-1717916991709-target\n source: '1717916984996'\n sourceHandle: source\n target: '1717916991709'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: false\n sourceType: iteration\n targetType: template-transform\n id: 1717916955547-source-1717917057450-target\n source: '1717916955547'\n sourceHandle: source\n target: '1717917057450'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: false\n sourceType: template-transform\n targetType: end\n id: 1717917057450-source-1711068257370-target\n source: '1717917057450'\n sourceHandle: source\n target: '1711068257370'\n targetHandle: target\n type: custom\n zIndex: 0\n nodes:\n - data:\n desc: ''\n selected: false\n title: Start\n type: start\n variables:\n - label: Input Text\n max_length: null\n options: []\n required: true\n type: paragraph\n variable: input_text\n dragging: false\n height: 89\n id: '1711067409646'\n position:\n x: 30\n y: 301.5\n positionAbsolute:\n x: 30\n y: 301.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1717917057450'\n - output\n variable: final\n selected: false\n title: End\n type: end\n height: 89\n id: '1711068257370'\n position:\n x: 2291\n y: 301.5\n positionAbsolute:\n x: 2291\n y: 301.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n code: \"\\ndef main(input_text: str) -> str:\\n token_limit = 1000\\n overlap\\\n \\ = 100\\n chunk_size = int(token_limit * 6 * (4/3))\\n\\n # Initialize\\\n \\ variables\\n chunks = []\\n start_index = 0\\n text_length = len(input_text)\\n\\\n \\n # Loop until the end of the text is reached\\n while start_index\\\n \\ < text_length:\\n # If we are not at the beginning, adjust the start_index\\\n \\ to ensure overlap\\n if start_index > 0:\\n start_index\\\n \\ -= overlap\\n\\n # Calculate end index for the current chunk\\n \\\n \\ end_index = start_index + chunk_size\\n if end_index > text_length:\\n\\\n \\ end_index = text_length\\n\\n # Add the current chunk\\\n \\ to the list\\n chunks.append(input_text[start_index:end_index])\\n\\\n \\n # Update the start_index for the next chunk\\n start_index\\\n \\ += chunk_size\\n\\n return {\\n \\\"chunks\\\": chunks,\\n }\\n\"\n code_language: python3\n dependencies: []\n desc: 'token_limit = 1000\n\n overlap = 100'\n outputs:\n chunks:\n children: null\n type: array[string]\n selected: false\n title: Code\n type: code\n variables:\n - value_selector:\n - '1711067409646'\n - input_text\n variable: input_text\n height: 101\n id: '1717916867969'\n position:\n x: 336\n y: 301.5\n positionAbsolute:\n x: 336\n y: 301.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n desc: 'Take good care on maximum number of iterations. '\n height: 203\n iterator_selector:\n - '1717916867969'\n - chunks\n output_selector:\n - '1717916991709'\n - text\n output_type: array[string]\n selected: false\n startNodeType: llm\n start_node_id: '1717916961837'\n title: Iteration\n type: iteration\n width: 1289\n height: 203\n id: '1717916955547'\n position:\n x: 638\n y: 301.5\n positionAbsolute:\n x: 638\n y: 301.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 1289\n zIndex: 1\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n isIterationStart: true\n iteration_id: '1717916955547'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: 7261280b-cb27-4f84-8363-b93e09246d16\n role: system\n text: \" Identify the technical terms in the users input. Use the following\\\n \\ format {XXX} -> {XXX} to show the corresponding technical terms before\\\n \\ and after translation. \\n\\n \\n{{#1717916955547.item#}}\\n\\\n \\n\\n| \\u82F1\\u6587 | \\u4E2D\\u6587 |\\n| --- | --- |\\n| Prompt\\\n \\ Engineering | \\u63D0\\u793A\\u8BCD\\u5DE5\\u7A0B |\\n| Text Generation \\_\\\n | \\u6587\\u672C\\u751F\\u6210 |\\n| Token \\_| Token |\\n| Prompt \\_| \\u63D0\\\n \\u793A\\u8BCD |\\n| Meta Prompting \\_| \\u5143\\u63D0\\u793A |\\n| diffusion\\\n \\ models \\_| \\u6269\\u6563\\u6A21\\u578B |\\n| Agent \\_| \\u667A\\u80FD\\u4F53\\\n \\ |\\n| Transformer \\_| Transformer |\\n| Zero Shot \\_| \\u96F6\\u6837\\u672C\\\n \\ |\\n| Few Shot \\_| \\u5C11\\u6837\\u672C |\\n| chat window \\_| \\u804A\\u5929\\\n \\ |\\n| context | \\u4E0A\\u4E0B\\u6587 |\\n| stock photo \\_| \\u56FE\\u5E93\\u7167\\\n \\u7247 |\\n\\n\\n \"\n selected: false\n title: 'Identify Terms '\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: true\n extent: parent\n height: 97\n id: '1717916961837'\n parentId: '1717916955547'\n position:\n x: 117\n y: 85\n positionAbsolute:\n x: 755\n y: 386.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1001\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n iteration_id: '1717916955547'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: 05e03f0d-c1a9-43ab-b4c0-44b55049434d\n role: system\n text: \" You are a professional translator proficient in Simplified\\\n \\ Chinese especially skilled in translating professional academic papers\\\n \\ into easy-to-understand popular science articles. Please help me translate\\\n \\ the following english paragraph into Chinese, in a style similar to\\\n \\ Chinese popular science articles .\\n \\nTranslate directly\\\n \\ based on the English content, maintain the original format and do not\\\n \\ omit any information. \\n \\n{{#1717916955547.item#}}\\n\\\n \\n{{#1717916961837.text#}}\\n \"\n selected: false\n title: 1st Translation\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: true\n extent: parent\n height: 97\n id: '1717916977413'\n parentId: '1717916955547'\n position:\n x: 421\n y: 85\n positionAbsolute:\n x: 1059\n y: 386.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n iteration_id: '1717916955547'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: 9e6cc050-465e-4632-abc9-411acb255a95\n role: system\n text: \"\\nBased on the results of the direct translation, point out\\\n \\ specific issues it have. Accurate descriptions are required, avoiding\\\n \\ vague statements, and there's no need to add content or formats that\\\n \\ were not present in the original text, including but not liimited to:\\\n \\ \\n- inconsistent with chinese expression habits, clearly indicate where\\\n \\ it does not conform\\n- Clumsy sentences, specify the location, no need\\\n \\ to offer suggestions for modification, which will be fixed during free\\\n \\ translation\\n- Obscure and difficult to understand, attempts to explain\\\n \\ may be made\\n- \\u65E0\\u6F0F\\u8BD1\\uFF08\\u539F\\u2F42\\u4E2D\\u7684\\u5173\\\n \\u952E\\u8BCD\\u3001\\u53E5\\u2F26\\u3001\\u6BB5\\u843D\\u90FD\\u5E94\\u4F53\\u73B0\\\n \\u5728\\u8BD1\\u2F42\\u4E2D\\uFF09\\u3002\\n- \\u2F46\\u9519\\u8BD1\\uFF08\\u770B\\\n \\u9519\\u539F\\u2F42\\u3001\\u8BEF\\u89E3\\u539F\\u2F42\\u610F\\u601D\\u5747\\u7B97\\\n \\u9519\\u8BD1\\uFF09\\u3002\\n- \\u2F46\\u6709\\u610F\\u589E\\u52A0\\u6216\\u8005\\\n \\u5220\\u51CF\\u7684\\u539F\\u2F42\\u5185\\u5BB9\\uFF08\\u7FFB\\u8BD1\\u5E76\\u2FAE\\\n \\u521B\\u4F5C\\uFF0C\\u9700\\u5C0A\\u91CD\\u4F5C\\u8005\\u89C2 \\u70B9\\uFF1B\\u53EF\\\n \\u4EE5\\u9002\\u5F53\\u52A0\\u8BD1\\u8005\\u6CE8\\u8BF4\\u660E\\uFF09\\u3002\\n-\\\n \\ \\u8BD1\\u2F42\\u6D41\\u7545\\uFF0C\\u7B26\\u5408\\u4E2D\\u2F42\\u8868\\u8FBE\\u4E60\\\n \\u60EF\\u3002\\n- \\u5173\\u4E8E\\u2F08\\u540D\\u7684\\u7FFB\\u8BD1\\u3002\\u6280\\\n \\u672F\\u56FE\\u4E66\\u4E2D\\u7684\\u2F08\\u540D\\u901A\\u5E38\\u4E0D\\u7FFB\\u8BD1\\\n \\uFF0C\\u4F46\\u662F\\u2F00\\u4E9B\\u4F17\\u6240 \\u5468\\u77E5\\u7684\\u2F08\\u540D\\\n \\u9700\\u2F64\\u4E2D\\u2F42\\uFF08\\u5982\\u4E54\\u5E03\\u65AF\\uFF09\\u3002\\n-\\\n \\ \\u5173\\u4E8E\\u4E66\\u540D\\u7684\\u7FFB\\u8BD1\\u3002\\u6709\\u4E2D\\u2F42\\u7248\\\n \\u7684\\u56FE\\u4E66\\uFF0C\\u8BF7\\u2F64\\u4E2D\\u2F42\\u7248\\u4E66\\u540D\\uFF1B\\\n \\u2F46\\u4E2D\\u2F42\\u7248 \\u7684\\u56FE\\u4E66\\uFF0C\\u76F4\\u63A5\\u2F64\\u82F1\\\n \\u2F42\\u4E66\\u540D\\u3002\\n- \\u5173\\u4E8E\\u56FE\\u8868\\u7684\\u7FFB\\u8BD1\\\n \\u3002\\u8868\\u683C\\u4E2D\\u7684\\u8868\\u9898\\u3001\\u8868\\u5B57\\u548C\\u6CE8\\\n \\u89E3\\u7B49\\u5747\\u9700\\u7FFB\\u8BD1\\u3002\\u56FE\\u9898 \\u9700\\u8981\\u7FFB\\\n \\u8BD1\\u3002\\u754C\\u2FAF\\u622A\\u56FE\\u4E0D\\u9700\\u8981\\u7FFB\\u8BD1\\u56FE\\\n \\u5B57\\u3002\\u89E3\\u91CA\\u6027\\u56FE\\u9700\\u8981\\u6309\\u7167\\u4E2D\\u82F1\\\n \\u2F42 \\u5BF9\\u7167\\u683C\\u5F0F\\u7ED9\\u51FA\\u56FE\\u5B57\\u7FFB\\u8BD1\\u3002\\\n \\n- \\u5173\\u4E8E\\u82F1\\u2F42\\u672F\\u8BED\\u7684\\u8868\\u8FF0\\u3002\\u82F1\\\n \\u2F42\\u672F\\u8BED\\u2FB8\\u6B21\\u51FA\\u73B0\\u65F6\\uFF0C\\u5E94\\u8BE5\\u6839\\\n \\u636E\\u8BE5\\u672F\\u8BED\\u7684 \\u6D41\\u2F8F\\u60C5\\u51B5\\uFF0C\\u4F18\\u5148\\\n \\u4F7F\\u2F64\\u7B80\\u5199\\u5F62\\u5F0F\\uFF0C\\u5E76\\u5728\\u5176\\u540E\\u4F7F\\\n \\u2F64\\u62EC\\u53F7\\u52A0\\u82F1\\u2F42\\u3001\\u4E2D\\u2F42 \\u5168\\u79F0\\u6CE8\\\n \\u89E3\\uFF0C\\u683C\\u5F0F\\u4E3A\\uFF08\\u4E3E\\u4F8B\\uFF09\\uFF1AHTML\\uFF08\\\n Hypertext Markup Language\\uFF0C\\u8D85\\u2F42\\u672C\\u6807\\u8BC6\\u8BED\\u2F94\\\n \\uFF09\\u3002\\u7136\\u540E\\u5728\\u4E0B\\u2F42\\u4E2D\\u76F4\\u63A5\\u4F7F\\u2F64\\\n \\u7B80\\u5199\\u5F62 \\u5F0F\\u3002\\u5F53\\u7136\\uFF0C\\u5FC5\\u8981\\u65F6\\u4E5F\\\n \\u53EF\\u4EE5\\u6839\\u636E\\u8BED\\u5883\\u4F7F\\u2F64\\u4E2D\\u3001\\u82F1\\u2F42\\\n \\u5168\\u79F0\\u3002\\n- \\u5173\\u4E8E\\u4EE3\\u7801\\u6E05\\u5355\\u548C\\u4EE3\\\n \\u7801\\u2F5A\\u6BB5\\u3002\\u539F\\u4E66\\u4E2D\\u5305\\u542B\\u7684\\u7A0B\\u5E8F\\\n \\u4EE3\\u7801\\u4E0D\\u8981\\u6C42\\u8BD1\\u8005\\u5F55 \\u2F0A\\uFF0C\\u4F46\\u5E94\\\n \\u8BE5\\u4F7F\\u2F64\\u201C\\u539F\\u4E66P99\\u2EDA\\u4EE3\\u78011\\u201D\\uFF08\\\n \\u5373\\u539F\\u4E66\\u7B2C99\\u2EDA\\u4E2D\\u7684\\u7B2C\\u2F00\\u6BB5\\u4EE3 \\u7801\\\n \\uFF09\\u7684\\u683C\\u5F0F\\u4F5C\\u51FA\\u6807\\u6CE8\\u3002\\u540C\\u65F6\\uFF0C\\\n \\u8BD1\\u8005\\u5E94\\u8BE5\\u5728\\u6709\\u6761\\u4EF6\\u7684\\u60C5\\u51B5\\u4E0B\\\n \\u68C0\\u6838\\u4EE3 \\u7801\\u7684\\u6B63\\u786E\\u6027\\uFF0C\\u5BF9\\u53D1\\u73B0\\\n \\u7684\\u9519\\u8BEF\\u4EE5\\u8BD1\\u8005\\u6CE8\\u5F62\\u5F0F\\u8BF4\\u660E\\u3002\\\n \\u7A0B\\u5E8F\\u4EE3\\u7801\\u4E2D\\u7684\\u6CE8 \\u91CA\\u8981\\u6C42\\u7FFB\\u8BD1\\\n \\uFF0C\\u5982\\u679C\\u8BD1\\u7A3F\\u4E2D\\u6CA1\\u6709\\u4EE3\\u7801\\uFF0C\\u5219\\\n \\u5E94\\u8BE5\\u4EE5\\u2F00\\u53E5\\u82F1\\u2F42\\uFF08\\u6CE8\\u91CA\\uFF09 \\u2F00\\\n \\u53E5\\u4E2D\\u2F42\\uFF08\\u6CE8\\u91CA\\uFF09\\u7684\\u5F62\\u5F0F\\u7ED9\\u51FA\\\n \\u6CE8\\u91CA\\u3002\\n- \\u5173\\u4E8E\\u6807\\u70B9\\u7B26\\u53F7\\u3002\\u8BD1\\\n \\u7A3F\\u4E2D\\u7684\\u6807\\u70B9\\u7B26\\u53F7\\u8981\\u9075\\u5FAA\\u4E2D\\u2F42\\\n \\u8868\\u8FBE\\u4E60\\u60EF\\u548C\\u4E2D\\u2F42\\u6807 \\u70B9\\u7B26\\u53F7\\u7684\\\n \\u4F7F\\u2F64\\u4E60\\u60EF\\uFF0C\\u4E0D\\u80FD\\u7167\\u642C\\u539F\\u2F42\\u7684\\\n \\u6807\\u70B9\\u7B26\\u53F7\\u3002\\n\\n\\n{{#1717916977413.text#}}\\n\\\n \\n{{#1717916955547.item#}}\\n\\n{{#1717916961837.text#}}\\n\\\n \"\n selected: false\n title: 'Problems '\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: true\n extent: parent\n height: 97\n id: '1717916984996'\n parentId: '1717916955547'\n position:\n x: 725\n y: 85\n positionAbsolute:\n x: 1363\n y: 386.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n iteration_id: '1717916955547'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: 4d7ae758-2d7b-4404-ad9f-d6748ee64439\n role: system\n text: \"\\nBased on the results of the direct translation in the first\\\n \\ step and the problems identified in the second step, re-translate to\\\n \\ achieve a meaning-based interpretation. Ensure the original intent of\\\n \\ the content is preserved while making it easier to understand and more\\\n \\ in line with Chinese expression habits. All the while maintaining the\\\n \\ original format unchanged. \\n\\n\\n- inconsistent with chinese\\\n \\ expression habits, clearly indicate where it does not conform\\n- Clumsy\\\n \\ sentences, specify the location, no need to offer suggestions for modification,\\\n \\ which will be fixed during free translation\\n- Obscure and difficult\\\n \\ to understand, attempts to explain may be made\\n- \\u65E0\\u6F0F\\u8BD1\\\n \\uFF08\\u539F\\u2F42\\u4E2D\\u7684\\u5173\\u952E\\u8BCD\\u3001\\u53E5\\u2F26\\u3001\\\n \\u6BB5\\u843D\\u90FD\\u5E94\\u4F53\\u73B0\\u5728\\u8BD1\\u2F42\\u4E2D\\uFF09\\u3002\\\n \\n- \\u2F46\\u9519\\u8BD1\\uFF08\\u770B\\u9519\\u539F\\u2F42\\u3001\\u8BEF\\u89E3\\\n \\u539F\\u2F42\\u610F\\u601D\\u5747\\u7B97\\u9519\\u8BD1\\uFF09\\u3002\\n- \\u2F46\\\n \\u6709\\u610F\\u589E\\u52A0\\u6216\\u8005\\u5220\\u51CF\\u7684\\u539F\\u2F42\\u5185\\\n \\u5BB9\\uFF08\\u7FFB\\u8BD1\\u5E76\\u2FAE\\u521B\\u4F5C\\uFF0C\\u9700\\u5C0A\\u91CD\\\n \\u4F5C\\u8005\\u89C2 \\u70B9\\uFF1B\\u53EF\\u4EE5\\u9002\\u5F53\\u52A0\\u8BD1\\u8005\\\n \\u6CE8\\u8BF4\\u660E\\uFF09\\u3002\\n- \\u8BD1\\u2F42\\u6D41\\u7545\\uFF0C\\u7B26\\\n \\u5408\\u4E2D\\u2F42\\u8868\\u8FBE\\u4E60\\u60EF\\u3002\\n- \\u5173\\u4E8E\\u2F08\\\n \\u540D\\u7684\\u7FFB\\u8BD1\\u3002\\u6280\\u672F\\u56FE\\u4E66\\u4E2D\\u7684\\u2F08\\\n \\u540D\\u901A\\u5E38\\u4E0D\\u7FFB\\u8BD1\\uFF0C\\u4F46\\u662F\\u2F00\\u4E9B\\u4F17\\\n \\u6240 \\u5468\\u77E5\\u7684\\u2F08\\u540D\\u9700\\u2F64\\u4E2D\\u2F42\\uFF08\\u5982\\\n \\u4E54\\u5E03\\u65AF\\uFF09\\u3002\\n- \\u5173\\u4E8E\\u4E66\\u540D\\u7684\\u7FFB\\\n \\u8BD1\\u3002\\u6709\\u4E2D\\u2F42\\u7248\\u7684\\u56FE\\u4E66\\uFF0C\\u8BF7\\u2F64\\\n \\u4E2D\\u2F42\\u7248\\u4E66\\u540D\\uFF1B\\u2F46\\u4E2D\\u2F42\\u7248 \\u7684\\u56FE\\\n \\u4E66\\uFF0C\\u76F4\\u63A5\\u2F64\\u82F1\\u2F42\\u4E66\\u540D\\u3002\\n- \\u5173\\\n \\u4E8E\\u56FE\\u8868\\u7684\\u7FFB\\u8BD1\\u3002\\u8868\\u683C\\u4E2D\\u7684\\u8868\\\n \\u9898\\u3001\\u8868\\u5B57\\u548C\\u6CE8\\u89E3\\u7B49\\u5747\\u9700\\u7FFB\\u8BD1\\\n \\u3002\\u56FE\\u9898 \\u9700\\u8981\\u7FFB\\u8BD1\\u3002\\u754C\\u2FAF\\u622A\\u56FE\\\n \\u4E0D\\u9700\\u8981\\u7FFB\\u8BD1\\u56FE\\u5B57\\u3002\\u89E3\\u91CA\\u6027\\u56FE\\\n \\u9700\\u8981\\u6309\\u7167\\u4E2D\\u82F1\\u2F42 \\u5BF9\\u7167\\u683C\\u5F0F\\u7ED9\\\n \\u51FA\\u56FE\\u5B57\\u7FFB\\u8BD1\\u3002\\n- \\u5173\\u4E8E\\u82F1\\u2F42\\u672F\\\n \\u8BED\\u7684\\u8868\\u8FF0\\u3002\\u82F1\\u2F42\\u672F\\u8BED\\u2FB8\\u6B21\\u51FA\\\n \\u73B0\\u65F6\\uFF0C\\u5E94\\u8BE5\\u6839\\u636E\\u8BE5\\u672F\\u8BED\\u7684 \\u6D41\\\n \\u2F8F\\u60C5\\u51B5\\uFF0C\\u4F18\\u5148\\u4F7F\\u2F64\\u7B80\\u5199\\u5F62\\u5F0F\\\n \\uFF0C\\u5E76\\u5728\\u5176\\u540E\\u4F7F\\u2F64\\u62EC\\u53F7\\u52A0\\u82F1\\u2F42\\\n \\u3001\\u4E2D\\u2F42 \\u5168\\u79F0\\u6CE8\\u89E3\\uFF0C\\u683C\\u5F0F\\u4E3A\\uFF08\\\n \\u4E3E\\u4F8B\\uFF09\\uFF1AHTML\\uFF08Hypertext Markup Language\\uFF0C\\u8D85\\\n \\u2F42\\u672C\\u6807\\u8BC6\\u8BED\\u2F94\\uFF09\\u3002\\u7136\\u540E\\u5728\\u4E0B\\\n \\u2F42\\u4E2D\\u76F4\\u63A5\\u4F7F\\u2F64\\u7B80\\u5199\\u5F62 \\u5F0F\\u3002\\u5F53\\\n \\u7136\\uFF0C\\u5FC5\\u8981\\u65F6\\u4E5F\\u53EF\\u4EE5\\u6839\\u636E\\u8BED\\u5883\\\n \\u4F7F\\u2F64\\u4E2D\\u3001\\u82F1\\u2F42\\u5168\\u79F0\\u3002\\n- \\u5173\\u4E8E\\\n \\u4EE3\\u7801\\u6E05\\u5355\\u548C\\u4EE3\\u7801\\u2F5A\\u6BB5\\u3002\\u539F\\u4E66\\\n \\u4E2D\\u5305\\u542B\\u7684\\u7A0B\\u5E8F\\u4EE3\\u7801\\u4E0D\\u8981\\u6C42\\u8BD1\\\n \\u8005\\u5F55 \\u2F0A\\uFF0C\\u4F46\\u5E94\\u8BE5\\u4F7F\\u2F64\\u201C\\u539F\\u4E66\\\n P99\\u2EDA\\u4EE3\\u78011\\u201D\\uFF08\\u5373\\u539F\\u4E66\\u7B2C99\\u2EDA\\u4E2D\\\n \\u7684\\u7B2C\\u2F00\\u6BB5\\u4EE3 \\u7801\\uFF09\\u7684\\u683C\\u5F0F\\u4F5C\\u51FA\\\n \\u6807\\u6CE8\\u3002\\u540C\\u65F6\\uFF0C\\u8BD1\\u8005\\u5E94\\u8BE5\\u5728\\u6709\\\n \\u6761\\u4EF6\\u7684\\u60C5\\u51B5\\u4E0B\\u68C0\\u6838\\u4EE3 \\u7801\\u7684\\u6B63\\\n \\u786E\\u6027\\uFF0C\\u5BF9\\u53D1\\u73B0\\u7684\\u9519\\u8BEF\\u4EE5\\u8BD1\\u8005\\\n \\u6CE8\\u5F62\\u5F0F\\u8BF4\\u660E\\u3002\\u7A0B\\u5E8F\\u4EE3\\u7801\\u4E2D\\u7684\\\n \\u6CE8 \\u91CA\\u8981\\u6C42\\u7FFB\\u8BD1\\uFF0C\\u5982\\u679C\\u8BD1\\u7A3F\\u4E2D\\\n \\u6CA1\\u6709\\u4EE3\\u7801\\uFF0C\\u5219\\u5E94\\u8BE5\\u4EE5\\u2F00\\u53E5\\u82F1\\\n \\u2F42\\uFF08\\u6CE8\\u91CA\\uFF09 \\u2F00\\u53E5\\u4E2D\\u2F42\\uFF08\\u6CE8\\u91CA\\\n \\uFF09\\u7684\\u5F62\\u5F0F\\u7ED9\\u51FA\\u6CE8\\u91CA\\u3002\\n- \\u5173\\u4E8E\\\n \\u6807\\u70B9\\u7B26\\u53F7\\u3002\\u8BD1\\u7A3F\\u4E2D\\u7684\\u6807\\u70B9\\u7B26\\\n \\u53F7\\u8981\\u9075\\u5FAA\\u4E2D\\u2F42\\u8868\\u8FBE\\u4E60\\u60EF\\u548C\\u4E2D\\\n \\u2F42\\u6807 \\u70B9\\u7B26\\u53F7\\u7684\\u4F7F\\u2F64\\u4E60\\u60EF\\uFF0C\\u4E0D\\\n \\u80FD\\u7167\\u642C\\u539F\\u2F42\\u7684\\u6807\\u70B9\\u7B26\\u53F7\\u3002\\n\\n\\\n \\n{{#1717916977413.text#}}\\n\\n{{#1717916984996.text#}}\\n\\n{{#1711067409646.input_text#}}\\n\\\n \\n{{#1717916961837.text#}}\\n \"\n selected: false\n title: '2nd Translation '\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: true\n extent: parent\n height: 97\n id: '1717916991709'\n parentId: '1717916955547'\n position:\n x: 1029\n y: 85\n positionAbsolute:\n x: 1667\n y: 386.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: 'Combine all chunks of translation. '\n selected: false\n template: '{{ translated_text | join('' '') }}'\n title: Template\n type: template-transform\n variables:\n - value_selector:\n - '1717916955547'\n - output\n variable: translated_text\n height: 83\n id: '1717917057450'\n position:\n x: 1987\n y: 301.5\n positionAbsolute:\n x: 1987\n y: 301.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n author: Dify\n desc: ''\n height: 186\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Code\n node separates the input_text into chunks with length of token_limit. Each\n chunk overlap with each other to make sure the texts are consistent. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"The\n code node outputs an array of segmented texts of input_texts. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 340\n height: 186\n id: '1718990593686'\n position:\n x: 259.3026056936437\n y: 451.6924912936374\n positionAbsolute:\n x: 259.3026056936437\n y: 451.6924912936374\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 340\n - data:\n author: Dify\n desc: ''\n height: 128\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Iterate\n through all the elements in output of the code node and translate each chunk\n using a three steps translation workflow. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 355\n height: 128\n id: '1718991836605'\n position:\n x: 764.3891977435923\n y: 530.8917807505335\n positionAbsolute:\n x: 764.3891977435923\n y: 530.8917807505335\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 355\n - data:\n author: Dify\n desc: ''\n height: 126\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Avoid\n using a high token_limit, LLM''s performance decreases with longer context\n length for gpt-4o. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Recommend\n to use less than or equal to 1000 tokens. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: yellow\n title: ''\n type: ''\n width: 351\n height: 126\n id: '1718991882984'\n position:\n x: 304.49115824454367\n y: 148.4042994607805\n positionAbsolute:\n x: 304.49115824454367\n y: 148.4042994607805\n selected: true\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 351\n viewport:\n x: 335.92505067152274\n y: 18.806553508850584\n zoom: 0.8705505632961259\n","icon":"\ud83e\udd16","icon_background":"#FFEAD5","id":"98b87f88-bd22-4d86-8b74-86beba5e0ed4","mode":"workflow","name":"Book Translation "}, + "e9d92058-7d20-4904-892f-75d90bef7587":{"export_data":"app:\n icon: \"\\U0001F916\"\n icon_background: '#FFEAD5'\n mode: advanced-chat\n name: 'Automated Email Reply '\nworkflow:\n features:\n file_upload:\n image:\n enabled: false\n opening_statement: ''\n retriever_resource:\n enabled: false\n sensitive_word_avoidance:\n enabled: false\n speech_to_text:\n enabled: false\n suggested_questions: []\n suggested_questions_after_answer:\n enabled: false\n text_to_speech:\n enabled: false\n language: ''\n voice: ''\n graph:\n edges:\n - data:\n isInIteration: false\n sourceType: code\n targetType: iteration\n id: 1716909112104-source-1716909114582-target\n source: '1716909112104'\n sourceHandle: source\n target: '1716909114582'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: false\n sourceType: iteration\n targetType: template-transform\n id: 1716909114582-source-1716913435742-target\n source: '1716909114582'\n sourceHandle: source\n target: '1716913435742'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: false\n sourceType: template-transform\n targetType: answer\n id: 1716913435742-source-1716806267180-target\n source: '1716913435742'\n sourceHandle: source\n target: '1716806267180'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: false\n sourceType: start\n targetType: tool\n id: 1716800588219-source-1716946869294-target\n source: '1716800588219'\n sourceHandle: source\n target: '1716946869294'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: false\n sourceType: tool\n targetType: code\n id: 1716946869294-source-1716909112104-target\n source: '1716946869294'\n sourceHandle: source\n target: '1716909112104'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: tool\n targetType: code\n id: 1716946889408-source-1716909122343-target\n source: '1716946889408'\n sourceHandle: source\n target: '1716909122343'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: code\n targetType: code\n id: 1716909122343-source-1716951357236-target\n source: '1716909122343'\n sourceHandle: source\n target: '1716951357236'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: code\n targetType: llm\n id: 1716951357236-source-1716913272656-target\n source: '1716951357236'\n sourceHandle: source\n target: '1716913272656'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: template-transform\n targetType: llm\n id: 1716951236700-source-1716951159073-target\n source: '1716951236700'\n sourceHandle: source\n target: '1716951159073'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: llm\n targetType: template-transform\n id: 1716951159073-source-1716952228079-target\n source: '1716951159073'\n sourceHandle: source\n target: '1716952228079'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: template-transform\n targetType: tool\n id: 1716952228079-source-1716952912103-target\n source: '1716952228079'\n sourceHandle: source\n target: '1716952912103'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: llm\n targetType: question-classifier\n id: 1716913272656-source-1716960721611-target\n source: '1716913272656'\n sourceHandle: source\n target: '1716960721611'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: question-classifier\n targetType: llm\n id: 1716960721611-1-1716909125498-target\n source: '1716960721611'\n sourceHandle: '1'\n target: '1716909125498'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: question-classifier\n targetType: llm\n id: 1716960721611-2-1716960728136-target\n source: '1716960721611'\n sourceHandle: '2'\n target: '1716960728136'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: llm\n targetType: variable-aggregator\n id: 1716909125498-source-1716960791399-target\n source: '1716909125498'\n sourceHandle: source\n target: '1716960791399'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: variable-aggregator\n targetType: template-transform\n id: 1716960791399-source-1716951236700-target\n source: '1716960791399'\n sourceHandle: source\n target: '1716951236700'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: question-classifier\n targetType: template-transform\n id: 1716960721611-1716960736883-1716960834468-target\n source: '1716960721611'\n sourceHandle: '1716960736883'\n target: '1716960834468'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: llm\n targetType: variable-aggregator\n id: 1716960728136-source-1716960791399-target\n source: '1716960728136'\n sourceHandle: source\n target: '1716960791399'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: template-transform\n targetType: variable-aggregator\n id: 1716960834468-source-1716960791399-target\n source: '1716960834468'\n sourceHandle: source\n target: '1716960791399'\n targetHandle: target\n type: custom\n zIndex: 1002\n nodes:\n - data:\n desc: ''\n selected: false\n title: Start\n type: start\n variables:\n - label: Your Email\n max_length: 256\n options: []\n required: true\n type: text-input\n variable: email\n - label: Maximum Number of Email you want to retrieve\n max_length: 256\n options: []\n required: true\n type: number\n variable: maxResults\n height: 115\n id: '1716800588219'\n position:\n x: 30\n y: 445\n positionAbsolute:\n x: 30\n y: 445\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n answer: '{{#1716913435742.output#}}'\n desc: ''\n selected: false\n title: Direct Reply\n type: answer\n variables: []\n height: 106\n id: '1716806267180'\n position:\n x: 4700\n y: 445\n positionAbsolute:\n x: 4700\n y: 445\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n code: \"def main(message: str) -> dict:\\n import json\\n \\n # Parse\\\n \\ the JSON string\\n parsed_data = json.loads(message)\\n \\n # Extract\\\n \\ all the \\\"id\\\" values\\n ids = [msg['id'] for msg in parsed_data['messages']]\\n\\\n \\ \\n return {\\n \\\"result\\\": ids\\n }\"\n code_language: python3\n desc: ''\n outputs:\n result:\n children: null\n type: array[string]\n selected: false\n title: 'Code: Extract Email ID'\n type: code\n variables:\n - value_selector:\n - '1716946869294'\n - text\n variable: message\n height: 53\n id: '1716909112104'\n position:\n x: 638\n y: 445\n positionAbsolute:\n x: 638\n y: 445\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n desc: ''\n height: 490\n iterator_selector:\n - '1716909112104'\n - result\n output_selector:\n - '1716909125498'\n - text\n output_type: array[string]\n selected: false\n startNodeType: tool\n start_node_id: '1716946889408'\n title: 'Iteraction '\n type: iteration\n width: 3393.7520359289056\n height: 490\n id: '1716909114582'\n position:\n x: 942\n y: 445\n positionAbsolute:\n x: 942\n y: 445\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 3394\n zIndex: 1\n - data:\n desc: ''\n isInIteration: true\n isIterationStart: true\n iteration_id: '1716909114582'\n provider_id: e64b4c7f-2795-499c-8d11-a971a7d57fc9\n provider_name: List and Get Gmail\n provider_type: api\n selected: false\n title: getMessage\n tool_configurations: {}\n tool_label: getMessage\n tool_name: getMessage\n tool_parameters:\n format:\n type: mixed\n value: full\n id:\n type: mixed\n value: '{{#1716909114582.item#}}'\n userId:\n type: mixed\n value: '{{#1716800588219.email#}}'\n type: tool\n extent: parent\n height: 53\n id: '1716946889408'\n parentId: '1716909114582'\n position:\n x: 117\n y: 85\n positionAbsolute:\n x: 1059\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1001\n - data:\n code: \"\\ndef main(email_json: dict) -> dict:\\n import json \\n email_dict\\\n \\ = json.loads(email_json)\\n base64_data = email_dict['payload']['parts'][0]['body']['data']\\n\\\n \\n return {\\n \\\"result\\\": base64_data, \\n }\\n\"\n code_language: python3\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n outputs:\n result:\n children: null\n type: string\n selected: false\n title: 'Code: Extract Email Body'\n type: code\n variables:\n - value_selector:\n - '1716946889408'\n - text\n variable: email_json\n extent: parent\n height: 53\n id: '1716909122343'\n parentId: '1716909114582'\n position:\n x: 421\n y: 85\n positionAbsolute:\n x: 1363\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: 'Generate reply. '\n isInIteration: true\n iteration_id: '1716909114582'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: 982014aa-702b-4d7c-ae1f-08dbceb6e930\n role: system\n text: \" \\nRespond to the emails. \\n\\n{{#1716913272656.text#}}\\n\\\n \"\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: true\n extent: parent\n height: 127\n id: '1716909125498'\n parentId: '1716909114582'\n position:\n x: 1625\n y: 85\n positionAbsolute:\n x: 2567\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: fd8de569-c099-4320-955b-61aa4b054789\n role: system\n text: \"\\nYou need to transform the input data (in base64 encoding)\\\n \\ to text. Input base64. Output text. \\n\\n{{#1716909122343.result#}}\\n\\\n \"\n selected: false\n title: 'Base64 Decoder '\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: false\n extent: parent\n height: 97\n id: '1716913272656'\n parentId: '1716909114582'\n position:\n x: 1025\n y: 85\n positionAbsolute:\n x: 1967\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 | join(\"\\n\\n -------------------------\\n\\n\") }}'\n title: 'Template '\n type: template-transform\n variables:\n - value_selector:\n - '1716909114582'\n - output\n variable: arg1\n height: 53\n id: '1716913435742'\n position:\n x: 4396\n y: 445\n positionAbsolute:\n x: 4396\n y: 445\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n desc: ''\n provider_id: e64b4c7f-2795-499c-8d11-a971a7d57fc9\n provider_name: List and Get Gmail\n provider_type: api\n selected: false\n title: listMessages\n tool_configurations: {}\n tool_label: listMessages\n tool_name: listMessages\n tool_parameters:\n maxResults:\n type: variable\n value:\n - '1716800588219'\n - maxResults\n userId:\n type: mixed\n value: '{{#1716800588219.email#}}'\n type: tool\n height: 53\n id: '1716946869294'\n position:\n x: 334\n y: 445\n positionAbsolute:\n x: 334\n y: 445\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: b7fd0ec5-864a-42c6-9d04-a1958bd4fc0d\n role: system\n text: \"\\nYou need to encode the input data from text to base64. Input\\\n \\ text. Output base64 encoding. Output nothing other than base64 encoding.\\\n \\ \\n\\n{{#1716951236700.output#}}\\n \"\n selected: false\n title: Base64 Encoder\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: true\n extent: parent\n height: 97\n id: '1716951159073'\n parentId: '1716909114582'\n position:\n x: 2525.7520359289056\n y: 85\n positionAbsolute:\n x: 3467.7520359289056\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: Generate MIME email template\n isInIteration: true\n iteration_id: '1716909114582'\n selected: false\n template: \"Content-Type: text/plain; charset=\\\"utf-8\\\"\\r\\nContent-Transfer-Encoding:\\\n \\ 7bit\\r\\nMIME-Version: 1.0\\r\\nTo: {{ emailMetadata.recipientEmail }} #\\\n \\ xiaoyi@dify.ai\\r\\nFrom: {{ emailMetadata.senderEmail }} # sxy.hj156@gmail.com\\r\\\n \\nSubject: Re: {{ emailMetadata.subject }} \\r\\n\\r\\n{{ text }}\\r\\n\"\n title: 'Template: Reply Email'\n type: template-transform\n variables:\n - value_selector:\n - '1716951357236'\n - result\n variable: emailMetadata\n - value_selector:\n - '1716960791399'\n - output\n variable: text\n extent: parent\n height: 83\n id: '1716951236700'\n parentId: '1716909114582'\n position:\n x: 2231.269960149744\n y: 85\n positionAbsolute:\n x: 3173.269960149744\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n code: \"def main(email_json: dict) -> dict:\\n import json\\n if isinstance(email_json,\\\n \\ str): \\n email_json = json.loads(email_json)\\n\\n subject = None\\n\\\n \\ recipient_email = None \\n sender_email = None\\n \\n headers\\\n \\ = email_json['payload']['headers']\\n for header in headers:\\n \\\n \\ if header['name'] == 'Subject':\\n subject = header['value']\\n\\\n \\ elif header['name'] == 'To':\\n recipient_email = header['value']\\n\\\n \\ elif header['name'] == 'From':\\n sender_email = header['value']\\n\\\n \\n return {\\n \\\"result\\\": [subject, recipient_email, sender_email]\\n\\\n \\ }\\n\"\n code_language: python3\n desc: \"Recipient, Sender, Subject\\uFF0COutput Array[String]\"\n isInIteration: true\n iteration_id: '1716909114582'\n outputs:\n result:\n children: null\n type: array[string]\n selected: false\n title: Extract Email Metadata\n type: code\n variables:\n - value_selector:\n - '1716946889408'\n - text\n variable: email_json\n extent: parent\n height: 101\n id: '1716951357236'\n parentId: '1716909114582'\n position:\n x: 725\n y: 85\n positionAbsolute:\n x: 1667\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n selected: false\n template: '{\"raw\": \"{{ encoded_message }}\"}'\n title: \"Template\\uFF1AEmail Request Body\"\n type: template-transform\n variables:\n - value_selector:\n - '1716951159073'\n - text\n variable: encoded_message\n extent: parent\n height: 53\n id: '1716952228079'\n parentId: '1716909114582'\n position:\n x: 2828.4325280181324\n y: 86.31950791077293\n positionAbsolute:\n x: 3770.4325280181324\n y: 531.3195079107729\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n provider_id: 038963aa-43c8-47fc-be4b-0255c19959c1\n provider_name: Draft Gmail\n provider_type: api\n selected: false\n title: createDraft\n tool_configurations: {}\n tool_label: createDraft\n tool_name: createDraft\n tool_parameters:\n message:\n type: mixed\n value: '{{#1716952228079.output#}}'\n userId:\n type: mixed\n value: '{{#1716800588219.email#}}'\n type: tool\n extent: parent\n height: 53\n id: '1716952912103'\n parentId: '1716909114582'\n position:\n x: 3133.7520359289056\n y: 85\n positionAbsolute:\n x: 4075.7520359289056\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n classes:\n - id: '1'\n name: 'Technical questions, related to product '\n - id: '2'\n name: Unrelated to technicals, non technical\n - id: '1716960736883'\n name: Other questions\n desc: ''\n instructions: ''\n isInIteration: true\n iteration_id: '1716909114582'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n query_variable_selector:\n - '1716800588219'\n - sys.query\n selected: false\n title: Question Classifier\n topics: []\n type: question-classifier\n extent: parent\n height: 255\n id: '1716960721611'\n parentId: '1716909114582'\n position:\n x: 1325\n y: 85\n positionAbsolute:\n x: 2267\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - id: a639bbf8-bc58-42a2-b477-6748e80ecda2\n role: system\n text: \" \\nRespond to the emails. \\n\\n{{#1716913272656.text#}}\\n\\\n \"\n selected: false\n title: 'LLM - Non technical '\n type: llm\n variables: []\n vision:\n enabled: false\n extent: parent\n height: 97\n id: '1716960728136'\n parentId: '1716909114582'\n position:\n x: 1625\n y: 251\n positionAbsolute:\n x: 2567\n y: 696\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n output_type: string\n selected: false\n title: Variable Aggregator\n type: variable-aggregator\n variables:\n - - '1716909125498'\n - text\n - - '1716960728136'\n - text\n - - '1716960834468'\n - output\n extent: parent\n height: 164\n id: '1716960791399'\n parentId: '1716909114582'\n position:\n x: 1931.2699601497438\n y: 85\n positionAbsolute:\n x: 2873.269960149744\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: Other questions\n isInIteration: true\n iteration_id: '1716909114582'\n selected: false\n template: 'Sorry, I cannot answer that. This is outside my capabilities. '\n title: 'Direct Reply '\n type: template-transform\n variables: []\n extent: parent\n height: 83\n id: '1716960834468'\n parentId: '1716909114582'\n position:\n x: 1625\n y: 385.57142857142856\n positionAbsolute:\n x: 2567\n y: 830.5714285714286\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n author: Dify\n desc: ''\n height: 153\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":3,\"mode\":\"normal\",\"style\":\"font-size:\n 14px;\",\"text\":\"OpenAPI-Swagger for all custom tools: \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":3},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"openapi:\n 3.0.0\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"info:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" title:\n Gmail API\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n OpenAPI schema for Gmail API methods `users.messages.get`, `users.messages.list`,\n and `users.drafts.create`.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" version:\n 1.0.0\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"servers:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n url: https://gmail.googleapis.com\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Gmail API Server\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"paths:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" /gmail/v1/users/{userId}/messages/{id}:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" get:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" summary:\n Get a message.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Retrieves a specific message by ID.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" operationId:\n getMessage\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" parameters:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n name: userId\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" in:\n path\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" required:\n true\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The user''s email address. The special value `me` can be used to indicate\n the authenticated user.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n name: id\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" in:\n path\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" required:\n true\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The ID of the message to retrieve.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n name: format\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" in:\n query\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" required:\n false\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" enum:\n [full, metadata, minimal, raw]\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" default:\n full\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The format to return the message in.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" responses:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''200'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Successful response\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" content:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" application/json:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" id:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" threadId:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" labelIds:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n array\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" items:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" snippet:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" historyId:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" internalDate:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" payload:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" sizeEstimate:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n integer\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" raw:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''401'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Unauthorized\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''403'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Forbidden\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''404'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Not Found\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" /gmail/v1/users/{userId}/messages:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" get:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" summary:\n List messages.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Lists the messages in the user''s mailbox.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" operationId:\n listMessages\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" parameters:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n name: userId\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" in:\n path\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" required:\n true\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The user''s email address. The special value `me` can be used to indicate\n the authenticated user.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n name: maxResults\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" in:\n query\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n integer\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" format:\n int32\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" default:\n 100\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Maximum number of messages to return.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" responses:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''200'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Successful response\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" content:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" application/json:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" messages:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n array\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" items:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" id:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" threadId:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" nextPageToken:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" resultSizeEstimate:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n integer\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''401'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Unauthorized\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" /gmail/v1/users/{userId}/drafts:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" post:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" summary:\n Creates a new draft.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" operationId:\n createDraft\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" tags:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n Drafts\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" parameters:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n name: userId\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" in:\n path\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" required:\n true\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The user''s email address. The special value \\\"me\\\" can be used to indicate\n the authenticated user.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" requestBody:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" required:\n true\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" content:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" application/json:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" message:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" raw:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The entire email message in an RFC 2822 formatted and base64url encoded\n string.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" responses:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''200'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Successful response with the created draft.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" content:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" application/json:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" id:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The immutable ID of the draft.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" message:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" id:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The immutable ID of the message.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" threadId:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The ID of the thread the message belongs to.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" labelIds:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n array\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" items:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" snippet:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n A short part of the message text.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" historyId:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The ID of the last history record that modified this message.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''400'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Bad Request - The request is invalid.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''401'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Unauthorized - Authentication is required.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''403'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Forbidden - The user does not have permission to create drafts.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''404'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Not Found - The specified user does not exist.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''500'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Internal Server Error - An error occurred on the server.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"components:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" securitySchemes:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" OAuth2:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n oauth2\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" flows:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" authorizationCode:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" authorizationUrl:\n https://accounts.google.com/o/oauth2/auth\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" tokenUrl:\n https://oauth2.googleapis.com/token\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" scopes:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" https://mail.google.com/:\n All access to Gmail.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" https://www.googleapis.com/auth/gmail.compose:\n Send email on your behalf.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" https://www.googleapis.com/auth/gmail.modify:\n Modify your email.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"security:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n OAuth2:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n https://mail.google.com/\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n https://www.googleapis.com/auth/gmail.compose\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n https://www.googleapis.com/auth/gmail.modify\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: yellow\n title: ''\n type: ''\n width: 367\n height: 153\n id: '1718992681576'\n position:\n x: 321.9646831030669\n y: 538.1642616264143\n positionAbsolute:\n x: 321.9646831030669\n y: 538.1642616264143\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 367\n - data:\n author: Dify\n desc: ''\n height: 158\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Replace\n custom tools after added this template to your own workspace. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Fill\n in \",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"\",\"text\":\"your\n email \",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"and\n the \",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"\",\"text\":\"maximum\n number of results you want to retrieve from your inbox \",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"to\n get started. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 287\n height: 158\n id: '1718992805687'\n position:\n x: 18.571428571428356\n y: 237.80887395992687\n positionAbsolute:\n x: 18.571428571428356\n y: 237.80887395992687\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 287\n - data:\n author: Dify\n desc: ''\n height: 375\n selected: true\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"font-size:\n 16px;\",\"text\":\"Steps within Iteraction node: \",\"type\":\"text\",\"version\":1},{\"type\":\"linebreak\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"1.\n getMessage: This step retrieves the incoming email message.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":1},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"2.\n Code: Extract Email Body: Custom code is executed to extract the body of\n the email from the retrieved message.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"3.\n Extract Email Metadata: Extracts metadata from the email, such as the recipient,\n sender, subject, and other relevant information.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"4.\n Base64 Decoder: Decodes the email content from Base64 encoding.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"5.\n Question Classifier (gpt-3.5-turbo): Uses a GPT-3.5-turbo model to classify\n the email content into different categories. For each classified question,\n the workflow uses a GPT-4.0 model to generate an appropriate reply:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"6.\n Template: Reply Email: Uses a template to generate a MIME email format for\n the reply.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"6.\n Base64 Encoder: Encodes the generated reply email content back to Base64.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"7.\n Template: Email Request: Prepares the email request using a template.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"8.\n createDraft: Creates a draft of the email reply.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"This\n workflow automates the process of reading, classifying, responding to, and\n drafting replies to incoming emails, leveraging advanced language models\n to generate contextually appropriate responses.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 640\n height: 375\n id: '1718993366836'\n position:\n x: 966.7525290975368\n y: 971.80362905854\n positionAbsolute:\n x: 966.7525290975368\n y: 971.80362905854\n selected: true\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 640\n - data:\n author: Dify\n desc: ''\n height: 400\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":3,\"mode\":\"normal\",\"style\":\"font-size:\n 16px;\",\"text\":\"Preparation\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":3},{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Enable\n Gmail API in Google Cloud Console\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"listitem\",\"version\":1,\"value\":1},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Configure\n OAuth Client ID, OAuth Client Secrets, and OAuth Consent Screen for the\n Web Application in Google Cloud Console\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"listitem\",\"version\":1,\"value\":2},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Use\n Postman to authorize and obtain the OAuth Access Token (Google''s Access\n Token will expire after 1 hour and cannot be used for a long time)\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"listitem\",\"version\":1,\"value\":3}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"list\",\"version\":1,\"listType\":\"bullet\",\"start\":1,\"tag\":\"ul\"},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Users\n who want to try building an AI auto-reply email can refer to this document\n to use Postman (Postman.com) to obtain all the above keys: https://blog.postman.com/how-to-access-google-apis-using-oauth-in-postman/.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Developers\n who want to use Google OAuth to call the Gmail API to develop corresponding\n plugins can refer to this official document: \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"https://developers.google.com/identity/protocols/oauth2/web-server.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"At\n this stage, it is still a bit difficult to reproduce this example within\n the Dify platform. If you have development capabilities, developing the\n corresponding plugin externally and using an external database to automatically\n read and write the user''s Access Token and write the Refresh Token would\n be a better choice.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 608\n height: 400\n id: '1718993557447'\n position:\n x: 354.0157230378119\n y: -1.2732157979666\n positionAbsolute:\n x: 354.0157230378119\n y: -1.2732157979666\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 608\n viewport:\n x: 147.09446825757777\n y: 101.03530130020579\n zoom: 0.9548416039104178\n","icon":"\ud83e\udd16","icon_background":"#FFEAD5","id":"e9d92058-7d20-4904-892f-75d90bef7587","mode":"advanced-chat","name":"Automated Email Reply "}, + "98b87f88-bd22-4d86-8b74-86beba5e0ed4":{"export_data":"app:\n icon: \"\\U0001F916\"\n icon_background: '#FFEAD5'\n mode: workflow\n name: 'Book Translation '\nworkflow:\n features:\n file_upload:\n image:\n enabled: false\n number_limits: 3\n transfer_methods:\n - local_file\n - remote_url\n opening_statement: ''\n retriever_resource:\n enabled: false\n sensitive_word_avoidance:\n enabled: false\n speech_to_text:\n enabled: false\n suggested_questions: []\n suggested_questions_after_answer:\n enabled: false\n text_to_speech:\n enabled: false\n language: ''\n voice: ''\n graph:\n edges:\n - data:\n isInIteration: false\n sourceType: start\n targetType: code\n id: 1711067409646-source-1717916867969-target\n source: '1711067409646'\n sourceHandle: source\n target: '1717916867969'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: false\n sourceType: code\n targetType: iteration\n id: 1717916867969-source-1717916955547-target\n source: '1717916867969'\n sourceHandle: source\n target: '1717916955547'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: true\n iteration_id: '1717916955547'\n sourceType: llm\n targetType: llm\n id: 1717916961837-source-1717916977413-target\n source: '1717916961837'\n sourceHandle: source\n target: '1717916977413'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1717916955547'\n sourceType: llm\n targetType: llm\n id: 1717916977413-source-1717916984996-target\n source: '1717916977413'\n sourceHandle: source\n target: '1717916984996'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1717916955547'\n sourceType: llm\n targetType: llm\n id: 1717916984996-source-1717916991709-target\n source: '1717916984996'\n sourceHandle: source\n target: '1717916991709'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: false\n sourceType: iteration\n targetType: template-transform\n id: 1717916955547-source-1717917057450-target\n source: '1717916955547'\n sourceHandle: source\n target: '1717917057450'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: false\n sourceType: template-transform\n targetType: end\n id: 1717917057450-source-1711068257370-target\n source: '1717917057450'\n sourceHandle: source\n target: '1711068257370'\n targetHandle: target\n type: custom\n zIndex: 0\n nodes:\n - data:\n desc: ''\n selected: false\n title: Start\n type: start\n variables:\n - label: Input Text\n max_length: null\n options: []\n required: true\n type: paragraph\n variable: input_text\n dragging: false\n height: 89\n id: '1711067409646'\n position:\n x: 30\n y: 301.5\n positionAbsolute:\n x: 30\n y: 301.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1717917057450'\n - output\n variable: final\n selected: false\n title: End\n type: end\n height: 89\n id: '1711068257370'\n position:\n x: 2291\n y: 301.5\n positionAbsolute:\n x: 2291\n y: 301.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n code: \"\\ndef main(input_text: str) -> str:\\n token_limit = 1000\\n overlap\\\n \\ = 100\\n chunk_size = int(token_limit * 6 * (4/3))\\n\\n # Initialize\\\n \\ variables\\n chunks = []\\n start_index = 0\\n text_length = len(input_text)\\n\\\n \\n # Loop until the end of the text is reached\\n while start_index\\\n \\ < text_length:\\n # If we are not at the beginning, adjust the start_index\\\n \\ to ensure overlap\\n if start_index > 0:\\n start_index\\\n \\ -= overlap\\n\\n # Calculate end index for the current chunk\\n \\\n \\ end_index = start_index + chunk_size\\n if end_index > text_length:\\n\\\n \\ end_index = text_length\\n\\n # Add the current chunk\\\n \\ to the list\\n chunks.append(input_text[start_index:end_index])\\n\\\n \\n # Update the start_index for the next chunk\\n start_index\\\n \\ += chunk_size\\n\\n return {\\n \\\"chunks\\\": chunks,\\n }\\n\"\n code_language: python3\n dependencies: []\n desc: 'token_limit = 1000\n\n overlap = 100'\n outputs:\n chunks:\n children: null\n type: array[string]\n selected: false\n title: Code\n type: code\n variables:\n - value_selector:\n - '1711067409646'\n - input_text\n variable: input_text\n height: 101\n id: '1717916867969'\n position:\n x: 336\n y: 301.5\n positionAbsolute:\n x: 336\n y: 301.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n desc: 'Take good care on maximum number of iterations. '\n height: 203\n iterator_selector:\n - '1717916867969'\n - chunks\n output_selector:\n - '1717916991709'\n - text\n output_type: array[string]\n selected: false\n startNodeType: llm\n start_node_id: '1717916961837'\n title: Iteration\n type: iteration\n width: 1289\n height: 203\n id: '1717916955547'\n position:\n x: 638\n y: 301.5\n positionAbsolute:\n x: 638\n y: 301.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 1289\n zIndex: 1\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n isIterationStart: true\n iteration_id: '1717916955547'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: 7261280b-cb27-4f84-8363-b93e09246d16\n role: system\n text: \" Identify the technical terms in the users input. Use the following\\\n \\ format {XXX} -> {XXX} to show the corresponding technical terms before\\\n \\ and after translation. \\n\\n \\n{{#1717916955547.item#}}\\n\\\n \\n\\n| \\u82F1\\u6587 | \\u4E2D\\u6587 |\\n| --- | --- |\\n| Prompt\\\n \\ Engineering | \\u63D0\\u793A\\u8BCD\\u5DE5\\u7A0B |\\n| Text Generation \\_\\\n | \\u6587\\u672C\\u751F\\u6210 |\\n| Token \\_| Token |\\n| Prompt \\_| \\u63D0\\\n \\u793A\\u8BCD |\\n| Meta Prompting \\_| \\u5143\\u63D0\\u793A |\\n| diffusion\\\n \\ models \\_| \\u6269\\u6563\\u6A21\\u578B |\\n| Agent \\_| \\u667A\\u80FD\\u4F53\\\n \\ |\\n| Transformer \\_| Transformer |\\n| Zero Shot \\_| \\u96F6\\u6837\\u672C\\\n \\ |\\n| Few Shot \\_| \\u5C11\\u6837\\u672C |\\n| chat window \\_| \\u804A\\u5929\\\n \\ |\\n| context | \\u4E0A\\u4E0B\\u6587 |\\n| stock photo \\_| \\u56FE\\u5E93\\u7167\\\n \\u7247 |\\n\\n\\n \"\n selected: false\n title: 'Identify Terms '\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: true\n extent: parent\n height: 97\n id: '1717916961837'\n parentId: '1717916955547'\n position:\n x: 117\n y: 85\n positionAbsolute:\n x: 755\n y: 386.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1001\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n iteration_id: '1717916955547'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: 05e03f0d-c1a9-43ab-b4c0-44b55049434d\n role: system\n text: \" You are a professional translator proficient in Simplified\\\n \\ Chinese especially skilled in translating professional academic papers\\\n \\ into easy-to-understand popular science articles. Please help me translate\\\n \\ the following english paragraph into Chinese, in a style similar to\\\n \\ Chinese popular science articles .\\n \\nTranslate directly\\\n \\ based on the English content, maintain the original format and do not\\\n \\ omit any information. \\n \\n{{#1717916955547.item#}}\\n\\\n \\n{{#1717916961837.text#}}\\n \"\n selected: false\n title: 1st Translation\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: true\n extent: parent\n height: 97\n id: '1717916977413'\n parentId: '1717916955547'\n position:\n x: 421\n y: 85\n positionAbsolute:\n x: 1059\n y: 386.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n iteration_id: '1717916955547'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: 9e6cc050-465e-4632-abc9-411acb255a95\n role: system\n text: \"\\nBased on the results of the direct translation, point out\\\n \\ specific issues it have. Accurate descriptions are required, avoiding\\\n \\ vague statements, and there's no need to add content or formats that\\\n \\ were not present in the original text, including but not limited to:\\\n \\ \\n- inconsistent with chinese expression habits, clearly indicate where\\\n \\ it does not conform\\n- Clumsy sentences, specify the location, no need\\\n \\ to offer suggestions for modification, which will be fixed during free\\\n \\ translation\\n- Obscure and difficult to understand, attempts to explain\\\n \\ may be made\\n- \\u65E0\\u6F0F\\u8BD1\\uFF08\\u539F\\u2F42\\u4E2D\\u7684\\u5173\\\n \\u952E\\u8BCD\\u3001\\u53E5\\u2F26\\u3001\\u6BB5\\u843D\\u90FD\\u5E94\\u4F53\\u73B0\\\n \\u5728\\u8BD1\\u2F42\\u4E2D\\uFF09\\u3002\\n- \\u2F46\\u9519\\u8BD1\\uFF08\\u770B\\\n \\u9519\\u539F\\u2F42\\u3001\\u8BEF\\u89E3\\u539F\\u2F42\\u610F\\u601D\\u5747\\u7B97\\\n \\u9519\\u8BD1\\uFF09\\u3002\\n- \\u2F46\\u6709\\u610F\\u589E\\u52A0\\u6216\\u8005\\\n \\u5220\\u51CF\\u7684\\u539F\\u2F42\\u5185\\u5BB9\\uFF08\\u7FFB\\u8BD1\\u5E76\\u2FAE\\\n \\u521B\\u4F5C\\uFF0C\\u9700\\u5C0A\\u91CD\\u4F5C\\u8005\\u89C2 \\u70B9\\uFF1B\\u53EF\\\n \\u4EE5\\u9002\\u5F53\\u52A0\\u8BD1\\u8005\\u6CE8\\u8BF4\\u660E\\uFF09\\u3002\\n-\\\n \\ \\u8BD1\\u2F42\\u6D41\\u7545\\uFF0C\\u7B26\\u5408\\u4E2D\\u2F42\\u8868\\u8FBE\\u4E60\\\n \\u60EF\\u3002\\n- \\u5173\\u4E8E\\u2F08\\u540D\\u7684\\u7FFB\\u8BD1\\u3002\\u6280\\\n \\u672F\\u56FE\\u4E66\\u4E2D\\u7684\\u2F08\\u540D\\u901A\\u5E38\\u4E0D\\u7FFB\\u8BD1\\\n \\uFF0C\\u4F46\\u662F\\u2F00\\u4E9B\\u4F17\\u6240 \\u5468\\u77E5\\u7684\\u2F08\\u540D\\\n \\u9700\\u2F64\\u4E2D\\u2F42\\uFF08\\u5982\\u4E54\\u5E03\\u65AF\\uFF09\\u3002\\n-\\\n \\ \\u5173\\u4E8E\\u4E66\\u540D\\u7684\\u7FFB\\u8BD1\\u3002\\u6709\\u4E2D\\u2F42\\u7248\\\n \\u7684\\u56FE\\u4E66\\uFF0C\\u8BF7\\u2F64\\u4E2D\\u2F42\\u7248\\u4E66\\u540D\\uFF1B\\\n \\u2F46\\u4E2D\\u2F42\\u7248 \\u7684\\u56FE\\u4E66\\uFF0C\\u76F4\\u63A5\\u2F64\\u82F1\\\n \\u2F42\\u4E66\\u540D\\u3002\\n- \\u5173\\u4E8E\\u56FE\\u8868\\u7684\\u7FFB\\u8BD1\\\n \\u3002\\u8868\\u683C\\u4E2D\\u7684\\u8868\\u9898\\u3001\\u8868\\u5B57\\u548C\\u6CE8\\\n \\u89E3\\u7B49\\u5747\\u9700\\u7FFB\\u8BD1\\u3002\\u56FE\\u9898 \\u9700\\u8981\\u7FFB\\\n \\u8BD1\\u3002\\u754C\\u2FAF\\u622A\\u56FE\\u4E0D\\u9700\\u8981\\u7FFB\\u8BD1\\u56FE\\\n \\u5B57\\u3002\\u89E3\\u91CA\\u6027\\u56FE\\u9700\\u8981\\u6309\\u7167\\u4E2D\\u82F1\\\n \\u2F42 \\u5BF9\\u7167\\u683C\\u5F0F\\u7ED9\\u51FA\\u56FE\\u5B57\\u7FFB\\u8BD1\\u3002\\\n \\n- \\u5173\\u4E8E\\u82F1\\u2F42\\u672F\\u8BED\\u7684\\u8868\\u8FF0\\u3002\\u82F1\\\n \\u2F42\\u672F\\u8BED\\u2FB8\\u6B21\\u51FA\\u73B0\\u65F6\\uFF0C\\u5E94\\u8BE5\\u6839\\\n \\u636E\\u8BE5\\u672F\\u8BED\\u7684 \\u6D41\\u2F8F\\u60C5\\u51B5\\uFF0C\\u4F18\\u5148\\\n \\u4F7F\\u2F64\\u7B80\\u5199\\u5F62\\u5F0F\\uFF0C\\u5E76\\u5728\\u5176\\u540E\\u4F7F\\\n \\u2F64\\u62EC\\u53F7\\u52A0\\u82F1\\u2F42\\u3001\\u4E2D\\u2F42 \\u5168\\u79F0\\u6CE8\\\n \\u89E3\\uFF0C\\u683C\\u5F0F\\u4E3A\\uFF08\\u4E3E\\u4F8B\\uFF09\\uFF1AHTML\\uFF08\\\n Hypertext Markup Language\\uFF0C\\u8D85\\u2F42\\u672C\\u6807\\u8BC6\\u8BED\\u2F94\\\n \\uFF09\\u3002\\u7136\\u540E\\u5728\\u4E0B\\u2F42\\u4E2D\\u76F4\\u63A5\\u4F7F\\u2F64\\\n \\u7B80\\u5199\\u5F62 \\u5F0F\\u3002\\u5F53\\u7136\\uFF0C\\u5FC5\\u8981\\u65F6\\u4E5F\\\n \\u53EF\\u4EE5\\u6839\\u636E\\u8BED\\u5883\\u4F7F\\u2F64\\u4E2D\\u3001\\u82F1\\u2F42\\\n \\u5168\\u79F0\\u3002\\n- \\u5173\\u4E8E\\u4EE3\\u7801\\u6E05\\u5355\\u548C\\u4EE3\\\n \\u7801\\u2F5A\\u6BB5\\u3002\\u539F\\u4E66\\u4E2D\\u5305\\u542B\\u7684\\u7A0B\\u5E8F\\\n \\u4EE3\\u7801\\u4E0D\\u8981\\u6C42\\u8BD1\\u8005\\u5F55 \\u2F0A\\uFF0C\\u4F46\\u5E94\\\n \\u8BE5\\u4F7F\\u2F64\\u201C\\u539F\\u4E66P99\\u2EDA\\u4EE3\\u78011\\u201D\\uFF08\\\n \\u5373\\u539F\\u4E66\\u7B2C99\\u2EDA\\u4E2D\\u7684\\u7B2C\\u2F00\\u6BB5\\u4EE3 \\u7801\\\n \\uFF09\\u7684\\u683C\\u5F0F\\u4F5C\\u51FA\\u6807\\u6CE8\\u3002\\u540C\\u65F6\\uFF0C\\\n \\u8BD1\\u8005\\u5E94\\u8BE5\\u5728\\u6709\\u6761\\u4EF6\\u7684\\u60C5\\u51B5\\u4E0B\\\n \\u68C0\\u6838\\u4EE3 \\u7801\\u7684\\u6B63\\u786E\\u6027\\uFF0C\\u5BF9\\u53D1\\u73B0\\\n \\u7684\\u9519\\u8BEF\\u4EE5\\u8BD1\\u8005\\u6CE8\\u5F62\\u5F0F\\u8BF4\\u660E\\u3002\\\n \\u7A0B\\u5E8F\\u4EE3\\u7801\\u4E2D\\u7684\\u6CE8 \\u91CA\\u8981\\u6C42\\u7FFB\\u8BD1\\\n \\uFF0C\\u5982\\u679C\\u8BD1\\u7A3F\\u4E2D\\u6CA1\\u6709\\u4EE3\\u7801\\uFF0C\\u5219\\\n \\u5E94\\u8BE5\\u4EE5\\u2F00\\u53E5\\u82F1\\u2F42\\uFF08\\u6CE8\\u91CA\\uFF09 \\u2F00\\\n \\u53E5\\u4E2D\\u2F42\\uFF08\\u6CE8\\u91CA\\uFF09\\u7684\\u5F62\\u5F0F\\u7ED9\\u51FA\\\n \\u6CE8\\u91CA\\u3002\\n- \\u5173\\u4E8E\\u6807\\u70B9\\u7B26\\u53F7\\u3002\\u8BD1\\\n \\u7A3F\\u4E2D\\u7684\\u6807\\u70B9\\u7B26\\u53F7\\u8981\\u9075\\u5FAA\\u4E2D\\u2F42\\\n \\u8868\\u8FBE\\u4E60\\u60EF\\u548C\\u4E2D\\u2F42\\u6807 \\u70B9\\u7B26\\u53F7\\u7684\\\n \\u4F7F\\u2F64\\u4E60\\u60EF\\uFF0C\\u4E0D\\u80FD\\u7167\\u642C\\u539F\\u2F42\\u7684\\\n \\u6807\\u70B9\\u7B26\\u53F7\\u3002\\n\\n\\n{{#1717916977413.text#}}\\n\\\n \\n{{#1717916955547.item#}}\\n\\n{{#1717916961837.text#}}\\n\\\n \"\n selected: false\n title: 'Problems '\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: true\n extent: parent\n height: 97\n id: '1717916984996'\n parentId: '1717916955547'\n position:\n x: 725\n y: 85\n positionAbsolute:\n x: 1363\n y: 386.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n iteration_id: '1717916955547'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: 4d7ae758-2d7b-4404-ad9f-d6748ee64439\n role: system\n text: \"\\nBased on the results of the direct translation in the first\\\n \\ step and the problems identified in the second step, re-translate to\\\n \\ achieve a meaning-based interpretation. Ensure the original intent of\\\n \\ the content is preserved while making it easier to understand and more\\\n \\ in line with Chinese expression habits. All the while maintaining the\\\n \\ original format unchanged. \\n\\n\\n- inconsistent with chinese\\\n \\ expression habits, clearly indicate where it does not conform\\n- Clumsy\\\n \\ sentences, specify the location, no need to offer suggestions for modification,\\\n \\ which will be fixed during free translation\\n- Obscure and difficult\\\n \\ to understand, attempts to explain may be made\\n- \\u65E0\\u6F0F\\u8BD1\\\n \\uFF08\\u539F\\u2F42\\u4E2D\\u7684\\u5173\\u952E\\u8BCD\\u3001\\u53E5\\u2F26\\u3001\\\n \\u6BB5\\u843D\\u90FD\\u5E94\\u4F53\\u73B0\\u5728\\u8BD1\\u2F42\\u4E2D\\uFF09\\u3002\\\n \\n- \\u2F46\\u9519\\u8BD1\\uFF08\\u770B\\u9519\\u539F\\u2F42\\u3001\\u8BEF\\u89E3\\\n \\u539F\\u2F42\\u610F\\u601D\\u5747\\u7B97\\u9519\\u8BD1\\uFF09\\u3002\\n- \\u2F46\\\n \\u6709\\u610F\\u589E\\u52A0\\u6216\\u8005\\u5220\\u51CF\\u7684\\u539F\\u2F42\\u5185\\\n \\u5BB9\\uFF08\\u7FFB\\u8BD1\\u5E76\\u2FAE\\u521B\\u4F5C\\uFF0C\\u9700\\u5C0A\\u91CD\\\n \\u4F5C\\u8005\\u89C2 \\u70B9\\uFF1B\\u53EF\\u4EE5\\u9002\\u5F53\\u52A0\\u8BD1\\u8005\\\n \\u6CE8\\u8BF4\\u660E\\uFF09\\u3002\\n- \\u8BD1\\u2F42\\u6D41\\u7545\\uFF0C\\u7B26\\\n \\u5408\\u4E2D\\u2F42\\u8868\\u8FBE\\u4E60\\u60EF\\u3002\\n- \\u5173\\u4E8E\\u2F08\\\n \\u540D\\u7684\\u7FFB\\u8BD1\\u3002\\u6280\\u672F\\u56FE\\u4E66\\u4E2D\\u7684\\u2F08\\\n \\u540D\\u901A\\u5E38\\u4E0D\\u7FFB\\u8BD1\\uFF0C\\u4F46\\u662F\\u2F00\\u4E9B\\u4F17\\\n \\u6240 \\u5468\\u77E5\\u7684\\u2F08\\u540D\\u9700\\u2F64\\u4E2D\\u2F42\\uFF08\\u5982\\\n \\u4E54\\u5E03\\u65AF\\uFF09\\u3002\\n- \\u5173\\u4E8E\\u4E66\\u540D\\u7684\\u7FFB\\\n \\u8BD1\\u3002\\u6709\\u4E2D\\u2F42\\u7248\\u7684\\u56FE\\u4E66\\uFF0C\\u8BF7\\u2F64\\\n \\u4E2D\\u2F42\\u7248\\u4E66\\u540D\\uFF1B\\u2F46\\u4E2D\\u2F42\\u7248 \\u7684\\u56FE\\\n \\u4E66\\uFF0C\\u76F4\\u63A5\\u2F64\\u82F1\\u2F42\\u4E66\\u540D\\u3002\\n- \\u5173\\\n \\u4E8E\\u56FE\\u8868\\u7684\\u7FFB\\u8BD1\\u3002\\u8868\\u683C\\u4E2D\\u7684\\u8868\\\n \\u9898\\u3001\\u8868\\u5B57\\u548C\\u6CE8\\u89E3\\u7B49\\u5747\\u9700\\u7FFB\\u8BD1\\\n \\u3002\\u56FE\\u9898 \\u9700\\u8981\\u7FFB\\u8BD1\\u3002\\u754C\\u2FAF\\u622A\\u56FE\\\n \\u4E0D\\u9700\\u8981\\u7FFB\\u8BD1\\u56FE\\u5B57\\u3002\\u89E3\\u91CA\\u6027\\u56FE\\\n \\u9700\\u8981\\u6309\\u7167\\u4E2D\\u82F1\\u2F42 \\u5BF9\\u7167\\u683C\\u5F0F\\u7ED9\\\n \\u51FA\\u56FE\\u5B57\\u7FFB\\u8BD1\\u3002\\n- \\u5173\\u4E8E\\u82F1\\u2F42\\u672F\\\n \\u8BED\\u7684\\u8868\\u8FF0\\u3002\\u82F1\\u2F42\\u672F\\u8BED\\u2FB8\\u6B21\\u51FA\\\n \\u73B0\\u65F6\\uFF0C\\u5E94\\u8BE5\\u6839\\u636E\\u8BE5\\u672F\\u8BED\\u7684 \\u6D41\\\n \\u2F8F\\u60C5\\u51B5\\uFF0C\\u4F18\\u5148\\u4F7F\\u2F64\\u7B80\\u5199\\u5F62\\u5F0F\\\n \\uFF0C\\u5E76\\u5728\\u5176\\u540E\\u4F7F\\u2F64\\u62EC\\u53F7\\u52A0\\u82F1\\u2F42\\\n \\u3001\\u4E2D\\u2F42 \\u5168\\u79F0\\u6CE8\\u89E3\\uFF0C\\u683C\\u5F0F\\u4E3A\\uFF08\\\n \\u4E3E\\u4F8B\\uFF09\\uFF1AHTML\\uFF08Hypertext Markup Language\\uFF0C\\u8D85\\\n \\u2F42\\u672C\\u6807\\u8BC6\\u8BED\\u2F94\\uFF09\\u3002\\u7136\\u540E\\u5728\\u4E0B\\\n \\u2F42\\u4E2D\\u76F4\\u63A5\\u4F7F\\u2F64\\u7B80\\u5199\\u5F62 \\u5F0F\\u3002\\u5F53\\\n \\u7136\\uFF0C\\u5FC5\\u8981\\u65F6\\u4E5F\\u53EF\\u4EE5\\u6839\\u636E\\u8BED\\u5883\\\n \\u4F7F\\u2F64\\u4E2D\\u3001\\u82F1\\u2F42\\u5168\\u79F0\\u3002\\n- \\u5173\\u4E8E\\\n \\u4EE3\\u7801\\u6E05\\u5355\\u548C\\u4EE3\\u7801\\u2F5A\\u6BB5\\u3002\\u539F\\u4E66\\\n \\u4E2D\\u5305\\u542B\\u7684\\u7A0B\\u5E8F\\u4EE3\\u7801\\u4E0D\\u8981\\u6C42\\u8BD1\\\n \\u8005\\u5F55 \\u2F0A\\uFF0C\\u4F46\\u5E94\\u8BE5\\u4F7F\\u2F64\\u201C\\u539F\\u4E66\\\n P99\\u2EDA\\u4EE3\\u78011\\u201D\\uFF08\\u5373\\u539F\\u4E66\\u7B2C99\\u2EDA\\u4E2D\\\n \\u7684\\u7B2C\\u2F00\\u6BB5\\u4EE3 \\u7801\\uFF09\\u7684\\u683C\\u5F0F\\u4F5C\\u51FA\\\n \\u6807\\u6CE8\\u3002\\u540C\\u65F6\\uFF0C\\u8BD1\\u8005\\u5E94\\u8BE5\\u5728\\u6709\\\n \\u6761\\u4EF6\\u7684\\u60C5\\u51B5\\u4E0B\\u68C0\\u6838\\u4EE3 \\u7801\\u7684\\u6B63\\\n \\u786E\\u6027\\uFF0C\\u5BF9\\u53D1\\u73B0\\u7684\\u9519\\u8BEF\\u4EE5\\u8BD1\\u8005\\\n \\u6CE8\\u5F62\\u5F0F\\u8BF4\\u660E\\u3002\\u7A0B\\u5E8F\\u4EE3\\u7801\\u4E2D\\u7684\\\n \\u6CE8 \\u91CA\\u8981\\u6C42\\u7FFB\\u8BD1\\uFF0C\\u5982\\u679C\\u8BD1\\u7A3F\\u4E2D\\\n \\u6CA1\\u6709\\u4EE3\\u7801\\uFF0C\\u5219\\u5E94\\u8BE5\\u4EE5\\u2F00\\u53E5\\u82F1\\\n \\u2F42\\uFF08\\u6CE8\\u91CA\\uFF09 \\u2F00\\u53E5\\u4E2D\\u2F42\\uFF08\\u6CE8\\u91CA\\\n \\uFF09\\u7684\\u5F62\\u5F0F\\u7ED9\\u51FA\\u6CE8\\u91CA\\u3002\\n- \\u5173\\u4E8E\\\n \\u6807\\u70B9\\u7B26\\u53F7\\u3002\\u8BD1\\u7A3F\\u4E2D\\u7684\\u6807\\u70B9\\u7B26\\\n \\u53F7\\u8981\\u9075\\u5FAA\\u4E2D\\u2F42\\u8868\\u8FBE\\u4E60\\u60EF\\u548C\\u4E2D\\\n \\u2F42\\u6807 \\u70B9\\u7B26\\u53F7\\u7684\\u4F7F\\u2F64\\u4E60\\u60EF\\uFF0C\\u4E0D\\\n \\u80FD\\u7167\\u642C\\u539F\\u2F42\\u7684\\u6807\\u70B9\\u7B26\\u53F7\\u3002\\n\\n\\\n \\n{{#1717916977413.text#}}\\n\\n{{#1717916984996.text#}}\\n\\n{{#1711067409646.input_text#}}\\n\\\n \\n{{#1717916961837.text#}}\\n \"\n selected: false\n title: '2nd Translation '\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: true\n extent: parent\n height: 97\n id: '1717916991709'\n parentId: '1717916955547'\n position:\n x: 1029\n y: 85\n positionAbsolute:\n x: 1667\n y: 386.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: 'Combine all chunks of translation. '\n selected: false\n template: '{{ translated_text | join('' '') }}'\n title: Template\n type: template-transform\n variables:\n - value_selector:\n - '1717916955547'\n - output\n variable: translated_text\n height: 83\n id: '1717917057450'\n position:\n x: 1987\n y: 301.5\n positionAbsolute:\n x: 1987\n y: 301.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n author: Dify\n desc: ''\n height: 186\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Code\n node separates the input_text into chunks with length of token_limit. Each\n chunk overlap with each other to make sure the texts are consistent. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"The\n code node outputs an array of segmented texts of input_texts. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 340\n height: 186\n id: '1718990593686'\n position:\n x: 259.3026056936437\n y: 451.6924912936374\n positionAbsolute:\n x: 259.3026056936437\n y: 451.6924912936374\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 340\n - data:\n author: Dify\n desc: ''\n height: 128\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Iterate\n through all the elements in output of the code node and translate each chunk\n using a three steps translation workflow. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 355\n height: 128\n id: '1718991836605'\n position:\n x: 764.3891977435923\n y: 530.8917807505335\n positionAbsolute:\n x: 764.3891977435923\n y: 530.8917807505335\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 355\n - data:\n author: Dify\n desc: ''\n height: 126\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Avoid\n using a high token_limit, LLM''s performance decreases with longer context\n length for gpt-4o. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Recommend\n to use less than or equal to 1000 tokens. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: yellow\n title: ''\n type: ''\n width: 351\n height: 126\n id: '1718991882984'\n position:\n x: 304.49115824454367\n y: 148.4042994607805\n positionAbsolute:\n x: 304.49115824454367\n y: 148.4042994607805\n selected: true\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 351\n viewport:\n x: 335.92505067152274\n y: 18.806553508850584\n zoom: 0.8705505632961259\n","icon":"\ud83e\udd16","icon_background":"#FFEAD5","id":"98b87f88-bd22-4d86-8b74-86beba5e0ed4","mode":"workflow","name":"Book Translation "}, "cae337e6-aec5-4c7b-beca-d6f1a808bd5e":{ "export_data": "app:\n icon: \"\\U0001F916\"\n icon_background: '#FFEAD5'\n mode: chat\n name: Python bug fixer\nmodel_config:\n agent_mode:\n enabled: false\n max_iteration: 5\n strategy: function_call\n tools: []\n annotation_reply:\n enabled: false\n chat_prompt_config: {}\n completion_prompt_config: {}\n dataset_configs:\n datasets:\n datasets: []\n retrieval_model: single\n dataset_query_variable: ''\n external_data_tools: []\n file_upload:\n image:\n detail: high\n enabled: false\n number_limits: 3\n transfer_methods:\n - remote_url\n - local_file\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n stop: []\n temperature: 0\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n more_like_this:\n enabled: false\n opening_statement: ''\n pre_prompt: Your task is to analyze the provided Python code snippet, identify any\n bugs or errors present, and provide a corrected version of the code that resolves\n these issues. Explain the problems you found in the original code and how your\n fixes address them. The corrected code should be functional, efficient, and adhere\n to best practices in Python programming.\n prompt_type: simple\n retriever_resource:\n enabled: false\n sensitive_word_avoidance:\n configs: []\n enabled: false\n type: ''\n speech_to_text:\n enabled: false\n suggested_questions: []\n suggested_questions_after_answer:\n enabled: false\n text_to_speech:\n enabled: false\n language: ''\n voice: ''\n user_input_form: []\n", "icon": "🤖", @@ -553,15 +553,15 @@ "name": "AI Front-end interviewer" }, "e9870913-dd01-4710-9f06-15d4180ca1ce": { - "export_data": "app:\n icon: \"\\U0001F916\"\n icon_background: '#FFEAD5'\n mode: advanced-chat\n name: 'Knowledge Retreival + Chatbot '\nworkflow:\n features:\n file_upload:\n image:\n enabled: false\n number_limits: 3\n transfer_methods:\n - local_file\n - remote_url\n opening_statement: ''\n retriever_resource:\n enabled: false\n sensitive_word_avoidance:\n enabled: false\n speech_to_text:\n enabled: false\n suggested_questions: []\n suggested_questions_after_answer:\n enabled: false\n text_to_speech:\n enabled: false\n language: ''\n voice: ''\n graph:\n edges:\n - data:\n sourceType: start\n targetType: knowledge-retrieval\n id: 1711528914102-1711528915811\n source: '1711528914102'\n sourceHandle: source\n target: '1711528915811'\n targetHandle: target\n type: custom\n - data:\n sourceType: knowledge-retrieval\n targetType: llm\n id: 1711528915811-1711528917469\n source: '1711528915811'\n sourceHandle: source\n target: '1711528917469'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: answer\n id: 1711528917469-1711528919501\n source: '1711528917469'\n sourceHandle: source\n target: '1711528919501'\n targetHandle: target\n type: custom\n nodes:\n - data:\n desc: ''\n selected: true\n title: Start\n type: start\n variables: []\n height: 53\n id: '1711528914102'\n position:\n x: 79.5\n y: 2634.5\n positionAbsolute:\n x: 79.5\n y: 2634.5\n selected: true\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n dataset_ids:\n - 6084ed3f-d100-4df2-a277-b40d639ea7c6\n desc: Allows you to query text content related to user questions from the\n Knowledge\n query_variable_selector:\n - '1711528914102'\n - sys.query\n retrieval_mode: single\n selected: false\n single_retrieval_config:\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n title: Knowledge Retrieval\n type: knowledge-retrieval\n dragging: false\n height: 101\n id: '1711528915811'\n position:\n x: 362.5\n y: 2634.5\n positionAbsolute:\n x: 362.5\n y: 2634.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Invoking large language models to answer questions or process natural\n language\n memory:\n role_prefix:\n assistant: ''\n user: ''\n window:\n enabled: false\n size: 50\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: \"You are a helpful assistant. \\nUse the following context as your\\\n \\ learned knowledge, inside XML tags.\\n\\n\\\n {{#context#}}\\n\\nWhen answer to user:\\n- If you don't know,\\\n \\ just say that you don't know.\\n- If you don't know when you are not\\\n \\ sure, ask for clarification.\\nAvoid mentioning that you obtained the\\\n \\ information from the context.\\nAnd answer according to the language\\\n \\ of the user's question.\"\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n height: 163\n id: '1711528917469'\n position:\n x: 645.5\n y: 2634.5\n positionAbsolute:\n x: 645.5\n y: 2634.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n answer: '{{#1711528917469.text#}}'\n desc: ''\n selected: false\n title: Answer\n type: answer\n variables: []\n height: 105\n id: '1711528919501'\n position:\n x: 928.5\n y: 2634.5\n positionAbsolute:\n x: 928.5\n y: 2634.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n viewport:\n x: 86.31278232100044\n y: -2276.452137533831\n zoom: 0.9753554615276419\n", + "export_data": "app:\n icon: \"\\U0001F916\"\n icon_background: '#FFEAD5'\n mode: advanced-chat\n name: 'Knowledge Retrieval + Chatbot '\nworkflow:\n features:\n file_upload:\n image:\n enabled: false\n number_limits: 3\n transfer_methods:\n - local_file\n - remote_url\n opening_statement: ''\n retriever_resource:\n enabled: false\n sensitive_word_avoidance:\n enabled: false\n speech_to_text:\n enabled: false\n suggested_questions: []\n suggested_questions_after_answer:\n enabled: false\n text_to_speech:\n enabled: false\n language: ''\n voice: ''\n graph:\n edges:\n - data:\n sourceType: start\n targetType: knowledge-retrieval\n id: 1711528914102-1711528915811\n source: '1711528914102'\n sourceHandle: source\n target: '1711528915811'\n targetHandle: target\n type: custom\n - data:\n sourceType: knowledge-retrieval\n targetType: llm\n id: 1711528915811-1711528917469\n source: '1711528915811'\n sourceHandle: source\n target: '1711528917469'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: answer\n id: 1711528917469-1711528919501\n source: '1711528917469'\n sourceHandle: source\n target: '1711528919501'\n targetHandle: target\n type: custom\n nodes:\n - data:\n desc: ''\n selected: true\n title: Start\n type: start\n variables: []\n height: 53\n id: '1711528914102'\n position:\n x: 79.5\n y: 2634.5\n positionAbsolute:\n x: 79.5\n y: 2634.5\n selected: true\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n dataset_ids:\n - 6084ed3f-d100-4df2-a277-b40d639ea7c6\n desc: Allows you to query text content related to user questions from the\n Knowledge\n query_variable_selector:\n - '1711528914102'\n - sys.query\n retrieval_mode: single\n selected: false\n single_retrieval_config:\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n title: Knowledge Retrieval\n type: knowledge-retrieval\n dragging: false\n height: 101\n id: '1711528915811'\n position:\n x: 362.5\n y: 2634.5\n positionAbsolute:\n x: 362.5\n y: 2634.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Invoking large language models to answer questions or process natural\n language\n memory:\n role_prefix:\n assistant: ''\n user: ''\n window:\n enabled: false\n size: 50\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: \"You are a helpful assistant. \\nUse the following context as your\\\n \\ learned knowledge, inside XML tags.\\n\\n\\\n {{#context#}}\\n\\nWhen answer to user:\\n- If you don't know,\\\n \\ just say that you don't know.\\n- If you don't know when you are not\\\n \\ sure, ask for clarification.\\nAvoid mentioning that you obtained the\\\n \\ information from the context.\\nAnd answer according to the language\\\n \\ of the user's question.\"\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n height: 163\n id: '1711528917469'\n position:\n x: 645.5\n y: 2634.5\n positionAbsolute:\n x: 645.5\n y: 2634.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n answer: '{{#1711528917469.text#}}'\n desc: ''\n selected: false\n title: Answer\n type: answer\n variables: []\n height: 105\n id: '1711528919501'\n position:\n x: 928.5\n y: 2634.5\n positionAbsolute:\n x: 928.5\n y: 2634.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n viewport:\n x: 86.31278232100044\n y: -2276.452137533831\n zoom: 0.9753554615276419\n", "icon": "🤖", "icon_background": "#FFEAD5", "id": "e9870913-dd01-4710-9f06-15d4180ca1ce", "mode": "advanced-chat", - "name": "Knowledge Retreival + Chatbot " + "name": "Knowledge Retrieval + Chatbot " }, "dd5b6353-ae9b-4bce-be6a-a681a12cf709":{ - "export_data": "app:\n icon: \"\\U0001F916\"\n icon_background: '#FFEAD5'\n mode: workflow\n name: 'Email Assistant Workflow '\nworkflow:\n features:\n file_upload:\n image:\n enabled: false\n number_limits: 3\n transfer_methods:\n - local_file\n - remote_url\n opening_statement: ''\n retriever_resource:\n enabled: false\n sensitive_word_avoidance:\n enabled: false\n speech_to_text:\n enabled: false\n suggested_questions: []\n suggested_questions_after_answer:\n enabled: false\n text_to_speech:\n enabled: false\n language: ''\n voice: ''\n graph:\n edges:\n - data:\n sourceType: start\n targetType: question-classifier\n id: 1711511281652-1711512802873\n source: '1711511281652'\n sourceHandle: source\n target: '1711512802873'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: question-classifier\n id: 1711512802873-1711512837494\n source: '1711512802873'\n sourceHandle: '1711512813038'\n target: '1711512837494'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512802873-1711512911454\n source: '1711512802873'\n sourceHandle: '1711512811520'\n target: '1711512911454'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512802873-1711512914870\n source: '1711512802873'\n sourceHandle: '1711512812031'\n target: '1711512914870'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512802873-1711512916516\n source: '1711512802873'\n sourceHandle: '1711512812510'\n target: '1711512916516'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512837494-1711512924231\n source: '1711512837494'\n sourceHandle: '1711512846439'\n target: '1711512924231'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512837494-1711512926020\n source: '1711512837494'\n sourceHandle: '1711512847112'\n target: '1711512926020'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512837494-1711512927569\n source: '1711512837494'\n sourceHandle: '1711512847641'\n target: '1711512927569'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512837494-1711512929190\n source: '1711512837494'\n sourceHandle: '1711512848120'\n target: '1711512929190'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512837494-1711512930700\n source: '1711512837494'\n sourceHandle: '1711512848616'\n target: '1711512930700'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512911454-1711513015189\n source: '1711512911454'\n sourceHandle: source\n target: '1711513015189'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512914870-1711513017096\n source: '1711512914870'\n sourceHandle: source\n target: '1711513017096'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512916516-1711513018759\n source: '1711512916516'\n sourceHandle: source\n target: '1711513018759'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512924231-1711513020857\n source: '1711512924231'\n sourceHandle: source\n target: '1711513020857'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512926020-1711513022516\n source: '1711512926020'\n sourceHandle: source\n target: '1711513022516'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512927569-1711513024315\n source: '1711512927569'\n sourceHandle: source\n target: '1711513024315'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512929190-1711513025732\n source: '1711512929190'\n sourceHandle: source\n target: '1711513025732'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512930700-1711513027347\n source: '1711512930700'\n sourceHandle: source\n target: '1711513027347'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513015189-1711513029058\n source: '1711513015189'\n sourceHandle: source\n target: '1711513029058'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513017096-1711513030924\n source: '1711513017096'\n sourceHandle: source\n target: '1711513030924'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513018759-1711513032459\n source: '1711513018759'\n sourceHandle: source\n target: '1711513032459'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513020857-1711513034850\n source: '1711513020857'\n sourceHandle: source\n target: '1711513034850'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513022516-1711513036356\n source: '1711513022516'\n sourceHandle: source\n target: '1711513036356'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513024315-1711513037973\n source: '1711513024315'\n sourceHandle: source\n target: '1711513037973'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513025732-1711513039350\n source: '1711513025732'\n sourceHandle: source\n target: '1711513039350'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513027347-1711513041219\n source: '1711513027347'\n sourceHandle: source\n target: '1711513041219'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512802873-1711513940609\n source: '1711512802873'\n sourceHandle: '1711513927279'\n target: '1711513940609'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711513940609-1711513967853\n source: '1711513940609'\n sourceHandle: source\n target: '1711513967853'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513967853-1711513974643\n source: '1711513967853'\n sourceHandle: source\n target: '1711513974643'\n targetHandle: target\n type: custom\n nodes:\n - data:\n desc: ''\n selected: true\n title: Start\n type: start\n variables:\n - label: Email\n max_length: null\n options: []\n required: true\n type: paragraph\n variable: Input_Text\n - label: What do you need to do? (Summarize / Reply / Write / Improve)\n max_length: 48\n options:\n - Summarize\n - 'Reply '\n - Write a email\n - 'Improve writings '\n required: true\n type: select\n variable: user_request\n - label: 'How do you want it to be polished? (Optional) '\n max_length: 48\n options:\n - 'Imporve writing and clarity '\n - Shorten\n - 'Lengthen '\n - 'Simplify '\n - Rewrite in my voice\n required: false\n type: select\n variable: how_polish\n dragging: false\n height: 141\n id: '1711511281652'\n position:\n x: 79.5\n y: 409.5\n positionAbsolute:\n x: 79.5\n y: 409.5\n selected: true\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n classes:\n - id: '1711512811520'\n name: Summarize\n - id: '1711512812031'\n name: Reply to emails\n - id: '1711512812510'\n name: Help me write the email\n - id: '1711512813038'\n name: Improve writings or polish\n - id: '1711513927279'\n name: Grammer check\n desc: 'Classify users'' demands. '\n instructions: ''\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n query_variable_selector:\n - '1711511281652'\n - user_request\n selected: false\n title: 'Question Classifier '\n topics: []\n type: question-classifier\n dragging: false\n height: 333\n id: '1711512802873'\n position:\n x: 362.5\n y: 409.5\n positionAbsolute:\n x: 362.5\n y: 409.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n classes:\n - id: '1711512846439'\n name: 'Improve writing and clarity '\n - id: '1711512847112'\n name: 'Shorten '\n - id: '1711512847641'\n name: 'Lengthen '\n - id: '1711512848120'\n name: 'Simplify '\n - id: '1711512848616'\n name: Rewrite in my voice\n desc: 'Improve writings. '\n instructions: ''\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n query_variable_selector:\n - '1711511281652'\n - how_polish\n selected: false\n title: 'Question Classifier '\n topics: []\n type: question-classifier\n dragging: false\n height: 333\n id: '1711512837494'\n position:\n x: 645.5\n y: 409.5\n positionAbsolute:\n x: 645.5\n y: 409.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Summary\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Summary the email for me. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512911454'\n position:\n x: 645.5\n y: 1327.5\n positionAbsolute:\n x: 645.5\n y: 1327.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Reply\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Rely the emails for me, in my own voice. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512914870'\n position:\n x: 645.5\n y: 1518.5\n positionAbsolute:\n x: 645.5\n y: 1518.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Turn idea into email\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Turn my idea into email. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512916516'\n position:\n x: 645.5\n y: 1709.5\n positionAbsolute:\n x: 645.5\n y: 1709.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: 'Improve the clarity. '\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: \" Imporve the clarity of the email for me. \\n{{#1711511281652.Input_Text#}}\\n\\\n \"\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512924231'\n position:\n x: 928.5\n y: 409.5\n positionAbsolute:\n x: 928.5\n y: 409.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: 'Shorten. '\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Shorten the email for me. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512926020'\n position:\n x: 928.5\n y: 600.5\n positionAbsolute:\n x: 928.5\n y: 600.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: 'Lengthen '\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Lengthen the email for me. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512927569'\n position:\n x: 928.5\n y: 791.5\n positionAbsolute:\n x: 928.5\n y: 791.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Simplify\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Simplify the email for me. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512929190'\n position:\n x: 928.5\n y: 982.5\n positionAbsolute:\n x: 928.5\n y: 982.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Rewrite in my voice\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Rewrite the email for me. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512930700'\n position:\n x: 928.5\n y: 1173.5\n positionAbsolute:\n x: 928.5\n y: 1173.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template\n type: template-transform\n variables:\n - value_selector:\n - '1711512911454'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513015189'\n position:\n x: 928.5\n y: 1327.5\n positionAbsolute:\n x: 928.5\n y: 1327.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 2\n type: template-transform\n variables:\n - value_selector:\n - '1711512914870'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513017096'\n position:\n x: 928.5\n y: 1518.5\n positionAbsolute:\n x: 928.5\n y: 1518.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 3\n type: template-transform\n variables:\n - value_selector:\n - '1711512916516'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513018759'\n position:\n x: 928.5\n y: 1709.5\n positionAbsolute:\n x: 928.5\n y: 1709.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 4\n type: template-transform\n variables:\n - value_selector:\n - '1711512924231'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513020857'\n position:\n x: 1211.5\n y: 409.5\n positionAbsolute:\n x: 1211.5\n y: 409.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 5\n type: template-transform\n variables:\n - value_selector:\n - '1711512926020'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513022516'\n position:\n x: 1211.5\n y: 600.5\n positionAbsolute:\n x: 1211.5\n y: 600.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 6\n type: template-transform\n variables:\n - value_selector:\n - '1711512927569'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513024315'\n position:\n x: 1211.5\n y: 791.5\n positionAbsolute:\n x: 1211.5\n y: 791.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 7\n type: template-transform\n variables:\n - value_selector:\n - '1711512929190'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513025732'\n position:\n x: 1211.5\n y: 982.5\n positionAbsolute:\n x: 1211.5\n y: 982.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 8\n type: template-transform\n variables:\n - value_selector:\n - '1711512930700'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513027347'\n position:\n x: 1211.5\n y: 1173.5\n positionAbsolute:\n x: 1211.5\n y: 1173.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512911454'\n - text\n variable: text\n selected: false\n title: End\n type: end\n dragging: false\n height: 89\n id: '1711513029058'\n position:\n x: 1211.5\n y: 1327.5\n positionAbsolute:\n x: 1211.5\n y: 1327.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512914870'\n - text\n variable: text\n selected: false\n title: End 2\n type: end\n dragging: false\n height: 89\n id: '1711513030924'\n position:\n x: 1211.5\n y: 1518.5\n positionAbsolute:\n x: 1211.5\n y: 1518.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512916516'\n - text\n variable: text\n selected: false\n title: End 3\n type: end\n dragging: false\n height: 89\n id: '1711513032459'\n position:\n x: 1211.5\n y: 1709.5\n positionAbsolute:\n x: 1211.5\n y: 1709.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512924231'\n - text\n variable: text\n selected: false\n title: End 4\n type: end\n dragging: false\n height: 89\n id: '1711513034850'\n position:\n x: 1494.5\n y: 409.5\n positionAbsolute:\n x: 1494.5\n y: 409.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512926020'\n - text\n variable: text\n selected: false\n title: End 5\n type: end\n dragging: false\n height: 89\n id: '1711513036356'\n position:\n x: 1494.5\n y: 600.5\n positionAbsolute:\n x: 1494.5\n y: 600.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512927569'\n - text\n variable: text\n selected: false\n title: End 6\n type: end\n dragging: false\n height: 89\n id: '1711513037973'\n position:\n x: 1494.5\n y: 791.5\n positionAbsolute:\n x: 1494.5\n y: 791.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512929190'\n - text\n variable: text\n selected: false\n title: End 7\n type: end\n dragging: false\n height: 89\n id: '1711513039350'\n position:\n x: 1494.5\n y: 982.5\n positionAbsolute:\n x: 1494.5\n y: 982.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512930700'\n - text\n variable: text\n selected: false\n title: End 8\n type: end\n dragging: false\n height: 89\n id: '1711513041219'\n position:\n x: 1494.5\n y: 1173.5\n positionAbsolute:\n x: 1494.5\n y: 1173.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Grammer Check\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: 'Please check grammer of my email and comment on the grammer. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711513940609'\n position:\n x: 645.5\n y: 1900.5\n positionAbsolute:\n x: 645.5\n y: 1900.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 9\n type: template-transform\n variables:\n - value_selector:\n - '1711513940609'\n - text\n variable: arg1\n height: 53\n id: '1711513967853'\n position:\n x: 928.5\n y: 1900.5\n positionAbsolute:\n x: 928.5\n y: 1900.5\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711513940609'\n - text\n variable: text\n selected: false\n title: End 9\n type: end\n height: 89\n id: '1711513974643'\n position:\n x: 1211.5\n y: 1900.5\n positionAbsolute:\n x: 1211.5\n y: 1900.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n viewport:\n x: 0\n y: 0\n zoom: 0.7\n", + "export_data": "app:\n icon: \"\\U0001F916\"\n icon_background: '#FFEAD5'\n mode: workflow\n name: 'Email Assistant Workflow '\nworkflow:\n features:\n file_upload:\n image:\n enabled: false\n number_limits: 3\n transfer_methods:\n - local_file\n - remote_url\n opening_statement: ''\n retriever_resource:\n enabled: false\n sensitive_word_avoidance:\n enabled: false\n speech_to_text:\n enabled: false\n suggested_questions: []\n suggested_questions_after_answer:\n enabled: false\n text_to_speech:\n enabled: false\n language: ''\n voice: ''\n graph:\n edges:\n - data:\n sourceType: start\n targetType: question-classifier\n id: 1711511281652-1711512802873\n source: '1711511281652'\n sourceHandle: source\n target: '1711512802873'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: question-classifier\n id: 1711512802873-1711512837494\n source: '1711512802873'\n sourceHandle: '1711512813038'\n target: '1711512837494'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512802873-1711512911454\n source: '1711512802873'\n sourceHandle: '1711512811520'\n target: '1711512911454'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512802873-1711512914870\n source: '1711512802873'\n sourceHandle: '1711512812031'\n target: '1711512914870'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512802873-1711512916516\n source: '1711512802873'\n sourceHandle: '1711512812510'\n target: '1711512916516'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512837494-1711512924231\n source: '1711512837494'\n sourceHandle: '1711512846439'\n target: '1711512924231'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512837494-1711512926020\n source: '1711512837494'\n sourceHandle: '1711512847112'\n target: '1711512926020'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512837494-1711512927569\n source: '1711512837494'\n sourceHandle: '1711512847641'\n target: '1711512927569'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512837494-1711512929190\n source: '1711512837494'\n sourceHandle: '1711512848120'\n target: '1711512929190'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512837494-1711512930700\n source: '1711512837494'\n sourceHandle: '1711512848616'\n target: '1711512930700'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512911454-1711513015189\n source: '1711512911454'\n sourceHandle: source\n target: '1711513015189'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512914870-1711513017096\n source: '1711512914870'\n sourceHandle: source\n target: '1711513017096'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512916516-1711513018759\n source: '1711512916516'\n sourceHandle: source\n target: '1711513018759'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512924231-1711513020857\n source: '1711512924231'\n sourceHandle: source\n target: '1711513020857'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512926020-1711513022516\n source: '1711512926020'\n sourceHandle: source\n target: '1711513022516'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512927569-1711513024315\n source: '1711512927569'\n sourceHandle: source\n target: '1711513024315'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512929190-1711513025732\n source: '1711512929190'\n sourceHandle: source\n target: '1711513025732'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512930700-1711513027347\n source: '1711512930700'\n sourceHandle: source\n target: '1711513027347'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513015189-1711513029058\n source: '1711513015189'\n sourceHandle: source\n target: '1711513029058'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513017096-1711513030924\n source: '1711513017096'\n sourceHandle: source\n target: '1711513030924'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513018759-1711513032459\n source: '1711513018759'\n sourceHandle: source\n target: '1711513032459'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513020857-1711513034850\n source: '1711513020857'\n sourceHandle: source\n target: '1711513034850'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513022516-1711513036356\n source: '1711513022516'\n sourceHandle: source\n target: '1711513036356'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513024315-1711513037973\n source: '1711513024315'\n sourceHandle: source\n target: '1711513037973'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513025732-1711513039350\n source: '1711513025732'\n sourceHandle: source\n target: '1711513039350'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513027347-1711513041219\n source: '1711513027347'\n sourceHandle: source\n target: '1711513041219'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512802873-1711513940609\n source: '1711512802873'\n sourceHandle: '1711513927279'\n target: '1711513940609'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711513940609-1711513967853\n source: '1711513940609'\n sourceHandle: source\n target: '1711513967853'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513967853-1711513974643\n source: '1711513967853'\n sourceHandle: source\n target: '1711513974643'\n targetHandle: target\n type: custom\n nodes:\n - data:\n desc: ''\n selected: true\n title: Start\n type: start\n variables:\n - label: Email\n max_length: null\n options: []\n required: true\n type: paragraph\n variable: Input_Text\n - label: What do you need to do? (Summarize / Reply / Write / Improve)\n max_length: 48\n options:\n - Summarize\n - 'Reply '\n - Write a email\n - 'Improve writings '\n required: true\n type: select\n variable: user_request\n - label: 'How do you want it to be polished? (Optional) '\n max_length: 48\n options:\n - 'Imporve writing and clarity '\n - Shorten\n - 'Lengthen '\n - 'Simplify '\n - Rewrite in my voice\n required: false\n type: select\n variable: how_polish\n dragging: false\n height: 141\n id: '1711511281652'\n position:\n x: 79.5\n y: 409.5\n positionAbsolute:\n x: 79.5\n y: 409.5\n selected: true\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n classes:\n - id: '1711512811520'\n name: Summarize\n - id: '1711512812031'\n name: Reply to emails\n - id: '1711512812510'\n name: Help me write the email\n - id: '1711512813038'\n name: Improve writings or polish\n - id: '1711513927279'\n name: Grammar check\n desc: 'Classify users'' demands. '\n instructions: ''\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n query_variable_selector:\n - '1711511281652'\n - user_request\n selected: false\n title: 'Question Classifier '\n topics: []\n type: question-classifier\n dragging: false\n height: 333\n id: '1711512802873'\n position:\n x: 362.5\n y: 409.5\n positionAbsolute:\n x: 362.5\n y: 409.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n classes:\n - id: '1711512846439'\n name: 'Improve writing and clarity '\n - id: '1711512847112'\n name: 'Shorten '\n - id: '1711512847641'\n name: 'Lengthen '\n - id: '1711512848120'\n name: 'Simplify '\n - id: '1711512848616'\n name: Rewrite in my voice\n desc: 'Improve writings. '\n instructions: ''\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n query_variable_selector:\n - '1711511281652'\n - how_polish\n selected: false\n title: 'Question Classifier '\n topics: []\n type: question-classifier\n dragging: false\n height: 333\n id: '1711512837494'\n position:\n x: 645.5\n y: 409.5\n positionAbsolute:\n x: 645.5\n y: 409.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Summary\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Summary the email for me. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512911454'\n position:\n x: 645.5\n y: 1327.5\n positionAbsolute:\n x: 645.5\n y: 1327.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Reply\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Rely the emails for me, in my own voice. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512914870'\n position:\n x: 645.5\n y: 1518.5\n positionAbsolute:\n x: 645.5\n y: 1518.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Turn idea into email\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Turn my idea into email. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512916516'\n position:\n x: 645.5\n y: 1709.5\n positionAbsolute:\n x: 645.5\n y: 1709.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: 'Improve the clarity. '\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: \" Imporve the clarity of the email for me. \\n{{#1711511281652.Input_Text#}}\\n\\\n \"\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512924231'\n position:\n x: 928.5\n y: 409.5\n positionAbsolute:\n x: 928.5\n y: 409.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: 'Shorten. '\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Shorten the email for me. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512926020'\n position:\n x: 928.5\n y: 600.5\n positionAbsolute:\n x: 928.5\n y: 600.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: 'Lengthen '\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Lengthen the email for me. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512927569'\n position:\n x: 928.5\n y: 791.5\n positionAbsolute:\n x: 928.5\n y: 791.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Simplify\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Simplify the email for me. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512929190'\n position:\n x: 928.5\n y: 982.5\n positionAbsolute:\n x: 928.5\n y: 982.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Rewrite in my voice\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Rewrite the email for me. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512930700'\n position:\n x: 928.5\n y: 1173.5\n positionAbsolute:\n x: 928.5\n y: 1173.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template\n type: template-transform\n variables:\n - value_selector:\n - '1711512911454'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513015189'\n position:\n x: 928.5\n y: 1327.5\n positionAbsolute:\n x: 928.5\n y: 1327.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 2\n type: template-transform\n variables:\n - value_selector:\n - '1711512914870'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513017096'\n position:\n x: 928.5\n y: 1518.5\n positionAbsolute:\n x: 928.5\n y: 1518.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 3\n type: template-transform\n variables:\n - value_selector:\n - '1711512916516'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513018759'\n position:\n x: 928.5\n y: 1709.5\n positionAbsolute:\n x: 928.5\n y: 1709.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 4\n type: template-transform\n variables:\n - value_selector:\n - '1711512924231'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513020857'\n position:\n x: 1211.5\n y: 409.5\n positionAbsolute:\n x: 1211.5\n y: 409.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 5\n type: template-transform\n variables:\n - value_selector:\n - '1711512926020'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513022516'\n position:\n x: 1211.5\n y: 600.5\n positionAbsolute:\n x: 1211.5\n y: 600.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 6\n type: template-transform\n variables:\n - value_selector:\n - '1711512927569'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513024315'\n position:\n x: 1211.5\n y: 791.5\n positionAbsolute:\n x: 1211.5\n y: 791.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 7\n type: template-transform\n variables:\n - value_selector:\n - '1711512929190'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513025732'\n position:\n x: 1211.5\n y: 982.5\n positionAbsolute:\n x: 1211.5\n y: 982.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 8\n type: template-transform\n variables:\n - value_selector:\n - '1711512930700'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513027347'\n position:\n x: 1211.5\n y: 1173.5\n positionAbsolute:\n x: 1211.5\n y: 1173.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512911454'\n - text\n variable: text\n selected: false\n title: End\n type: end\n dragging: false\n height: 89\n id: '1711513029058'\n position:\n x: 1211.5\n y: 1327.5\n positionAbsolute:\n x: 1211.5\n y: 1327.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512914870'\n - text\n variable: text\n selected: false\n title: End 2\n type: end\n dragging: false\n height: 89\n id: '1711513030924'\n position:\n x: 1211.5\n y: 1518.5\n positionAbsolute:\n x: 1211.5\n y: 1518.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512916516'\n - text\n variable: text\n selected: false\n title: End 3\n type: end\n dragging: false\n height: 89\n id: '1711513032459'\n position:\n x: 1211.5\n y: 1709.5\n positionAbsolute:\n x: 1211.5\n y: 1709.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512924231'\n - text\n variable: text\n selected: false\n title: End 4\n type: end\n dragging: false\n height: 89\n id: '1711513034850'\n position:\n x: 1494.5\n y: 409.5\n positionAbsolute:\n x: 1494.5\n y: 409.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512926020'\n - text\n variable: text\n selected: false\n title: End 5\n type: end\n dragging: false\n height: 89\n id: '1711513036356'\n position:\n x: 1494.5\n y: 600.5\n positionAbsolute:\n x: 1494.5\n y: 600.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512927569'\n - text\n variable: text\n selected: false\n title: End 6\n type: end\n dragging: false\n height: 89\n id: '1711513037973'\n position:\n x: 1494.5\n y: 791.5\n positionAbsolute:\n x: 1494.5\n y: 791.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512929190'\n - text\n variable: text\n selected: false\n title: End 7\n type: end\n dragging: false\n height: 89\n id: '1711513039350'\n position:\n x: 1494.5\n y: 982.5\n positionAbsolute:\n x: 1494.5\n y: 982.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512930700'\n - text\n variable: text\n selected: false\n title: End 8\n type: end\n dragging: false\n height: 89\n id: '1711513041219'\n position:\n x: 1494.5\n y: 1173.5\n positionAbsolute:\n x: 1494.5\n y: 1173.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Grammar Check\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: 'Please check grammar of my email and comment on the grammar. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711513940609'\n position:\n x: 645.5\n y: 1900.5\n positionAbsolute:\n x: 645.5\n y: 1900.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 9\n type: template-transform\n variables:\n - value_selector:\n - '1711513940609'\n - text\n variable: arg1\n height: 53\n id: '1711513967853'\n position:\n x: 928.5\n y: 1900.5\n positionAbsolute:\n x: 928.5\n y: 1900.5\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711513940609'\n - text\n variable: text\n selected: false\n title: End 9\n type: end\n height: 89\n id: '1711513974643'\n position:\n x: 1211.5\n y: 1900.5\n positionAbsolute:\n x: 1211.5\n y: 1900.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n viewport:\n x: 0\n y: 0\n zoom: 0.7\n", "icon": "🤖", "icon_background": "#FFEAD5", "id": "dd5b6353-ae9b-4bce-be6a-a681a12cf709", diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 3f5e1adca2..35ac42a14c 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -57,7 +57,7 @@ class BaseApiKeyListResource(Resource): def post(self, resource_id): resource_id = str(resource_id) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() current_key_count = ( diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index cc9c8b31cb..1b46a3a7d3 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -174,6 +174,7 @@ class AppApi(Resource): parser.add_argument("icon", type=str, location="json") parser.add_argument("icon_background", type=str, location="json") parser.add_argument("max_active_requests", type=int, location="json") + parser.add_argument("use_icon_as_answer_icon", type=bool, location="json") args = parser.parse_args() app_service = AppService() diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 753a6be20c..df7bd352af 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -20,7 +20,7 @@ from fields.conversation_fields import ( conversation_pagination_fields, conversation_with_summary_pagination_fields, ) -from libs.helper import datetime_string +from libs.helper import DatetimeString from libs.login import login_required from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation @@ -36,8 +36,8 @@ class CompletionConversationApi(Resource): raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("keyword", type=str, location="args") - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument( "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" ) @@ -143,8 +143,8 @@ class ChatConversationApi(Resource): raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("keyword", type=str, location="args") - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument( "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" ) @@ -201,7 +201,11 @@ class ChatConversationApi(Resource): start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - query = query.where(Conversation.created_at >= start_datetime_utc) + match args["sort_by"]: + case "updated_at" | "-updated_at": + query = query.where(Conversation.updated_at >= start_datetime_utc) + case "created_at" | "-created_at" | _: + query = query.where(Conversation.created_at >= start_datetime_utc) if args["end"]: end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") @@ -210,7 +214,11 @@ class ChatConversationApi(Resource): end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - query = query.where(Conversation.created_at < end_datetime_utc) + match args["sort_by"]: + case "updated_at" | "-updated_at": + query = query.where(Conversation.updated_at <= end_datetime_utc) + case "created_at" | "-created_at" | _: + query = query.where(Conversation.created_at <= end_datetime_utc) if args["annotation_status"] == "annotated": query = query.options(joinedload(Conversation.message_annotations)).join( diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index f936642acd..26da1ef26d 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -34,6 +34,7 @@ def parse_app_site_args(): ) parser.add_argument("prompt_public", type=bool, required=False, location="json") parser.add_argument("show_workflow_steps", type=bool, required=False, location="json") + parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json") return parser.parse_args() @@ -68,6 +69,7 @@ class AppSite(Resource): "customize_token_strategy", "prompt_public", "show_workflow_steps", + "use_icon_as_answer_icon", ]: value = args.get(attr_name) if value is not None: diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 81826a20d0..4806b02b55 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from extensions.ext_database import db -from libs.helper import datetime_string +from libs.helper import DatetimeString from libs.login import login_required from models.model import AppMode @@ -25,8 +25,8 @@ class DailyMessageStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -79,8 +79,8 @@ class DailyConversationStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -133,8 +133,8 @@ class DailyTerminalsStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -187,8 +187,8 @@ class DailyTokenCostStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -245,8 +245,8 @@ class AverageSessionInteractionStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, @@ -307,8 +307,8 @@ class UserSatisfactionRateStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -369,8 +369,8 @@ class AverageResponseTimeStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -425,8 +425,8 @@ class TokensPerSecondStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index db2f683589..942271a634 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from extensions.ext_database import db -from libs.helper import datetime_string +from libs.helper import DatetimeString from libs.login import login_required from models.model import AppMode from models.workflow import WorkflowRunTriggeredFrom @@ -26,8 +26,8 @@ class WorkflowDailyRunsStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -86,8 +86,8 @@ class WorkflowDailyTerminalsStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -146,8 +146,8 @@ class WorkflowDailyTokenCostStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -213,8 +213,8 @@ class WorkflowAverageAppInteractionStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index e3402329c1..017f643781 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -7,7 +7,8 @@ from constants.languages import supported_language from controllers.console import api from controllers.console.error import AlreadyActivateError from extensions.ext_database import db -from libs.helper import email, get_remote_ip, str_len, timezone +from libs.helper import StrLen, email, get_remote_ip, timezone +from libs.password import valid_password from models.account import AccountStatus, Tenant from services.account_service import AccountService, RegisterService @@ -45,7 +46,8 @@ class ActivateApi(Resource): parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") parser.add_argument("email", type=email, required=False, nullable=True, location="json") parser.add_argument("token", type=str, required=True, nullable=False, location="json") - parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json") + parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") + parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json") parser.add_argument( "interface_language", type=supported_language, required=True, nullable=False, location="json" ) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index d369730594..6ccacc78ee 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -18,7 +18,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager from core.rag.datasource.vdb.vector_type import VectorType from core.rag.extractor.entity.extract_setting import ExtractSetting -from core.rag.retrieval.retrival_methods import RetrievalMethod +from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from fields.app_fields import related_app_list from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields @@ -122,6 +122,7 @@ class DatasetListApi(Resource): name=args["name"], indexing_technique=args["indexing_technique"], account=current_user, + permission=DatasetPermissionEnum.ONLY_ME, ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 6bc29a8643..076f3cd44d 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -302,6 +302,8 @@ class DatasetInitApi(Resource): "doc_language", type=str, default="English", required=False, nullable=False, location="json" ) parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") + parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") + parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") args = parser.parse_args() # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator @@ -309,6 +311,8 @@ class DatasetInitApi(Resource): raise Forbidden() if args["indexing_technique"] == "high_quality": + if args["embedding_model"] is None or args["embedding_model_provider"] is None: + raise ValueError("embedding model and embedding model provider are required for high quality indexing.") try: model_manager = ModelManager() model_manager.get_default_model_instance( diff --git a/api/controllers/console/datasets/file.py b/api/controllers/console/datasets/file.py index d6a464545e..846aa70e86 100644 --- a/api/controllers/console/datasets/file.py +++ b/api/controllers/console/datasets/file.py @@ -39,7 +39,7 @@ class FileApi(Resource): @login_required @account_initialization_required @marshal_with(file_fields) - @cloud_edition_billing_resource_check(resource="documents") + @cloud_edition_billing_resource_check("documents") def post(self): # get file from request file = request.files["file"] diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index b71078760c..3f1e64a247 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -35,6 +35,7 @@ class InstalledAppsListApi(Resource): "uninstallable": current_tenant_id == installed_app.app_owner_tenant_id, } for installed_app in installed_apps + if installed_app.app is not None ] installed_apps.sort( key=lambda app: ( diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index 7d3ae677ee..ae759bb752 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -4,7 +4,7 @@ from flask import session from flask_restful import Resource, reqparse from configs import dify_config -from libs.helper import str_len +from libs.helper import StrLen from models.model import DifySetup from services.account_service import TenantService @@ -28,7 +28,7 @@ class InitValidateAPI(Resource): raise AlreadySetupError() parser = reqparse.RequestParser() - parser.add_argument("password", type=str_len(30), required=True, location="json") + parser.add_argument("password", type=StrLen(30), required=True, location="json") input_password = parser.parse_args()["password"] if input_password != os.environ.get("INIT_PASSWORD"): diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index 827695e00f..46b4ef5d87 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -4,7 +4,7 @@ from flask import request from flask_restful import Resource, reqparse from configs import dify_config -from libs.helper import email, get_remote_ip, str_len +from libs.helper import StrLen, email, get_remote_ip from libs.password import valid_password from models.model import DifySetup from services.account_service import RegisterService, TenantService @@ -40,7 +40,7 @@ class SetupApi(Resource): parser = reqparse.RequestParser() parser.add_argument("email", type=email, required=True, location="json") - parser.add_argument("name", type=str_len(30), required=True, location="json") + parser.add_argument("name", type=StrLen(30), required=True, location="json") parser.add_argument("password", type=valid_password, required=True, location="json") args = parser.parse_args() diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 7293aeeb34..de30547e93 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -13,7 +13,7 @@ from services.tag_service import TagService def _validate_name(name): - if not name or len(name) < 1 or len(name) > 40: + if not name or len(name) < 1 or len(name) > 50: raise ValueError("Name must be between 1 to 50 characters.") return name diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 5a964c84fa..7667b30e34 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -46,9 +46,7 @@ def only_edition_self_hosted(view): return decorated -def cloud_edition_billing_resource_check( - resource: str, error_msg: str = "You have reached the limit of your subscription." -): +def cloud_edition_billing_resource_check(resource: str): def interceptor(view): @wraps(view) def decorated(*args, **kwargs): @@ -60,22 +58,22 @@ def cloud_edition_billing_resource_check( documents_upload_quota = features.documents_upload_quota annotation_quota_limit = features.annotation_quota_limit if resource == "members" and 0 < members.limit <= members.size: - abort(403, error_msg) + abort(403, "The number of members has reached the limit of your subscription.") elif resource == "apps" and 0 < apps.limit <= apps.size: - abort(403, error_msg) + abort(403, "The number of apps has reached the limit of your subscription.") elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size: - abort(403, error_msg) + abort(403, "The capacity of the vector space has reached the limit of your subscription.") elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size: # The api of file upload is used in the multiple places, so we need to check the source of the request from datasets source = request.args.get("source") if source == "datasets": - abort(403, error_msg) + abort(403, "The number of documents has reached the limit of your subscription.") else: return view(*args, **kwargs) elif resource == "workspace_custom" and not features.can_replace_logo: - abort(403, error_msg) + abort(403, "The workspace custom feature has reached the limit of your subscription.") elif resource == "annotation" and 0 < annotation_quota_limit.limit < annotation_quota_limit.size: - abort(403, error_msg) + abort(403, "The annotation quota has reached the limit of your subscription.") else: return view(*args, **kwargs) @@ -86,10 +84,7 @@ def cloud_edition_billing_resource_check( return interceptor -def cloud_edition_billing_knowledge_limit_check( - resource: str, - error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.", -): +def cloud_edition_billing_knowledge_limit_check(resource: str): def interceptor(view): @wraps(view) def decorated(*args, **kwargs): @@ -97,7 +92,10 @@ def cloud_edition_billing_knowledge_limit_check( if features.billing.enabled: if resource == "add_segment": if features.billing.subscription.plan == "sandbox": - abort(403, error_msg) + abort( + 403, + "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.", + ) else: return view(*args, **kwargs) diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 5e10f3b48c..e68f6b4dc4 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -36,6 +36,10 @@ class SegmentApi(DatasetApiResource): document = DocumentService.get_document(dataset.id, document_id) if not document: raise NotFound("Document not found.") + if document.indexing_status != "completed": + raise NotFound("Document is not completed.") + if not document.enabled: + raise NotFound("Document is disabled.") # check embedding model setting if dataset.indexing_technique == "high_quality": try: @@ -63,7 +67,7 @@ class SegmentApi(DatasetApiResource): segments = SegmentService.multi_create_segment(args["segments"], document, dataset) return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200 else: - return {"error": "Segemtns is required"}, 400 + return {"error": "Segments is required"}, 400 def get(self, tenant_id, dataset_id, document_id): """Create single segment.""" diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index a596c6f287..b935b23ed6 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -83,9 +83,7 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio return decorator(view) -def cloud_edition_billing_resource_check( - resource: str, api_token_type: str, error_msg: str = "You have reached the limit of your subscription." -): +def cloud_edition_billing_resource_check(resource: str, api_token_type: str): def interceptor(view): def decorated(*args, **kwargs): api_token = validate_and_get_api_token(api_token_type) @@ -98,13 +96,13 @@ def cloud_edition_billing_resource_check( documents_upload_quota = features.documents_upload_quota if resource == "members" and 0 < members.limit <= members.size: - raise Forbidden(error_msg) + raise Forbidden("The number of members has reached the limit of your subscription.") elif resource == "apps" and 0 < apps.limit <= apps.size: - raise Forbidden(error_msg) + raise Forbidden("The number of apps has reached the limit of your subscription.") elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size: - raise Forbidden(error_msg) + raise Forbidden("The capacity of the vector space has reached the limit of your subscription.") elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size: - raise Forbidden(error_msg) + raise Forbidden("The number of documents has reached the limit of your subscription.") else: return view(*args, **kwargs) @@ -115,11 +113,7 @@ def cloud_edition_billing_resource_check( return interceptor -def cloud_edition_billing_knowledge_limit_check( - resource: str, - api_token_type: str, - error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.", -): +def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str): def interceptor(view): @wraps(view) def decorated(*args, **kwargs): @@ -128,7 +122,9 @@ def cloud_edition_billing_knowledge_limit_check( if features.billing.enabled: if resource == "add_segment": if features.billing.subscription.plan == "sandbox": - raise Forbidden(error_msg) + raise Forbidden( + "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan." + ) else: return view(*args, **kwargs) diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 2b4d0e7630..0564b15ea3 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -39,6 +39,7 @@ class AppSiteApi(WebApiResource): "default_language": fields.String, "prompt_public": fields.Boolean, "show_workflow_steps": fields.Boolean, + "use_icon_as_answer_icon": fields.Boolean, } app_fields = { diff --git a/api/core/__init__.py b/api/core/__init__.py index 8c986fc8bd..6eaea7b1c8 100644 --- a/api/core/__init__.py +++ b/api/core/__init__.py @@ -1 +1 @@ -import core.moderation.base \ No newline at end of file +import core.moderation.base diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index d8290ca608..d09a9956a4 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -1,6 +1,7 @@ import json import logging import uuid +from collections.abc import Mapping, Sequence from datetime import datetime, timezone from typing import Optional, Union, cast @@ -45,22 +46,25 @@ from models.tools import ToolConversationVariables logger = logging.getLogger(__name__) + class BaseAgentRunner(AppRunner): - def __init__(self, tenant_id: str, - application_generate_entity: AgentChatAppGenerateEntity, - conversation: Conversation, - app_config: AgentChatAppConfig, - model_config: ModelConfigWithCredentialsEntity, - config: AgentEntity, - queue_manager: AppQueueManager, - message: Message, - user_id: str, - memory: Optional[TokenBufferMemory] = None, - prompt_messages: Optional[list[PromptMessage]] = None, - variables_pool: Optional[ToolRuntimeVariablePool] = None, - db_variables: Optional[ToolConversationVariables] = None, - model_instance: ModelInstance = None - ) -> None: + def __init__( + self, + tenant_id: str, + application_generate_entity: AgentChatAppGenerateEntity, + conversation: Conversation, + app_config: AgentChatAppConfig, + model_config: ModelConfigWithCredentialsEntity, + config: AgentEntity, + queue_manager: AppQueueManager, + message: Message, + user_id: str, + memory: Optional[TokenBufferMemory] = None, + prompt_messages: Optional[list[PromptMessage]] = None, + variables_pool: Optional[ToolRuntimeVariablePool] = None, + db_variables: Optional[ToolConversationVariables] = None, + model_instance: ModelInstance = None, + ) -> None: """ Agent runner :param tenant_id: tenant id @@ -88,9 +92,7 @@ class BaseAgentRunner(AppRunner): self.message = message self.user_id = user_id self.memory = memory - self.history_prompt_messages = self.organize_agent_history( - prompt_messages=prompt_messages or [] - ) + self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or []) self.variables_pool = variables_pool self.db_variables_pool = db_variables self.model_instance = model_instance @@ -111,12 +113,16 @@ class BaseAgentRunner(AppRunner): retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None, return_resource=app_config.additional_features.show_retrieve_source, invoke_from=application_generate_entity.invoke_from, - hit_callback=hit_callback + hit_callback=hit_callback, ) # get how many agent thoughts have been created - self.agent_thought_count = db.session.query(MessageAgentThought).filter( - MessageAgentThought.message_id == self.message.id, - ).count() + self.agent_thought_count = ( + db.session.query(MessageAgentThought) + .filter( + MessageAgentThought.message_id == self.message.id, + ) + .count() + ) db.session.close() # check if model supports stream tool call @@ -135,25 +141,26 @@ class BaseAgentRunner(AppRunner): self.query = None self._current_thoughts: list[PromptMessage] = [] - def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \ - -> AgentChatAppGenerateEntity: + def _repack_app_generate_entity( + self, app_generate_entity: AgentChatAppGenerateEntity + ) -> AgentChatAppGenerateEntity: """ Repack app generate entity """ if app_generate_entity.app_config.prompt_template.simple_prompt_template is None: - app_generate_entity.app_config.prompt_template.simple_prompt_template = '' + app_generate_entity.app_config.prompt_template.simple_prompt_template = "" return app_generate_entity - + def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]: """ - convert tool to prompt message tool + convert tool to prompt message tool """ tool_entity = ToolManager.get_agent_tool_runtime( tenant_id=self.tenant_id, app_id=self.app_config.app_id, agent_tool=tool, - invoke_from=self.application_generate_entity.invoke_from + invoke_from=self.application_generate_entity.invoke_from, ) tool_entity.load_variables(self.variables_pool) @@ -164,7 +171,7 @@ class BaseAgentRunner(AppRunner): "type": "object", "properties": {}, "required": [], - } + }, ) parameters = tool_entity.get_all_runtime_parameters() @@ -177,19 +184,19 @@ class BaseAgentRunner(AppRunner): if parameter.type == ToolParameter.ToolParameterType.SELECT: enum = [option.value for option in parameter.options] - message_tool.parameters['properties'][parameter.name] = { + message_tool.parameters["properties"][parameter.name] = { "type": parameter_type, - "description": parameter.llm_description or '', + "description": parameter.llm_description or "", } if len(enum) > 0: - message_tool.parameters['properties'][parameter.name]['enum'] = enum + message_tool.parameters["properties"][parameter.name]["enum"] = enum if parameter.required: - message_tool.parameters['required'].append(parameter.name) + message_tool.parameters["required"].append(parameter.name) return message_tool, tool_entity - + def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool: """ convert dataset retriever tool to prompt message tool @@ -201,24 +208,24 @@ class BaseAgentRunner(AppRunner): "type": "object", "properties": {}, "required": [], - } + }, ) for parameter in tool.get_runtime_parameters(): - parameter_type = 'string' - - prompt_tool.parameters['properties'][parameter.name] = { + parameter_type = "string" + + prompt_tool.parameters["properties"][parameter.name] = { "type": parameter_type, - "description": parameter.llm_description or '', + "description": parameter.llm_description or "", } if parameter.required: - if parameter.name not in prompt_tool.parameters['required']: - prompt_tool.parameters['required'].append(parameter.name) + if parameter.name not in prompt_tool.parameters["required"]: + prompt_tool.parameters["required"].append(parameter.name) return prompt_tool - - def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]: + + def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]: """ Init tools """ @@ -261,51 +268,51 @@ class BaseAgentRunner(AppRunner): enum = [] if parameter.type == ToolParameter.ToolParameterType.SELECT: enum = [option.value for option in parameter.options] - - prompt_tool.parameters['properties'][parameter.name] = { + + prompt_tool.parameters["properties"][parameter.name] = { "type": parameter_type, - "description": parameter.llm_description or '', + "description": parameter.llm_description or "", } if len(enum) > 0: - prompt_tool.parameters['properties'][parameter.name]['enum'] = enum + prompt_tool.parameters["properties"][parameter.name]["enum"] = enum if parameter.required: - if parameter.name not in prompt_tool.parameters['required']: - prompt_tool.parameters['required'].append(parameter.name) + if parameter.name not in prompt_tool.parameters["required"]: + prompt_tool.parameters["required"].append(parameter.name) return prompt_tool - - def create_agent_thought(self, message_id: str, message: str, - tool_name: str, tool_input: str, messages_ids: list[str] - ) -> MessageAgentThought: + + def create_agent_thought( + self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str] + ) -> MessageAgentThought: """ Create agent thought """ thought = MessageAgentThought( message_id=message_id, message_chain_id=None, - thought='', + thought="", tool=tool_name, - tool_labels_str='{}', - tool_meta_str='{}', + tool_labels_str="{}", + tool_meta_str="{}", tool_input=tool_input, message=message, message_token=0, message_unit_price=0, message_price_unit=0, - message_files=json.dumps(messages_ids) if messages_ids else '', - answer='', - observation='', + message_files=json.dumps(messages_ids) if messages_ids else "", + answer="", + observation="", answer_token=0, answer_unit_price=0, answer_price_unit=0, tokens=0, total_price=0, position=self.agent_thought_count + 1, - currency='USD', + currency="USD", latency=0, - created_by_role='account', + created_by_role="account", created_by=self.user_id, ) @@ -318,22 +325,22 @@ class BaseAgentRunner(AppRunner): return thought - def save_agent_thought(self, - agent_thought: MessageAgentThought, - tool_name: str, - tool_input: Union[str, dict], - thought: str, - observation: Union[str, dict], - tool_invoke_meta: Union[str, dict], - answer: str, - messages_ids: list[str], - llm_usage: LLMUsage = None) -> MessageAgentThought: + def save_agent_thought( + self, + agent_thought: MessageAgentThought, + tool_name: str, + tool_input: Union[str, dict], + thought: str, + observation: Union[str, dict], + tool_invoke_meta: Union[str, dict], + answer: str, + messages_ids: list[str], + llm_usage: LLMUsage = None, + ) -> MessageAgentThought: """ Save agent thought """ - agent_thought = db.session.query(MessageAgentThought).filter( - MessageAgentThought.id == agent_thought.id - ).first() + agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() if thought is not None: agent_thought.thought = thought @@ -356,7 +363,7 @@ class BaseAgentRunner(AppRunner): observation = json.dumps(observation, ensure_ascii=False) except Exception as e: observation = json.dumps(observation) - + agent_thought.observation = observation if answer is not None: @@ -364,7 +371,7 @@ class BaseAgentRunner(AppRunner): if messages_ids is not None and len(messages_ids) > 0: agent_thought.message_files = json.dumps(messages_ids) - + if llm_usage: agent_thought.message_token = llm_usage.prompt_tokens agent_thought.message_price_unit = llm_usage.prompt_price_unit @@ -377,7 +384,7 @@ class BaseAgentRunner(AppRunner): # check if tool labels is not empty labels = agent_thought.tool_labels or {} - tools = agent_thought.tool.split(';') if agent_thought.tool else [] + tools = agent_thought.tool.split(";") if agent_thought.tool else [] for tool in tools: if not tool: continue @@ -386,7 +393,7 @@ class BaseAgentRunner(AppRunner): if tool_label: labels[tool] = tool_label.to_dict() else: - labels[tool] = {'en_US': tool, 'zh_Hans': tool} + labels[tool] = {"en_US": tool, "zh_Hans": tool} agent_thought.tool_labels_str = json.dumps(labels) @@ -401,14 +408,18 @@ class BaseAgentRunner(AppRunner): db.session.commit() db.session.close() - + def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables): """ convert tool variables to db variables """ - db_variables = db.session.query(ToolConversationVariables).filter( - ToolConversationVariables.conversation_id == self.message.conversation_id, - ).first() + db_variables = ( + db.session.query(ToolConversationVariables) + .filter( + ToolConversationVariables.conversation_id == self.message.conversation_id, + ) + .first() + ) db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) @@ -425,9 +436,14 @@ class BaseAgentRunner(AppRunner): if isinstance(prompt_message, SystemPromptMessage): result.append(prompt_message) - messages: list[Message] = db.session.query(Message).filter( - Message.conversation_id == self.message.conversation_id, - ).order_by(Message.created_at.asc()).all() + messages: list[Message] = ( + db.session.query(Message) + .filter( + Message.conversation_id == self.message.conversation_id, + ) + .order_by(Message.created_at.asc()) + .all() + ) for message in messages: if message.id == self.message.id: @@ -439,13 +455,13 @@ class BaseAgentRunner(AppRunner): for agent_thought in agent_thoughts: tools = agent_thought.tool if tools: - tools = tools.split(';') + tools = tools.split(";") tool_calls: list[AssistantPromptMessage.ToolCall] = [] tool_call_response: list[ToolPromptMessage] = [] try: tool_inputs = json.loads(agent_thought.tool_input) except Exception as e: - tool_inputs = { tool: {} for tool in tools } + tool_inputs = {tool: {} for tool in tools} try: tool_responses = json.loads(agent_thought.observation) except Exception as e: @@ -454,27 +470,33 @@ class BaseAgentRunner(AppRunner): for tool in tools: # generate a uuid for tool call tool_call_id = str(uuid.uuid4()) - tool_calls.append(AssistantPromptMessage.ToolCall( - id=tool_call_id, - type='function', - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=tool, - arguments=json.dumps(tool_inputs.get(tool, {})), + tool_calls.append( + AssistantPromptMessage.ToolCall( + id=tool_call_id, + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=tool, + arguments=json.dumps(tool_inputs.get(tool, {})), + ), ) - )) - tool_call_response.append(ToolPromptMessage( - content=tool_responses.get(tool, agent_thought.observation), - name=tool, - tool_call_id=tool_call_id, - )) + ) + tool_call_response.append( + ToolPromptMessage( + content=tool_responses.get(tool, agent_thought.observation), + name=tool, + tool_call_id=tool_call_id, + ) + ) - result.extend([ - AssistantPromptMessage( - content=agent_thought.thought, - tool_calls=tool_calls, - ), - *tool_call_response - ]) + result.extend( + [ + AssistantPromptMessage( + content=agent_thought.thought, + tool_calls=tool_calls, + ), + *tool_call_response, + ] + ) if not tools: result.append(AssistantPromptMessage(content=agent_thought.thought)) else: @@ -496,10 +518,7 @@ class BaseAgentRunner(AppRunner): file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.transform_message_files( - files, - file_extra_config - ) + file_objs = message_file_parser.transform_message_files(files, file_extra_config) else: file_objs = [] diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 89c948d2e2..29b428a7c3 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -25,17 +25,19 @@ from models.model import Message class CotAgentRunner(BaseAgentRunner, ABC): _is_first_iteration = True - _ignore_observation_providers = ['wenxin'] + _ignore_observation_providers = ["wenxin"] _historic_prompt_messages: list[PromptMessage] = None _agent_scratchpad: list[AgentScratchpadUnit] = None _instruction: str = None _query: str = None _prompt_messages_tools: list[PromptMessage] = None - def run(self, message: Message, - query: str, - inputs: dict[str, str], - ) -> Union[Generator, LLMResult]: + def run( + self, + message: Message, + query: str, + inputs: dict[str, str], + ) -> Union[Generator, LLMResult]: """ Run Cot agent application """ @@ -46,17 +48,16 @@ class CotAgentRunner(BaseAgentRunner, ABC): trace_manager = app_generate_entity.trace_manager # check model mode - if 'Observation' not in app_generate_entity.model_conf.stop: + if "Observation" not in app_generate_entity.model_conf.stop: if app_generate_entity.model_conf.provider not in self._ignore_observation_providers: - app_generate_entity.model_conf.stop.append('Observation') + app_generate_entity.model_conf.stop.append("Observation") app_config = self.app_config # init instruction inputs = inputs or {} instruction = app_config.prompt_template.simple_prompt_template - self._instruction = self._fill_in_inputs_from_external_data_tools( - instruction, inputs) + self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) iteration_step = 1 max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 @@ -65,16 +66,14 @@ class CotAgentRunner(BaseAgentRunner, ABC): tool_instances, self._prompt_messages_tools = self._init_prompt_tools() function_call_state = True - llm_usage = { - 'usage': None - } - final_answer = '' + llm_usage = {"usage": None} + final_answer = "" def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): - if not final_llm_usage_dict['usage']: - final_llm_usage_dict['usage'] = usage + if not final_llm_usage_dict["usage"]: + final_llm_usage_dict["usage"] = usage else: - llm_usage = final_llm_usage_dict['usage'] + llm_usage = final_llm_usage_dict["usage"] llm_usage.prompt_tokens += usage.prompt_tokens llm_usage.completion_tokens += usage.completion_tokens llm_usage.prompt_price += usage.prompt_price @@ -94,17 +93,13 @@ class CotAgentRunner(BaseAgentRunner, ABC): message_file_ids = [] agent_thought = self.create_agent_thought( - message_id=message.id, - message='', - tool_name='', - tool_input='', - messages_ids=message_file_ids + message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids ) if iteration_step > 1: - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) # recalc llm max tokens prompt_messages = self._organize_prompt_messages() @@ -125,21 +120,20 @@ class CotAgentRunner(BaseAgentRunner, ABC): raise ValueError("failed to invoke llm") usage_dict = {} - react_chunks = CotAgentOutputParser.handle_react_stream_output( - chunks, usage_dict) + react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) scratchpad = AgentScratchpadUnit( - agent_response='', - thought='', - action_str='', - observation='', + agent_response="", + thought="", + action_str="", + observation="", action=None, ) # publish agent thought if it's first iteration if iteration_step == 1: - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) for chunk in react_chunks: if isinstance(chunk, AgentScratchpadUnit.Action): @@ -154,61 +148,51 @@ class CotAgentRunner(BaseAgentRunner, ABC): yield LLMResultChunk( model=self.model_config.model, prompt_messages=prompt_messages, - system_fingerprint='', - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage( - content=chunk - ), - usage=None - ) + system_fingerprint="", + delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None), ) - scratchpad.thought = scratchpad.thought.strip( - ) or 'I am thinking about how to help you' + scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" self._agent_scratchpad.append(scratchpad) # get llm usage - if 'usage' in usage_dict: - increase_usage(llm_usage, usage_dict['usage']) + if "usage" in usage_dict: + increase_usage(llm_usage, usage_dict["usage"]) else: - usage_dict['usage'] = LLMUsage.empty_usage() + usage_dict["usage"] = LLMUsage.empty_usage() self.save_agent_thought( agent_thought=agent_thought, - tool_name=scratchpad.action.action_name if scratchpad.action else '', - tool_input={ - scratchpad.action.action_name: scratchpad.action.action_input - } if scratchpad.action else {}, + tool_name=scratchpad.action.action_name if scratchpad.action else "", + tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {}, tool_invoke_meta={}, thought=scratchpad.thought, - observation='', + observation="", answer=scratchpad.agent_response, messages_ids=[], - llm_usage=usage_dict['usage'] + llm_usage=usage_dict["usage"], ) if not scratchpad.is_final(): - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) if not scratchpad.action: # failed to extract action, return final answer directly - final_answer = '' + final_answer = "" else: if scratchpad.action.action_name.lower() == "final answer": # action is final answer, return final answer directly try: if isinstance(scratchpad.action.action_input, dict): - final_answer = json.dumps( - scratchpad.action.action_input) + final_answer = json.dumps(scratchpad.action.action_input) elif isinstance(scratchpad.action.action_input, str): final_answer = scratchpad.action.action_input else: - final_answer = f'{scratchpad.action.action_input}' + final_answer = f"{scratchpad.action.action_input}" except json.JSONDecodeError: - final_answer = f'{scratchpad.action.action_input}' + final_answer = f"{scratchpad.action.action_input}" else: function_call_state = True # action is tool call, invoke tool @@ -224,21 +208,18 @@ class CotAgentRunner(BaseAgentRunner, ABC): self.save_agent_thought( agent_thought=agent_thought, tool_name=scratchpad.action.action_name, - tool_input={ - scratchpad.action.action_name: scratchpad.action.action_input}, + tool_input={scratchpad.action.action_name: scratchpad.action.action_input}, thought=scratchpad.thought, - observation={ - scratchpad.action.action_name: tool_invoke_response}, - tool_invoke_meta={ - scratchpad.action.action_name: tool_invoke_meta.to_dict()}, + observation={scratchpad.action.action_name: tool_invoke_response}, + tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()}, answer=scratchpad.agent_response, messages_ids=message_file_ids, - llm_usage=usage_dict['usage'] + llm_usage=usage_dict["usage"], ) - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) # update prompt tool message for prompt_tool in self._prompt_messages_tools: @@ -250,44 +231,45 @@ class CotAgentRunner(BaseAgentRunner, ABC): model=model_instance.model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage( - content=final_answer - ), - usage=llm_usage['usage'] + index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"] ), - system_fingerprint='' + system_fingerprint="", ) # save agent thought self.save_agent_thought( agent_thought=agent_thought, - tool_name='', + tool_name="", tool_input={}, tool_invoke_meta={}, thought=final_answer, observation={}, answer=final_answer, - messages_ids=[] + messages_ids=[], ) self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event - self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( - model=model_instance.model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=final_answer + self.queue_manager.publish( + QueueMessageEndEvent( + llm_result=LLMResult( + model=model_instance.model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content=final_answer), + usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(), + system_fingerprint="", + ) ), - usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(), - system_fingerprint='' - )), PublishFrom.APPLICATION_MANAGER) + PublishFrom.APPLICATION_MANAGER, + ) - def _handle_invoke_action(self, action: AgentScratchpadUnit.Action, - tool_instances: dict[str, Tool], - message_file_ids: list[str], - trace_manager: Optional[TraceQueueManager] = None - ) -> tuple[str, ToolInvokeMeta]: + def _handle_invoke_action( + self, + action: AgentScratchpadUnit.Action, + tool_instances: dict[str, Tool], + message_file_ids: list[str], + trace_manager: Optional[TraceQueueManager] = None, + ) -> tuple[str, ToolInvokeMeta]: """ handle invoke action :param action: action @@ -326,13 +308,12 @@ class CotAgentRunner(BaseAgentRunner, ABC): # publish files for message_file_id, save_as in message_files: if save_as: - self.variables_pool.set_file( - tool_name=tool_call_name, value=message_file_id, name=save_as) + self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as) # publish message file - self.queue_manager.publish(QueueMessageFileEvent( - message_file_id=message_file_id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER + ) # add message file ids message_file_ids.append(message_file_id) @@ -342,10 +323,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): """ convert dict to action """ - return AgentScratchpadUnit.Action( - action_name=action['action'], - action_input=action['action_input'] - ) + return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"]) def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str: """ @@ -353,7 +331,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): """ for key, value in inputs.items(): try: - instruction = instruction.replace(f'{{{{{key}}}}}', str(value)) + instruction = instruction.replace(f"{{{{{key}}}}}", str(value)) except Exception as e: continue @@ -370,14 +348,14 @@ class CotAgentRunner(BaseAgentRunner, ABC): @abstractmethod def _organize_prompt_messages(self) -> list[PromptMessage]: """ - organize prompt messages + organize prompt messages """ def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str: """ - format assistant message + format assistant message """ - message = '' + message = "" for scratchpad in agent_scratchpad: if scratchpad.is_final(): message += f"Final Answer: {scratchpad.agent_response}" @@ -390,9 +368,11 @@ class CotAgentRunner(BaseAgentRunner, ABC): return message - def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]: + def _organize_historic_prompt_messages( + self, current_session_messages: list[PromptMessage] = None + ) -> list[PromptMessage]: """ - organize historic prompt messages + organize historic prompt messages """ result: list[PromptMessage] = [] scratchpads: list[AgentScratchpadUnit] = [] @@ -403,8 +383,8 @@ class CotAgentRunner(BaseAgentRunner, ABC): if not current_scratchpad: current_scratchpad = AgentScratchpadUnit( agent_response=message.content, - thought=message.content or 'I am thinking about how to help you', - action_str='', + thought=message.content or "I am thinking about how to help you", + action_str="", action=None, observation=None, ) @@ -413,12 +393,9 @@ class CotAgentRunner(BaseAgentRunner, ABC): try: current_scratchpad.action = AgentScratchpadUnit.Action( action_name=message.tool_calls[0].function.name, - action_input=json.loads( - message.tool_calls[0].function.arguments) - ) - current_scratchpad.action_str = json.dumps( - current_scratchpad.action.to_dict() + action_input=json.loads(message.tool_calls[0].function.arguments), ) + current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict()) except: pass elif isinstance(message, ToolPromptMessage): @@ -426,23 +403,19 @@ class CotAgentRunner(BaseAgentRunner, ABC): current_scratchpad.observation = message.content elif isinstance(message, UserPromptMessage): if scratchpads: - result.append(AssistantPromptMessage( - content=self._format_assistant_message(scratchpads) - )) + result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) scratchpads = [] current_scratchpad = None result.append(message) if scratchpads: - result.append(AssistantPromptMessage( - content=self._format_assistant_message(scratchpads) - )) + result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) historic_prompts = AgentHistoryPromptTransform( model_config=self.model_config, prompt_messages=current_session_messages or [], history_messages=result, - memory=self.memory + memory=self.memory, ).get_prompt() return historic_prompts diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index 8debbe5c5d..bdec6b7ed1 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -19,14 +19,15 @@ class CotChatAgentRunner(CotAgentRunner): prompt_entity = self.app_config.agent.prompt first_prompt = prompt_entity.first_prompt - system_prompt = first_prompt \ - .replace("{{instruction}}", self._instruction) \ - .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \ - .replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools])) + system_prompt = ( + first_prompt.replace("{{instruction}}", self._instruction) + .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) + .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools])) + ) return SystemPromptMessage(content=system_prompt) - def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: + def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: """ Organize user query """ @@ -43,7 +44,7 @@ class CotChatAgentRunner(CotAgentRunner): def _organize_prompt_messages(self) -> list[PromptMessage]: """ - Organize + Organize """ # organize system prompt system_message = self._organize_system_prompt() @@ -53,7 +54,7 @@ class CotChatAgentRunner(CotAgentRunner): if not agent_scratchpad: assistant_messages = [] else: - assistant_message = AssistantPromptMessage(content='') + assistant_message = AssistantPromptMessage(content="") for unit in agent_scratchpad: if unit.is_final(): assistant_message.content += f"Final Answer: {unit.agent_response}" @@ -71,18 +72,15 @@ class CotChatAgentRunner(CotAgentRunner): if assistant_messages: # organize historic prompt messages - historic_messages = self._organize_historic_prompt_messages([ - system_message, - *query_messages, - *assistant_messages, - UserPromptMessage(content='continue') - ]) + historic_messages = self._organize_historic_prompt_messages( + [system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")] + ) messages = [ system_message, *historic_messages, *query_messages, *assistant_messages, - UserPromptMessage(content='continue') + UserPromptMessage(content="continue"), ] else: # organize historic prompt messages diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index 9e6eb54f4f..9dab956f9a 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -13,10 +13,12 @@ class CotCompletionAgentRunner(CotAgentRunner): prompt_entity = self.app_config.agent.prompt first_prompt = prompt_entity.first_prompt - system_prompt = first_prompt.replace("{{instruction}}", self._instruction) \ - .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \ - .replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools])) - + system_prompt = ( + first_prompt.replace("{{instruction}}", self._instruction) + .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) + .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools])) + ) + return system_prompt def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str: @@ -46,7 +48,7 @@ class CotCompletionAgentRunner(CotAgentRunner): # organize current assistant messages agent_scratchpad = self._agent_scratchpad - assistant_prompt = '' + assistant_prompt = "" for unit in agent_scratchpad: if unit.is_final(): assistant_prompt += f"Final Answer: {unit.agent_response}" @@ -61,9 +63,10 @@ class CotCompletionAgentRunner(CotAgentRunner): query_prompt = f"Question: {self._query}" # join all messages - prompt = system_prompt \ - .replace("{{historic_messages}}", historic_prompt) \ - .replace("{{agent_scratchpad}}", assistant_prompt) \ + prompt = ( + system_prompt.replace("{{historic_messages}}", historic_prompt) + .replace("{{agent_scratchpad}}", assistant_prompt) .replace("{{query}}", query_prompt) + ) - return [UserPromptMessage(content=prompt)] \ No newline at end of file + return [UserPromptMessage(content=prompt)] diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 5274224de5..119a88fc7b 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -8,6 +8,7 @@ class AgentToolEntity(BaseModel): """ Agent Tool Entity. """ + provider_type: Literal["builtin", "api", "workflow"] provider_id: str tool_name: str @@ -18,6 +19,7 @@ class AgentPromptEntity(BaseModel): """ Agent Prompt Entity. """ + first_prompt: str next_iteration: str @@ -31,6 +33,7 @@ class AgentScratchpadUnit(BaseModel): """ Action Entity. """ + action_name: str action_input: Union[dict, str] @@ -39,8 +42,8 @@ class AgentScratchpadUnit(BaseModel): Convert to dictionary. """ return { - 'action': self.action_name, - 'action_input': self.action_input, + "action": self.action_name, + "action_input": self.action_input, } agent_response: Optional[str] = None @@ -54,10 +57,10 @@ class AgentScratchpadUnit(BaseModel): Check if the scratchpad unit is final. """ return self.action is None or ( - 'final' in self.action.action_name.lower() and - 'answer' in self.action.action_name.lower() + "final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower() ) + class AgentEntity(BaseModel): """ Agent Entity. @@ -67,8 +70,9 @@ class AgentEntity(BaseModel): """ Agent Strategy. """ - CHAIN_OF_THOUGHT = 'chain-of-thought' - FUNCTION_CALLING = 'function-calling' + + CHAIN_OF_THOUGHT = "chain-of-thought" + FUNCTION_CALLING = "function-calling" provider: str model: str diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 3ee6e47742..27cf561e3d 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -24,11 +24,9 @@ from models.model import Message logger = logging.getLogger(__name__) -class FunctionCallAgentRunner(BaseAgentRunner): - def run(self, - message: Message, query: str, **kwargs: Any - ) -> Generator[LLMResultChunk, None, None]: +class FunctionCallAgentRunner(BaseAgentRunner): + def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]: """ Run FunctionCall agent application """ @@ -45,19 +43,17 @@ class FunctionCallAgentRunner(BaseAgentRunner): # continue to run until there is not any tool call function_call_state = True - llm_usage = { - 'usage': None - } - final_answer = '' + llm_usage = {"usage": None} + final_answer = "" # get tracing instance trace_manager = app_generate_entity.trace_manager - + def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): - if not final_llm_usage_dict['usage']: - final_llm_usage_dict['usage'] = usage + if not final_llm_usage_dict["usage"]: + final_llm_usage_dict["usage"] = usage else: - llm_usage = final_llm_usage_dict['usage'] + llm_usage = final_llm_usage_dict["usage"] llm_usage.prompt_tokens += usage.prompt_tokens llm_usage.completion_tokens += usage.completion_tokens llm_usage.prompt_price += usage.prompt_price @@ -75,11 +71,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): message_file_ids = [] agent_thought = self.create_agent_thought( - message_id=message.id, - message='', - tool_name='', - tool_input='', - messages_ids=message_file_ids + message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids ) # recalc llm max tokens @@ -99,11 +91,11 @@ class FunctionCallAgentRunner(BaseAgentRunner): tool_calls: list[tuple[str, str, dict[str, Any]]] = [] # save full response - response = '' + response = "" # save tool call names and inputs - tool_call_names = '' - tool_call_inputs = '' + tool_call_names = "" + tool_call_inputs = "" current_llm_usage = None @@ -111,24 +103,22 @@ class FunctionCallAgentRunner(BaseAgentRunner): is_first_chunk = True for chunk in chunks: if is_first_chunk: - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) is_first_chunk = False # check if there is any tool call if self.check_tool_calls(chunk): function_call_state = True tool_calls.extend(self.extract_tool_calls(chunk)) - tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls]) + tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) try: - tool_call_inputs = json.dumps({ - tool_call[1]: tool_call[2] for tool_call in tool_calls - }, ensure_ascii=False) + tool_call_inputs = json.dumps( + {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False + ) except json.JSONDecodeError as e: # ensure ascii to avoid encoding error - tool_call_inputs = json.dumps({ - tool_call[1]: tool_call[2] for tool_call in tool_calls - }) + tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) if chunk.delta.message and chunk.delta.message.content: if isinstance(chunk.delta.message.content, list): @@ -148,16 +138,14 @@ class FunctionCallAgentRunner(BaseAgentRunner): if self.check_blocking_tool_calls(result): function_call_state = True tool_calls.extend(self.extract_blocking_tool_calls(result)) - tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls]) + tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) try: - tool_call_inputs = json.dumps({ - tool_call[1]: tool_call[2] for tool_call in tool_calls - }, ensure_ascii=False) + tool_call_inputs = json.dumps( + {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False + ) except json.JSONDecodeError as e: # ensure ascii to avoid encoding error - tool_call_inputs = json.dumps({ - tool_call[1]: tool_call[2] for tool_call in tool_calls - }) + tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) if result.usage: increase_usage(llm_usage, result.usage) @@ -171,12 +159,12 @@ class FunctionCallAgentRunner(BaseAgentRunner): response += result.message.content if not result.message.content: - result.message.content = '' + result.message.content = "" + + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) - yield LLMResultChunk( model=model_instance.model, prompt_messages=result.prompt_messages, @@ -185,32 +173,29 @@ class FunctionCallAgentRunner(BaseAgentRunner): index=0, message=result.message, usage=result.usage, - ) + ), ) - assistant_message = AssistantPromptMessage( - content='', - tool_calls=[] - ) + assistant_message = AssistantPromptMessage(content="", tool_calls=[]) if tool_calls: - assistant_message.tool_calls=[ + assistant_message.tool_calls = [ AssistantPromptMessage.ToolCall( id=tool_call[0], - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=tool_call[1], - arguments=json.dumps(tool_call[2], ensure_ascii=False) - ) - ) for tool_call in tool_calls + name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False) + ), + ) + for tool_call in tool_calls ] else: assistant_message.content = response - + self._current_thoughts.append(assistant_message) # save thought self.save_agent_thought( - agent_thought=agent_thought, + agent_thought=agent_thought, tool_name=tool_call_names, tool_input=tool_call_inputs, thought=response, @@ -218,13 +203,13 @@ class FunctionCallAgentRunner(BaseAgentRunner): observation=None, answer=response, messages_ids=[], - llm_usage=current_llm_usage + llm_usage=current_llm_usage, ) - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) - - final_answer += response + '\n' + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) + + final_answer += response + "\n" # call tools tool_responses = [] @@ -235,7 +220,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): "tool_call_id": tool_call_id, "tool_call_name": tool_call_name, "tool_response": f"there is not a tool named {tool_call_name}", - "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict() + "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(), } else: # invoke tool @@ -255,50 +240,49 @@ class FunctionCallAgentRunner(BaseAgentRunner): self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as) # publish message file - self.queue_manager.publish(QueueMessageFileEvent( - message_file_id=message_file_id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER + ) # add message file ids message_file_ids.append(message_file_id) - + tool_response = { "tool_call_id": tool_call_id, "tool_call_name": tool_call_name, "tool_response": tool_invoke_response, - "meta": tool_invoke_meta.to_dict() + "meta": tool_invoke_meta.to_dict(), } - + tool_responses.append(tool_response) - if tool_response['tool_response'] is not None: + if tool_response["tool_response"] is not None: self._current_thoughts.append( ToolPromptMessage( - content=tool_response['tool_response'], + content=tool_response["tool_response"], tool_call_id=tool_call_id, name=tool_call_name, ) - ) + ) if len(tool_responses) > 0: # save agent thought self.save_agent_thought( - agent_thought=agent_thought, + agent_thought=agent_thought, tool_name=None, tool_input=None, - thought=None, + thought=None, tool_invoke_meta={ - tool_response['tool_call_name']: tool_response['meta'] - for tool_response in tool_responses + tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses }, observation={ - tool_response['tool_call_name']: tool_response['tool_response'] + tool_response["tool_call_name"]: tool_response["tool_response"] for tool_response in tool_responses }, answer=None, - messages_ids=message_file_ids + messages_ids=message_file_ids, + ) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER ) - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) # update prompt tool for prompt_tool in prompt_messages_tools: @@ -308,15 +292,18 @@ class FunctionCallAgentRunner(BaseAgentRunner): self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event - self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( - model=model_instance.model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=final_answer + self.queue_manager.publish( + QueueMessageEndEvent( + llm_result=LLMResult( + model=model_instance.model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content=final_answer), + usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(), + system_fingerprint="", + ) ), - usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(), - system_fingerprint='' - )), PublishFrom.APPLICATION_MANAGER) + PublishFrom.APPLICATION_MANAGER, + ) def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool: """ @@ -325,7 +312,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): if llm_result_chunk.delta.message.tool_calls: return True return False - + def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool: """ Check if there is any blocking tool call in llm result @@ -334,7 +321,9 @@ class FunctionCallAgentRunner(BaseAgentRunner): return True return False - def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: + def extract_tool_calls( + self, llm_result_chunk: LLMResultChunk + ) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: """ Extract tool calls from llm result chunk @@ -344,17 +333,19 @@ class FunctionCallAgentRunner(BaseAgentRunner): tool_calls = [] for prompt_message in llm_result_chunk.delta.message.tool_calls: args = {} - if prompt_message.function.arguments != '': + if prompt_message.function.arguments != "": args = json.loads(prompt_message.function.arguments) - tool_calls.append(( - prompt_message.id, - prompt_message.function.name, - args, - )) + tool_calls.append( + ( + prompt_message.id, + prompt_message.function.name, + args, + ) + ) return tool_calls - + def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: """ Extract blocking tool calls from llm result @@ -365,18 +356,22 @@ class FunctionCallAgentRunner(BaseAgentRunner): tool_calls = [] for prompt_message in llm_result.message.tool_calls: args = {} - if prompt_message.function.arguments != '': + if prompt_message.function.arguments != "": args = json.loads(prompt_message.function.arguments) - tool_calls.append(( - prompt_message.id, - prompt_message.function.name, - args, - )) + tool_calls.append( + ( + prompt_message.id, + prompt_message.function.name, + args, + ) + ) return tool_calls - def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: + def _init_system_message( + self, prompt_template: str, prompt_messages: list[PromptMessage] = None + ) -> list[PromptMessage]: """ Initialize system message """ @@ -384,13 +379,13 @@ class FunctionCallAgentRunner(BaseAgentRunner): return [ SystemPromptMessage(content=prompt_template), ] - + if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template: prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) return prompt_messages - def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: + def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: """ Organize user query """ @@ -404,7 +399,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): prompt_messages.append(UserPromptMessage(content=query)) return prompt_messages - + def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ As for now, gpt supports both fc and vision at the first iteration. @@ -415,17 +410,21 @@ class FunctionCallAgentRunner(BaseAgentRunner): for prompt_message in prompt_messages: if isinstance(prompt_message, UserPromptMessage): if isinstance(prompt_message.content, list): - prompt_message.content = '\n'.join([ - content.data if content.type == PromptMessageContentType.TEXT else - '[image]' if content.type == PromptMessageContentType.IMAGE else - '[file]' - for content in prompt_message.content - ]) + prompt_message.content = "\n".join( + [ + content.data + if content.type == PromptMessageContentType.TEXT + else "[image]" + if content.type == PromptMessageContentType.IMAGE + else "[file]" + for content in prompt_message.content + ] + ) return prompt_messages def _organize_prompt_messages(self): - prompt_template = self.app_config.prompt_template.simple_prompt_template or '' + prompt_template = self.app_config.prompt_template.simple_prompt_template or "" self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) query_prompt_messages = self._organize_user_query(self.query, []) @@ -433,14 +432,10 @@ class FunctionCallAgentRunner(BaseAgentRunner): model_config=self.model_config, prompt_messages=[*query_prompt_messages, *self._current_thoughts], history_messages=self.history_prompt_messages, - memory=self.memory + memory=self.memory, ).get_prompt() - prompt_messages = [ - *self.history_prompt_messages, - *query_prompt_messages, - *self._current_thoughts - ] + prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts] if len(self._current_thoughts) != 0: # clear messages after the first iteration prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index c53fa5000e..1a161677dd 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -9,8 +9,9 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk class CotAgentOutputParser: @classmethod - def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict) -> \ - Generator[Union[str, AgentScratchpadUnit.Action], None, None]: + def handle_react_stream_output( + cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict + ) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]: def parse_action(json_str): try: action = json.loads(json_str) @@ -22,7 +23,7 @@ class CotAgentOutputParser: action = action[0] for key, value in action.items(): - if 'input' in key.lower(): + if "input" in key.lower(): action_input = value else: action_name = value @@ -33,37 +34,37 @@ class CotAgentOutputParser: action_input=action_input, ) else: - return json_str or '' + return json_str or "" except: - return json_str or '' - + return json_str or "" + def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]: - code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL) + code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL) if not code_blocks: return for block in code_blocks: - json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE) + json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE) yield parse_action(json_text) - - code_block_cache = '' + + code_block_cache = "" code_block_delimiter_count = 0 in_code_block = False - json_cache = '' + json_cache = "" json_quote_count = 0 in_json = False got_json = False - action_cache = '' - action_str = 'action:' + action_cache = "" + action_str = "action:" action_idx = 0 - thought_cache = '' - thought_str = 'thought:' + thought_cache = "" + thought_str = "thought:" thought_idx = 0 for response in llm_response: if response.delta.usage: - usage_dict['usage'] = response.delta.usage + usage_dict["usage"] = response.delta.usage response = response.delta.message.content if not isinstance(response, str): continue @@ -72,24 +73,24 @@ class CotAgentOutputParser: index = 0 while index < len(response): steps = 1 - delta = response[index:index+steps] - last_character = response[index-1] if index > 0 else '' + delta = response[index : index + steps] + last_character = response[index - 1] if index > 0 else "" - if delta == '`': + if delta == "`": code_block_cache += delta code_block_delimiter_count += 1 else: if not in_code_block: if code_block_delimiter_count > 0: yield code_block_cache - code_block_cache = '' + code_block_cache = "" else: code_block_cache += delta code_block_delimiter_count = 0 if not in_code_block and not in_json: if delta.lower() == action_str[action_idx] and action_idx == 0: - if last_character not in ['\n', ' ', '']: + if last_character not in ["\n", " ", ""]: index += steps yield delta continue @@ -97,7 +98,7 @@ class CotAgentOutputParser: action_cache += delta action_idx += 1 if action_idx == len(action_str): - action_cache = '' + action_cache = "" action_idx = 0 index += steps continue @@ -105,18 +106,18 @@ class CotAgentOutputParser: action_cache += delta action_idx += 1 if action_idx == len(action_str): - action_cache = '' + action_cache = "" action_idx = 0 index += steps continue else: if action_cache: yield action_cache - action_cache = '' + action_cache = "" action_idx = 0 - + if delta.lower() == thought_str[thought_idx] and thought_idx == 0: - if last_character not in ['\n', ' ', '']: + if last_character not in ["\n", " ", ""]: index += steps yield delta continue @@ -124,7 +125,7 @@ class CotAgentOutputParser: thought_cache += delta thought_idx += 1 if thought_idx == len(thought_str): - thought_cache = '' + thought_cache = "" thought_idx = 0 index += steps continue @@ -132,31 +133,31 @@ class CotAgentOutputParser: thought_cache += delta thought_idx += 1 if thought_idx == len(thought_str): - thought_cache = '' + thought_cache = "" thought_idx = 0 index += steps continue else: if thought_cache: yield thought_cache - thought_cache = '' + thought_cache = "" thought_idx = 0 if code_block_delimiter_count == 3: if in_code_block: yield from extra_json_from_code_block(code_block_cache) - code_block_cache = '' - + code_block_cache = "" + in_code_block = not in_code_block code_block_delimiter_count = 0 if not in_code_block: # handle single json - if delta == '{': + if delta == "{": json_quote_count += 1 in_json = True json_cache += delta - elif delta == '}': + elif delta == "}": json_cache += delta if json_quote_count > 0: json_quote_count -= 1 @@ -172,12 +173,12 @@ class CotAgentOutputParser: if got_json: got_json = False yield parse_action(json_cache) - json_cache = '' + json_cache = "" json_quote_count = 0 in_json = False - + if not in_code_block and not in_json: - yield delta.replace('`', '') + yield delta.replace("`", "") index += steps @@ -186,4 +187,3 @@ class CotAgentOutputParser: if json_cache: yield parse_action(json_cache) - diff --git a/api/core/agent/prompt/template.py b/api/core/agent/prompt/template.py index b0cf1a77fb..cb98f5501d 100644 --- a/api/core/agent/prompt/template.py +++ b/api/core/agent/prompt/template.py @@ -91,14 +91,14 @@ Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = "" REACT_PROMPT_TEMPLATES = { - 'english': { - 'chat': { - 'prompt': ENGLISH_REACT_CHAT_PROMPT_TEMPLATES, - 'agent_scratchpad': ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES + "english": { + "chat": { + "prompt": ENGLISH_REACT_CHAT_PROMPT_TEMPLATES, + "agent_scratchpad": ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES, + }, + "completion": { + "prompt": ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES, + "agent_scratchpad": ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES, }, - 'completion': { - 'prompt': ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES, - 'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES - } } -} \ No newline at end of file +} diff --git a/api/core/app/app_config/base_app_config_manager.py b/api/core/app/app_config/base_app_config_manager.py index 3dea305e98..0fd2a779a4 100644 --- a/api/core/app/app_config/base_app_config_manager.py +++ b/api/core/app/app_config/base_app_config_manager.py @@ -26,34 +26,24 @@ class BaseAppConfigManager: config_dict = dict(config_dict.items()) additional_features = AppAdditionalFeatures() - additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert( - config=config_dict - ) + additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict) additional_features.file_upload = FileUploadConfigManager.convert( - config=config_dict, - is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT] + config=config_dict, is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT] ) - additional_features.opening_statement, additional_features.suggested_questions = \ - OpeningStatementConfigManager.convert( - config=config_dict - ) + additional_features.opening_statement, additional_features.suggested_questions = ( + OpeningStatementConfigManager.convert(config=config_dict) + ) additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert( config=config_dict ) - additional_features.more_like_this = MoreLikeThisConfigManager.convert( - config=config_dict - ) + additional_features.more_like_this = MoreLikeThisConfigManager.convert(config=config_dict) - additional_features.speech_to_text = SpeechToTextConfigManager.convert( - config=config_dict - ) + additional_features.speech_to_text = SpeechToTextConfigManager.convert(config=config_dict) - additional_features.text_to_speech = TextToSpeechConfigManager.convert( - config=config_dict - ) + additional_features.text_to_speech = TextToSpeechConfigManager.convert(config=config_dict) return additional_features diff --git a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py index 1ca8b1e3b8..037037e6ca 100644 --- a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py +++ b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py @@ -7,25 +7,24 @@ from core.moderation.factory import ModerationFactory class SensitiveWordAvoidanceConfigManager: @classmethod def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]: - sensitive_word_avoidance_dict = config.get('sensitive_word_avoidance') + sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance") if not sensitive_word_avoidance_dict: return None - if sensitive_word_avoidance_dict.get('enabled'): + if sensitive_word_avoidance_dict.get("enabled"): return SensitiveWordAvoidanceEntity( - type=sensitive_word_avoidance_dict.get('type'), - config=sensitive_word_avoidance_dict.get('config'), + type=sensitive_word_avoidance_dict.get("type"), + config=sensitive_word_avoidance_dict.get("config"), ) else: return None @classmethod - def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_validate: bool = False) \ - -> tuple[dict, list[str]]: + def validate_and_set_defaults( + cls, tenant_id, config: dict, only_structure_validate: bool = False + ) -> tuple[dict, list[str]]: if not config.get("sensitive_word_avoidance"): - config["sensitive_word_avoidance"] = { - "enabled": False - } + config["sensitive_word_avoidance"] = {"enabled": False} if not isinstance(config["sensitive_word_avoidance"], dict): raise ValueError("sensitive_word_avoidance must be of dict type") @@ -41,10 +40,6 @@ class SensitiveWordAvoidanceConfigManager: typ = config["sensitive_word_avoidance"]["type"] sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"] - ModerationFactory.validate_config( - name=typ, - tenant_id=tenant_id, - config=sensitive_word_avoidance_config - ) + ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config) return config, ["sensitive_word_avoidance"] diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index dc65d4439b..6e89f19508 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -12,67 +12,70 @@ class AgentConfigManager: :param config: model config args """ - if 'agent_mode' in config and config['agent_mode'] \ - and 'enabled' in config['agent_mode']: + if "agent_mode" in config and config["agent_mode"] and "enabled" in config["agent_mode"]: + agent_dict = config.get("agent_mode", {}) + agent_strategy = agent_dict.get("strategy", "cot") - agent_dict = config.get('agent_mode', {}) - agent_strategy = agent_dict.get('strategy', 'cot') - - if agent_strategy == 'function_call': + if agent_strategy == "function_call": strategy = AgentEntity.Strategy.FUNCTION_CALLING - elif agent_strategy == 'cot' or agent_strategy == 'react': + elif agent_strategy == "cot" or agent_strategy == "react": strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT else: # old configs, try to detect default strategy - if config['model']['provider'] == 'openai': + if config["model"]["provider"] == "openai": strategy = AgentEntity.Strategy.FUNCTION_CALLING else: strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT agent_tools = [] - for tool in agent_dict.get('tools', []): + for tool in agent_dict.get("tools", []): keys = tool.keys() if len(keys) >= 4: if "enabled" not in tool or not tool["enabled"]: continue agent_tool_properties = { - 'provider_type': tool['provider_type'], - 'provider_id': tool['provider_id'], - 'tool_name': tool['tool_name'], - 'tool_parameters': tool.get('tool_parameters', {}) + "provider_type": tool["provider_type"], + "provider_id": tool["provider_id"], + "tool_name": tool["tool_name"], + "tool_parameters": tool.get("tool_parameters", {}), } agent_tools.append(AgentToolEntity(**agent_tool_properties)) - if 'strategy' in config['agent_mode'] and \ - config['agent_mode']['strategy'] not in ['react_router', 'router']: - agent_prompt = agent_dict.get('prompt', None) or {} + if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in [ + "react_router", + "router", + ]: + agent_prompt = agent_dict.get("prompt", None) or {} # check model mode - model_mode = config.get('model', {}).get('mode', 'completion') - if model_mode == 'completion': + model_mode = config.get("model", {}).get("mode", "completion") + if model_mode == "completion": agent_prompt_entity = AgentPromptEntity( - first_prompt=agent_prompt.get('first_prompt', - REACT_PROMPT_TEMPLATES['english']['completion']['prompt']), - next_iteration=agent_prompt.get('next_iteration', - REACT_PROMPT_TEMPLATES['english']['completion'][ - 'agent_scratchpad']), + first_prompt=agent_prompt.get( + "first_prompt", REACT_PROMPT_TEMPLATES["english"]["completion"]["prompt"] + ), + next_iteration=agent_prompt.get( + "next_iteration", REACT_PROMPT_TEMPLATES["english"]["completion"]["agent_scratchpad"] + ), ) else: agent_prompt_entity = AgentPromptEntity( - first_prompt=agent_prompt.get('first_prompt', - REACT_PROMPT_TEMPLATES['english']['chat']['prompt']), - next_iteration=agent_prompt.get('next_iteration', - REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']), + first_prompt=agent_prompt.get( + "first_prompt", REACT_PROMPT_TEMPLATES["english"]["chat"]["prompt"] + ), + next_iteration=agent_prompt.get( + "next_iteration", REACT_PROMPT_TEMPLATES["english"]["chat"]["agent_scratchpad"] + ), ) return AgentEntity( - provider=config['model']['provider'], - model=config['model']['name'], + provider=config["model"]["provider"], + model=config["model"]["name"], strategy=strategy, prompt=agent_prompt_entity, tools=agent_tools, - max_iteration=agent_dict.get('max_iteration', 5) + max_iteration=agent_dict.get("max_iteration", 5), ) return None diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index a8eb1f9f76..ff131b62e2 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -15,39 +15,38 @@ class DatasetConfigManager: :param config: model config args """ dataset_ids = [] - if 'datasets' in config.get('dataset_configs', {}): - datasets = config.get('dataset_configs', {}).get('datasets', { - 'strategy': 'router', - 'datasets': [] - }) + if "datasets" in config.get("dataset_configs", {}): + datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []}) - for dataset in datasets.get('datasets', []): + for dataset in datasets.get("datasets", []): keys = list(dataset.keys()) - if len(keys) == 0 or keys[0] != 'dataset': + if len(keys) == 0 or keys[0] != "dataset": continue - dataset = dataset['dataset'] + dataset = dataset["dataset"] - if 'enabled' not in dataset or not dataset['enabled']: + if "enabled" not in dataset or not dataset["enabled"]: continue - dataset_id = dataset.get('id', None) + dataset_id = dataset.get("id", None) if dataset_id: dataset_ids.append(dataset_id) - if 'agent_mode' in config and config['agent_mode'] \ - and 'enabled' in config['agent_mode'] \ - and config['agent_mode']['enabled']: + if ( + "agent_mode" in config + and config["agent_mode"] + and "enabled" in config["agent_mode"] + and config["agent_mode"]["enabled"] + ): + agent_dict = config.get("agent_mode", {}) - agent_dict = config.get('agent_mode', {}) - - for tool in agent_dict.get('tools', []): + for tool in agent_dict.get("tools", []): keys = tool.keys() if len(keys) == 1: # old standard key = list(tool.keys())[0] - if key != 'dataset': + if key != "dataset": continue tool_item = tool[key] @@ -55,30 +54,28 @@ class DatasetConfigManager: if "enabled" not in tool_item or not tool_item["enabled"]: continue - dataset_id = tool_item['id'] + dataset_id = tool_item["id"] dataset_ids.append(dataset_id) if len(dataset_ids) == 0: return None # dataset configs - if 'dataset_configs' in config and config.get('dataset_configs'): - dataset_configs = config.get('dataset_configs') + if "dataset_configs" in config and config.get("dataset_configs"): + dataset_configs = config.get("dataset_configs") else: - dataset_configs = { - 'retrieval_model': 'multiple' - } - query_variable = config.get('dataset_query_variable') + dataset_configs = {"retrieval_model": "multiple"} + query_variable = config.get("dataset_query_variable") - if dataset_configs['retrieval_model'] == 'single': + if dataset_configs["retrieval_model"] == "single": return DatasetEntity( dataset_ids=dataset_ids, retrieve_config=DatasetRetrieveConfigEntity( query_variable=query_variable, retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( - dataset_configs['retrieval_model'] - ) - ) + dataset_configs["retrieval_model"] + ), + ), ) else: return DatasetEntity( @@ -86,15 +83,15 @@ class DatasetConfigManager: retrieve_config=DatasetRetrieveConfigEntity( query_variable=query_variable, retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( - dataset_configs['retrieval_model'] + dataset_configs["retrieval_model"] ), - top_k=dataset_configs.get('top_k', 4), - score_threshold=dataset_configs.get('score_threshold'), - reranking_model=dataset_configs.get('reranking_model'), - weights=dataset_configs.get('weights'), - reranking_enabled=dataset_configs.get('reranking_enabled', True), - rerank_mode=dataset_configs.get('rerank_mode', 'reranking_model'), - ) + top_k=dataset_configs.get("top_k", 4), + score_threshold=dataset_configs.get("score_threshold"), + reranking_model=dataset_configs.get("reranking_model"), + weights=dataset_configs.get("weights"), + reranking_enabled=dataset_configs.get("reranking_enabled", True), + rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"), + ), ) @classmethod @@ -111,13 +108,10 @@ class DatasetConfigManager: # dataset_configs if not config.get("dataset_configs"): - config["dataset_configs"] = {'retrieval_model': 'single'} + config["dataset_configs"] = {"retrieval_model": "single"} if not config["dataset_configs"].get("datasets"): - config["dataset_configs"]["datasets"] = { - "strategy": "router", - "datasets": [] - } + config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []} if not isinstance(config["dataset_configs"], dict): raise ValueError("dataset_configs must be of object type") @@ -125,8 +119,9 @@ class DatasetConfigManager: if not isinstance(config["dataset_configs"], dict): raise ValueError("dataset_configs must be of object type") - need_manual_query_datasets = (config.get("dataset_configs") - and config["dataset_configs"].get("datasets", {}).get("datasets")) + need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get( + "datasets", {} + ).get("datasets") if need_manual_query_datasets and app_mode == AppMode.COMPLETION: # Only check when mode is completion @@ -148,10 +143,7 @@ class DatasetConfigManager: """ # Extract dataset config for legacy compatibility if not config.get("agent_mode"): - config["agent_mode"] = { - "enabled": False, - "tools": [] - } + config["agent_mode"] = {"enabled": False, "tools": []} if not isinstance(config["agent_mode"], dict): raise ValueError("agent_mode must be of object type") @@ -188,7 +180,7 @@ class DatasetConfigManager: if not isinstance(tool_item["enabled"], bool): raise ValueError("enabled in agent_mode.tools must be of boolean type") - if 'id' not in tool_item: + if "id" not in tool_item: raise ValueError("id is required in dataset") try: diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index 5c9b2cfec7..a91b9f0f02 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -11,9 +11,7 @@ from core.provider_manager import ProviderManager class ModelConfigConverter: @classmethod - def convert(cls, app_config: EasyUIBasedAppConfig, - skip_check: bool = False) \ - -> ModelConfigWithCredentialsEntity: + def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity: """ Convert app model config dict to entity. :param app_config: app config @@ -25,9 +23,7 @@ class ModelConfigConverter: provider_manager = ProviderManager() provider_model_bundle = provider_manager.get_provider_model_bundle( - tenant_id=app_config.tenant_id, - provider=model_config.provider, - model_type=ModelType.LLM + tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM ) provider_name = provider_model_bundle.configuration.provider.provider @@ -38,8 +34,7 @@ class ModelConfigConverter: # check model credentials model_credentials = provider_model_bundle.configuration.get_current_credentials( - model_type=ModelType.LLM, - model=model_config.model + model_type=ModelType.LLM, model=model_config.model ) if model_credentials is None: @@ -51,8 +46,7 @@ class ModelConfigConverter: if not skip_check: # check model provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_config.model, - model_type=ModelType.LLM + model=model_config.model, model_type=ModelType.LLM ) if provider_model is None: @@ -69,24 +63,18 @@ class ModelConfigConverter: # model config completion_params = model_config.parameters stop = [] - if 'stop' in completion_params: - stop = completion_params['stop'] - del completion_params['stop'] + if "stop" in completion_params: + stop = completion_params["stop"] + del completion_params["stop"] # get model mode model_mode = model_config.mode if not model_mode: - mode_enum = model_type_instance.get_model_mode( - model=model_config.model, - credentials=model_credentials - ) + mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials) model_mode = mode_enum.value - model_schema = model_type_instance.get_model_schema( - model_config.model, - model_credentials - ) + model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials) if not skip_check and not model_schema: raise ValueError(f"Model {model_name} not exist.") diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 730a9527cf..b5e4554181 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -13,23 +13,23 @@ class ModelConfigManager: :param config: model config args """ # model config - model_config = config.get('model') + model_config = config.get("model") if not model_config: raise ValueError("model is required") - completion_params = model_config.get('completion_params') + completion_params = model_config.get("completion_params") stop = [] - if 'stop' in completion_params: - stop = completion_params['stop'] - del completion_params['stop'] + if "stop" in completion_params: + stop = completion_params["stop"] + del completion_params["stop"] # get model mode - model_mode = model_config.get('mode') + model_mode = model_config.get("mode") return ModelConfigEntity( - provider=config['model']['provider'], - model=config['model']['name'], + provider=config["model"]["provider"], + model=config["model"]["name"], mode=model_mode, parameters=completion_params, stop=stop, @@ -43,7 +43,7 @@ class ModelConfigManager: :param tenant_id: tenant id :param config: app model config args """ - if 'model' not in config: + if "model" not in config: raise ValueError("model is required") if not isinstance(config["model"], dict): @@ -52,17 +52,16 @@ class ModelConfigManager: # model.provider provider_entities = model_provider_factory.get_providers() model_provider_names = [provider.provider for provider in provider_entities] - if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names: + if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names: raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") # model.name - if 'name' not in config["model"]: + if "name" not in config["model"]: raise ValueError("model.name is required") provider_manager = ProviderManager() models = provider_manager.get_configurations(tenant_id).get_models( - provider=config["model"]["provider"], - model_type=ModelType.LLM + provider=config["model"]["provider"], model_type=ModelType.LLM ) if not models: @@ -80,12 +79,12 @@ class ModelConfigManager: # model.mode if model_mode: - config['model']["mode"] = model_mode + config["model"]["mode"] = model_mode else: - config['model']["mode"] = "completion" + config["model"]["mode"] = "completion" # model.completion_params - if 'completion_params' not in config["model"]: + if "completion_params" not in config["model"]: raise ValueError("model.completion_params is required") config["model"]["completion_params"] = cls.validate_model_completion_params( @@ -101,7 +100,7 @@ class ModelConfigManager: raise ValueError("model.completion_params must be of object type") # stop - if 'stop' not in cp: + if "stop" not in cp: cp["stop"] = [] elif not isinstance(cp["stop"], list): raise ValueError("stop in model.completion_params must be of list type") diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 1f410758aa..de91c9a065 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -14,39 +14,33 @@ class PromptTemplateConfigManager: if not config.get("prompt_type"): raise ValueError("prompt_type is required") - prompt_type = PromptTemplateEntity.PromptType.value_of(config['prompt_type']) + prompt_type = PromptTemplateEntity.PromptType.value_of(config["prompt_type"]) if prompt_type == PromptTemplateEntity.PromptType.SIMPLE: simple_prompt_template = config.get("pre_prompt", "") - return PromptTemplateEntity( - prompt_type=prompt_type, - simple_prompt_template=simple_prompt_template - ) + return PromptTemplateEntity(prompt_type=prompt_type, simple_prompt_template=simple_prompt_template) else: advanced_chat_prompt_template = None chat_prompt_config = config.get("chat_prompt_config", {}) if chat_prompt_config: chat_prompt_messages = [] for message in chat_prompt_config.get("prompt", []): - chat_prompt_messages.append({ - "text": message["text"], - "role": PromptMessageRole.value_of(message["role"]) - }) + chat_prompt_messages.append( + {"text": message["text"], "role": PromptMessageRole.value_of(message["role"])} + ) - advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity( - messages=chat_prompt_messages - ) + advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages) advanced_completion_prompt_template = None completion_prompt_config = config.get("completion_prompt_config", {}) if completion_prompt_config: completion_prompt_template_params = { - 'prompt': completion_prompt_config['prompt']['text'], + "prompt": completion_prompt_config["prompt"]["text"], } - if 'conversation_histories_role' in completion_prompt_config: - completion_prompt_template_params['role_prefix'] = { - 'user': completion_prompt_config['conversation_histories_role']['user_prefix'], - 'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix'] + if "conversation_histories_role" in completion_prompt_config: + completion_prompt_template_params["role_prefix"] = { + "user": completion_prompt_config["conversation_histories_role"]["user_prefix"], + "assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"], } advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity( @@ -56,7 +50,7 @@ class PromptTemplateConfigManager: return PromptTemplateEntity( prompt_type=prompt_type, advanced_chat_prompt_template=advanced_chat_prompt_template, - advanced_completion_prompt_template=advanced_completion_prompt_template + advanced_completion_prompt_template=advanced_completion_prompt_template, ) @classmethod @@ -72,7 +66,7 @@ class PromptTemplateConfigManager: config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType] - if config['prompt_type'] not in prompt_type_vals: + if config["prompt_type"] not in prompt_type_vals: raise ValueError(f"prompt_type must be in {prompt_type_vals}") # chat_prompt_config @@ -89,27 +83,28 @@ class PromptTemplateConfigManager: if not isinstance(config["completion_prompt_config"], dict): raise ValueError("completion_prompt_config must be of object type") - if config['prompt_type'] == PromptTemplateEntity.PromptType.ADVANCED.value: - if not config['chat_prompt_config'] and not config['completion_prompt_config']: - raise ValueError("chat_prompt_config or completion_prompt_config is required " - "when prompt_type is advanced") + if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value: + if not config["chat_prompt_config"] and not config["completion_prompt_config"]: + raise ValueError( + "chat_prompt_config or completion_prompt_config is required " "when prompt_type is advanced" + ) model_mode_vals = [mode.value for mode in ModelMode] - if config['model']["mode"] not in model_mode_vals: + if config["model"]["mode"] not in model_mode_vals: raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced") - if app_mode == AppMode.CHAT and config['model']["mode"] == ModelMode.COMPLETION.value: - user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix'] - assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] + if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value: + user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] + assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] if not user_prefix: - config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human' + config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] = "Human" if not assistant_prefix: - config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant' + config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant" - if config['model']["mode"] == ModelMode.CHAT.value: - prompt_list = config['chat_prompt_config']['prompt'] + if config["model"]["mode"] == ModelMode.CHAT.value: + prompt_list = config["chat_prompt_config"]["prompt"] if len(prompt_list) > 10: raise ValueError("prompt messages must be less than 10") diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index 15fa4d99fd..2c0232c743 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -16,32 +16,30 @@ class BasicVariablesConfigManager: variable_entities = [] # old external_data_tools - external_data_tools = config.get('external_data_tools', []) + external_data_tools = config.get("external_data_tools", []) for external_data_tool in external_data_tools: - if 'enabled' not in external_data_tool or not external_data_tool['enabled']: + if "enabled" not in external_data_tool or not external_data_tool["enabled"]: continue external_data_variables.append( ExternalDataVariableEntity( - variable=external_data_tool['variable'], - type=external_data_tool['type'], - config=external_data_tool['config'] + variable=external_data_tool["variable"], + type=external_data_tool["type"], + config=external_data_tool["config"], ) ) # variables and external_data_tools - for variables in config.get('user_input_form', []): + for variables in config.get("user_input_form", []): variable_type = list(variables.keys())[0] if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL: variable = variables[variable_type] - if 'config' not in variable: + if "config" not in variable: continue external_data_variables.append( ExternalDataVariableEntity( - variable=variable['variable'], - type=variable['type'], - config=variable['config'] + variable=variable["variable"], type=variable["type"], config=variable["config"] ) ) elif variable_type in [ @@ -54,13 +52,13 @@ class BasicVariablesConfigManager: variable_entities.append( VariableEntity( type=variable_type, - variable=variable.get('variable'), - description=variable.get('description'), - label=variable.get('label'), - required=variable.get('required', False), - max_length=variable.get('max_length'), - options=variable.get('options'), - default=variable.get('default'), + variable=variable.get("variable"), + description=variable.get("description"), + label=variable.get("label"), + required=variable.get("required", False), + max_length=variable.get("max_length"), + options=variable.get("options"), + default=variable.get("default"), ) ) @@ -103,13 +101,13 @@ class BasicVariablesConfigManager: raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") form_item = item[key] - if 'label' not in form_item: + if "label" not in form_item: raise ValueError("label is required in user_input_form") if not isinstance(form_item["label"], str): raise ValueError("label in user_input_form must be of string type") - if 'variable' not in form_item: + if "variable" not in form_item: raise ValueError("variable is required in user_input_form") if not isinstance(form_item["variable"], str): @@ -117,26 +115,24 @@ class BasicVariablesConfigManager: pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$") if pattern.match(form_item["variable"]) is None: - raise ValueError("variable in user_input_form must be a string, " - "and cannot start with a number") + raise ValueError("variable in user_input_form must be a string, " "and cannot start with a number") variables.append(form_item["variable"]) - if 'required' not in form_item or not form_item["required"]: + if "required" not in form_item or not form_item["required"]: form_item["required"] = False if not isinstance(form_item["required"], bool): raise ValueError("required in user_input_form must be of boolean type") if key == "select": - if 'options' not in form_item or not form_item["options"]: + if "options" not in form_item or not form_item["options"]: form_item["options"] = [] if not isinstance(form_item["options"], list): raise ValueError("options in user_input_form must be a list of strings") - if "default" in form_item and form_item['default'] \ - and form_item["default"] not in form_item["options"]: + if "default" in form_item and form_item["default"] and form_item["default"] not in form_item["options"]: raise ValueError("default value in user_input_form must be in the options list") return config, ["user_input_form"] @@ -168,10 +164,6 @@ class BasicVariablesConfigManager: typ = tool["type"] config = tool["config"] - ExternalDataToolFactory.validate_config( - name=typ, - tenant_id=tenant_id, - config=config - ) + ExternalDataToolFactory.validate_config(name=typ, tenant_id=tenant_id, config=config) return config, ["external_data_tools"] diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index bbb10d3d76..d208db2b01 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -12,6 +12,7 @@ class ModelConfigEntity(BaseModel): """ Model Config Entity. """ + provider: str model: str mode: Optional[str] = None @@ -23,6 +24,7 @@ class AdvancedChatMessageEntity(BaseModel): """ Advanced Chat Message Entity. """ + text: str role: PromptMessageRole @@ -31,6 +33,7 @@ class AdvancedChatPromptTemplateEntity(BaseModel): """ Advanced Chat Prompt Template Entity. """ + messages: list[AdvancedChatMessageEntity] @@ -43,6 +46,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel): """ Role Prefix Entity. """ + user: str assistant: str @@ -60,11 +64,12 @@ class PromptTemplateEntity(BaseModel): Prompt Type. 'simple', 'advanced' """ - SIMPLE = 'simple' - ADVANCED = 'advanced' + + SIMPLE = "simple" + ADVANCED = "advanced" @classmethod - def value_of(cls, value: str) -> 'PromptType': + def value_of(cls, value: str) -> "PromptType": """ Get value of given mode. @@ -74,7 +79,7 @@ class PromptTemplateEntity(BaseModel): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid prompt type value {value}') + raise ValueError(f"invalid prompt type value {value}") prompt_type: PromptType simple_prompt_template: Optional[str] = None @@ -110,6 +115,7 @@ class ExternalDataVariableEntity(BaseModel): """ External Data Variable Entity. """ + variable: str type: str config: dict[str, Any] = {} @@ -125,11 +131,12 @@ class DatasetRetrieveConfigEntity(BaseModel): Dataset Retrieve Strategy. 'single' or 'multiple' """ - SINGLE = 'single' - MULTIPLE = 'multiple' + + SINGLE = "single" + MULTIPLE = "multiple" @classmethod - def value_of(cls, value: str) -> 'RetrieveStrategy': + def value_of(cls, value: str) -> "RetrieveStrategy": """ Get value of given mode. @@ -139,25 +146,24 @@ class DatasetRetrieveConfigEntity(BaseModel): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid retrieve strategy value {value}') + raise ValueError(f"invalid retrieve strategy value {value}") query_variable: Optional[str] = None # Only when app mode is completion retrieve_strategy: RetrieveStrategy top_k: Optional[int] = None - score_threshold: Optional[float] = .0 - rerank_mode: Optional[str] = 'reranking_model' + score_threshold: Optional[float] = 0.0 + rerank_mode: Optional[str] = "reranking_model" reranking_model: Optional[dict] = None weights: Optional[dict] = None reranking_enabled: Optional[bool] = True - - class DatasetEntity(BaseModel): """ Dataset Config Entity. """ + dataset_ids: list[str] retrieve_config: DatasetRetrieveConfigEntity @@ -166,6 +172,7 @@ class SensitiveWordAvoidanceEntity(BaseModel): """ Sensitive Word Avoidance Entity. """ + type: str config: dict[str, Any] = {} @@ -174,6 +181,7 @@ class TextToSpeechEntity(BaseModel): """ Sensitive Word Avoidance Entity. """ + enabled: bool voice: Optional[str] = None language: Optional[str] = None @@ -183,12 +191,11 @@ class TracingConfigEntity(BaseModel): """ Tracing Config Entity. """ + enabled: bool tracing_provider: str - - class AppAdditionalFeatures(BaseModel): file_upload: Optional[FileExtraConfig] = None opening_statement: Optional[str] = None @@ -200,10 +207,12 @@ class AppAdditionalFeatures(BaseModel): text_to_speech: Optional[TextToSpeechEntity] = None trace_config: Optional[TracingConfigEntity] = None + class AppConfig(BaseModel): """ Application Config Entity. """ + tenant_id: str app_id: str app_mode: AppMode @@ -216,15 +225,17 @@ class EasyUIBasedAppModelConfigFrom(Enum): """ App Model Config From. """ - ARGS = 'args' - APP_LATEST_CONFIG = 'app-latest-config' - CONVERSATION_SPECIFIC_CONFIG = 'conversation-specific-config' + + ARGS = "args" + APP_LATEST_CONFIG = "app-latest-config" + CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config" class EasyUIBasedAppConfig(AppConfig): """ Easy UI Based App Config Entity. """ + app_model_config_from: EasyUIBasedAppModelConfigFrom app_model_config_id: str app_model_config_dict: dict @@ -238,4 +249,5 @@ class WorkflowUIBasedAppConfig(AppConfig): """ Workflow UI Based App Config Entity. """ + workflow_id: str diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 3da3c2eddb..5f7fc99151 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -13,21 +13,19 @@ class FileUploadConfigManager: :param config: model config args :param is_vision: if True, the feature is vision feature """ - file_upload_dict = config.get('file_upload') + file_upload_dict = config.get("file_upload") if file_upload_dict: - if file_upload_dict.get('image'): - if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']: + if file_upload_dict.get("image"): + if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]: image_config = { - 'number_limits': file_upload_dict['image']['number_limits'], - 'transfer_methods': file_upload_dict['image']['transfer_methods'] + "number_limits": file_upload_dict["image"]["number_limits"], + "transfer_methods": file_upload_dict["image"]["transfer_methods"], } if is_vision: - image_config['detail'] = file_upload_dict['image']['detail'] + image_config["detail"] = file_upload_dict["image"]["detail"] - return FileExtraConfig( - image_config=image_config - ) + return FileExtraConfig(image_config=image_config) return None @@ -49,21 +47,21 @@ class FileUploadConfigManager: if not config["file_upload"].get("image"): config["file_upload"]["image"] = {"enabled": False} - if config['file_upload']['image']['enabled']: - number_limits = config['file_upload']['image']['number_limits'] + if config["file_upload"]["image"]["enabled"]: + number_limits = config["file_upload"]["image"]["number_limits"] if number_limits < 1 or number_limits > 6: raise ValueError("number_limits must be in [1, 6]") if is_vision: - detail = config['file_upload']['image']['detail'] - if detail not in ['high', 'low']: + detail = config["file_upload"]["image"]["detail"] + if detail not in ["high", "low"]: raise ValueError("detail must be in ['high', 'low']") - transfer_methods = config['file_upload']['image']['transfer_methods'] + transfer_methods = config["file_upload"]["image"]["transfer_methods"] if not isinstance(transfer_methods, list): raise ValueError("transfer_methods must be of list type") for method in transfer_methods: - if method not in ['remote_url', 'local_file']: + if method not in ["remote_url", "local_file"]: raise ValueError("transfer_methods must be in ['remote_url', 'local_file']") return config, ["file_upload"] diff --git a/api/core/app/app_config/features/more_like_this/manager.py b/api/core/app/app_config/features/more_like_this/manager.py index 2ba99a5c40..496e1beeec 100644 --- a/api/core/app/app_config/features/more_like_this/manager.py +++ b/api/core/app/app_config/features/more_like_this/manager.py @@ -7,9 +7,9 @@ class MoreLikeThisConfigManager: :param config: model config args """ more_like_this = False - more_like_this_dict = config.get('more_like_this') + more_like_this_dict = config.get("more_like_this") if more_like_this_dict: - if more_like_this_dict.get('enabled'): + if more_like_this_dict.get("enabled"): more_like_this = True return more_like_this @@ -22,9 +22,7 @@ class MoreLikeThisConfigManager: :param config: app model config args """ if not config.get("more_like_this"): - config["more_like_this"] = { - "enabled": False - } + config["more_like_this"] = {"enabled": False} if not isinstance(config["more_like_this"], dict): raise ValueError("more_like_this must be of dict type") diff --git a/api/core/app/app_config/features/opening_statement/manager.py b/api/core/app/app_config/features/opening_statement/manager.py index 0d8a71bfcf..b4dacbc409 100644 --- a/api/core/app/app_config/features/opening_statement/manager.py +++ b/api/core/app/app_config/features/opening_statement/manager.py @@ -1,5 +1,3 @@ - - class OpeningStatementConfigManager: @classmethod def convert(cls, config: dict) -> tuple[str, list]: @@ -9,10 +7,10 @@ class OpeningStatementConfigManager: :param config: model config args """ # opening statement - opening_statement = config.get('opening_statement') + opening_statement = config.get("opening_statement") # suggested questions - suggested_questions_list = config.get('suggested_questions') + suggested_questions_list = config.get("suggested_questions") return opening_statement, suggested_questions_list diff --git a/api/core/app/app_config/features/retrieval_resource/manager.py b/api/core/app/app_config/features/retrieval_resource/manager.py index fca58e12e8..d098abac2f 100644 --- a/api/core/app/app_config/features/retrieval_resource/manager.py +++ b/api/core/app/app_config/features/retrieval_resource/manager.py @@ -2,9 +2,9 @@ class RetrievalResourceConfigManager: @classmethod def convert(cls, config: dict) -> bool: show_retrieve_source = False - retriever_resource_dict = config.get('retriever_resource') + retriever_resource_dict = config.get("retriever_resource") if retriever_resource_dict: - if retriever_resource_dict.get('enabled'): + if retriever_resource_dict.get("enabled"): show_retrieve_source = True return show_retrieve_source @@ -17,9 +17,7 @@ class RetrievalResourceConfigManager: :param config: app model config args """ if not config.get("retriever_resource"): - config["retriever_resource"] = { - "enabled": False - } + config["retriever_resource"] = {"enabled": False} if not isinstance(config["retriever_resource"], dict): raise ValueError("retriever_resource must be of dict type") diff --git a/api/core/app/app_config/features/speech_to_text/manager.py b/api/core/app/app_config/features/speech_to_text/manager.py index 88b4be25d3..e10ae03e04 100644 --- a/api/core/app/app_config/features/speech_to_text/manager.py +++ b/api/core/app/app_config/features/speech_to_text/manager.py @@ -7,9 +7,9 @@ class SpeechToTextConfigManager: :param config: model config args """ speech_to_text = False - speech_to_text_dict = config.get('speech_to_text') + speech_to_text_dict = config.get("speech_to_text") if speech_to_text_dict: - if speech_to_text_dict.get('enabled'): + if speech_to_text_dict.get("enabled"): speech_to_text = True return speech_to_text @@ -22,9 +22,7 @@ class SpeechToTextConfigManager: :param config: app model config args """ if not config.get("speech_to_text"): - config["speech_to_text"] = { - "enabled": False - } + config["speech_to_text"] = {"enabled": False} if not isinstance(config["speech_to_text"], dict): raise ValueError("speech_to_text must be of dict type") diff --git a/api/core/app/app_config/features/suggested_questions_after_answer/manager.py b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py index c6cab01220..9ac5114d12 100644 --- a/api/core/app/app_config/features/suggested_questions_after_answer/manager.py +++ b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py @@ -7,9 +7,9 @@ class SuggestedQuestionsAfterAnswerConfigManager: :param config: model config args """ suggested_questions_after_answer = False - suggested_questions_after_answer_dict = config.get('suggested_questions_after_answer') + suggested_questions_after_answer_dict = config.get("suggested_questions_after_answer") if suggested_questions_after_answer_dict: - if suggested_questions_after_answer_dict.get('enabled'): + if suggested_questions_after_answer_dict.get("enabled"): suggested_questions_after_answer = True return suggested_questions_after_answer @@ -22,15 +22,15 @@ class SuggestedQuestionsAfterAnswerConfigManager: :param config: app model config args """ if not config.get("suggested_questions_after_answer"): - config["suggested_questions_after_answer"] = { - "enabled": False - } + config["suggested_questions_after_answer"] = {"enabled": False} if not isinstance(config["suggested_questions_after_answer"], dict): raise ValueError("suggested_questions_after_answer must be of dict type") - if "enabled" not in config["suggested_questions_after_answer"] or not \ - config["suggested_questions_after_answer"]["enabled"]: + if ( + "enabled" not in config["suggested_questions_after_answer"] + or not config["suggested_questions_after_answer"]["enabled"] + ): config["suggested_questions_after_answer"]["enabled"] = False if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool): diff --git a/api/core/app/app_config/features/text_to_speech/manager.py b/api/core/app/app_config/features/text_to_speech/manager.py index f11e268e73..1c75981785 100644 --- a/api/core/app/app_config/features/text_to_speech/manager.py +++ b/api/core/app/app_config/features/text_to_speech/manager.py @@ -10,13 +10,13 @@ class TextToSpeechConfigManager: :param config: model config args """ text_to_speech = None - text_to_speech_dict = config.get('text_to_speech') + text_to_speech_dict = config.get("text_to_speech") if text_to_speech_dict: - if text_to_speech_dict.get('enabled'): + if text_to_speech_dict.get("enabled"): text_to_speech = TextToSpeechEntity( - enabled=text_to_speech_dict.get('enabled'), - voice=text_to_speech_dict.get('voice'), - language=text_to_speech_dict.get('language'), + enabled=text_to_speech_dict.get("enabled"), + voice=text_to_speech_dict.get("voice"), + language=text_to_speech_dict.get("language"), ) return text_to_speech @@ -29,11 +29,7 @@ class TextToSpeechConfigManager: :param config: app model config args """ if not config.get("text_to_speech"): - config["text_to_speech"] = { - "enabled": False, - "voice": "", - "language": "" - } + config["text_to_speech"] = {"enabled": False, "voice": "", "language": ""} if not isinstance(config["text_to_speech"], dict): raise ValueError("text_to_speech must be of dict type") diff --git a/api/core/app/apps/advanced_chat/app_config_manager.py b/api/core/app/apps/advanced_chat/app_config_manager.py index c3d0e8ba03..b52f235849 100644 --- a/api/core/app/apps/advanced_chat/app_config_manager.py +++ b/api/core/app/apps/advanced_chat/app_config_manager.py @@ -1,4 +1,3 @@ - from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.entities import WorkflowUIBasedAppConfig @@ -19,13 +18,13 @@ class AdvancedChatAppConfig(WorkflowUIBasedAppConfig): """ Advanced Chatbot App Config Entity. """ + pass class AdvancedChatAppConfigManager(BaseAppConfigManager): @classmethod - def get_app_config(cls, app_model: App, - workflow: Workflow) -> AdvancedChatAppConfig: + def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig: features_dict = workflow.features_dict app_mode = AppMode.value_of(app_model.mode) @@ -34,13 +33,9 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager): app_id=app_model.id, app_mode=app_mode, workflow_id=workflow.id, - sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( - config=features_dict - ), - variables=WorkflowVariablesConfigManager.convert( - workflow=workflow - ), - additional_features=cls.convert_features(features_dict, app_mode) + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict), + variables=WorkflowVariablesConfigManager.convert(workflow=workflow), + additional_features=cls.convert_features(features_dict, app_mode), ) return app_config @@ -58,8 +53,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager): # file upload validation config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults( - config=config, - is_vision=False + config=config, is_vision=False ) related_config_keys.extend(current_related_config_keys) @@ -69,7 +63,8 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager): # suggested_questions_after_answer config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( - config) + config + ) related_config_keys.extend(current_related_config_keys) # speech_to_text @@ -86,9 +81,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager): # moderation validation config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( - tenant_id=tenant_id, - config=config, - only_structure_validate=only_structure_validate + tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate ) related_config_keys.extend(current_related_config_keys) @@ -98,4 +91,3 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager): filtered_config = {key: config.get(key) for key in related_config_keys} return filtered_config - diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 5a1e5973cd..88e1256ed5 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -4,12 +4,10 @@ import os import threading import uuid from collections.abc import Generator -from typing import Union +from typing import Any, Literal, Optional, Union, overload from flask import Flask, current_app from pydantic import ValidationError -from sqlalchemy import select -from sqlalchemy.orm import Session import contexts from core.app.app_config.features.file_upload.manager import FileUploadConfigManager @@ -17,36 +15,54 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager -from core.app.entities.app_invoke_entities import ( - AdvancedChatAppGenerateEntity, - InvokeFrom, -) +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey from extensions.ext_database import db from models.account import Account from models.model import App, Conversation, EndUser, Message -from models.workflow import ConversationVariable, Workflow +from models.workflow import Workflow logger = logging.getLogger(__name__) class AdvancedChatAppGenerator(MessageBasedAppGenerator): + @overload def generate( - self, app_model: App, + self, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[True] = True, + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[False] = False, + ) -> dict: ... + + def generate( + self, + app_model: App, workflow: Workflow, user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom, stream: bool = True, - ): + ) -> dict[str, Any] | Generator[str, Any, None]: """ Generate App response. @@ -57,44 +73,37 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): :param invoke_from: invoke from source :param stream: is stream """ - if not args.get('query'): - raise ValueError('query is required') + if not args.get("query"): + raise ValueError("query is required") - query = args['query'] + query = args["query"] if not isinstance(query, str): - raise ValueError('query must be a string') + raise ValueError("query must be a string") - query = query.replace('\x00', '') - inputs = args['inputs'] + query = query.replace("\x00", "") + inputs = args["inputs"] - extras = { - "auto_generate_conversation_name": args.get('auto_generate_name', False) - } + extras = {"auto_generate_conversation_name": args.get("auto_generate_name", False)} # get conversation conversation = None - conversation_id = args.get('conversation_id') + conversation_id = args.get("conversation_id") if conversation_id: - conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user) + conversation = self._get_conversation_by_user( + app_model=app_model, conversation_id=conversation_id, user=user + ) # parse files - files = args['files'] if args.get('files') else [] + files = args["files"] if args.get("files") else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_extra_config, - user - ) + file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) else: file_objs = [] # convert to app config - app_config = AdvancedChatAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow - ) + app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) # get tracing instance user_id = user.id if isinstance(user, Account) else user.session_id @@ -116,7 +125,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): stream=stream, invoke_from=invoke_from, extras=extras, - trace_manager=trace_manager + trace_manager=trace_manager, ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) @@ -126,15 +135,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): invoke_from=invoke_from, application_generate_entity=application_generate_entity, conversation=conversation, - stream=stream + stream=stream, ) - def single_iteration_generate(self, app_model: App, - workflow: Workflow, - node_id: str, - user: Account, - args: dict, - stream: bool = True): + def single_iteration_generate( + self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True + ) -> dict[str, Any] | Generator[str, Any, None]: """ Generate App response. @@ -146,43 +152,29 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): :param stream: is stream """ if not node_id: - raise ValueError('node_id is required') + raise ValueError("node_id is required") - if args.get('inputs') is None: - raise ValueError('inputs is required') - - extras = { - "auto_generate_conversation_name": False - } - - # get conversation - conversation = None - conversation_id = args.get('conversation_id') - if conversation_id: - conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user) + if args.get("inputs") is None: + raise ValueError("inputs is required") # convert to app config - app_config = AdvancedChatAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow - ) + app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) # init application generate entity application_generate_entity = AdvancedChatAppGenerateEntity( task_id=str(uuid.uuid4()), app_config=app_config, - conversation_id=conversation.id if conversation else None, + conversation_id=None, inputs={}, - query='', + query="", files=[], user_id=user.id, stream=stream, invoke_from=InvokeFrom.DEBUGGER, - extras=extras, + extras={"auto_generate_conversation_name": False}, single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity( - node_id=node_id, - inputs=args['inputs'] - ) + node_id=node_id, inputs=args["inputs"] + ), ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) @@ -191,32 +183,42 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, - conversation=conversation, - stream=stream + conversation=None, + stream=stream, ) - def _generate(self, *, - workflow: Workflow, - user: Union[Account, EndUser], - invoke_from: InvokeFrom, - application_generate_entity: AdvancedChatAppGenerateEntity, - conversation: Conversation | None = None, - stream: bool = True): + def _generate( + self, + *, + workflow: Workflow, + user: Union[Account, EndUser], + invoke_from: InvokeFrom, + application_generate_entity: AdvancedChatAppGenerateEntity, + conversation: Optional[Conversation] = None, + stream: bool = True, + ) -> dict[str, Any] | Generator[str, Any, None]: + """ + Generate App response. + + :param workflow: Workflow + :param user: account or end user + :param invoke_from: invoke from source + :param application_generate_entity: application generate entity + :param conversation: conversation + :param stream: is stream + """ is_first_conversation = False if not conversation: is_first_conversation = True # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity, conversation) + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) if is_first_conversation: # update conversation features conversation.override_model_configs = workflow.features db.session.commit() - # db.session.refresh(conversation) + db.session.refresh(conversation) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -225,73 +227,21 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, app_mode=conversation.mode, - message_id=message.id + message_id=message.id, ) - # Init conversation variables - stmt = select(ConversationVariable).where( - ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id - ) - with Session(db.engine) as session: - conversation_variables = session.scalars(stmt).all() - if not conversation_variables: - # Create conversation variables if they don't exist. - conversation_variables = [ - ConversationVariable.from_variable( - app_id=conversation.app_id, conversation_id=conversation.id, variable=variable - ) - for variable in workflow.conversation_variables - ] - session.add_all(conversation_variables) - # Convert database entities to variables. - conversation_variables = [item.to_variable() for item in conversation_variables] - - session.commit() - - # Increment dialogue count. - conversation.dialogue_count += 1 - - conversation_id = conversation.id - conversation_dialogue_count = conversation.dialogue_count - db.session.commit() - db.session.refresh(conversation) - - inputs = application_generate_entity.inputs - query = application_generate_entity.query - files = application_generate_entity.files - - user_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() - if end_user: - user_id = end_user.session_id - else: - user_id = application_generate_entity.user_id - - # Create a variable pool. - system_inputs = { - SystemVariableKey.QUERY: query, - SystemVariableKey.FILES: files, - SystemVariableKey.CONVERSATION_ID: conversation_id, - SystemVariableKey.USER_ID: user_id, - SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count, - } - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=inputs, - environment_variables=workflow.environment_variables, - conversation_variables=conversation_variables, - ) - contexts.workflow_variable_pool.set(variable_pool) - # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'message_id': message.id, - 'context': contextvars.copy_context(), - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + "context": contextvars.copy_context(), + }, + ) worker_thread.start() @@ -306,16 +256,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): stream=stream, ) - return AdvancedChatAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) - def _generate_worker(self, flask_app: Flask, - application_generate_entity: AdvancedChatAppGenerateEntity, - queue_manager: AppQueueManager, - message_id: str, - context: contextvars.Context) -> None: + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: AdvancedChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation_id: str, + message_id: str, + context: contextvars.Context, + ) -> None: """ Generate worker in a new thread. :param flask_app: Flask app @@ -329,40 +280,30 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): var.set(val) with flask_app.app_context(): try: - runner = AdvancedChatAppRunner() - if application_generate_entity.single_iteration_run: - single_iteration_run = application_generate_entity.single_iteration_run - runner.single_iteration_run( - app_id=application_generate_entity.app_config.app_id, - workflow_id=application_generate_entity.app_config.workflow_id, - queue_manager=queue_manager, - inputs=single_iteration_run.inputs, - node_id=single_iteration_run.node_id, - user_id=application_generate_entity.user_id - ) - else: - # get message - message = self._get_message(message_id) + # get conversation and message + conversation = self._get_conversation(conversation_id) + message = self._get_message(message_id) - # chatbot app - runner = AdvancedChatAppRunner() - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - message=message - ) - except GenerateTaskStoppedException: + # chatbot app + runner = AdvancedChatAppRunner( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + ) + + runner.run() + except GenerateTaskStoppedError: pass except InvokeAuthorizationError: queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER ) except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG", "false").lower() == 'true': + if os.environ.get("DEBUG", "false").lower() == "true": logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: @@ -408,7 +349,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): return generate_task_pipeline.process() except ValueError as e: if e.args[0] == "I/O operation on closed file.": # ignore this error - raise GenerateTaskStoppedException() + raise GenerateTaskStoppedError() else: logger.exception(e) raise e diff --git a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py index 0caff4a2e3..18b115dfe4 100644 --- a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py +++ b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py @@ -21,14 +21,11 @@ class AudioTrunk: self.status = status -def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str): +def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str): if not text_content or text_content.isspace(): return return model_instance.invoke_tts( - content_text=text_content.strip(), - user="responding_tts", - tenant_id=tenant_id, - voice=voice + content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice ) @@ -44,28 +41,26 @@ def _process_future(future_queue, audio_queue): except Exception as e: logging.getLogger(__name__).warning(e) break - audio_queue.put(AudioTrunk("finish", b'')) + audio_queue.put(AudioTrunk("finish", b"")) class AppGeneratorTTSPublisher: - def __init__(self, tenant_id: str, voice: str): self.logger = logging.getLogger(__name__) self.tenant_id = tenant_id - self.msg_text = '' + self.msg_text = "" self._audio_queue = queue.Queue() self._msg_queue = queue.Queue() - self.match = re.compile(r'[。.!?]') + self.match = re.compile(r"[。.!?]") self.model_manager = ModelManager() self.model_instance = self.model_manager.get_default_model_instance( - tenant_id=self.tenant_id, - model_type=ModelType.TTS + tenant_id=self.tenant_id, model_type=ModelType.TTS ) self.voices = self.model_instance.get_tts_voices() - values = [voice.get('value') for voice in self.voices] + values = [voice.get("value") for voice in self.voices] self.voice = voice if not voice or voice not in values: - self.voice = self.voices[0].get('value') + self.voice = self.voices[0].get("value") self.MAX_SENTENCE = 2 self._last_audio_event = None self._runtime_thread = threading.Thread(target=self._runtime).start() @@ -85,8 +80,9 @@ class AppGeneratorTTSPublisher: message = self._msg_queue.get() if message is None: if self.msg_text and len(self.msg_text.strip()) > 0: - futures_result = self.executor.submit(_invoiceTTS, self.msg_text, - self.model_instance, self.tenant_id, self.voice) + futures_result = self.executor.submit( + _invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice + ) future_queue.put(futures_result) break elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent): @@ -94,28 +90,27 @@ class AppGeneratorTTSPublisher: elif isinstance(message.event, QueueTextChunkEvent): self.msg_text += message.event.text elif isinstance(message.event, QueueNodeSucceededEvent): - self.msg_text += message.event.outputs.get('output', '') + self.msg_text += message.event.outputs.get("output", "") self.last_message = message sentence_arr, text_tmp = self._extract_sentence(self.msg_text) if len(sentence_arr) >= min(self.MAX_SENTENCE, 7): self.MAX_SENTENCE += 1 - text_content = ''.join(sentence_arr) - futures_result = self.executor.submit(_invoiceTTS, text_content, - self.model_instance, - self.tenant_id, - self.voice) + text_content = "".join(sentence_arr) + futures_result = self.executor.submit( + _invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice + ) future_queue.put(futures_result) if text_tmp: self.msg_text = text_tmp else: - self.msg_text = '' + self.msg_text = "" except Exception as e: self.logger.warning(e) break future_queue.put(None) - def checkAndGetAudio(self) -> AudioTrunk | None: + def check_and_get_audio(self) -> AudioTrunk | None: try: if self._last_audio_event and self._last_audio_event.status == "finish": if self.executor: diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 5dc03979cf..c4cdba6441 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -1,145 +1,197 @@ import logging import os -import time from collections.abc import Mapping -from typing import Any, Optional, cast +from typing import Any, cast + +from sqlalchemy import select +from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig -from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback -from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.apps.base_app_runner import AppRunner +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, InvokeFrom, ) -from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent -from core.moderation.base import ModerationException +from core.app.entities.queue_entities import ( + QueueAnnotationReplyEvent, + QueueStopEvent, + QueueTextChunkEvent, +) +from core.moderation.base import ModerationError from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.nodes.base_node import UserFrom -from core.workflow.workflow_engine_manager import WorkflowEngineManager +from core.workflow.entities.node_entities import UserFrom +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db -from models import App, Message, Workflow +from models.model import App, Conversation, EndUser, Message +from models.workflow import ConversationVariable, WorkflowType logger = logging.getLogger(__name__) -class AdvancedChatAppRunner(AppRunner): +class AdvancedChatAppRunner(WorkflowBasedAppRunner): """ AdvancedChat Application Runner """ - def run( + def __init__( self, application_generate_entity: AdvancedChatAppGenerateEntity, queue_manager: AppQueueManager, + conversation: Conversation, message: Message, ) -> None: """ - Run application :param application_generate_entity: application generate entity :param queue_manager: application queue manager :param conversation: conversation :param message: message + """ + super().__init__(queue_manager) + + self.application_generate_entity = application_generate_entity + self.conversation = conversation + self.message = message + + def run(self) -> None: + """ + Run application :return: """ - app_config = application_generate_entity.app_config + app_config = self.application_generate_entity.app_config app_config = cast(AdvancedChatAppConfig, app_config) app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: - raise ValueError('App not found') + raise ValueError("App not found") workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) if not workflow: - raise ValueError('Workflow not initialized') + raise ValueError("Workflow not initialized") - inputs = application_generate_entity.inputs - query = application_generate_entity.query + user_id = None + if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() + if end_user: + user_id = end_user.session_id + else: + user_id = self.application_generate_entity.user_id - # moderation - if self.handle_input_moderation( - queue_manager=queue_manager, - app_record=app_record, - app_generate_entity=application_generate_entity, - inputs=inputs, - query=query, - message_id=message.id, - ): - return + workflow_callbacks: list[WorkflowCallback] = [] + if bool(os.environ.get("DEBUG", "False").lower() == "true"): + workflow_callbacks.append(WorkflowLoggingCallback()) - # annotation reply - if self.handle_annotation_reply( - app_record=app_record, - message=message, - query=query, - queue_manager=queue_manager, - app_generate_entity=application_generate_entity, - ): - return + if self.application_generate_entity.single_iteration_run: + # if only single iteration run is requested + graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( + workflow=workflow, + node_id=self.application_generate_entity.single_iteration_run.node_id, + user_inputs=self.application_generate_entity.single_iteration_run.inputs, + ) + else: + inputs = self.application_generate_entity.inputs + query = self.application_generate_entity.query + files = self.application_generate_entity.files + + # moderation + if self.handle_input_moderation( + app_record=app_record, + app_generate_entity=self.application_generate_entity, + inputs=inputs, + query=query, + message_id=self.message.id, + ): + return + + # annotation reply + if self.handle_annotation_reply( + app_record=app_record, + message=self.message, + query=query, + app_generate_entity=self.application_generate_entity, + ): + return + + # Init conversation variables + stmt = select(ConversationVariable).where( + ConversationVariable.app_id == self.conversation.app_id, + ConversationVariable.conversation_id == self.conversation.id, + ) + with Session(db.engine) as session: + conversation_variables = session.scalars(stmt).all() + if not conversation_variables: + # Create conversation variables if they don't exist. + conversation_variables = [ + ConversationVariable.from_variable( + app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable + ) + for variable in workflow.conversation_variables + ] + session.add_all(conversation_variables) + # Convert database entities to variables. + conversation_variables = [item.to_variable() for item in conversation_variables] + + session.commit() + + # Increment dialogue count. + self.conversation.dialogue_count += 1 + + conversation_dialogue_count = self.conversation.dialogue_count + db.session.commit() + + # Create a variable pool. + system_inputs = { + SystemVariableKey.QUERY: query, + SystemVariableKey.FILES: files, + SystemVariableKey.CONVERSATION_ID: self.conversation.id, + SystemVariableKey.USER_ID: user_id, + SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count, + } + + # init variable pool + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=inputs, + environment_variables=workflow.environment_variables, + conversation_variables=conversation_variables, + ) + + # init graph + graph = self._init_graph(graph_config=workflow.graph_dict) db.session.close() - workflow_callbacks: list[WorkflowCallback] = [ - WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow) - ] - - if bool(os.environ.get('DEBUG', 'False').lower() == 'true'): - workflow_callbacks.append(WorkflowLoggingCallback()) - # RUN WORKFLOW - workflow_engine_manager = WorkflowEngineManager() - workflow_engine_manager.run_workflow( - workflow=workflow, - user_id=application_generate_entity.user_id, - user_from=UserFrom.ACCOUNT - if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] - else UserFrom.END_USER, - invoke_from=application_generate_entity.invoke_from, + workflow_entry = WorkflowEntry( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + workflow_type=WorkflowType.value_of(workflow.type), + graph=graph, + graph_config=workflow.graph_dict, + user_id=self.application_generate_entity.user_id, + user_from=( + UserFrom.ACCOUNT + if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] + else UserFrom.END_USER + ), + invoke_from=self.application_generate_entity.invoke_from, + call_depth=self.application_generate_entity.call_depth, + variable_pool=variable_pool, + ) + + generator = workflow_entry.run( callbacks=workflow_callbacks, - call_depth=application_generate_entity.call_depth, ) - def single_iteration_run( - self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str - ) -> None: - """ - Single iteration run - """ - app_record = db.session.query(App).filter(App.id == app_id).first() - if not app_record: - raise ValueError('App not found') - - workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id) - if not workflow: - raise ValueError('Workflow not initialized') - - workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)] - - workflow_engine_manager = WorkflowEngineManager() - workflow_engine_manager.single_step_run_iteration_workflow_node( - workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks - ) - - def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: - """ - Get workflow - """ - # fetch workflow by workflow_id - workflow = ( - db.session.query(Workflow) - .filter( - Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id - ) - .first() - ) - - # return workflow - return workflow + for event in generator: + self._handle_event(workflow_entry, event) def handle_input_moderation( self, - queue_manager: AppQueueManager, app_record: App, app_generate_entity: AdvancedChatAppGenerateEntity, inputs: Mapping[str, Any], @@ -148,7 +200,6 @@ class AdvancedChatAppRunner(AppRunner): ) -> bool: """ Handle input moderation - :param queue_manager: application queue manager :param app_record: app record :param app_generate_entity: application generate entity :param inputs: inputs @@ -166,31 +217,20 @@ class AdvancedChatAppRunner(AppRunner): query=query, message_id=message_id, ) - except ModerationException as e: - self._stream_output( - queue_manager=queue_manager, - text=str(e), - stream=app_generate_entity.stream, - stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION, - ) + except ModerationError as e: + self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION) return True return False def handle_annotation_reply( - self, - app_record: App, - message: Message, - query: str, - queue_manager: AppQueueManager, - app_generate_entity: AdvancedChatAppGenerateEntity, + self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity ) -> bool: """ Handle annotation reply :param app_record: app record :param message: message :param query: query - :param queue_manager: application queue manager :param app_generate_entity: application generate entity """ # annotation reply @@ -203,37 +243,21 @@ class AdvancedChatAppRunner(AppRunner): ) if annotation_reply: - queue_manager.publish( - QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), PublishFrom.APPLICATION_MANAGER - ) + self._publish_event(QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id)) - self._stream_output( - queue_manager=queue_manager, - text=annotation_reply.content, - stream=app_generate_entity.stream, - stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY, + self._complete_with_stream_output( + text=annotation_reply.content, stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY ) return True return False - def _stream_output( - self, queue_manager: AppQueueManager, text: str, stream: bool, stopped_by: QueueStopEvent.StopBy - ) -> None: + def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None: """ Direct output - :param queue_manager: application queue manager :param text: text - :param stream: stream :return: """ - if stream: - index = 0 - for token in text: - queue_manager.publish(QueueTextChunkEvent(text=token), PublishFrom.APPLICATION_MANAGER) - index += 1 - time.sleep(0.01) - else: - queue_manager.publish(QueueTextChunkEvent(text=text), PublishFrom.APPLICATION_MANAGER) + self._publish_event(QueueTextChunkEvent(text=text)) - queue_manager.publish(QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER) + self._publish_event(QueueStopEvent(stopped_by=stopped_by)) diff --git a/api/core/app/apps/advanced_chat/generate_response_converter.py b/api/core/app/apps/advanced_chat/generate_response_converter.py index ef579827b4..5fbd3e9a94 100644 --- a/api/core/app/apps/advanced_chat/generate_response_converter.py +++ b/api/core/app/apps/advanced_chat/generate_response_converter.py @@ -28,15 +28,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): """ blocking_response = cast(ChatbotAppBlockingResponse, blocking_response) response = { - 'event': 'message', - 'task_id': blocking_response.task_id, - 'id': blocking_response.data.id, - 'message_id': blocking_response.data.message_id, - 'conversation_id': blocking_response.data.conversation_id, - 'mode': blocking_response.data.mode, - 'answer': blocking_response.data.answer, - 'metadata': blocking_response.data.metadata, - 'created_at': blocking_response.data.created_at + "event": "message", + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "conversation_id": blocking_response.data.conversation_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, } return response @@ -50,13 +50,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): """ response = cls.convert_blocking_full_response(blocking_response) - metadata = response.get('metadata', {}) - response['metadata'] = cls._get_simple_metadata(metadata) + metadata = response.get("metadata", {}) + response["metadata"] = cls._get_simple_metadata(metadata) return response @classmethod - def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]: + def convert_stream_full_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[str, Any, None]: """ Convert stream full response. :param stream_response: stream response @@ -67,14 +69,14 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, ErrorStreamResponse): @@ -85,7 +87,9 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield json.dumps(response_chunk) @classmethod - def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]: + def convert_stream_simple_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[str, Any, None]: """ Convert stream simple response. :param stream_response: stream response @@ -96,20 +100,20 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, MessageEndStreamResponse): sub_stream_response_dict = sub_stream_response.to_dict() - metadata = sub_stream_response_dict.get('metadata', {}) - sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) + metadata = sub_stream_response_dict.get("metadata", {}) + sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 2b3596ded2..94206a1b1c 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -2,9 +2,8 @@ import json import logging import time from collections.abc import Generator -from typing import Any, Optional, Union, cast +from typing import Any, Optional, Union -import contexts from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -22,6 +21,9 @@ from core.app.entities.queue_entities import ( QueueNodeFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, + QueueParallelBranchRunFailedEvent, + QueueParallelBranchRunStartedEvent, + QueueParallelBranchRunSucceededEvent, QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent, @@ -31,34 +33,28 @@ from core.app.entities.queue_entities import ( QueueWorkflowSucceededEvent, ) from core.app.entities.task_entities import ( - AdvancedChatTaskState, ChatbotAppBlockingResponse, ChatbotAppStreamResponse, - ChatflowStreamGenerateRoute, ErrorStreamResponse, MessageAudioEndStreamResponse, MessageAudioStreamResponse, MessageEndStreamResponse, StreamResponse, + WorkflowTaskState, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.message_cycle_manage import MessageCycleManage from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage -from core.file.file_obj import FileVar -from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.node_entities import NodeType from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from events.message_event import message_was_created from extensions.ext_database import db from models.account import Account from models.model import Conversation, EndUser, Message from models.workflow import ( Workflow, - WorkflowNodeExecution, WorkflowRunStatus, ) @@ -69,22 +65,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc """ AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ - _task_state: AdvancedChatTaskState + + _task_state: WorkflowTaskState _application_generate_entity: AdvancedChatAppGenerateEntity _workflow: Workflow _user: Union[Account, EndUser] - # Deprecated _workflow_system_variables: dict[SystemVariableKey, Any] - _iteration_nested_relations: dict[str, list[str]] def __init__( - self, application_generate_entity: AdvancedChatAppGenerateEntity, - workflow: Workflow, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - user: Union[Account, EndUser], - stream: bool, + self, + application_generate_entity: AdvancedChatAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool, ) -> None: """ Initialize AdvancedChatAppGenerateTaskPipeline. @@ -106,7 +102,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._workflow = workflow self._conversation = conversation self._message = message - # Deprecated self._workflow_system_variables = { SystemVariableKey.QUERY: message.query, SystemVariableKey.FILES: application_generate_entity.files, @@ -114,12 +109,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc SystemVariableKey.USER_ID: user_id, } - self._task_state = AdvancedChatTaskState( - usage=LLMUsage.empty_usage() - ) + self._task_state = WorkflowTaskState() - self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict) - self._stream_generate_routes = self._get_stream_generate_routes() self._conversation_name_generate_thread = None def process(self): @@ -133,13 +124,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc # start generate conversation name thread self._conversation_name_generate_thread = self._generate_conversation_name( - self._conversation, - self._application_generate_entity.query + self._conversation, self._application_generate_entity.query ) - generator = self._wrapper_process_stream_response( - trace_manager=self._application_generate_entity.trace_manager - ) + generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) + if self._stream: return self._to_stream_response(generator) else: @@ -156,7 +145,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc elif isinstance(stream_response, MessageEndStreamResponse): extras = {} if stream_response.metadata: - extras['metadata'] = stream_response.metadata + extras["metadata"] = stream_response.metadata return ChatbotAppBlockingResponse( task_id=stream_response.task_id, @@ -167,15 +156,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc message_id=self._message.id, answer=self._task_state.answer, created_at=int(self._message.created_at.timestamp()), - **extras - ) + **extras, + ), ) else: continue - raise Exception('Queue listening stopped unexpectedly.') + raise Exception("Queue listening stopped unexpectedly.") - def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]: + def _to_stream_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Generator[ChatbotAppStreamResponse, Any, None]: """ To stream response. :return: @@ -185,31 +176,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc conversation_id=self._conversation.id, message_id=self._message.id, created_at=int(self._message.created_at.timestamp()), - stream_response=stream_response + stream_response=stream_response, ) - def _listenAudioMsg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher, task_id: str): if not publisher: return None - audio_msg: AudioTrunk = publisher.checkAndGetAudio() + audio_msg: AudioTrunk = publisher.check_and_get_audio() if audio_msg and audio_msg.status != "finish": return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None - def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ - Generator[StreamResponse, None, None]: - - publisher = None + def _wrapper_process_stream_response( + self, trace_manager: Optional[TraceQueueManager] = None + ) -> Generator[StreamResponse, None, None]: + tts_publisher = None task_id = self._application_generate_entity.task_id tenant_id = self._application_generate_entity.app_config.tenant_id features_dict = self._workflow.features_dict - if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[ - 'text_to_speech'].get('autoPlay') == 'enabled': - publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) - for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): + if ( + features_dict.get("text_to_speech") + and features_dict["text_to_speech"].get("enabled") + and features_dict["text_to_speech"].get("autoPlay") == "enabled" + ): + tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice")) + + for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listenAudioMsg(publisher, task_id=task_id) + audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -220,9 +215,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc # timeout while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: try: - if not publisher: + if not tts_publisher: break - audio_trunk = publisher.checkAndGetAudio() + audio_trunk = tts_publisher.check_and_get_audio() if audio_trunk is None: # release cpu # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) @@ -236,38 +231,38 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc except Exception as e: logger.error(e) break - yield MessageAudioEndStreamResponse(audio='', task_id=task_id) + yield MessageAudioEndStreamResponse(audio="", task_id=task_id) def _process_stream_response( - self, - publisher: AppGeneratorTTSPublisher, - trace_manager: Optional[TraceQueueManager] = None + self, + tts_publisher: Optional[AppGeneratorTTSPublisher] = None, + trace_manager: Optional[TraceQueueManager] = None, ) -> Generator[StreamResponse, None, None]: """ Process stream response. :return: """ - for message in self._queue_manager.listen(): - if (message.event - and getattr(message.event, 'metadata', None) - and message.event.metadata.get('is_answer_previous_node', False) - and publisher): - publisher.publish(message=message) - elif (hasattr(message.event, 'execution_metadata') - and message.event.execution_metadata - and message.event.execution_metadata.get('is_answer_previous_node', False) - and publisher): - publisher.publish(message=message) - event = message.event + # init fake graph runtime state + graph_runtime_state = None + workflow_run = None - if isinstance(event, QueueErrorEvent): + for queue_message in self._queue_manager.listen(): + event = queue_message.event + + if isinstance(event, QueuePingEvent): + yield self._ping_stream_response() + elif isinstance(event, QueueErrorEvent): err = self._handle_error(event, self._message) yield self._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): - workflow_run = self._handle_workflow_start() + # override graph runtime state + graph_runtime_state = event.graph_runtime_state - self._message = db.session.query(Message).filter(Message.id == self._message.id).first() + # init workflow run + workflow_run = self._handle_workflow_run_start() + + self._refetch_message() self._message.workflow_run_id = workflow_run.id db.session.commit() @@ -275,137 +270,231 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc db.session.close() yield self._workflow_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) elif isinstance(event, QueueNodeStartedEvent): - workflow_node_execution = self._handle_node_start(event) + if not workflow_run: + raise Exception("Workflow run not initialized.") - # search stream_generate_routes if node id is answer start at node - if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes: - self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id] - # reset current route position to 0 - self._task_state.current_stream_generate_state.current_route_position = 0 + workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) - # generate stream outputs when node started - yield from self._generate_stream_outputs_when_node_started() - - yield self._workflow_node_start_to_stream_response( + response = self._workflow_node_start_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution + workflow_node_execution=workflow_node_execution, ) - elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): - workflow_node_execution = self._handle_node_finished(event) - # stream outputs when node finished - generator = self._generate_stream_outputs_when_node_finished() - if generator: - yield from generator + if response: + yield response + elif isinstance(event, QueueNodeSucceededEvent): + workflow_node_execution = self._handle_workflow_node_execution_success(event) - yield self._workflow_node_finish_to_stream_response( + response = self._workflow_node_finish_to_stream_response( + event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution + workflow_node_execution=workflow_node_execution, ) - if isinstance(event, QueueNodeFailedEvent): - yield from self._handle_iteration_exception( - task_id=self._application_generate_entity.task_id, - error=f'Child node failed: {event.error}' + if response: + yield response + elif isinstance(event, QueueNodeFailedEvent): + workflow_node_execution = self._handle_workflow_node_execution_failed(event) + + response = self._workflow_node_finish_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + if response: + yield response + elif isinstance(event, QueueParallelBranchRunStartedEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + yield self._workflow_parallel_branch_start_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event + ) + elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + yield self._workflow_parallel_branch_finished_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event + ) + elif isinstance(event, QueueIterationStartEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + yield self._workflow_iteration_start_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event + ) + elif isinstance(event, QueueIterationNextEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + yield self._workflow_iteration_next_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event + ) + elif isinstance(event, QueueIterationCompletedEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + yield self._workflow_iteration_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event + ) + elif isinstance(event, QueueWorkflowSucceededEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + if not graph_runtime_state: + raise Exception("Graph runtime state not initialized.") + + workflow_run = self._handle_workflow_run_success( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=json.dumps(event.outputs) if event.outputs else None, + conversation_id=self._conversation.id, + trace_manager=trace_manager, + ) + + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + + self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) + elif isinstance(event, QueueWorkflowFailedEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + if not graph_runtime_state: + raise Exception("Graph runtime state not initialized.") + + workflow_run = self._handle_workflow_run_failed( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.FAILED, + error=event.error, + conversation_id=self._conversation.id, + trace_manager=trace_manager, + ) + + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + + err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) + yield self._error_to_stream_response(self._handle_error(err_event, self._message)) + break + elif isinstance(event, QueueStopEvent): + if workflow_run and graph_runtime_state: + workflow_run = self._handle_workflow_run_failed( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.STOPPED, + error=event.get_stop_reason(), + conversation_id=self._conversation.id, + trace_manager=trace_manager, ) - elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent): - if isinstance(event, QueueIterationNextEvent): - # clear ran node execution infos of current iteration - iteration_relations = self._iteration_nested_relations.get(event.node_id) - if iteration_relations: - for node_id in iteration_relations: - self._task_state.ran_node_execution_infos.pop(node_id, None) - yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event) - self._handle_iteration_operation(event) - elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - workflow_run = self._handle_workflow_finished( - event, conversation_id=self._conversation.id, trace_manager=trace_manager - ) - if workflow_run: yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) - if workflow_run.status == WorkflowRunStatus.FAILED.value: - err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) - yield self._error_to_stream_response(self._handle_error(err_event, self._message)) - break - - if isinstance(event, QueueStopEvent): - # Save message - self._save_message() - - yield self._message_end_to_stream_response() - break - else: - self._queue_manager.publish( - QueueAdvancedChatMessageEndEvent(), - PublishFrom.TASK_PIPELINE - ) - elif isinstance(event, QueueAdvancedChatMessageEndEvent): - output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) - if output_moderation_answer: - self._task_state.answer = output_moderation_answer - yield self._message_replace_to_stream_response(answer=output_moderation_answer) - # Save message - self._save_message() + self._save_message(graph_runtime_state=graph_runtime_state) yield self._message_end_to_stream_response() + break elif isinstance(event, QueueRetrieverResourcesEvent): self._handle_retriever_resources(event) + + self._refetch_message() + + self._message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) + + db.session.commit() + db.session.refresh(self._message) + db.session.close() elif isinstance(event, QueueAnnotationReplyEvent): self._handle_annotation_reply(event) + + self._refetch_message() + + self._message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) + + db.session.commit() + db.session.refresh(self._message) + db.session.close() elif isinstance(event, QueueTextChunkEvent): delta_text = event.text if delta_text is None: continue - if not self._is_stream_out_support( - event=event - ): - continue - # handle output moderation chunk should_direct_answer = self._handle_output_moderation_chunk(delta_text) if should_direct_answer: continue + # only publish tts message at text chunk streaming + if tts_publisher: + tts_publisher.publish(message=queue_message) + self._task_state.answer += delta_text - yield self._message_to_stream_response(delta_text, self._message.id) + yield self._message_to_stream_response( + answer=delta_text, message_id=self._message.id, from_variable_selector=event.from_variable_selector + ) elif isinstance(event, QueueMessageReplaceEvent): + # published by moderation yield self._message_replace_to_stream_response(answer=event.text) - elif isinstance(event, QueuePingEvent): - yield self._ping_stream_response() + elif isinstance(event, QueueAdvancedChatMessageEndEvent): + if not graph_runtime_state: + raise Exception("Graph runtime state not initialized.") + + output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) + if output_moderation_answer: + self._task_state.answer = output_moderation_answer + yield self._message_replace_to_stream_response(answer=output_moderation_answer) + + # Save message + self._save_message(graph_runtime_state=graph_runtime_state) + + yield self._message_end_to_stream_response() else: continue - if publisher: - publisher.publish(None) + + # publish None when task finished + if tts_publisher: + tts_publisher.publish(None) + if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message(self) -> None: + def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: """ Save message. :return: """ - self._message = db.session.query(Message).filter(Message.id == self._message.id).first() + self._refetch_message() self._message.answer = self._task_state.answer self._message.provider_response_latency = time.perf_counter() - self._start_at - self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ - if self._task_state.metadata else None - - if self._task_state.metadata and self._task_state.metadata.get('usage'): - usage = LLMUsage(**self._task_state.metadata['usage']) + self._message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) + if graph_runtime_state and graph_runtime_state.llm_usage: + usage = graph_runtime_state.llm_usage self._message.message_tokens = usage.prompt_tokens self._message.message_unit_price = usage.prompt_unit_price self._message.message_price_unit = usage.prompt_price_unit @@ -422,7 +511,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc application_generate_entity=self._application_generate_entity, conversation=self._conversation, is_first_message=self._application_generate_entity.conversation_id is None, - extras=self._application_generate_entity.extras + extras=self._application_generate_entity.extras, ) def _message_end_to_stream_response(self) -> MessageEndStreamResponse: @@ -432,331 +521,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc """ extras = {} if self._task_state.metadata: - extras['metadata'] = self._task_state.metadata + extras["metadata"] = self._task_state.metadata.copy() + + if "annotation_reply" in extras["metadata"]: + del extras["metadata"]["annotation_reply"] return MessageEndStreamResponse( - task_id=self._application_generate_entity.task_id, - id=self._message.id, - **extras + task_id=self._application_generate_entity.task_id, id=self._message.id, **extras ) - def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]: - """ - Get stream generate routes. - :return: - """ - # find all answer nodes - graph = self._workflow.graph_dict - answer_node_configs = [ - node for node in graph['nodes'] - if node.get('data', {}).get('type') == NodeType.ANSWER.value - ] - - # parse stream output node value selectors of answer nodes - stream_generate_routes = {} - for node_config in answer_node_configs: - # get generate route for stream output - answer_node_id = node_config['id'] - generate_route = AnswerNode.extract_generate_route_selectors(node_config) - start_node_ids = self._get_answer_start_at_node_ids(graph, answer_node_id) - if not start_node_ids: - continue - - for start_node_id in start_node_ids: - stream_generate_routes[start_node_id] = ChatflowStreamGenerateRoute( - answer_node_id=answer_node_id, - generate_route=generate_route - ) - - return stream_generate_routes - - def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \ - -> list[str]: - """ - Get answer start at node id. - :param graph: graph - :param target_node_id: target node ID - :return: - """ - nodes = graph.get('nodes') - edges = graph.get('edges') - - # fetch all ingoing edges from source node - ingoing_edges = [] - for edge in edges: - if edge.get('target') == target_node_id: - ingoing_edges.append(edge) - - if not ingoing_edges: - # check if it's the first node in the iteration - target_node = next((node for node in nodes if node.get('id') == target_node_id), None) - if not target_node: - return [] - - node_iteration_id = target_node.get('data', {}).get('iteration_id') - # get iteration start node id - for node in nodes: - if node.get('id') == node_iteration_id: - if node.get('data', {}).get('start_node_id') == target_node_id: - return [target_node_id] - - return [] - - start_node_ids = [] - for ingoing_edge in ingoing_edges: - source_node_id = ingoing_edge.get('source') - source_node = next((node for node in nodes if node.get('id') == source_node_id), None) - if not source_node: - continue - - node_type = source_node.get('data', {}).get('type') - node_iteration_id = source_node.get('data', {}).get('iteration_id') - iteration_start_node_id = None - if node_iteration_id: - iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None) - iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id') - - if node_type in [ - NodeType.ANSWER.value, - NodeType.IF_ELSE.value, - NodeType.QUESTION_CLASSIFIER.value, - NodeType.ITERATION.value, - NodeType.LOOP.value - ]: - start_node_id = target_node_id - start_node_ids.append(start_node_id) - elif node_type == NodeType.START.value or \ - node_iteration_id is not None and iteration_start_node_id == source_node.get('id'): - start_node_id = source_node_id - start_node_ids.append(start_node_id) - else: - sub_start_node_ids = self._get_answer_start_at_node_ids(graph, source_node_id) - if sub_start_node_ids: - start_node_ids.extend(sub_start_node_ids) - - return start_node_ids - - def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]: - """ - Get iteration nested relations. - :param graph: graph - :return: - """ - nodes = graph.get('nodes') - - iteration_ids = [node.get('id') for node in nodes - if node.get('data', {}).get('type') in [ - NodeType.ITERATION.value, - NodeType.LOOP.value, - ]] - - return { - iteration_id: [ - node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id - ] for iteration_id in iteration_ids - } - - def _generate_stream_outputs_when_node_started(self) -> Generator: - """ - Generate stream outputs. - :return: - """ - if self._task_state.current_stream_generate_state: - route_chunks = self._task_state.current_stream_generate_state.generate_route[ - self._task_state.current_stream_generate_state.current_route_position: - ] - - for route_chunk in route_chunks: - if route_chunk.type == 'text': - route_chunk = cast(TextGenerateRouteChunk, route_chunk) - - # handle output moderation chunk - should_direct_answer = self._handle_output_moderation_chunk(route_chunk.text) - if should_direct_answer: - continue - - self._task_state.answer += route_chunk.text - yield self._message_to_stream_response(route_chunk.text, self._message.id) - else: - break - - self._task_state.current_stream_generate_state.current_route_position += 1 - - # all route chunks are generated - if self._task_state.current_stream_generate_state.current_route_position == len( - self._task_state.current_stream_generate_state.generate_route - ): - self._task_state.current_stream_generate_state = None - - def _generate_stream_outputs_when_node_finished(self) -> Optional[Generator]: - """ - Generate stream outputs. - :return: - """ - if not self._task_state.current_stream_generate_state: - return - - route_chunks = self._task_state.current_stream_generate_state.generate_route[ - self._task_state.current_stream_generate_state.current_route_position:] - - for route_chunk in route_chunks: - if route_chunk.type == 'text': - route_chunk = cast(TextGenerateRouteChunk, route_chunk) - self._task_state.answer += route_chunk.text - yield self._message_to_stream_response(route_chunk.text, self._message.id) - else: - value = None - route_chunk = cast(VarGenerateRouteChunk, route_chunk) - value_selector = route_chunk.value_selector - if not value_selector: - self._task_state.current_stream_generate_state.current_route_position += 1 - continue - - route_chunk_node_id = value_selector[0] - - if route_chunk_node_id == 'sys': - # system variable - value = contexts.workflow_variable_pool.get().get(value_selector) - if value: - value = value.text - elif route_chunk_node_id in self._iteration_nested_relations: - # it's a iteration variable - if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations: - continue - iteration_state = self._iteration_state.current_iterations[route_chunk_node_id] - iterator = iteration_state.inputs - if not iterator: - continue - iterator_selector = iterator.get('iterator_selector', []) - if value_selector[1] == 'index': - value = iteration_state.current_index - elif value_selector[1] == 'item': - value = iterator_selector[iteration_state.current_index] if iteration_state.current_index < len( - iterator_selector - ) else None - else: - # check chunk node id is before current node id or equal to current node id - if route_chunk_node_id not in self._task_state.ran_node_execution_infos: - break - - latest_node_execution_info = self._task_state.latest_node_execution_info - - # get route chunk node execution info - route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id] - if (route_chunk_node_execution_info.node_type == NodeType.LLM - and latest_node_execution_info.node_type == NodeType.LLM): - # only LLM support chunk stream output - self._task_state.current_stream_generate_state.current_route_position += 1 - continue - - # get route chunk node execution - route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id - ).first() - - outputs = route_chunk_node_execution.outputs_dict - - # get value from outputs - value = None - for key in value_selector[1:]: - if not value: - value = outputs.get(key) if outputs else None - else: - value = value.get(key) - - if value is not None: - text = '' - if isinstance(value, str | int | float): - text = str(value) - elif isinstance(value, FileVar): - # convert file to markdown - text = value.to_markdown() - elif isinstance(value, dict): - # handle files - file_vars = self._fetch_files_from_variable_value(value) - if file_vars: - file_var = file_vars[0] - try: - file_var_obj = FileVar(**file_var) - - # convert file to markdown - text = file_var_obj.to_markdown() - except Exception as e: - logger.error(f'Error creating file var: {e}') - - if not text: - # other types - text = json.dumps(value, ensure_ascii=False) - elif isinstance(value, list): - # handle files - file_vars = self._fetch_files_from_variable_value(value) - for file_var in file_vars: - try: - file_var_obj = FileVar(**file_var) - except Exception as e: - logger.error(f'Error creating file var: {e}') - continue - - # convert file to markdown - text = file_var_obj.to_markdown() + ' ' - - text = text.strip() - - if not text and value: - # other types - text = json.dumps(value, ensure_ascii=False) - - if text: - self._task_state.answer += text - yield self._message_to_stream_response(text, self._message.id) - - self._task_state.current_stream_generate_state.current_route_position += 1 - - # all route chunks are generated - if self._task_state.current_stream_generate_state.current_route_position == len( - self._task_state.current_stream_generate_state.generate_route - ): - self._task_state.current_stream_generate_state = None - - def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool: - """ - Is stream out support - :param event: queue text chunk event - :return: - """ - if not event.metadata: - return True - - if 'node_id' not in event.metadata: - return True - - node_type = event.metadata.get('node_type') - stream_output_value_selector = event.metadata.get('value_selector') - if not stream_output_value_selector: - return False - - if not self._task_state.current_stream_generate_state: - return False - - route_chunk = self._task_state.current_stream_generate_state.generate_route[ - self._task_state.current_stream_generate_state.current_route_position] - - if route_chunk.type != 'var': - return False - - if node_type != NodeType.LLM: - # only LLM support chunk stream output - return False - - route_chunk = cast(VarGenerateRouteChunk, route_chunk) - value_selector = route_chunk.value_selector - - # check chunk node id is before current node id or equal to current node id - if value_selector != stream_output_value_selector: - return False - - return True - def _handle_output_moderation_chunk(self, text: str) -> bool: """ Handle output moderation chunk. @@ -768,17 +541,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc # stop subscribe new token when output moderation should direct output self._task_state.answer = self._output_moderation_handler.get_final_output() self._queue_manager.publish( - QueueTextChunkEvent( - text=self._task_state.answer - ), PublishFrom.TASK_PIPELINE + QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE ) self._queue_manager.publish( - QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), - PublishFrom.TASK_PIPELINE + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE ) return True else: self._output_moderation_handler.append_new_token(text) return False + + def _refetch_message(self) -> None: + """ + Refetch message. + :return: + """ + message = db.session.query(Message).filter(Message.id == self._message.id).first() + if message: + self._message = message diff --git a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py deleted file mode 100644 index 8d43155a08..0000000000 --- a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py +++ /dev/null @@ -1,203 +0,0 @@ -from typing import Any, Optional - -from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.entities.queue_entities import ( - AppQueueEvent, - QueueIterationCompletedEvent, - QueueIterationNextEvent, - QueueIterationStartEvent, - QueueNodeFailedEvent, - QueueNodeStartedEvent, - QueueNodeSucceededEvent, - QueueTextChunkEvent, - QueueWorkflowFailedEvent, - QueueWorkflowStartedEvent, - QueueWorkflowSucceededEvent, -) -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType -from models.workflow import Workflow - - -class WorkflowEventTriggerCallback(WorkflowCallback): - - def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): - self._queue_manager = queue_manager - - def on_workflow_run_started(self) -> None: - """ - Workflow run started - """ - self._queue_manager.publish( - QueueWorkflowStartedEvent(), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_run_succeeded(self) -> None: - """ - Workflow run succeeded - """ - self._queue_manager.publish( - QueueWorkflowSucceededEvent(), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_run_failed(self, error: str) -> None: - """ - Workflow run failed - """ - self._queue_manager.publish( - QueueWorkflowFailedEvent( - error=error - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_started(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - node_run_index: int = 1, - predecessor_node_id: Optional[str] = None) -> None: - """ - Workflow node execute started - """ - self._queue_manager.publish( - QueueNodeStartedEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - node_run_index=node_run_index, - predecessor_node_id=predecessor_node_id - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_succeeded(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None) -> None: - """ - Workflow node execute succeeded - """ - self._queue_manager.publish( - QueueNodeSucceededEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - inputs=inputs, - process_data=process_data, - outputs=outputs, - execution_metadata=execution_metadata - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_failed(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - error: str, - inputs: Optional[dict] = None, - outputs: Optional[dict] = None, - process_data: Optional[dict] = None) -> None: - """ - Workflow node execute failed - """ - self._queue_manager.publish( - QueueNodeFailedEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - inputs=inputs, - outputs=outputs, - process_data=process_data, - error=error - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: - """ - Publish text chunk - """ - self._queue_manager.publish( - QueueTextChunkEvent( - text=text, - metadata={ - "node_id": node_id, - **metadata - } - ), PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_started(self, - node_id: str, - node_type: NodeType, - node_run_index: int = 1, - node_data: Optional[BaseNodeData] = None, - inputs: dict = None, - predecessor_node_id: Optional[str] = None, - metadata: Optional[dict] = None) -> None: - """ - Publish iteration started - """ - self._queue_manager.publish( - QueueIterationStartEvent( - node_id=node_id, - node_type=node_type, - node_run_index=node_run_index, - node_data=node_data, - inputs=inputs, - predecessor_node_id=predecessor_node_id, - metadata=metadata - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_next(self, node_id: str, - node_type: NodeType, - index: int, - node_run_index: int, - output: Optional[Any]) -> None: - """ - Publish iteration next - """ - self._queue_manager._publish( - QueueIterationNextEvent( - node_id=node_id, - node_type=node_type, - index=index, - node_run_index=node_run_index, - output=output - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_completed(self, node_id: str, - node_type: NodeType, - node_run_index: int, - outputs: dict) -> None: - """ - Publish iteration completed - """ - self._queue_manager._publish( - QueueIterationCompletedEvent( - node_id=node_id, - node_type=node_type, - node_run_index=node_run_index, - outputs=outputs - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_event(self, event: AppQueueEvent) -> None: - """ - Publish event - """ - self._queue_manager.publish( - event, - PublishFrom.APPLICATION_MANAGER - ) diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index f495ebbf35..9040f18bfd 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -28,15 +28,19 @@ class AgentChatAppConfig(EasyUIBasedAppConfig): """ Agent Chatbot App Config Entity. """ + agent: Optional[AgentEntity] = None class AgentChatAppConfigManager(BaseAppConfigManager): @classmethod - def get_app_config(cls, app_model: App, - app_model_config: AppModelConfig, - conversation: Optional[Conversation] = None, - override_config_dict: Optional[dict] = None) -> AgentChatAppConfig: + def get_app_config( + cls, + app_model: App, + app_model_config: AppModelConfig, + conversation: Optional[Conversation] = None, + override_config_dict: Optional[dict] = None, + ) -> AgentChatAppConfig: """ Convert app model config to agent chat app config :param app_model: app model @@ -66,22 +70,12 @@ class AgentChatAppConfigManager(BaseAppConfigManager): app_model_config_from=config_from, app_model_config_id=app_model_config.id, app_model_config_dict=config_dict, - model=ModelConfigManager.convert( - config=config_dict - ), - prompt_template=PromptTemplateConfigManager.convert( - config=config_dict - ), - sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( - config=config_dict - ), - dataset=DatasetConfigManager.convert( - config=config_dict - ), - agent=AgentConfigManager.convert( - config=config_dict - ), - additional_features=cls.convert_features(config_dict, app_mode) + model=ModelConfigManager.convert(config=config_dict), + prompt_template=PromptTemplateConfigManager.convert(config=config_dict), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), + dataset=DatasetConfigManager.convert(config=config_dict), + agent=AgentConfigManager.convert(config=config_dict), + additional_features=cls.convert_features(config_dict, app_mode), ) app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( @@ -128,7 +122,8 @@ class AgentChatAppConfigManager(BaseAppConfigManager): # suggested_questions_after_answer config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( - config) + config + ) related_config_keys.extend(current_related_config_keys) # speech_to_text @@ -145,13 +140,15 @@ class AgentChatAppConfigManager(BaseAppConfigManager): # dataset configs # dataset_query_variable - config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, - config) + config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults( + tenant_id, app_mode, config + ) related_config_keys.extend(current_related_config_keys) # moderation validation - config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, - config) + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id, config + ) related_config_keys.extend(current_related_config_keys) related_config_keys = list(set(related_config_keys)) @@ -170,10 +167,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): :param config: app model config args """ if not config.get("agent_mode"): - config["agent_mode"] = { - "enabled": False, - "tools": [] - } + config["agent_mode"] = {"enabled": False, "tools": []} if not isinstance(config["agent_mode"], dict): raise ValueError("agent_mode must be of object type") @@ -187,8 +181,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager): if not config["agent_mode"].get("strategy"): config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value - if config["agent_mode"]["strategy"] not in [member.value for member in - list(PlanningStrategy.__members__.values())]: + if config["agent_mode"]["strategy"] not in [ + member.value for member in list(PlanningStrategy.__members__.values()) + ]: raise ValueError("strategy in agent_mode must be in the specified strategy list") if not config["agent_mode"].get("tools"): @@ -210,7 +205,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): raise ValueError("enabled in agent_mode.tools must be of boolean type") if key == "dataset": - if 'id' not in tool_item: + if "id" not in tool_item: raise ValueError("id is required in dataset") try: diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 53780bdfb0..abf8a332ab 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -3,7 +3,7 @@ import os import threading import uuid from collections.abc import Generator -from typing import Any, Union +from typing import Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -13,7 +13,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom @@ -28,12 +28,29 @@ logger = logging.getLogger(__name__) class AgentChatAppGenerator(MessageBasedAppGenerator): - def generate(self, app_model: App, - user: Union[Account, EndUser], - args: Any, - invoke_from: InvokeFrom, - stream: bool = True) \ - -> Union[dict, Generator[dict, None, None]]: + @overload + def generate( + self, + app_model: App, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[True] = True, + ) -> Generator[dict, None, None]: ... + + @overload + def generate( + self, + app_model: App, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[False] = False, + ) -> dict: ... + + def generate( + self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True + ) -> Union[dict, Generator[dict, None, None]]: """ Generate App response. @@ -44,60 +61,48 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): :param stream: is stream """ if not stream: - raise ValueError('Agent Chat App does not support blocking mode') + raise ValueError("Agent Chat App does not support blocking mode") - if not args.get('query'): - raise ValueError('query is required') + if not args.get("query"): + raise ValueError("query is required") - query = args['query'] + query = args["query"] if not isinstance(query, str): - raise ValueError('query must be a string') + raise ValueError("query must be a string") - query = query.replace('\x00', '') - inputs = args['inputs'] + query = query.replace("\x00", "") + inputs = args["inputs"] - extras = { - "auto_generate_conversation_name": args.get('auto_generate_name', True) - } + extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)} # get conversation conversation = None - if args.get('conversation_id'): - conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + if args.get("conversation_id"): + conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user) # get app model config - app_model_config = self._get_app_model_config( - app_model=app_model, - conversation=conversation - ) + app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) # validate override model config override_model_config_dict = None - if args.get('model_config'): + if args.get("model_config"): if invoke_from != InvokeFrom.DEBUGGER: - raise ValueError('Only in App debug mode can override model config') + raise ValueError("Only in App debug mode can override model config") # validate config override_model_config_dict = AgentChatAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, - config=args.get('model_config') + tenant_id=app_model.tenant_id, config=args.get("model_config") ) # always enable retriever resource in debugger mode - override_model_config_dict["retriever_resource"] = { - "enabled": True - } + override_model_config_dict["retriever_resource"] = {"enabled": True} # parse files - files = args['files'] if args.get('files') else [] + files = args["files"] if args.get("files") else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_extra_config, - user - ) + file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) else: file_objs = [] @@ -106,7 +111,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): app_model=app_model, app_model_config=app_model_config, conversation=conversation, - override_config_dict=override_model_config_dict + override_config_dict=override_model_config_dict, ) # get tracing instance @@ -127,14 +132,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): invoke_from=invoke_from, extras=extras, call_depth=0, - trace_manager=trace_manager + trace_manager=trace_manager, ) # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity, conversation) + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -143,17 +145,20 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, app_mode=conversation.mode, - message_id=message.id + message_id=message.id, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'conversation_id': conversation.id, - 'message_id': message.id, - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + }, + ) worker_thread.start() @@ -167,13 +172,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): stream=stream, ) - return AgentChatAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( - self, flask_app: Flask, + self, + flask_app: Flask, application_generate_entity: AgentChatAppGenerateEntity, queue_manager: AppQueueManager, conversation_id: str, @@ -202,18 +205,17 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): conversation=conversation, message=message, ) - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: pass except InvokeAuthorizationError: queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER ) except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true": logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index d1bbf679c5..45b1bf0093 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -15,7 +15,7 @@ from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.moderation.base import ModerationException +from core.moderation.base import ModerationError from core.tools.entities.tool_entities import ToolRuntimeVariablePool from extensions.ext_database import db from models.model import App, Conversation, Message, MessageAgentThought @@ -30,7 +30,8 @@ class AgentChatAppRunner(AppRunner): """ def run( - self, application_generate_entity: AgentChatAppGenerateEntity, + self, + application_generate_entity: AgentChatAppGenerateEntity, queue_manager: AppQueueManager, conversation: Conversation, message: Message, @@ -65,7 +66,7 @@ class AgentChatAppRunner(AppRunner): prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, - query=query + query=query, ) memory = None @@ -73,13 +74,10 @@ class AgentChatAppRunner(AppRunner): # get memory of conversation (read-only) model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, - model=application_generate_entity.model_conf.model + model=application_generate_entity.model_conf.model, ) - memory = TokenBufferMemory( - conversation=conversation, - model_instance=model_instance - ) + memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) # organize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) @@ -91,7 +89,7 @@ class AgentChatAppRunner(AppRunner): inputs=inputs, files=files, query=query, - memory=memory + memory=memory, ) # moderation @@ -103,15 +101,15 @@ class AgentChatAppRunner(AppRunner): app_generate_entity=application_generate_entity, inputs=inputs, query=query, - message_id=message.id + message_id=message.id, ) - except ModerationException as e: + except ModerationError as e: self.direct_output( queue_manager=queue_manager, app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=str(e), - stream=application_generate_entity.stream + stream=application_generate_entity.stream, ) return @@ -122,13 +120,13 @@ class AgentChatAppRunner(AppRunner): message=message, query=query, user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from + invoke_from=application_generate_entity.invoke_from, ) if annotation_reply: queue_manager.publish( QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), - PublishFrom.APPLICATION_MANAGER + PublishFrom.APPLICATION_MANAGER, ) self.direct_output( @@ -136,7 +134,7 @@ class AgentChatAppRunner(AppRunner): app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=annotation_reply.content, - stream=application_generate_entity.stream + stream=application_generate_entity.stream, ) return @@ -148,7 +146,7 @@ class AgentChatAppRunner(AppRunner): app_id=app_record.id, external_data_tools=external_data_tools, inputs=inputs, - query=query + query=query, ) # reorganize all inputs and template to prompt messages @@ -161,14 +159,14 @@ class AgentChatAppRunner(AppRunner): inputs=inputs, files=files, query=query, - memory=memory + memory=memory, ) # check hosting moderation hosting_moderation_result = self.check_hosting_moderation( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - prompt_messages=prompt_messages + prompt_messages=prompt_messages, ) if hosting_moderation_result: @@ -177,9 +175,9 @@ class AgentChatAppRunner(AppRunner): agent_entity = app_config.agent # load tool variables - tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id, - user_id=application_generate_entity.user_id, - tenant_id=app_config.tenant_id) + tool_conversation_variables = self._load_tool_variables( + conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id + ) # convert db variables to tool variables tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables) @@ -187,7 +185,7 @@ class AgentChatAppRunner(AppRunner): # init model instance model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, - model=application_generate_entity.model_conf.model + model=application_generate_entity.model_conf.model, ) prompt_message, _ = self.organize_prompt_messages( app_record=app_record, @@ -238,7 +236,7 @@ class AgentChatAppRunner(AppRunner): prompt_messages=prompt_message, variables_pool=tool_variables, db_variables=tool_conversation_variables, - model_instance=model_instance + model_instance=model_instance, ) invoke_result = runner.run( @@ -252,17 +250,21 @@ class AgentChatAppRunner(AppRunner): invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream, - agent=True + agent=True, ) def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables: """ load tool variables from database """ - tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter( - ToolConversationVariables.conversation_id == conversation_id, - ToolConversationVariables.tenant_id == tenant_id - ).first() + tool_variables: ToolConversationVariables = ( + db.session.query(ToolConversationVariables) + .filter( + ToolConversationVariables.conversation_id == conversation_id, + ToolConversationVariables.tenant_id == tenant_id, + ) + .first() + ) if tool_variables: # save tool variables to session, so that we can update it later @@ -273,34 +275,40 @@ class AgentChatAppRunner(AppRunner): conversation_id=conversation_id, user_id=user_id, tenant_id=tenant_id, - variables_str='[]', + variables_str="[]", ) db.session.add(tool_variables) db.session.commit() return tool_variables - - def _convert_db_variables_to_tool_variables(self, db_variables: ToolConversationVariables) -> ToolRuntimeVariablePool: + + def _convert_db_variables_to_tool_variables( + self, db_variables: ToolConversationVariables + ) -> ToolRuntimeVariablePool: """ convert db variables to tool variables """ - return ToolRuntimeVariablePool(**{ - 'conversation_id': db_variables.conversation_id, - 'user_id': db_variables.user_id, - 'tenant_id': db_variables.tenant_id, - 'pool': db_variables.variables - }) + return ToolRuntimeVariablePool( + **{ + "conversation_id": db_variables.conversation_id, + "user_id": db_variables.user_id, + "tenant_id": db_variables.tenant_id, + "pool": db_variables.variables, + } + ) - def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigWithCredentialsEntity, - message: Message) -> LLMUsage: + def _get_usage_of_all_agent_thoughts( + self, model_config: ModelConfigWithCredentialsEntity, message: Message + ) -> LLMUsage: """ Get usage of all agent thoughts :param model_config: model config :param message: message :return: """ - agent_thoughts = (db.session.query(MessageAgentThought) - .filter(MessageAgentThought.message_id == message.id).all()) + agent_thoughts = ( + db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).all() + ) all_message_tokens = 0 all_answer_tokens = 0 @@ -312,8 +320,5 @@ class AgentChatAppRunner(AppRunner): model_type_instance = cast(LargeLanguageModel, model_type_instance) return model_type_instance._calc_response_usage( - model_config.model, - model_config.credentials, - all_message_tokens, - all_answer_tokens + model_config.model, model_config.credentials, all_message_tokens, all_answer_tokens ) diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py index 118d82c495..629c309c06 100644 --- a/api/core/app/apps/agent_chat/generate_response_converter.py +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -23,15 +23,15 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): :return: """ response = { - 'event': 'message', - 'task_id': blocking_response.task_id, - 'id': blocking_response.data.id, - 'message_id': blocking_response.data.message_id, - 'conversation_id': blocking_response.data.conversation_id, - 'mode': blocking_response.data.mode, - 'answer': blocking_response.data.answer, - 'metadata': blocking_response.data.metadata, - 'created_at': blocking_response.data.created_at + "event": "message", + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "conversation_id": blocking_response.data.conversation_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, } return response @@ -45,14 +45,15 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): """ response = cls.convert_blocking_full_response(blocking_response) - metadata = response.get('metadata', {}) - response['metadata'] = cls._get_simple_metadata(metadata) + metadata = response.get("metadata", {}) + response["metadata"] = cls._get_simple_metadata(metadata) return response @classmethod - def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_full_response( + cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -63,14 +64,14 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, ErrorStreamResponse): @@ -81,8 +82,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield json.dumps(response_chunk) @classmethod - def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_simple_response( + cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream simple response. :param stream_response: stream response @@ -93,20 +95,20 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, MessageEndStreamResponse): sub_stream_response_dict = sub_stream_response.to_dict() - metadata = sub_stream_response_dict.get('metadata', {}) - sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) + metadata = sub_stream_response_dict.get("metadata", {}) + sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 1165314a7f..73025d99d0 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -13,32 +13,33 @@ class AppGenerateResponseConverter(ABC): _blocking_response_type: type[AppBlockingResponse] @classmethod - def convert(cls, response: Union[ - AppBlockingResponse, - Generator[AppStreamResponse, Any, None] - ], invoke_from: InvokeFrom): + def convert( + cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom + ) -> dict[str, Any] | Generator[str, Any, None]: if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: if isinstance(response, AppBlockingResponse): return cls.convert_blocking_full_response(response) else: + def _generate_full_response() -> Generator[str, Any, None]: for chunk in cls.convert_stream_full_response(response): - if chunk == 'ping': - yield f'event: {chunk}\n\n' + if chunk == "ping": + yield f"event: {chunk}\n\n" else: - yield f'data: {chunk}\n\n' + yield f"data: {chunk}\n\n" return _generate_full_response() else: if isinstance(response, AppBlockingResponse): return cls.convert_blocking_simple_response(response) else: + def _generate_simple_response() -> Generator[str, Any, None]: for chunk in cls.convert_stream_simple_response(response): - if chunk == 'ping': - yield f'event: {chunk}\n\n' + if chunk == "ping": + yield f"event: {chunk}\n\n" else: - yield f'data: {chunk}\n\n' + yield f"data: {chunk}\n\n" return _generate_simple_response() @@ -54,14 +55,16 @@ class AppGenerateResponseConverter(ABC): @classmethod @abstractmethod - def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_full_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[str, None, None]: raise NotImplementedError @classmethod @abstractmethod - def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_simple_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[str, None, None]: raise NotImplementedError @classmethod @@ -72,24 +75,26 @@ class AppGenerateResponseConverter(ABC): :return: """ # show_retrieve_source - if 'retriever_resources' in metadata: - metadata['retriever_resources'] = [] - for resource in metadata['retriever_resources']: - metadata['retriever_resources'].append({ - 'segment_id': resource['segment_id'], - 'position': resource['position'], - 'document_name': resource['document_name'], - 'score': resource['score'], - 'content': resource['content'], - }) + if "retriever_resources" in metadata: + metadata["retriever_resources"] = [] + for resource in metadata["retriever_resources"]: + metadata["retriever_resources"].append( + { + "segment_id": resource["segment_id"], + "position": resource["position"], + "document_name": resource["document_name"], + "score": resource["score"], + "content": resource["content"], + } + ) # show annotation reply - if 'annotation_reply' in metadata: - del metadata['annotation_reply'] + if "annotation_reply" in metadata: + del metadata["annotation_reply"] # show usage - if 'usage' in metadata: - del metadata['usage'] + if "usage" in metadata: + del metadata["usage"] return metadata @@ -101,16 +106,16 @@ class AppGenerateResponseConverter(ABC): :return: """ error_responses = { - ValueError: {'code': 'invalid_param', 'status': 400}, - ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400}, + ValueError: {"code": "invalid_param", "status": 400}, + ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400}, QuotaExceededError: { - 'code': 'provider_quota_exceeded', - 'message': "Your quota for Dify Hosted Model Provider has been exhausted. " - "Please go to Settings -> Model Provider to complete your own provider credentials.", - 'status': 400 + "code": "provider_quota_exceeded", + "message": "Your quota for Dify Hosted Model Provider has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials.", + "status": 400, }, - ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400}, - InvokeError: {'code': 'completion_request_error', 'status': 400} + ModelCurrentlyNotSupportError: {"code": "model_currently_not_support", "status": 400}, + InvokeError: {"code": "completion_request_error", "status": 400}, } # Determine the response based on the type of exception @@ -120,13 +125,13 @@ class AppGenerateResponseConverter(ABC): data = v if data: - data.setdefault('message', getattr(e, 'description', str(e))) + data.setdefault("message", getattr(e, "description", str(e))) else: logging.error(e) data = { - 'code': 'internal_server_error', - 'message': 'Internal Server Error, please contact support.', - 'status': 500 + "code": "internal_server_error", + "message": "Internal Server Error, please contact support.", + "status": 500, } return data diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 9e331dff4d..ce6f7d4338 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -16,10 +16,10 @@ class BaseAppGenerator: def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): user_input_value = inputs.get(var.variable) if var.required and not user_input_value: - raise ValueError(f'{var.variable} is required in input form') + raise ValueError(f"{var.variable} is required in input form") if not var.required and not user_input_value: # TODO: should we return None here if the default value is None? - return var.default or '' + return var.default or "" if ( var.type in ( @@ -34,7 +34,7 @@ class BaseAppGenerator: if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str): # may raise ValueError if user_input_value is not a valid number try: - if '.' in user_input_value: + if "." in user_input_value: return float(user_input_value) else: return int(user_input_value) @@ -43,14 +43,14 @@ class BaseAppGenerator: if var.type == VariableEntityType.SELECT: options = var.options or [] if user_input_value not in options: - raise ValueError(f'{var.variable} in input form must be one of the following: {options}') + raise ValueError(f"{var.variable} in input form must be one of the following: {options}") elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): if var.max_length and user_input_value and len(user_input_value) > var.max_length: - raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters') + raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters") return user_input_value def _sanitize_value(self, value: Any) -> Any: if isinstance(value, str): - return value.replace('\x00', '') + return value.replace("\x00", "") return value diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index f929a979f1..f3c3199354 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -24,9 +24,7 @@ class PublishFrom(Enum): class AppQueueManager: - def __init__(self, task_id: str, - user_id: str, - invoke_from: InvokeFrom) -> None: + def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None: if not user_id: raise ValueError("user is required") @@ -34,9 +32,10 @@ class AppQueueManager: self._user_id = user_id self._invoke_from = invoke_from - user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' - redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, - f"{user_prefix}-{self._user_id}") + user_prefix = "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user" + redis_client.setex( + AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}" + ) q = queue.Queue() @@ -66,8 +65,7 @@ class AppQueueManager: # publish two messages to make sure the client can receive the stop signal # and stop listening after the stop signal processed self.publish( - QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), - PublishFrom.TASK_PIPELINE + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE ) if elapsed_time // 10 > last_ping_time: @@ -88,9 +86,7 @@ class AppQueueManager: :param pub_from: publish from :return: """ - self.publish(QueueErrorEvent( - error=e - ), pub_from) + self.publish(QueueErrorEvent(error=e), pub_from) def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: """ @@ -122,8 +118,8 @@ class AppQueueManager: if result is None: return - user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' - if result.decode('utf-8') != f"{user_prefix}-{user_id}": + user_prefix = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user" + if result.decode("utf-8") != f"{user_prefix}-{user_id}": return stopped_cache_key = cls._generate_stopped_cache_key(task_id) @@ -168,10 +164,12 @@ class AppQueueManager: for item in data: self._check_for_sqlalchemy_models(item) else: - if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'): - raise TypeError("Critical Error: Passing SQLAlchemy Model instances " - "that cause thread safety issues is not allowed.") + if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"): + raise TypeError( + "Critical Error: Passing SQLAlchemy Model instances " + "that cause thread safety issues is not allowed." + ) -class GenerateTaskStoppedException(Exception): +class GenerateTaskStoppedError(Exception): pass diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 2c5feaaaaf..aadb43ad39 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -1,6 +1,6 @@ import time -from collections.abc import Generator -from typing import TYPE_CHECKING, Optional, Union +from collections.abc import Generator, Mapping +from typing import TYPE_CHECKING, Any, Optional, Union from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -31,12 +31,15 @@ if TYPE_CHECKING: class AppRunner: - def get_pre_calculate_rest_tokens(self, app_record: App, - model_config: ModelConfigWithCredentialsEntity, - prompt_template_entity: PromptTemplateEntity, - inputs: dict[str, str], - files: list["FileVar"], - query: Optional[str] = None) -> int: + def get_pre_calculate_rest_tokens( + self, + app_record: App, + model_config: ModelConfigWithCredentialsEntity, + prompt_template_entity: PromptTemplateEntity, + inputs: dict[str, str], + files: list["FileVar"], + query: Optional[str] = None, + ) -> int: """ Get pre calculate rest tokens :param app_record: app record @@ -49,18 +52,20 @@ class AppRunner: """ # Invoke model model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, - model=model_config.model + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) max_tokens = 0 for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template) + ) or 0 if model_context_tokens is None: return -1 @@ -75,36 +80,39 @@ class AppRunner: prompt_template_entity=prompt_template_entity, inputs=inputs, files=files, - query=query + query=query, ) - prompt_tokens = model_instance.get_llm_num_tokens( - prompt_messages - ) + prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) rest_tokens = model_context_tokens - max_tokens - prompt_tokens if rest_tokens < 0: - raise InvokeBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, " - "or shrink the max token, or switch to a llm with a larger token limit size.") + raise InvokeBadRequestError( + "Query or prefix prompt is too long, you can reduce the prefix prompt, " + "or shrink the max token, or switch to a llm with a larger token limit size." + ) return rest_tokens - def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity, - prompt_messages: list[PromptMessage]): + def recalc_llm_max_tokens( + self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage] + ): # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, - model=model_config.model + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) max_tokens = 0 for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template) + ) or 0 if model_context_tokens is None: return -1 @@ -112,27 +120,28 @@ class AppRunner: if max_tokens is None: max_tokens = 0 - prompt_tokens = model_instance.get_llm_num_tokens( - prompt_messages - ) + prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) if prompt_tokens + max_tokens > model_context_tokens: max_tokens = max(model_context_tokens - prompt_tokens, 16) for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): model_config.parameters[parameter_rule.name] = max_tokens - def organize_prompt_messages(self, app_record: App, - model_config: ModelConfigWithCredentialsEntity, - prompt_template_entity: PromptTemplateEntity, - inputs: dict[str, str], - files: list["FileVar"], - query: Optional[str] = None, - context: Optional[str] = None, - memory: Optional[TokenBufferMemory] = None) \ - -> tuple[list[PromptMessage], Optional[list[str]]]: + def organize_prompt_messages( + self, + app_record: App, + model_config: ModelConfigWithCredentialsEntity, + prompt_template_entity: PromptTemplateEntity, + inputs: dict[str, str], + files: list["FileVar"], + query: Optional[str] = None, + context: Optional[str] = None, + memory: Optional[TokenBufferMemory] = None, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: """ Organize prompt messages :param context: @@ -152,60 +161,54 @@ class AppRunner: app_mode=AppMode.value_of(app_record.mode), prompt_template_entity=prompt_template_entity, inputs=inputs, - query=query if query else '', + query=query if query else "", files=files, context=context, memory=memory, - model_config=model_config + model_config=model_config, ) else: - memory_config = MemoryConfig( - window=MemoryConfig.WindowConfig( - enabled=False - ) - ) + memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) model_mode = ModelMode.value_of(model_config.mode) if model_mode == ModelMode.COMPLETION: advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template - prompt_template = CompletionModelPromptTemplate( - text=advanced_completion_prompt_template.prompt - ) + prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt) if advanced_completion_prompt_template.role_prefix: memory_config.role_prefix = MemoryConfig.RolePrefix( user=advanced_completion_prompt_template.role_prefix.user, - assistant=advanced_completion_prompt_template.role_prefix.assistant + assistant=advanced_completion_prompt_template.role_prefix.assistant, ) else: prompt_template = [] for message in prompt_template_entity.advanced_chat_prompt_template.messages: - prompt_template.append(ChatModelMessage( - text=message.text, - role=message.role - )) + prompt_template.append(ChatModelMessage(text=message.text, role=message.role)) prompt_transform = AdvancedPromptTransform() prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs=inputs, - query=query if query else '', + query=query if query else "", files=files, context=context, memory_config=memory_config, memory=memory, - model_config=model_config + model_config=model_config, ) stop = model_config.stop return prompt_messages, stop - def direct_output(self, queue_manager: AppQueueManager, - app_generate_entity: EasyUIBasedAppGenerateEntity, - prompt_messages: list, - text: str, - stream: bool, - usage: Optional[LLMUsage] = None) -> None: + def direct_output( + self, + queue_manager: AppQueueManager, + app_generate_entity: EasyUIBasedAppGenerateEntity, + prompt_messages: list, + text: str, + stream: bool, + usage: Optional[LLMUsage] = None, + ) -> None: """ Direct output :param queue_manager: application queue manager @@ -222,17 +225,10 @@ class AppRunner: chunk = LLMResultChunk( model=app_generate_entity.model_conf.model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=AssistantPromptMessage(content=token) - ) + delta=LLMResultChunkDelta(index=index, message=AssistantPromptMessage(content=token)), ) - queue_manager.publish( - QueueLLMChunkEvent( - chunk=chunk - ), PublishFrom.APPLICATION_MANAGER - ) + queue_manager.publish(QueueLLMChunkEvent(chunk=chunk), PublishFrom.APPLICATION_MANAGER) index += 1 time.sleep(0.01) @@ -242,15 +238,19 @@ class AppRunner: model=app_generate_entity.model_conf.model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), - usage=usage if usage else LLMUsage.empty_usage() + usage=usage if usage else LLMUsage.empty_usage(), ), - ), PublishFrom.APPLICATION_MANAGER + ), + PublishFrom.APPLICATION_MANAGER, ) - def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator], - queue_manager: AppQueueManager, - stream: bool, - agent: bool = False) -> None: + def _handle_invoke_result( + self, + invoke_result: Union[LLMResult, Generator], + queue_manager: AppQueueManager, + stream: bool, + agent: bool = False, + ) -> None: """ Handle invoke result :param invoke_result: invoke result @@ -260,21 +260,13 @@ class AppRunner: :return: """ if not stream: - self._handle_invoke_result_direct( - invoke_result=invoke_result, - queue_manager=queue_manager, - agent=agent - ) + self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) else: - self._handle_invoke_result_stream( - invoke_result=invoke_result, - queue_manager=queue_manager, - agent=agent - ) + self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) - def _handle_invoke_result_direct(self, invoke_result: LLMResult, - queue_manager: AppQueueManager, - agent: bool) -> None: + def _handle_invoke_result_direct( + self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool + ) -> None: """ Handle invoke result direct :param invoke_result: invoke result @@ -285,12 +277,13 @@ class AppRunner: queue_manager.publish( QueueMessageEndEvent( llm_result=invoke_result, - ), PublishFrom.APPLICATION_MANAGER + ), + PublishFrom.APPLICATION_MANAGER, ) - def _handle_invoke_result_stream(self, invoke_result: Generator, - queue_manager: AppQueueManager, - agent: bool) -> None: + def _handle_invoke_result_stream( + self, invoke_result: Generator, queue_manager: AppQueueManager, agent: bool + ) -> None: """ Handle invoke result :param invoke_result: invoke result @@ -300,21 +293,13 @@ class AppRunner: """ model = None prompt_messages = [] - text = '' + text = "" usage = None for result in invoke_result: if not agent: - queue_manager.publish( - QueueLLMChunkEvent( - chunk=result - ), PublishFrom.APPLICATION_MANAGER - ) + queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER) else: - queue_manager.publish( - QueueAgentMessageEvent( - chunk=result - ), PublishFrom.APPLICATION_MANAGER - ) + queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER) text += result.delta.message.content @@ -331,25 +316,24 @@ class AppRunner: usage = LLMUsage.empty_usage() llm_result = LLMResult( - model=model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage(content=text), - usage=usage + model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), usage=usage ) queue_manager.publish( QueueMessageEndEvent( llm_result=llm_result, - ), PublishFrom.APPLICATION_MANAGER + ), + PublishFrom.APPLICATION_MANAGER, ) def moderation_for_inputs( - self, app_id: str, - tenant_id: str, - app_generate_entity: AppGenerateEntity, - inputs: dict, - query: str, - message_id: str, + self, + app_id: str, + tenant_id: str, + app_generate_entity: AppGenerateEntity, + inputs: Mapping[str, Any], + query: str, + message_id: str, ) -> tuple[bool, dict, str]: """ Process sensitive_word_avoidance. @@ -367,14 +351,17 @@ class AppRunner: tenant_id=tenant_id, app_config=app_generate_entity.app_config, inputs=inputs, - query=query if query else '', + query=query if query else "", message_id=message_id, - trace_manager=app_generate_entity.trace_manager + trace_manager=app_generate_entity.trace_manager, ) - def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity, - queue_manager: AppQueueManager, - prompt_messages: list[PromptMessage]) -> bool: + def check_hosting_moderation( + self, + application_generate_entity: EasyUIBasedAppGenerateEntity, + queue_manager: AppQueueManager, + prompt_messages: list[PromptMessage], + ) -> bool: """ Check hosting moderation :param application_generate_entity: application generate entity @@ -384,8 +371,7 @@ class AppRunner: """ hosting_moderation_feature = HostingModerationFeature() moderation_result = hosting_moderation_feature.check( - application_generate_entity=application_generate_entity, - prompt_messages=prompt_messages + application_generate_entity=application_generate_entity, prompt_messages=prompt_messages ) if moderation_result: @@ -393,18 +379,20 @@ class AppRunner: queue_manager=queue_manager, app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, - text="I apologize for any confusion, " \ - "but I'm an AI assistant to be helpful, harmless, and honest.", - stream=application_generate_entity.stream + text="I apologize for any confusion, " "but I'm an AI assistant to be helpful, harmless, and honest.", + stream=application_generate_entity.stream, ) return moderation_result - def fill_in_inputs_from_external_data_tools(self, tenant_id: str, - app_id: str, - external_data_tools: list[ExternalDataVariableEntity], - inputs: dict, - query: str) -> dict: + def fill_in_inputs_from_external_data_tools( + self, + tenant_id: str, + app_id: str, + external_data_tools: list[ExternalDataVariableEntity], + inputs: dict, + query: str, + ) -> dict: """ Fill in variable inputs from external data tools if exists. @@ -417,18 +405,12 @@ class AppRunner: """ external_data_fetch_feature = ExternalDataFetch() return external_data_fetch_feature.fetch( - tenant_id=tenant_id, - app_id=app_id, - external_data_tools=external_data_tools, - inputs=inputs, - query=query + tenant_id=tenant_id, app_id=app_id, external_data_tools=external_data_tools, inputs=inputs, query=query ) - def query_app_annotations_to_reply(self, app_record: App, - message: Message, - query: str, - user_id: str, - invoke_from: InvokeFrom) -> Optional[MessageAnnotation]: + def query_app_annotations_to_reply( + self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom + ) -> Optional[MessageAnnotation]: """ Query app annotations to reply :param app_record: app record @@ -440,9 +422,5 @@ class AppRunner: """ annotation_reply_feature = AnnotationReplyFeature() return annotation_reply_feature.query( - app_record=app_record, - message=message, - query=query, - user_id=user_id, - invoke_from=invoke_from + app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from ) diff --git a/api/core/app/apps/chat/app_config_manager.py b/api/core/app/apps/chat/app_config_manager.py index a286c349b2..96dc7dda79 100644 --- a/api/core/app/apps/chat/app_config_manager.py +++ b/api/core/app/apps/chat/app_config_manager.py @@ -22,15 +22,19 @@ class ChatAppConfig(EasyUIBasedAppConfig): """ Chatbot App Config Entity. """ + pass class ChatAppConfigManager(BaseAppConfigManager): @classmethod - def get_app_config(cls, app_model: App, - app_model_config: AppModelConfig, - conversation: Optional[Conversation] = None, - override_config_dict: Optional[dict] = None) -> ChatAppConfig: + def get_app_config( + cls, + app_model: App, + app_model_config: AppModelConfig, + conversation: Optional[Conversation] = None, + override_config_dict: Optional[dict] = None, + ) -> ChatAppConfig: """ Convert app model config to chat app config :param app_model: app model @@ -51,7 +55,7 @@ class ChatAppConfigManager(BaseAppConfigManager): config_dict = app_model_config_dict.copy() else: if not override_config_dict: - raise Exception('override_config_dict is required when config_from is ARGS') + raise Exception("override_config_dict is required when config_from is ARGS") config_dict = override_config_dict @@ -63,19 +67,11 @@ class ChatAppConfigManager(BaseAppConfigManager): app_model_config_from=config_from, app_model_config_id=app_model_config.id, app_model_config_dict=config_dict, - model=ModelConfigManager.convert( - config=config_dict - ), - prompt_template=PromptTemplateConfigManager.convert( - config=config_dict - ), - sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( - config=config_dict - ), - dataset=DatasetConfigManager.convert( - config=config_dict - ), - additional_features=cls.convert_features(config_dict, app_mode) + model=ModelConfigManager.convert(config=config_dict), + prompt_template=PromptTemplateConfigManager.convert(config=config_dict), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), + dataset=DatasetConfigManager.convert(config=config_dict), + additional_features=cls.convert_features(config_dict, app_mode), ) app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( @@ -113,8 +109,9 @@ class ChatAppConfigManager(BaseAppConfigManager): related_config_keys.extend(current_related_config_keys) # dataset_query_variable - config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, - config) + config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults( + tenant_id, app_mode, config + ) related_config_keys.extend(current_related_config_keys) # opening_statement @@ -123,7 +120,8 @@ class ChatAppConfigManager(BaseAppConfigManager): # suggested_questions_after_answer config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( - config) + config + ) related_config_keys.extend(current_related_config_keys) # speech_to_text @@ -139,8 +137,9 @@ class ChatAppConfigManager(BaseAppConfigManager): related_config_keys.extend(current_related_config_keys) # moderation validation - config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, - config) + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id, config + ) related_config_keys.extend(current_related_config_keys) related_config_keys = list(set(related_config_keys)) diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 5b896e2845..032556ec4c 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -3,14 +3,14 @@ import os import threading import uuid from collections.abc import Generator -from typing import Any, Union +from typing import Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.chat.app_runner import ChatAppRunner from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter @@ -28,13 +28,34 @@ logger = logging.getLogger(__name__) class ChatAppGenerator(MessageBasedAppGenerator): + @overload def generate( - self, app_model: App, + self, + app_model: App, + user: Union[Account, EndUser], + args: Any, + invoke_from: InvokeFrom, + stream: Literal[True] = True, + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, + app_model: App, + user: Union[Account, EndUser], + args: Any, + invoke_from: InvokeFrom, + stream: Literal[False] = False, + ) -> dict: ... + + def generate( + self, + app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True, - ) -> Union[dict, Generator[dict, None, None]]: + ) -> Union[dict, Generator[str, None, None]]: """ Generate App response. @@ -44,58 +65,46 @@ class ChatAppGenerator(MessageBasedAppGenerator): :param invoke_from: invoke from source :param stream: is stream """ - if not args.get('query'): - raise ValueError('query is required') + if not args.get("query"): + raise ValueError("query is required") - query = args['query'] + query = args["query"] if not isinstance(query, str): - raise ValueError('query must be a string') + raise ValueError("query must be a string") - query = query.replace('\x00', '') - inputs = args['inputs'] + query = query.replace("\x00", "") + inputs = args["inputs"] - extras = { - "auto_generate_conversation_name": args.get('auto_generate_name', True) - } + extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)} # get conversation conversation = None - if args.get('conversation_id'): - conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + if args.get("conversation_id"): + conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user) # get app model config - app_model_config = self._get_app_model_config( - app_model=app_model, - conversation=conversation - ) + app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) # validate override model config override_model_config_dict = None - if args.get('model_config'): + if args.get("model_config"): if invoke_from != InvokeFrom.DEBUGGER: - raise ValueError('Only in App debug mode can override model config') + raise ValueError("Only in App debug mode can override model config") # validate config override_model_config_dict = ChatAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, - config=args.get('model_config') + tenant_id=app_model.tenant_id, config=args.get("model_config") ) # always enable retriever resource in debugger mode - override_model_config_dict["retriever_resource"] = { - "enabled": True - } + override_model_config_dict["retriever_resource"] = {"enabled": True} # parse files - files = args['files'] if args.get('files') else [] + files = args["files"] if args.get("files") else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_extra_config, - user - ) + file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) else: file_objs = [] @@ -104,7 +113,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): app_model=app_model, app_model_config=app_model_config, conversation=conversation, - override_config_dict=override_model_config_dict + override_config_dict=override_model_config_dict, ) # get tracing instance @@ -123,14 +132,11 @@ class ChatAppGenerator(MessageBasedAppGenerator): stream=stream, invoke_from=invoke_from, extras=extras, - trace_manager=trace_manager + trace_manager=trace_manager, ) # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity, conversation) + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -139,17 +145,20 @@ class ChatAppGenerator(MessageBasedAppGenerator): invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, app_mode=conversation.mode, - message_id=message.id + message_id=message.id, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'conversation_id': conversation.id, - 'message_id': message.id, - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + }, + ) worker_thread.start() @@ -163,16 +172,16 @@ class ChatAppGenerator(MessageBasedAppGenerator): stream=stream, ) - return ChatAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) - def _generate_worker(self, flask_app: Flask, - application_generate_entity: ChatAppGenerateEntity, - queue_manager: AppQueueManager, - conversation_id: str, - message_id: str) -> None: + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: ChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation_id: str, + message_id: str, + ) -> None: """ Generate worker in a new thread. :param flask_app: Flask app @@ -194,20 +203,19 @@ class ChatAppGenerator(MessageBasedAppGenerator): application_generate_entity=application_generate_entity, queue_manager=queue_manager, conversation=conversation, - message=message + message=message, ) - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: pass except InvokeAuthorizationError: queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER ) except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true": logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 89a498eb36..425f1ab7ef 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -11,7 +11,7 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.moderation.base import ModerationException +from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db from models.model import App, Conversation, Message @@ -24,10 +24,13 @@ class ChatAppRunner(AppRunner): Chat Application Runner """ - def run(self, application_generate_entity: ChatAppGenerateEntity, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message) -> None: + def run( + self, + application_generate_entity: ChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + ) -> None: """ Run application :param application_generate_entity: application generate entity @@ -58,7 +61,7 @@ class ChatAppRunner(AppRunner): prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, - query=query + query=query, ) memory = None @@ -66,13 +69,10 @@ class ChatAppRunner(AppRunner): # get memory of conversation (read-only) model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, - model=application_generate_entity.model_conf.model + model=application_generate_entity.model_conf.model, ) - memory = TokenBufferMemory( - conversation=conversation, - model_instance=model_instance - ) + memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) # organize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) @@ -84,7 +84,7 @@ class ChatAppRunner(AppRunner): inputs=inputs, files=files, query=query, - memory=memory + memory=memory, ) # moderation @@ -96,15 +96,15 @@ class ChatAppRunner(AppRunner): app_generate_entity=application_generate_entity, inputs=inputs, query=query, - message_id=message.id + message_id=message.id, ) - except ModerationException as e: + except ModerationError as e: self.direct_output( queue_manager=queue_manager, app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=str(e), - stream=application_generate_entity.stream + stream=application_generate_entity.stream, ) return @@ -115,13 +115,13 @@ class ChatAppRunner(AppRunner): message=message, query=query, user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from + invoke_from=application_generate_entity.invoke_from, ) if annotation_reply: queue_manager.publish( QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), - PublishFrom.APPLICATION_MANAGER + PublishFrom.APPLICATION_MANAGER, ) self.direct_output( @@ -129,7 +129,7 @@ class ChatAppRunner(AppRunner): app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=annotation_reply.content, - stream=application_generate_entity.stream + stream=application_generate_entity.stream, ) return @@ -141,7 +141,7 @@ class ChatAppRunner(AppRunner): app_id=app_record.id, external_data_tools=external_data_tools, inputs=inputs, - query=query + query=query, ) # get context from datasets @@ -152,7 +152,7 @@ class ChatAppRunner(AppRunner): app_record.id, message.id, application_generate_entity.user_id, - application_generate_entity.invoke_from + application_generate_entity.invoke_from, ) dataset_retrieval = DatasetRetrieval(application_generate_entity) @@ -181,29 +181,26 @@ class ChatAppRunner(AppRunner): files=files, query=query, context=context, - memory=memory + memory=memory, ) # check hosting moderation hosting_moderation_result = self.check_hosting_moderation( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - prompt_messages=prompt_messages + prompt_messages=prompt_messages, ) if hosting_moderation_result: return # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit - self.recalc_llm_max_tokens( - model_config=application_generate_entity.model_conf, - prompt_messages=prompt_messages - ) + self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages) # Invoke model model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, - model=application_generate_entity.model_conf.model + model=application_generate_entity.model_conf.model, ) db.session.close() @@ -218,7 +215,5 @@ class ChatAppRunner(AppRunner): # handle invoke result self._handle_invoke_result( - invoke_result=invoke_result, - queue_manager=queue_manager, - stream=application_generate_entity.stream + invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream ) diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py index 625e14c9c3..0fa7af0a7f 100644 --- a/api/core/app/apps/chat/generate_response_converter.py +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -23,15 +23,15 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): :return: """ response = { - 'event': 'message', - 'task_id': blocking_response.task_id, - 'id': blocking_response.data.id, - 'message_id': blocking_response.data.message_id, - 'conversation_id': blocking_response.data.conversation_id, - 'mode': blocking_response.data.mode, - 'answer': blocking_response.data.answer, - 'metadata': blocking_response.data.metadata, - 'created_at': blocking_response.data.created_at + "event": "message", + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "conversation_id": blocking_response.data.conversation_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, } return response @@ -45,14 +45,15 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): """ response = cls.convert_blocking_full_response(blocking_response) - metadata = response.get('metadata', {}) - response['metadata'] = cls._get_simple_metadata(metadata) + metadata = response.get("metadata", {}) + response["metadata"] = cls._get_simple_metadata(metadata) return response @classmethod - def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_full_response( + cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -63,14 +64,14 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, ErrorStreamResponse): @@ -81,8 +82,9 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield json.dumps(response_chunk) @classmethod - def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_simple_response( + cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream simple response. :param stream_response: stream response @@ -93,20 +95,20 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, MessageEndStreamResponse): sub_stream_response_dict = sub_stream_response.to_dict() - metadata = sub_stream_response_dict.get('metadata', {}) - sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) + metadata = sub_stream_response_dict.get("metadata", {}) + sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py index a771198324..1193c4b7a4 100644 --- a/api/core/app/apps/completion/app_config_manager.py +++ b/api/core/app/apps/completion/app_config_manager.py @@ -17,14 +17,15 @@ class CompletionAppConfig(EasyUIBasedAppConfig): """ Completion App Config Entity. """ + pass class CompletionAppConfigManager(BaseAppConfigManager): @classmethod - def get_app_config(cls, app_model: App, - app_model_config: AppModelConfig, - override_config_dict: Optional[dict] = None) -> CompletionAppConfig: + def get_app_config( + cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: Optional[dict] = None + ) -> CompletionAppConfig: """ Convert app model config to completion app config :param app_model: app model @@ -51,19 +52,11 @@ class CompletionAppConfigManager(BaseAppConfigManager): app_model_config_from=config_from, app_model_config_id=app_model_config.id, app_model_config_dict=config_dict, - model=ModelConfigManager.convert( - config=config_dict - ), - prompt_template=PromptTemplateConfigManager.convert( - config=config_dict - ), - sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( - config=config_dict - ), - dataset=DatasetConfigManager.convert( - config=config_dict - ), - additional_features=cls.convert_features(config_dict, app_mode) + model=ModelConfigManager.convert(config=config_dict), + prompt_template=PromptTemplateConfigManager.convert(config=config_dict), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), + dataset=DatasetConfigManager.convert(config=config_dict), + additional_features=cls.convert_features(config_dict, app_mode), ) app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( @@ -101,8 +94,9 @@ class CompletionAppConfigManager(BaseAppConfigManager): related_config_keys.extend(current_related_config_keys) # dataset_query_variable - config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, - config) + config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults( + tenant_id, app_mode, config + ) related_config_keys.extend(current_related_config_keys) # text_to_speech @@ -114,8 +108,9 @@ class CompletionAppConfigManager(BaseAppConfigManager): related_config_keys.extend(current_related_config_keys) # moderation validation - config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, - config) + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id, config + ) related_config_keys.extend(current_related_config_keys) related_config_keys = list(set(related_config_keys)) diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index c4e1caf65a..7fce296f2b 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -3,14 +3,14 @@ import os import threading import uuid from collections.abc import Generator -from typing import Any, Union +from typing import Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from core.app.apps.completion.app_runner import CompletionAppRunner from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter @@ -30,12 +30,29 @@ logger = logging.getLogger(__name__) class CompletionAppGenerator(MessageBasedAppGenerator): - def generate(self, app_model: App, - user: Union[Account, EndUser], - args: Any, - invoke_from: InvokeFrom, - stream: bool = True) \ - -> Union[dict, Generator[dict, None, None]]: + @overload + def generate( + self, + app_model: App, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[True] = True, + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, + app_model: App, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[False] = False, + ) -> dict: ... + + def generate( + self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True + ) -> Union[dict, Generator[str, None, None]]: """ Generate App response. @@ -45,12 +62,12 @@ class CompletionAppGenerator(MessageBasedAppGenerator): :param invoke_from: invoke from source :param stream: is stream """ - query = args['query'] + query = args["query"] if not isinstance(query, str): - raise ValueError('query must be a string') + raise ValueError("query must be a string") - query = query.replace('\x00', '') - inputs = args['inputs'] + query = query.replace("\x00", "") + inputs = args["inputs"] extras = {} @@ -58,41 +75,31 @@ class CompletionAppGenerator(MessageBasedAppGenerator): conversation = None # get app model config - app_model_config = self._get_app_model_config( - app_model=app_model, - conversation=conversation - ) + app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) # validate override model config override_model_config_dict = None - if args.get('model_config'): + if args.get("model_config"): if invoke_from != InvokeFrom.DEBUGGER: - raise ValueError('Only in App debug mode can override model config') + raise ValueError("Only in App debug mode can override model config") # validate config override_model_config_dict = CompletionAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, - config=args.get('model_config') + tenant_id=app_model.tenant_id, config=args.get("model_config") ) # parse files - files = args['files'] if args.get('files') else [] + files = args["files"] if args.get("files") else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_extra_config, - user - ) + file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) else: file_objs = [] # convert to app config app_config = CompletionAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config, - override_config_dict=override_model_config_dict + app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict ) # get tracing instance @@ -110,14 +117,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator): stream=stream, invoke_from=invoke_from, extras=extras, - trace_manager=trace_manager + trace_manager=trace_manager, ) # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity) + (conversation, message) = self._init_generate_records(application_generate_entity) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -126,16 +130,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator): invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, app_mode=conversation.mode, - message_id=message.id + message_id=message.id, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'message_id': message.id, - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "message_id": message.id, + }, + ) worker_thread.start() @@ -149,15 +156,15 @@ class CompletionAppGenerator(MessageBasedAppGenerator): stream=stream, ) - return CompletionAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) - def _generate_worker(self, flask_app: Flask, - application_generate_entity: CompletionAppGenerateEntity, - queue_manager: AppQueueManager, - message_id: str) -> None: + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: CompletionAppGenerateEntity, + queue_manager: AppQueueManager, + message_id: str, + ) -> None: """ Generate worker in a new thread. :param flask_app: Flask app @@ -176,20 +183,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator): runner.run( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - message=message + message=message, ) - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: pass except InvokeAuthorizationError: queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER ) except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true": logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: @@ -198,12 +204,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator): finally: db.session.close() - def generate_more_like_this(self, app_model: App, - message_id: str, - user: Union[Account, EndUser], - invoke_from: InvokeFrom, - stream: bool = True) \ - -> Union[dict, Generator[dict, None, None]]: + def generate_more_like_this( + self, + app_model: App, + message_id: str, + user: Union[Account, EndUser], + invoke_from: InvokeFrom, + stream: bool = True, + ) -> Union[dict, Generator[str, None, None]]: """ Generate App response. @@ -213,13 +221,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator): :param invoke_from: invoke from source :param stream: is stream """ - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id, - Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), - Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), - Message.from_account_id == (user.id if isinstance(user, Account) else None), - ).first() + message = ( + db.session.query(Message) + .filter( + Message.id == message_id, + Message.app_id == app_model.id, + Message.from_source == ("api" if isinstance(user, EndUser) else "console"), + Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Message.from_account_id == (user.id if isinstance(user, Account) else None), + ) + .first() + ) if not message: raise MessageNotExistsError() @@ -232,29 +244,23 @@ class CompletionAppGenerator(MessageBasedAppGenerator): app_model_config = message.app_model_config override_model_config_dict = app_model_config.to_dict() - model_dict = override_model_config_dict['model'] - completion_params = model_dict.get('completion_params') - completion_params['temperature'] = 0.9 - model_dict['completion_params'] = completion_params - override_model_config_dict['model'] = model_dict + model_dict = override_model_config_dict["model"] + completion_params = model_dict.get("completion_params") + completion_params["temperature"] = 0.9 + model_dict["completion_params"] = completion_params + override_model_config_dict["model"] = model_dict # parse files message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - message.files, - file_extra_config, - user - ) + file_objs = message_file_parser.validate_and_transform_files_arg(message.files, file_extra_config, user) else: file_objs = [] # convert to app config app_config = CompletionAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config, - override_config_dict=override_model_config_dict + app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict ) # init application generate entity @@ -268,14 +274,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator): user_id=user.id, stream=stream, invoke_from=invoke_from, - extras={} + extras={}, ) # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity) + (conversation, message) = self._init_generate_records(application_generate_entity) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -284,16 +287,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator): invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, app_mode=conversation.mode, - message_id=message.id + message_id=message.id, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'message_id': message.id, - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "message_id": message.id, + }, + ) worker_thread.start() @@ -307,7 +313,4 @@ class CompletionAppGenerator(MessageBasedAppGenerator): stream=stream, ) - return CompletionAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index f0e5f9ae17..908d74ff53 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -9,7 +9,7 @@ from core.app.entities.app_invoke_entities import ( ) from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.model_manager import ModelInstance -from core.moderation.base import ModerationException +from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db from models.model import App, Message @@ -22,9 +22,9 @@ class CompletionAppRunner(AppRunner): Completion Application Runner """ - def run(self, application_generate_entity: CompletionAppGenerateEntity, - queue_manager: AppQueueManager, - message: Message) -> None: + def run( + self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message + ) -> None: """ Run application :param application_generate_entity: application generate entity @@ -54,7 +54,7 @@ class CompletionAppRunner(AppRunner): prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, - query=query + query=query, ) # organize all inputs and template to prompt messages @@ -65,7 +65,7 @@ class CompletionAppRunner(AppRunner): prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, - query=query + query=query, ) # moderation @@ -77,15 +77,15 @@ class CompletionAppRunner(AppRunner): app_generate_entity=application_generate_entity, inputs=inputs, query=query, - message_id=message.id + message_id=message.id, ) - except ModerationException as e: + except ModerationError as e: self.direct_output( queue_manager=queue_manager, app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=str(e), - stream=application_generate_entity.stream + stream=application_generate_entity.stream, ) return @@ -97,7 +97,7 @@ class CompletionAppRunner(AppRunner): app_id=app_record.id, external_data_tools=external_data_tools, inputs=inputs, - query=query + query=query, ) # get context from datasets @@ -108,7 +108,7 @@ class CompletionAppRunner(AppRunner): app_record.id, message.id, application_generate_entity.user_id, - application_generate_entity.invoke_from + application_generate_entity.invoke_from, ) dataset_config = app_config.dataset @@ -126,7 +126,7 @@ class CompletionAppRunner(AppRunner): invoke_from=application_generate_entity.invoke_from, show_retrieve_source=app_config.additional_features.show_retrieve_source, hit_callback=hit_callback, - message_id=message.id + message_id=message.id, ) # reorganize all inputs and template to prompt messages @@ -139,29 +139,26 @@ class CompletionAppRunner(AppRunner): inputs=inputs, files=files, query=query, - context=context + context=context, ) # check hosting moderation hosting_moderation_result = self.check_hosting_moderation( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - prompt_messages=prompt_messages + prompt_messages=prompt_messages, ) if hosting_moderation_result: return # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit - self.recalc_llm_max_tokens( - model_config=application_generate_entity.model_conf, - prompt_messages=prompt_messages - ) + self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages) # Invoke model model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, - model=application_generate_entity.model_conf.model + model=application_generate_entity.model_conf.model, ) db.session.close() @@ -176,8 +173,5 @@ class CompletionAppRunner(AppRunner): # handle invoke result self._handle_invoke_result( - invoke_result=invoke_result, - queue_manager=queue_manager, - stream=application_generate_entity.stream + invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream ) - \ No newline at end of file diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py index 14db74dbd0..697f0273a5 100644 --- a/api/core/app/apps/completion/generate_response_converter.py +++ b/api/core/app/apps/completion/generate_response_converter.py @@ -23,14 +23,14 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): :return: """ response = { - 'event': 'message', - 'task_id': blocking_response.task_id, - 'id': blocking_response.data.id, - 'message_id': blocking_response.data.message_id, - 'mode': blocking_response.data.mode, - 'answer': blocking_response.data.answer, - 'metadata': blocking_response.data.metadata, - 'created_at': blocking_response.data.created_at + "event": "message", + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, } return response @@ -44,14 +44,15 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): """ response = cls.convert_blocking_full_response(blocking_response) - metadata = response.get('metadata', {}) - response['metadata'] = cls._get_simple_metadata(metadata) + metadata = response.get("metadata", {}) + response["metadata"] = cls._get_simple_metadata(metadata) return response @classmethod - def convert_stream_full_response(cls, stream_response: Generator[CompletionAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_full_response( + cls, stream_response: Generator[CompletionAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -62,13 +63,13 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, ErrorStreamResponse): @@ -79,8 +80,9 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): yield json.dumps(response_chunk) @classmethod - def convert_stream_simple_response(cls, stream_response: Generator[CompletionAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_simple_response( + cls, stream_response: Generator[CompletionAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream simple response. :param stream_response: stream response @@ -91,19 +93,19 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, MessageEndStreamResponse): sub_stream_response_dict = sub_stream_response.to_dict() - metadata = sub_stream_response_dict.get('metadata', {}) - sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) + metadata = sub_stream_response_dict.get("metadata", {}) + sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index fceed95b91..f629c5c8b7 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -8,7 +8,7 @@ from sqlalchemy import and_ from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom from core.app.apps.base_app_generator import BaseAppGenerator -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, AgentChatAppGenerateEntity, @@ -35,23 +35,23 @@ logger = logging.getLogger(__name__) class MessageBasedAppGenerator(BaseAppGenerator): - def _handle_response( - self, application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity, - AdvancedChatAppGenerateEntity - ], - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - user: Union[Account, EndUser], - stream: bool = False, + self, + application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity, + ], + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool = False, ) -> Union[ ChatbotAppBlockingResponse, CompletionAppBlockingResponse, - Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None] + Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None], ]: """ Handle response. @@ -70,24 +70,25 @@ class MessageBasedAppGenerator(BaseAppGenerator): conversation=conversation, message=message, user=user, - stream=stream + stream=stream, ) try: return generate_task_pipeline.process() except ValueError as e: if e.args[0] == "I/O operation on closed file.": # ignore this error - raise GenerateTaskStoppedException() + raise GenerateTaskStoppedError() else: logger.exception(e) raise e - def _get_conversation_by_user(self, app_model: App, conversation_id: str, - user: Union[Account, EndUser]) -> Conversation: + def _get_conversation_by_user( + self, app_model: App, conversation_id: str, user: Union[Account, EndUser] + ) -> Conversation: conversation_filter = [ Conversation.id == conversation_id, Conversation.app_id == app_model.id, - Conversation.status == 'normal' + Conversation.status == "normal", ] if isinstance(user, Account): @@ -100,19 +101,18 @@ class MessageBasedAppGenerator(BaseAppGenerator): if not conversation: raise ConversationNotExistsError() - if conversation.status != 'normal': + if conversation.status != "normal": raise ConversationCompletedError() return conversation - def _get_app_model_config(self, app_model: App, - conversation: Optional[Conversation] = None) \ - -> AppModelConfig: + def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig: if conversation: - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == conversation.app_model_config_id, - AppModelConfig.app_id == app_model.id - ).first() + app_model_config = ( + db.session.query(AppModelConfig) + .filter(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id) + .first() + ) if not app_model_config: raise AppModelConfigBrokenError() @@ -127,15 +127,16 @@ class MessageBasedAppGenerator(BaseAppGenerator): return app_model_config - def _init_generate_records(self, - application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity, - AdvancedChatAppGenerateEntity - ], - conversation: Optional[Conversation] = None) \ - -> tuple[Conversation, Message]: + def _init_generate_records( + self, + application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity, + ], + conversation: Optional[Conversation] = None, + ) -> tuple[Conversation, Message]: """ Initialize generate records :param application_generate_entity: application generate entity @@ -148,10 +149,10 @@ class MessageBasedAppGenerator(BaseAppGenerator): end_user_id = None account_id = None if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - from_source = 'api' + from_source = "api" end_user_id = application_generate_entity.user_id else: - from_source = 'console' + from_source = "console" account_id = application_generate_entity.user_id if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity): @@ -164,8 +165,11 @@ class MessageBasedAppGenerator(BaseAppGenerator): model_provider = application_generate_entity.model_conf.provider model_id = application_generate_entity.model_conf.model override_model_configs = None - if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS \ - and app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]: + if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in [ + AppMode.AGENT_CHAT, + AppMode.CHAT, + AppMode.COMPLETION, + ]: override_model_configs = app_config.app_model_config_dict # get conversation introduction @@ -179,12 +183,12 @@ class MessageBasedAppGenerator(BaseAppGenerator): model_id=model_id, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, mode=app_config.app_mode.value, - name='New conversation', + name="New conversation", inputs=application_generate_entity.inputs, introduction=introduction, system_instruction="", system_instruction_tokens=0, - status='normal', + status="normal", invoke_from=application_generate_entity.invoke_from.value, from_source=from_source, from_end_user_id=end_user_id, @@ -216,11 +220,11 @@ class MessageBasedAppGenerator(BaseAppGenerator): answer_price_unit=0, provider_response_latency=0, total_price=0, - currency='USD', + currency="USD", invoke_from=application_generate_entity.invoke_from.value, from_source=from_source, from_end_user_id=end_user_id, - from_account_id=account_id + from_account_id=account_id, ) db.session.add(message) @@ -232,10 +236,10 @@ class MessageBasedAppGenerator(BaseAppGenerator): message_id=message.id, type=file.type.value, transfer_method=file.transfer_method.value, - belongs_to='user', + belongs_to="user", url=file.url, upload_file_id=file.related_id, - created_by_role=('account' if account_id else 'end_user'), + created_by_role=("account" if account_id else "end_user"), created_by=account_id or end_user_id, ) db.session.add(message_file) @@ -269,11 +273,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): :param conversation_id: conversation id :return: conversation """ - conversation = ( - db.session.query(Conversation) - .filter(Conversation.id == conversation_id) - .first() - ) + conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() if not conversation: raise ConversationNotExistsError() @@ -286,10 +286,6 @@ class MessageBasedAppGenerator(BaseAppGenerator): :param message_id: message id :return: message """ - message = ( - db.session.query(Message) - .filter(Message.id == message_id) - .first() - ) + message = db.session.query(Message).filter(Message.id == message_id).first() return message diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py index f4ff44ddda..363c3c82bb 100644 --- a/api/core/app/apps/message_based_app_queue_manager.py +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -1,4 +1,4 @@ -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, @@ -12,12 +12,9 @@ from core.app.entities.queue_entities import ( class MessageBasedAppQueueManager(AppQueueManager): - def __init__(self, task_id: str, - user_id: str, - invoke_from: InvokeFrom, - conversation_id: str, - app_mode: str, - message_id: str) -> None: + def __init__( + self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str + ) -> None: super().__init__(task_id, user_id, invoke_from) self._conversation_id = str(conversation_id) @@ -30,7 +27,7 @@ class MessageBasedAppQueueManager(AppQueueManager): message_id=self._message_id, conversation_id=self._conversation_id, app_mode=self._app_mode, - event=event + event=event, ) def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: @@ -45,17 +42,15 @@ class MessageBasedAppQueueManager(AppQueueManager): message_id=self._message_id, conversation_id=self._conversation_id, app_mode=self._app_mode, - event=event + event=event, ) self._q.put(message) - if isinstance(event, QueueStopEvent - | QueueErrorEvent - | QueueMessageEndEvent - | QueueAdvancedChatMessageEndEvent): + if isinstance( + event, QueueStopEvent | QueueErrorEvent | QueueMessageEndEvent | QueueAdvancedChatMessageEndEvent + ): self.stop_listen() if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): - raise GenerateTaskStoppedException() - + raise GenerateTaskStoppedError() diff --git a/api/core/app/apps/workflow/app_config_manager.py b/api/core/app/apps/workflow/app_config_manager.py index 36d3696d60..8b98e74b85 100644 --- a/api/core/app/apps/workflow/app_config_manager.py +++ b/api/core/app/apps/workflow/app_config_manager.py @@ -12,6 +12,7 @@ class WorkflowAppConfig(WorkflowUIBasedAppConfig): """ Workflow App Config Entity. """ + pass @@ -26,13 +27,9 @@ class WorkflowAppConfigManager(BaseAppConfigManager): app_id=app_model.id, app_mode=app_mode, workflow_id=workflow.id, - sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( - config=features_dict - ), - variables=WorkflowVariablesConfigManager.convert( - workflow=workflow - ), - additional_features=cls.convert_features(features_dict, app_mode) + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict), + variables=WorkflowVariablesConfigManager.convert(workflow=workflow), + additional_features=cls.convert_features(features_dict, app_mode), ) return app_config @@ -50,8 +47,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager): # file upload validation config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults( - config=config, - is_vision=False + config=config, is_vision=False ) related_config_keys.extend(current_related_config_keys) @@ -61,9 +57,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager): # moderation validation config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( - tenant_id=tenant_id, - config=config, - only_structure_validate=only_structure_validate + tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate ) related_config_keys.extend(current_related_config_keys) diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index df40aec154..57a77591a0 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -4,7 +4,7 @@ import os import threading import uuid from collections.abc import Generator -from typing import Union +from typing import Any, Literal, Optional, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -12,7 +12,7 @@ from pydantic import ValidationError import contexts from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.base_app_generator import BaseAppGenerator -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner @@ -32,14 +32,42 @@ logger = logging.getLogger(__name__) class WorkflowAppGenerator(BaseAppGenerator): + @overload def generate( - self, app_model: App, + self, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[True] = True, + call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None, + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[False] = False, + call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None, + ) -> dict: ... + + def generate( + self, + app_model: App, workflow: Workflow, user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom, stream: bool = True, call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None, ): """ Generate App response. @@ -51,27 +79,21 @@ class WorkflowAppGenerator(BaseAppGenerator): :param invoke_from: invoke from source :param stream: is stream :param call_depth: call depth + :param workflow_thread_pool_id: workflow thread pool id """ - inputs = args['inputs'] + inputs = args["inputs"] # parse files - files = args['files'] if args.get('files') else [] + files = args["files"] if args.get("files") else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_extra_config, - user - ) + file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) else: file_objs = [] # convert to app config - app_config = WorkflowAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow - ) + app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) # get tracing instance user_id = user.id if isinstance(user, Account) else user.session_id @@ -87,7 +109,7 @@ class WorkflowAppGenerator(BaseAppGenerator): stream=stream, invoke_from=invoke_from, call_depth=call_depth, - trace_manager=trace_manager + trace_manager=trace_manager, ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) @@ -98,16 +120,20 @@ class WorkflowAppGenerator(BaseAppGenerator): application_generate_entity=application_generate_entity, invoke_from=invoke_from, stream=stream, + workflow_thread_pool_id=workflow_thread_pool_id, ) def _generate( - self, app_model: App, + self, + *, + app_model: App, workflow: Workflow, user: Union[Account, EndUser], application_generate_entity: WorkflowAppGenerateEntity, invoke_from: InvokeFrom, stream: bool = True, - ) -> Union[dict, Generator[dict, None, None]]: + workflow_thread_pool_id: Optional[str] = None, + ) -> dict[str, Any] | Generator[str, None, None]: """ Generate App response. @@ -117,22 +143,27 @@ class WorkflowAppGenerator(BaseAppGenerator): :param application_generate_entity: application generate entity :param invoke_from: invoke from source :param stream: is stream + :param workflow_thread_pool_id: workflow thread pool id """ # init queue manager queue_manager = WorkflowAppQueueManager( task_id=application_generate_entity.task_id, user_id=application_generate_entity.user_id, invoke_from=application_generate_entity.invoke_from, - app_mode=app_model.mode + app_mode=app_model.mode, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'context': contextvars.copy_context() - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "context": contextvars.copy_context(), + "workflow_thread_pool_id": workflow_thread_pool_id, + }, + ) worker_thread.start() @@ -145,17 +176,11 @@ class WorkflowAppGenerator(BaseAppGenerator): stream=stream, ) - return WorkflowAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) - def single_iteration_generate(self, app_model: App, - workflow: Workflow, - node_id: str, - user: Account, - args: dict, - stream: bool = True): + def single_iteration_generate( + self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True + ) -> dict[str, Any] | Generator[str, Any, None]: """ Generate App response. @@ -167,20 +192,13 @@ class WorkflowAppGenerator(BaseAppGenerator): :param stream: is stream """ if not node_id: - raise ValueError('node_id is required') + raise ValueError("node_id is required") - if args.get('inputs') is None: - raise ValueError('inputs is required') - - extras = { - "auto_generate_conversation_name": False - } + if args.get("inputs") is None: + raise ValueError("inputs is required") # convert to app config - app_config = WorkflowAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow - ) + app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) # init application generate entity application_generate_entity = WorkflowAppGenerateEntity( @@ -191,11 +209,10 @@ class WorkflowAppGenerator(BaseAppGenerator): user_id=user.id, stream=stream, invoke_from=InvokeFrom.DEBUGGER, - extras=extras, + extras={"auto_generate_conversation_name": False}, single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( - node_id=node_id, - inputs=args['inputs'] - ) + node_id=node_id, inputs=args["inputs"] + ), ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) @@ -205,18 +222,23 @@ class WorkflowAppGenerator(BaseAppGenerator): user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, - stream=stream + stream=stream, ) - def _generate_worker(self, flask_app: Flask, - application_generate_entity: WorkflowAppGenerateEntity, - queue_manager: AppQueueManager, - context: contextvars.Context) -> None: + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: WorkflowAppGenerateEntity, + queue_manager: AppQueueManager, + context: contextvars.Context, + workflow_thread_pool_id: Optional[str] = None, + ) -> None: """ Generate worker in a new thread. :param flask_app: Flask app :param application_generate_entity: application generate entity :param queue_manager: queue manager + :param workflow_thread_pool_id: workflow thread pool id :return: """ for var, val in context.items(): @@ -224,50 +246,40 @@ class WorkflowAppGenerator(BaseAppGenerator): with flask_app.app_context(): try: # workflow app - runner = WorkflowAppRunner() - if application_generate_entity.single_iteration_run: - single_iteration_run = application_generate_entity.single_iteration_run - runner.single_iteration_run( - app_id=application_generate_entity.app_config.app_id, - workflow_id=application_generate_entity.app_config.workflow_id, - queue_manager=queue_manager, - inputs=single_iteration_run.inputs, - node_id=single_iteration_run.node_id, - user_id=application_generate_entity.user_id - ) - else: - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager - ) - except GenerateTaskStoppedException: + runner = WorkflowAppRunner( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + workflow_thread_pool_id=workflow_thread_pool_id, + ) + + runner.run() + except GenerateTaskStoppedError: pass except InvokeAuthorizationError: queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER ) except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == "true": logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: logger.exception("Unknown Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) finally: - db.session.remove() + db.session.close() - def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity, - workflow: Workflow, - queue_manager: AppQueueManager, - user: Union[Account, EndUser], - stream: bool = False) -> Union[ - WorkflowAppBlockingResponse, - Generator[WorkflowAppStreamResponse, None, None] - ]: + def _handle_response( + self, + application_generate_entity: WorkflowAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + user: Union[Account, EndUser], + stream: bool = False, + ) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: """ Handle response. :param application_generate_entity: application generate entity @@ -283,14 +295,14 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow=workflow, queue_manager=queue_manager, user=user, - stream=stream + stream=stream, ) try: return generate_task_pipeline.process() except ValueError as e: if e.args[0] == "I/O operation on closed file.": # ignore this error - raise GenerateTaskStoppedException() + raise GenerateTaskStoppedError() else: logger.exception(e) raise e diff --git a/api/core/app/apps/workflow/app_queue_manager.py b/api/core/app/apps/workflow/app_queue_manager.py index f448138b53..76371f800b 100644 --- a/api/core/app/apps/workflow/app_queue_manager.py +++ b/api/core/app/apps/workflow/app_queue_manager.py @@ -1,4 +1,4 @@ -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, @@ -12,10 +12,7 @@ from core.app.entities.queue_entities import ( class WorkflowAppQueueManager(AppQueueManager): - def __init__(self, task_id: str, - user_id: str, - invoke_from: InvokeFrom, - app_mode: str) -> None: + def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None: super().__init__(task_id, user_id, invoke_from) self._app_mode = app_mode @@ -27,20 +24,19 @@ class WorkflowAppQueueManager(AppQueueManager): :param pub_from: :return: """ - message = WorkflowQueueMessage( - task_id=self._task_id, - app_mode=self._app_mode, - event=event - ) + message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event) self._q.put(message) - if isinstance(event, QueueStopEvent - | QueueErrorEvent - | QueueMessageEndEvent - | QueueWorkflowSucceededEvent - | QueueWorkflowFailedEvent): + if isinstance( + event, + QueueStopEvent + | QueueErrorEvent + | QueueMessageEndEvent + | QueueWorkflowSucceededEvent + | QueueWorkflowFailedEvent, + ): self.stop_listen() if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): - raise GenerateTaskStoppedException() + raise GenerateTaskStoppedError() diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index e388d0184b..81c8463dd5 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -4,129 +4,125 @@ from typing import Optional, cast from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfig -from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback from core.app.entities.app_invoke_entities import ( InvokeFrom, WorkflowAppGenerateEntity, ) from core.workflow.callbacks.base_workflow_callback import WorkflowCallback +from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.base_node import UserFrom -from core.workflow.workflow_engine_manager import WorkflowEngineManager +from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.model import App, EndUser -from models.workflow import Workflow +from models.workflow import WorkflowType logger = logging.getLogger(__name__) -class WorkflowAppRunner: +class WorkflowAppRunner(WorkflowBasedAppRunner): """ Workflow Application Runner """ - def run(self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None: + def __init__( + self, + application_generate_entity: WorkflowAppGenerateEntity, + queue_manager: AppQueueManager, + workflow_thread_pool_id: Optional[str] = None, + ) -> None: + """ + :param application_generate_entity: application generate entity + :param queue_manager: application queue manager + :param workflow_thread_pool_id: workflow thread pool id + """ + self.application_generate_entity = application_generate_entity + self.queue_manager = queue_manager + self.workflow_thread_pool_id = workflow_thread_pool_id + + def run(self) -> None: """ Run application :param application_generate_entity: application generate entity :param queue_manager: application queue manager :return: """ - app_config = application_generate_entity.app_config + app_config = self.application_generate_entity.app_config app_config = cast(WorkflowAppConfig, app_config) user_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() + if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() if end_user: user_id = end_user.session_id else: - user_id = application_generate_entity.user_id + user_id = self.application_generate_entity.user_id app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: - raise ValueError('App not found') + raise ValueError("App not found") workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) if not workflow: - raise ValueError('Workflow not initialized') - - inputs = application_generate_entity.inputs - files = application_generate_entity.files + raise ValueError("Workflow not initialized") db.session.close() - workflow_callbacks: list[WorkflowCallback] = [ - WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow) - ] - - if bool(os.environ.get('DEBUG', 'False').lower() == 'true'): + workflow_callbacks: list[WorkflowCallback] = [] + if bool(os.environ.get("DEBUG", "False").lower() == "true"): workflow_callbacks.append(WorkflowLoggingCallback()) - # Create a variable pool. - system_inputs = { - SystemVariableKey.FILES: files, - SystemVariableKey.USER_ID: user_id, - } - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=inputs, - environment_variables=workflow.environment_variables, - conversation_variables=[], - ) + # if only single iteration run is requested + if self.application_generate_entity.single_iteration_run: + # if only single iteration run is requested + graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( + workflow=workflow, + node_id=self.application_generate_entity.single_iteration_run.node_id, + user_inputs=self.application_generate_entity.single_iteration_run.inputs, + ) + else: + inputs = self.application_generate_entity.inputs + files = self.application_generate_entity.files + + # Create a variable pool. + system_inputs = { + SystemVariableKey.FILES: files, + SystemVariableKey.USER_ID: user_id, + } + + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=inputs, + environment_variables=workflow.environment_variables, + conversation_variables=[], + ) + + # init graph + graph = self._init_graph(graph_config=workflow.graph_dict) # RUN WORKFLOW - workflow_engine_manager = WorkflowEngineManager() - workflow_engine_manager.run_workflow( - workflow=workflow, - user_id=application_generate_entity.user_id, - user_from=UserFrom.ACCOUNT - if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] - else UserFrom.END_USER, - invoke_from=application_generate_entity.invoke_from, - callbacks=workflow_callbacks, - call_depth=application_generate_entity.call_depth, + workflow_entry = WorkflowEntry( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + workflow_type=WorkflowType.value_of(workflow.type), + graph=graph, + graph_config=workflow.graph_dict, + user_id=self.application_generate_entity.user_id, + user_from=( + UserFrom.ACCOUNT + if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] + else UserFrom.END_USER + ), + invoke_from=self.application_generate_entity.invoke_from, + call_depth=self.application_generate_entity.call_depth, variable_pool=variable_pool, + thread_pool_id=self.workflow_thread_pool_id, ) - def single_iteration_run( - self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str - ) -> None: - """ - Single iteration run - """ - app_record = db.session.query(App).filter(App.id == app_id).first() - if not app_record: - raise ValueError('App not found') + generator = workflow_entry.run(callbacks=workflow_callbacks) - if not app_record.workflow_id: - raise ValueError('Workflow not initialized') - - workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id) - if not workflow: - raise ValueError('Workflow not initialized') - - workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)] - - workflow_engine_manager = WorkflowEngineManager() - workflow_engine_manager.single_step_run_iteration_workflow_node( - workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks - ) - - def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: - """ - Get workflow - """ - # fetch workflow by workflow_id - workflow = ( - db.session.query(Workflow) - .filter( - Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id - ) - .first() - ) - - # return workflow - return workflow + for event in generator: + self._handle_event(workflow_entry, event) diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index 88bde58ba0..08d00ee180 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -35,8 +35,9 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): return cls.convert_blocking_full_response(blocking_response) @classmethod - def convert_stream_full_response(cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_full_response( + cls, stream_response: Generator[WorkflowAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -47,12 +48,12 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'workflow_run_id': chunk.workflow_run_id, + "event": sub_stream_response.event.value, + "workflow_run_id": chunk.workflow_run_id, } if isinstance(sub_stream_response, ErrorStreamResponse): @@ -63,8 +64,9 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): yield json.dumps(response_chunk) @classmethod - def convert_stream_simple_response(cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_simple_response( + cls, stream_response: Generator[WorkflowAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream simple response. :param stream_response: stream response @@ -75,12 +77,12 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'workflow_run_id': chunk.workflow_run_id, + "event": sub_stream_response.event.value, + "workflow_run_id": chunk.workflow_run_id, } if isinstance(sub_stream_response, ErrorStreamResponse): diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index de8542d7b9..93edf8e0e8 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -1,3 +1,4 @@ +import json import logging import time from collections.abc import Generator @@ -15,10 +16,12 @@ from core.app.entities.queue_entities import ( QueueIterationCompletedEvent, QueueIterationNextEvent, QueueIterationStartEvent, - QueueMessageReplaceEvent, QueueNodeFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, + QueueParallelBranchRunFailedEvent, + QueueParallelBranchRunStartedEvent, + QueueParallelBranchRunSucceededEvent, QueuePingEvent, QueueStopEvent, QueueTextChunkEvent, @@ -32,19 +35,16 @@ from core.app.entities.task_entities import ( MessageAudioStreamResponse, StreamResponse, TextChunkStreamResponse, - TextReplaceStreamResponse, WorkflowAppBlockingResponse, WorkflowAppStreamResponse, WorkflowFinishStreamResponse, - WorkflowStreamGenerateNodes, + WorkflowStartStreamResponse, WorkflowTaskState, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.node_entities import NodeType from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.end.end_node import EndNode from extensions.ext_database import db from models.account import Account from models.model import EndUser @@ -52,8 +52,8 @@ from models.workflow import ( Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, - WorkflowNodeExecution, WorkflowRun, + WorkflowRunStatus, ) logger = logging.getLogger(__name__) @@ -63,18 +63,21 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa """ WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ + _workflow: Workflow _user: Union[Account, EndUser] _task_state: WorkflowTaskState _application_generate_entity: WorkflowAppGenerateEntity _workflow_system_variables: dict[SystemVariableKey, Any] - _iteration_nested_relations: dict[str, list[str]] - def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, - workflow: Workflow, - queue_manager: AppQueueManager, - user: Union[Account, EndUser], - stream: bool) -> None: + def __init__( + self, + application_generate_entity: WorkflowAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + user: Union[Account, EndUser], + stream: bool, + ) -> None: """ Initialize GenerateTaskPipeline. :param application_generate_entity: application generate entity @@ -93,14 +96,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa self._workflow = workflow self._workflow_system_variables = { SystemVariableKey.FILES: application_generate_entity.files, - SystemVariableKey.USER_ID: user_id + SystemVariableKey.USER_ID: user_id, } - self._task_state = WorkflowTaskState( - iteration_nested_node_ids=[] - ) - self._stream_generate_nodes = self._get_stream_generate_nodes() - self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict) + self._task_state = WorkflowTaskState() def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: """ @@ -111,16 +110,13 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa db.session.refresh(self._user) db.session.close() - generator = self._wrapper_process_stream_response( - trace_manager=self._application_generate_entity.trace_manager - ) + generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) if self._stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) - def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \ - -> WorkflowAppBlockingResponse: + def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse: """ To blocking response. :return: @@ -129,66 +125,69 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa if isinstance(stream_response, ErrorStreamResponse): raise stream_response.err elif isinstance(stream_response, WorkflowFinishStreamResponse): - workflow_run = db.session.query(WorkflowRun).filter( - WorkflowRun.id == self._task_state.workflow_run_id).first() - response = WorkflowAppBlockingResponse( task_id=self._application_generate_entity.task_id, - workflow_run_id=workflow_run.id, + workflow_run_id=stream_response.data.id, data=WorkflowAppBlockingResponse.Data( - id=workflow_run.id, - workflow_id=workflow_run.workflow_id, - status=workflow_run.status, - outputs=workflow_run.outputs_dict, - error=workflow_run.error, - elapsed_time=workflow_run.elapsed_time, - total_tokens=workflow_run.total_tokens, - total_steps=workflow_run.total_steps, - created_at=int(workflow_run.created_at.timestamp()), - finished_at=int(workflow_run.finished_at.timestamp()) - ) + id=stream_response.data.id, + workflow_id=stream_response.data.workflow_id, + status=stream_response.data.status, + outputs=stream_response.data.outputs, + error=stream_response.data.error, + elapsed_time=stream_response.data.elapsed_time, + total_tokens=stream_response.data.total_tokens, + total_steps=stream_response.data.total_steps, + created_at=int(stream_response.data.created_at), + finished_at=int(stream_response.data.finished_at), + ), ) return response else: continue - raise Exception('Queue listening stopped unexpectedly.') + raise Exception("Queue listening stopped unexpectedly.") - def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \ - -> Generator[WorkflowAppStreamResponse, None, None]: + def _to_stream_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Generator[WorkflowAppStreamResponse, None, None]: """ To stream response. :return: """ + workflow_run_id = None for stream_response in generator: - yield WorkflowAppStreamResponse( - workflow_run_id=self._task_state.workflow_run_id, - stream_response=stream_response - ) + if isinstance(stream_response, WorkflowStartStreamResponse): + workflow_run_id = stream_response.workflow_run_id - def _listenAudioMsg(self, publisher, task_id: str): + yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response) + + def _listen_audio_msg(self, publisher, task_id: str): if not publisher: return None - audio_msg: AudioTrunk = publisher.checkAndGetAudio() + audio_msg: AudioTrunk = publisher.check_and_get_audio() if audio_msg and audio_msg.status != "finish": return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None - def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ - Generator[StreamResponse, None, None]: - - publisher = None + def _wrapper_process_stream_response( + self, trace_manager: Optional[TraceQueueManager] = None + ) -> Generator[StreamResponse, None, None]: + tts_publisher = None task_id = self._application_generate_entity.task_id tenant_id = self._application_generate_entity.app_config.tenant_id features_dict = self._workflow.features_dict - if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[ - 'text_to_speech'].get('autoPlay') == 'enabled': - publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) - for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): + if ( + features_dict.get("text_to_speech") + and features_dict["text_to_speech"].get("enabled") + and features_dict["text_to_speech"].get("autoPlay") == "enabled" + ): + tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice")) + + for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listenAudioMsg(publisher, task_id=task_id) + audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -198,9 +197,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa start_listener_time = time.time() while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: try: - if not publisher: + if not tts_publisher: break - audio_trunk = publisher.checkAndGetAudio() + audio_trunk = tts_publisher.check_and_get_audio() if audio_trunk is None: # release cpu # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) @@ -213,105 +212,178 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa except Exception as e: logger.error(e) break - yield MessageAudioEndStreamResponse(audio='', task_id=task_id) - + yield MessageAudioEndStreamResponse(audio="", task_id=task_id) def _process_stream_response( self, - publisher: AppGeneratorTTSPublisher, - trace_manager: Optional[TraceQueueManager] = None + tts_publisher: Optional[AppGeneratorTTSPublisher] = None, + trace_manager: Optional[TraceQueueManager] = None, ) -> Generator[StreamResponse, None, None]: """ Process stream response. :return: """ - for message in self._queue_manager.listen(): - if publisher: - publisher.publish(message=message) - event = message.event + graph_runtime_state = None + workflow_run = None - if isinstance(event, QueueErrorEvent): + for queue_message in self._queue_manager.listen(): + event = queue_message.event + + if isinstance(event, QueuePingEvent): + yield self._ping_stream_response() + elif isinstance(event, QueueErrorEvent): err = self._handle_error(event) yield self._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): - workflow_run = self._handle_workflow_start() + # override graph runtime state + graph_runtime_state = event.graph_runtime_state + + # init workflow run + workflow_run = self._handle_workflow_run_start() yield self._workflow_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) elif isinstance(event, QueueNodeStartedEvent): - workflow_node_execution = self._handle_node_start(event) + if not workflow_run: + raise Exception("Workflow run not initialized.") - # search stream_generate_routes if node id is answer start at node - if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_nodes: - self._task_state.current_stream_generate_state = self._stream_generate_nodes[event.node_id] + workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) - # generate stream outputs when node started - yield from self._generate_stream_outputs_when_node_started() - - yield self._workflow_node_start_to_stream_response( + response = self._workflow_node_start_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution + workflow_node_execution=workflow_node_execution, ) - elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): - workflow_node_execution = self._handle_node_finished(event) - yield self._workflow_node_finish_to_stream_response( + if response: + yield response + elif isinstance(event, QueueNodeSucceededEvent): + workflow_node_execution = self._handle_workflow_node_execution_success(event) + + response = self._workflow_node_finish_to_stream_response( + event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution + workflow_node_execution=workflow_node_execution, ) - if isinstance(event, QueueNodeFailedEvent): - yield from self._handle_iteration_exception( - task_id=self._application_generate_entity.task_id, - error=f'Child node failed: {event.error}' - ) - elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent): - if isinstance(event, QueueIterationNextEvent): - # clear ran node execution infos of current iteration - iteration_relations = self._iteration_nested_relations.get(event.node_id) - if iteration_relations: - for node_id in iteration_relations: - self._task_state.ran_node_execution_infos.pop(node_id, None) + if response: + yield response + elif isinstance(event, QueueNodeFailedEvent): + workflow_node_execution = self._handle_workflow_node_execution_failed(event) - yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event) - self._handle_iteration_operation(event) - elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - workflow_run = self._handle_workflow_finished( - event, trace_manager=trace_manager + response = self._workflow_node_finish_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + if response: + yield response + elif isinstance(event, QueueParallelBranchRunStartedEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + yield self._workflow_parallel_branch_start_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event + ) + elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + yield self._workflow_parallel_branch_finished_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event + ) + elif isinstance(event, QueueIterationStartEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + yield self._workflow_iteration_start_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event + ) + elif isinstance(event, QueueIterationNextEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + yield self._workflow_iteration_next_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event + ) + elif isinstance(event, QueueIterationCompletedEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + yield self._workflow_iteration_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event + ) + elif isinstance(event, QueueWorkflowSucceededEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + if not graph_runtime_state: + raise Exception("Graph runtime state not initialized.") + + workflow_run = self._handle_workflow_run_success( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=json.dumps(event.outputs) + if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs + else None, + conversation_id=None, + trace_manager=trace_manager, ) # save workflow app log self._save_workflow_app_log(workflow_run) yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + if not graph_runtime_state: + raise Exception("Graph runtime state not initialized.") + + workflow_run = self._handle_workflow_run_failed( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.FAILED + if isinstance(event, QueueWorkflowFailedEvent) + else WorkflowRunStatus.STOPPED, + error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), + conversation_id=None, + trace_manager=trace_manager, + ) + + # save workflow app log + self._save_workflow_app_log(workflow_run) + + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) elif isinstance(event, QueueTextChunkEvent): delta_text = event.text if delta_text is None: continue - if not self._is_stream_out_support( - event=event - ): - continue + # only publish tts message at text chunk streaming + if tts_publisher: + tts_publisher.publish(message=queue_message) self._task_state.answer += delta_text - yield self._text_chunk_to_stream_response(delta_text) - elif isinstance(event, QueueMessageReplaceEvent): - yield self._text_replace_to_stream_response(event.text) - elif isinstance(event, QueuePingEvent): - yield self._ping_stream_response() + yield self._text_chunk_to_stream_response( + delta_text, from_variable_selector=event.from_variable_selector + ) else: continue - if publisher: - publisher.publish(None) - + if tts_publisher: + tts_publisher.publish(None) def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: """ @@ -329,20 +401,22 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa # not save log for debugging return - workflow_app_log = WorkflowAppLog( - tenant_id=workflow_run.tenant_id, - app_id=workflow_run.app_id, - workflow_id=workflow_run.workflow_id, - workflow_run_id=workflow_run.id, - created_from=created_from.value, - created_by_role=('account' if isinstance(self._user, Account) else 'end_user'), - created_by=self._user.id, - ) + workflow_app_log = WorkflowAppLog() + workflow_app_log.tenant_id = workflow_run.tenant_id + workflow_app_log.app_id = workflow_run.app_id + workflow_app_log.workflow_id = workflow_run.workflow_id + workflow_app_log.workflow_run_id = workflow_run.id + workflow_app_log.created_from = created_from.value + workflow_app_log.created_by_role = "account" if isinstance(self._user, Account) else "end_user" + workflow_app_log.created_by = self._user.id + db.session.add(workflow_app_log) db.session.commit() db.session.close() - def _text_chunk_to_stream_response(self, text: str) -> TextChunkStreamResponse: + def _text_chunk_to_stream_response( + self, text: str, from_variable_selector: Optional[list[str]] = None + ) -> TextChunkStreamResponse: """ Handle completed event. :param text: text @@ -350,184 +424,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa """ response = TextChunkStreamResponse( task_id=self._application_generate_entity.task_id, - data=TextChunkStreamResponse.Data(text=text) + data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector), ) return response - - def _text_replace_to_stream_response(self, text: str) -> TextReplaceStreamResponse: - """ - Text replace to stream response. - :param text: text - :return: - """ - return TextReplaceStreamResponse( - task_id=self._application_generate_entity.task_id, - text=TextReplaceStreamResponse.Data(text=text) - ) - - def _get_stream_generate_nodes(self) -> dict[str, WorkflowStreamGenerateNodes]: - """ - Get stream generate nodes. - :return: - """ - # find all answer nodes - graph = self._workflow.graph_dict - end_node_configs = [ - node for node in graph['nodes'] - if node.get('data', {}).get('type') == NodeType.END.value - ] - - # parse stream output node value selectors of end nodes - stream_generate_routes = {} - for node_config in end_node_configs: - # get generate route for stream output - end_node_id = node_config['id'] - generate_nodes = EndNode.extract_generate_nodes(graph, node_config) - start_node_ids = self._get_end_start_at_node_ids(graph, end_node_id) - if not start_node_ids: - continue - - for start_node_id in start_node_ids: - stream_generate_routes[start_node_id] = WorkflowStreamGenerateNodes( - end_node_id=end_node_id, - stream_node_ids=generate_nodes - ) - - return stream_generate_routes - - def _get_end_start_at_node_ids(self, graph: dict, target_node_id: str) \ - -> list[str]: - """ - Get end start at node id. - :param graph: graph - :param target_node_id: target node ID - :return: - """ - nodes = graph.get('nodes') - edges = graph.get('edges') - - # fetch all ingoing edges from source node - ingoing_edges = [] - for edge in edges: - if edge.get('target') == target_node_id: - ingoing_edges.append(edge) - - if not ingoing_edges: - return [] - - start_node_ids = [] - for ingoing_edge in ingoing_edges: - source_node_id = ingoing_edge.get('source') - source_node = next((node for node in nodes if node.get('id') == source_node_id), None) - if not source_node: - continue - - node_type = source_node.get('data', {}).get('type') - node_iteration_id = source_node.get('data', {}).get('iteration_id') - iteration_start_node_id = None - if node_iteration_id: - iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None) - iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id') - - if node_type in [ - NodeType.IF_ELSE.value, - NodeType.QUESTION_CLASSIFIER.value - ]: - start_node_id = target_node_id - start_node_ids.append(start_node_id) - elif node_type == NodeType.START.value or \ - node_iteration_id is not None and iteration_start_node_id == source_node.get('id'): - start_node_id = source_node_id - start_node_ids.append(start_node_id) - else: - sub_start_node_ids = self._get_end_start_at_node_ids(graph, source_node_id) - if sub_start_node_ids: - start_node_ids.extend(sub_start_node_ids) - - return start_node_ids - - def _generate_stream_outputs_when_node_started(self) -> Generator: - """ - Generate stream outputs. - :return: - """ - if self._task_state.current_stream_generate_state: - stream_node_ids = self._task_state.current_stream_generate_state.stream_node_ids - - for node_id, node_execution_info in self._task_state.ran_node_execution_infos.items(): - if node_id not in stream_node_ids: - continue - - node_execution_info = self._task_state.ran_node_execution_infos[node_id] - - # get chunk node execution - route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == node_execution_info.workflow_node_execution_id).first() - - if not route_chunk_node_execution: - continue - - outputs = route_chunk_node_execution.outputs_dict - - if not outputs: - continue - - # get value from outputs - text = outputs.get('text') - - if text: - self._task_state.answer += text - yield self._text_chunk_to_stream_response(text) - - db.session.close() - - def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool: - """ - Is stream out support - :param event: queue text chunk event - :return: - """ - if not event.metadata: - return False - - if 'node_id' not in event.metadata: - return False - - node_id = event.metadata.get('node_id') - node_type = event.metadata.get('node_type') - stream_output_value_selector = event.metadata.get('value_selector') - if not stream_output_value_selector: - return False - - if not self._task_state.current_stream_generate_state: - return False - - if node_id not in self._task_state.current_stream_generate_state.stream_node_ids: - return False - - if node_type != NodeType.LLM: - # only LLM support chunk stream output - return False - - return True - - def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]: - """ - Get iteration nested relations. - :param graph: graph - :return: - """ - nodes = graph.get('nodes') - - iteration_ids = [node.get('id') for node in nodes - if node.get('data', {}).get('type') in [ - NodeType.ITERATION.value, - NodeType.LOOP.value, - ]] - - return { - iteration_id: [ - node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id - ] for iteration_id in iteration_ids - } diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py deleted file mode 100644 index 4472a7e9b5..0000000000 --- a/api/core/app/apps/workflow/workflow_event_trigger_callback.py +++ /dev/null @@ -1,200 +0,0 @@ -from typing import Any, Optional - -from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.entities.queue_entities import ( - AppQueueEvent, - QueueIterationCompletedEvent, - QueueIterationNextEvent, - QueueIterationStartEvent, - QueueNodeFailedEvent, - QueueNodeStartedEvent, - QueueNodeSucceededEvent, - QueueTextChunkEvent, - QueueWorkflowFailedEvent, - QueueWorkflowStartedEvent, - QueueWorkflowSucceededEvent, -) -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType -from models.workflow import Workflow - - -class WorkflowEventTriggerCallback(WorkflowCallback): - - def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): - self._queue_manager = queue_manager - - def on_workflow_run_started(self) -> None: - """ - Workflow run started - """ - self._queue_manager.publish( - QueueWorkflowStartedEvent(), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_run_succeeded(self) -> None: - """ - Workflow run succeeded - """ - self._queue_manager.publish( - QueueWorkflowSucceededEvent(), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_run_failed(self, error: str) -> None: - """ - Workflow run failed - """ - self._queue_manager.publish( - QueueWorkflowFailedEvent( - error=error - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_started(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - node_run_index: int = 1, - predecessor_node_id: Optional[str] = None) -> None: - """ - Workflow node execute started - """ - self._queue_manager.publish( - QueueNodeStartedEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - node_run_index=node_run_index, - predecessor_node_id=predecessor_node_id - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_succeeded(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None) -> None: - """ - Workflow node execute succeeded - """ - self._queue_manager.publish( - QueueNodeSucceededEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - inputs=inputs, - process_data=process_data, - outputs=outputs, - execution_metadata=execution_metadata - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_failed(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - error: str, - inputs: Optional[dict] = None, - outputs: Optional[dict] = None, - process_data: Optional[dict] = None) -> None: - """ - Workflow node execute failed - """ - self._queue_manager.publish( - QueueNodeFailedEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - inputs=inputs, - outputs=outputs, - process_data=process_data, - error=error - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: - """ - Publish text chunk - """ - self._queue_manager.publish( - QueueTextChunkEvent( - text=text, - metadata={ - "node_id": node_id, - **metadata - } - ), PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_started(self, - node_id: str, - node_type: NodeType, - node_run_index: int = 1, - node_data: Optional[BaseNodeData] = None, - inputs: dict = None, - predecessor_node_id: Optional[str] = None, - metadata: Optional[dict] = None) -> None: - """ - Publish iteration started - """ - self._queue_manager.publish( - QueueIterationStartEvent( - node_id=node_id, - node_type=node_type, - node_run_index=node_run_index, - node_data=node_data, - inputs=inputs, - predecessor_node_id=predecessor_node_id, - metadata=metadata - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_next(self, node_id: str, - node_type: NodeType, - index: int, - node_run_index: int, - output: Optional[Any]) -> None: - """ - Publish iteration next - """ - self._queue_manager.publish( - QueueIterationNextEvent( - node_id=node_id, - node_type=node_type, - index=index, - node_run_index=node_run_index, - output=output - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_completed(self, node_id: str, - node_type: NodeType, - node_run_index: int, - outputs: dict) -> None: - """ - Publish iteration completed - """ - self._queue_manager.publish( - QueueIterationCompletedEvent( - node_id=node_id, - node_type=node_type, - node_run_index=node_run_index, - outputs=outputs - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_event(self, event: AppQueueEvent) -> None: - """ - Publish event - """ - pass diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py new file mode 100644 index 0000000000..ce266116a7 --- /dev/null +++ b/api/core/app/apps/workflow_app_runner.py @@ -0,0 +1,371 @@ +from collections.abc import Mapping +from typing import Any, Optional, cast + +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.base_app_runner import AppRunner +from core.app.entities.queue_entities import ( + AppQueueEvent, + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, + QueueNodeFailedEvent, + QueueNodeStartedEvent, + QueueNodeSucceededEvent, + QueueParallelBranchRunFailedEvent, + QueueParallelBranchRunStartedEvent, + QueueParallelBranchRunSucceededEvent, + QueueRetrieverResourcesEvent, + QueueTextChunkEvent, + QueueWorkflowFailedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + GraphRunFailedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + IterationRunFailedEvent, + IterationRunNextEvent, + IterationRunStartedEvent, + IterationRunSucceededEvent, + NodeRunFailedEvent, + NodeRunRetrieverResourceEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, + ParallelBranchRunFailedEvent, + ParallelBranchRunStartedEvent, + ParallelBranchRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.iteration.entities import IterationNodeData +from core.workflow.nodes.node_mapping import node_classes +from core.workflow.workflow_entry import WorkflowEntry +from extensions.ext_database import db +from models.model import App +from models.workflow import Workflow + + +class WorkflowBasedAppRunner(AppRunner): + def __init__(self, queue_manager: AppQueueManager): + self.queue_manager = queue_manager + + def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph: + """ + Init graph + """ + if "nodes" not in graph_config or "edges" not in graph_config: + raise ValueError("nodes or edges not found in workflow graph") + + if not isinstance(graph_config.get("nodes"), list): + raise ValueError("nodes in workflow graph must be a list") + + if not isinstance(graph_config.get("edges"), list): + raise ValueError("edges in workflow graph must be a list") + # init graph + graph = Graph.init(graph_config=graph_config) + + if not graph: + raise ValueError("graph not found in workflow") + + return graph + + def _get_graph_and_variable_pool_of_single_iteration( + self, + workflow: Workflow, + node_id: str, + user_inputs: dict, + ) -> tuple[Graph, VariablePool]: + """ + Get variable pool of single iteration + """ + # fetch workflow graph + graph_config = workflow.graph_dict + if not graph_config: + raise ValueError("workflow graph not found") + + graph_config = cast(dict[str, Any], graph_config) + + if "nodes" not in graph_config or "edges" not in graph_config: + raise ValueError("nodes or edges not found in workflow graph") + + if not isinstance(graph_config.get("nodes"), list): + raise ValueError("nodes in workflow graph must be a list") + + if not isinstance(graph_config.get("edges"), list): + raise ValueError("edges in workflow graph must be a list") + + # filter nodes only in iteration + node_configs = [ + node + for node in graph_config.get("nodes", []) + if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id + ] + + graph_config["nodes"] = node_configs + + node_ids = [node.get("id") for node in node_configs] + + # filter edges only in iteration + edge_configs = [ + edge + for edge in graph_config.get("edges", []) + if (edge.get("source") is None or edge.get("source") in node_ids) + and (edge.get("target") is None or edge.get("target") in node_ids) + ] + + graph_config["edges"] = edge_configs + + # init graph + graph = Graph.init(graph_config=graph_config, root_node_id=node_id) + + if not graph: + raise ValueError("graph not found in workflow") + + # fetch node config from node id + iteration_node_config = None + for node in node_configs: + if node.get("id") == node_id: + iteration_node_config = node + break + + if not iteration_node_config: + raise ValueError("iteration node id not found in workflow graph") + + # Get node class + node_type = NodeType.value_of(iteration_node_config.get("data", {}).get("type")) + node_cls = node_classes.get(node_type) + node_cls = cast(type[BaseNode], node_cls) + + # init variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + environment_variables=workflow.environment_variables, + ) + + try: + variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( + graph_config=workflow.graph_dict, config=iteration_node_config + ) + except NotImplementedError: + variable_mapping = {} + + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id=workflow.tenant_id, + node_type=node_type, + node_data=IterationNodeData(**iteration_node_config.get("data", {})), + ) + + return graph, variable_pool + + def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None: + """ + Handle event + :param workflow_entry: workflow entry + :param event: event + """ + if isinstance(event, GraphRunStartedEvent): + self._publish_event( + QueueWorkflowStartedEvent(graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state) + ) + elif isinstance(event, GraphRunSucceededEvent): + self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs)) + elif isinstance(event, GraphRunFailedEvent): + self._publish_event(QueueWorkflowFailedEvent(error=event.error)) + elif isinstance(event, NodeRunStartedEvent): + self._publish_event( + QueueNodeStartedEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.route_node_state.start_at, + node_run_index=event.route_node_state.index, + predecessor_node_id=event.predecessor_node_id, + in_iteration_id=event.in_iteration_id, + ) + ) + elif isinstance(event, NodeRunSucceededEvent): + self._publish_event( + QueueNodeSucceededEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.route_node_state.start_at, + inputs=event.route_node_state.node_run_result.inputs + if event.route_node_state.node_run_result + else {}, + process_data=event.route_node_state.node_run_result.process_data + if event.route_node_state.node_run_result + else {}, + outputs=event.route_node_state.node_run_result.outputs + if event.route_node_state.node_run_result + else {}, + execution_metadata=event.route_node_state.node_run_result.metadata + if event.route_node_state.node_run_result + else {}, + in_iteration_id=event.in_iteration_id, + ) + ) + elif isinstance(event, NodeRunFailedEvent): + self._publish_event( + QueueNodeFailedEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.route_node_state.start_at, + inputs=event.route_node_state.node_run_result.inputs + if event.route_node_state.node_run_result + else {}, + process_data=event.route_node_state.node_run_result.process_data + if event.route_node_state.node_run_result + else {}, + outputs=event.route_node_state.node_run_result.outputs + if event.route_node_state.node_run_result + else {}, + error=event.route_node_state.node_run_result.error + if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error + else "Unknown error", + in_iteration_id=event.in_iteration_id, + ) + ) + elif isinstance(event, NodeRunStreamChunkEvent): + self._publish_event( + QueueTextChunkEvent( + text=event.chunk_content, + from_variable_selector=event.from_variable_selector, + in_iteration_id=event.in_iteration_id, + ) + ) + elif isinstance(event, NodeRunRetrieverResourceEvent): + self._publish_event( + QueueRetrieverResourcesEvent( + retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id + ) + ) + elif isinstance(event, ParallelBranchRunStartedEvent): + self._publish_event( + QueueParallelBranchRunStartedEvent( + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + in_iteration_id=event.in_iteration_id, + ) + ) + elif isinstance(event, ParallelBranchRunSucceededEvent): + self._publish_event( + QueueParallelBranchRunSucceededEvent( + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + in_iteration_id=event.in_iteration_id, + ) + ) + elif isinstance(event, ParallelBranchRunFailedEvent): + self._publish_event( + QueueParallelBranchRunFailedEvent( + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + in_iteration_id=event.in_iteration_id, + error=event.error, + ) + ) + elif isinstance(event, IterationRunStartedEvent): + self._publish_event( + QueueIterationStartEvent( + node_execution_id=event.iteration_id, + node_id=event.iteration_node_id, + node_type=event.iteration_node_type, + node_data=event.iteration_node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.start_at, + node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, + inputs=event.inputs, + predecessor_node_id=event.predecessor_node_id, + metadata=event.metadata, + ) + ) + elif isinstance(event, IterationRunNextEvent): + self._publish_event( + QueueIterationNextEvent( + node_execution_id=event.iteration_id, + node_id=event.iteration_node_id, + node_type=event.iteration_node_type, + node_data=event.iteration_node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + index=event.index, + node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, + output=event.pre_iteration_output, + ) + ) + elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)): + self._publish_event( + QueueIterationCompletedEvent( + node_execution_id=event.iteration_id, + node_id=event.iteration_node_id, + node_type=event.iteration_node_type, + node_data=event.iteration_node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.start_at, + node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, + inputs=event.inputs, + outputs=event.outputs, + metadata=event.metadata, + steps=event.steps, + error=event.error if isinstance(event, IterationRunFailedEvent) else None, + ) + ) + + def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: + """ + Get workflow + """ + # fetch workflow by workflow_id + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id + ) + .first() + ) + + # return workflow + return workflow + + def _publish_event(self, event: AppQueueEvent) -> None: + self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) diff --git a/api/core/app/apps/workflow_logging_callback.py b/api/core/app/apps/workflow_logging_callback.py index 2e6431d6d0..cdd21bf7c2 100644 --- a/api/core/app/apps/workflow_logging_callback.py +++ b/api/core/app/apps/workflow_logging_callback.py @@ -1,10 +1,24 @@ from typing import Optional -from core.app.entities.queue_entities import AppQueueEvent from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + GraphRunFailedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + IterationRunFailedEvent, + IterationRunNextEvent, + IterationRunStartedEvent, + IterationRunSucceededEvent, + NodeRunFailedEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, + ParallelBranchRunFailedEvent, + ParallelBranchRunStartedEvent, + ParallelBranchRunSucceededEvent, +) _TEXT_COLOR_MAPPING = { "blue": "36;1", @@ -16,138 +30,184 @@ _TEXT_COLOR_MAPPING = { class WorkflowLoggingCallback(WorkflowCallback): - def __init__(self) -> None: self.current_node_id = None - def on_workflow_run_started(self) -> None: - """ - Workflow run started - """ - self.print_text("\n[on_workflow_run_started]", color='pink') + def on_event(self, event: GraphEngineEvent) -> None: + if isinstance(event, GraphRunStartedEvent): + self.print_text("\n[GraphRunStartedEvent]", color="pink") + elif isinstance(event, GraphRunSucceededEvent): + self.print_text("\n[GraphRunSucceededEvent]", color="green") + elif isinstance(event, GraphRunFailedEvent): + self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red") + elif isinstance(event, NodeRunStartedEvent): + self.on_workflow_node_execute_started(event=event) + elif isinstance(event, NodeRunSucceededEvent): + self.on_workflow_node_execute_succeeded(event=event) + elif isinstance(event, NodeRunFailedEvent): + self.on_workflow_node_execute_failed(event=event) + elif isinstance(event, NodeRunStreamChunkEvent): + self.on_node_text_chunk(event=event) + elif isinstance(event, ParallelBranchRunStartedEvent): + self.on_workflow_parallel_started(event=event) + elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent): + self.on_workflow_parallel_completed(event=event) + elif isinstance(event, IterationRunStartedEvent): + self.on_workflow_iteration_started(event=event) + elif isinstance(event, IterationRunNextEvent): + self.on_workflow_iteration_next(event=event) + elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent): + self.on_workflow_iteration_completed(event=event) + else: + self.print_text(f"\n[{event.__class__.__name__}]", color="blue") - def on_workflow_run_succeeded(self) -> None: - """ - Workflow run succeeded - """ - self.print_text("\n[on_workflow_run_succeeded]", color='green') - - def on_workflow_run_failed(self, error: str) -> None: - """ - Workflow run failed - """ - self.print_text("\n[on_workflow_run_failed]", color='red') - - def on_workflow_node_execute_started(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - node_run_index: int = 1, - predecessor_node_id: Optional[str] = None) -> None: + def on_workflow_node_execute_started(self, event: NodeRunStartedEvent) -> None: """ Workflow node execute started """ - self.print_text("\n[on_workflow_node_execute_started]", color='yellow') - self.print_text(f"Node ID: {node_id}", color='yellow') - self.print_text(f"Type: {node_type.value}", color='yellow') - self.print_text(f"Index: {node_run_index}", color='yellow') - if predecessor_node_id: - self.print_text(f"Predecessor Node ID: {predecessor_node_id}", color='yellow') + self.print_text("\n[NodeRunStartedEvent]", color="yellow") + self.print_text(f"Node ID: {event.node_id}", color="yellow") + self.print_text(f"Node Title: {event.node_data.title}", color="yellow") + self.print_text(f"Type: {event.node_type.value}", color="yellow") - def on_workflow_node_execute_succeeded(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None) -> None: + def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent) -> None: """ Workflow node execute succeeded """ - self.print_text("\n[on_workflow_node_execute_succeeded]", color='green') - self.print_text(f"Node ID: {node_id}", color='green') - self.print_text(f"Type: {node_type.value}", color='green') - self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='green') - self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='green') - self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='green') - self.print_text(f"Metadata: {jsonable_encoder(execution_metadata) if execution_metadata else ''}", - color='green') + route_node_state = event.route_node_state - def on_workflow_node_execute_failed(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - error: str, - inputs: Optional[dict] = None, - outputs: Optional[dict] = None, - process_data: Optional[dict] = None) -> None: + self.print_text("\n[NodeRunSucceededEvent]", color="green") + self.print_text(f"Node ID: {event.node_id}", color="green") + self.print_text(f"Node Title: {event.node_data.title}", color="green") + self.print_text(f"Type: {event.node_type.value}", color="green") + + if route_node_state.node_run_result: + node_run_result = route_node_state.node_run_result + self.print_text( + f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", color="green" + ) + self.print_text( + f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}", + color="green", + ) + self.print_text( + f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", + color="green", + ) + self.print_text( + f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}", + color="green", + ) + + def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent) -> None: """ Workflow node execute failed """ - self.print_text("\n[on_workflow_node_execute_failed]", color='red') - self.print_text(f"Node ID: {node_id}", color='red') - self.print_text(f"Type: {node_type.value}", color='red') - self.print_text(f"Error: {error}", color='red') - self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='red') - self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='red') - self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='red') + route_node_state = event.route_node_state - def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: + self.print_text("\n[NodeRunFailedEvent]", color="red") + self.print_text(f"Node ID: {event.node_id}", color="red") + self.print_text(f"Node Title: {event.node_data.title}", color="red") + self.print_text(f"Type: {event.node_type.value}", color="red") + + if route_node_state.node_run_result: + node_run_result = route_node_state.node_run_result + self.print_text(f"Error: {node_run_result.error}", color="red") + self.print_text( + f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", color="red" + ) + self.print_text( + f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}", + color="red", + ) + self.print_text( + f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", color="red" + ) + + def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None: """ Publish text chunk """ - if not self.current_node_id or self.current_node_id != node_id: - self.current_node_id = node_id - self.print_text('\n[on_node_text_chunk]') - self.print_text(f"Node ID: {node_id}") - self.print_text(f"Metadata: {jsonable_encoder(metadata) if metadata else ''}") + route_node_state = event.route_node_state + if not self.current_node_id or self.current_node_id != route_node_state.node_id: + self.current_node_id = route_node_state.node_id + self.print_text("\n[NodeRunStreamChunkEvent]") + self.print_text(f"Node ID: {route_node_state.node_id}") - self.print_text(text, color="pink", end="") + node_run_result = route_node_state.node_run_result + if node_run_result: + self.print_text( + f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}" + ) - def on_workflow_iteration_started(self, - node_id: str, - node_type: NodeType, - node_run_index: int = 1, - node_data: Optional[BaseNodeData] = None, - inputs: dict = None, - predecessor_node_id: Optional[str] = None, - metadata: Optional[dict] = None) -> None: + self.print_text(event.chunk_content, color="pink", end="") + + def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent) -> None: + """ + Publish parallel started + """ + self.print_text("\n[ParallelBranchRunStartedEvent]", color="blue") + self.print_text(f"Parallel ID: {event.parallel_id}", color="blue") + self.print_text(f"Branch ID: {event.parallel_start_node_id}", color="blue") + if event.in_iteration_id: + self.print_text(f"Iteration ID: {event.in_iteration_id}", color="blue") + + def on_workflow_parallel_completed( + self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent + ) -> None: + """ + Publish parallel completed + """ + if isinstance(event, ParallelBranchRunSucceededEvent): + color = "blue" + elif isinstance(event, ParallelBranchRunFailedEvent): + color = "red" + + self.print_text( + "\n[ParallelBranchRunSucceededEvent]" + if isinstance(event, ParallelBranchRunSucceededEvent) + else "\n[ParallelBranchRunFailedEvent]", + color=color, + ) + self.print_text(f"Parallel ID: {event.parallel_id}", color=color) + self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color) + if event.in_iteration_id: + self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color) + + if isinstance(event, ParallelBranchRunFailedEvent): + self.print_text(f"Error: {event.error}", color=color) + + def on_workflow_iteration_started(self, event: IterationRunStartedEvent) -> None: """ Publish iteration started """ - self.print_text("\n[on_workflow_iteration_started]", color='blue') - self.print_text(f"Node ID: {node_id}", color='blue') + self.print_text("\n[IterationRunStartedEvent]", color="blue") + self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue") - def on_workflow_iteration_next(self, node_id: str, - node_type: NodeType, - index: int, - node_run_index: int, - output: Optional[dict]) -> None: + def on_workflow_iteration_next(self, event: IterationRunNextEvent) -> None: """ Publish iteration next """ - self.print_text("\n[on_workflow_iteration_next]", color='blue') + self.print_text("\n[IterationRunNextEvent]", color="blue") + self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue") + self.print_text(f"Iteration Index: {event.index}", color="blue") - def on_workflow_iteration_completed(self, node_id: str, - node_type: NodeType, - node_run_index: int, - outputs: dict) -> None: + def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent) -> None: """ Publish iteration completed """ - self.print_text("\n[on_workflow_iteration_completed]", color='blue') + self.print_text( + "\n[IterationRunSucceededEvent]" + if isinstance(event, IterationRunSucceededEvent) + else "\n[IterationRunFailedEvent]", + color="blue", + ) + self.print_text(f"Node ID: {event.iteration_id}", color="blue") - def on_event(self, event: AppQueueEvent) -> None: - """ - Publish event - """ - self.print_text("\n[on_workflow_event]", color='blue') - self.print_text(f"Event: {jsonable_encoder(event)}", color='blue') - - def print_text( - self, text: str, color: Optional[str] = None, end: str = "\n" - ) -> None: + def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None: """Print text with highlighting and no end characters.""" text_to_print = self._get_colored_text(text, color) if color else text - print(f'{text_to_print}', end=end) + print(f"{text_to_print}", end=end) def _get_colored_text(self, text: str, color: str) -> str: """Get colored text.""" diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 6a1ab23041..ab8d4e374e 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -15,13 +15,14 @@ class InvokeFrom(Enum): """ Invoke From. """ - SERVICE_API = 'service-api' - WEB_APP = 'web-app' - EXPLORE = 'explore' - DEBUGGER = 'debugger' + + SERVICE_API = "service-api" + WEB_APP = "web-app" + EXPLORE = "explore" + DEBUGGER = "debugger" @classmethod - def value_of(cls, value: str) -> 'InvokeFrom': + def value_of(cls, value: str) -> "InvokeFrom": """ Get value of given mode. @@ -31,7 +32,7 @@ class InvokeFrom(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid invoke from value {value}') + raise ValueError(f"invalid invoke from value {value}") def to_source(self) -> str: """ @@ -40,21 +41,22 @@ class InvokeFrom(Enum): :return: source """ if self == InvokeFrom.WEB_APP: - return 'web_app' + return "web_app" elif self == InvokeFrom.DEBUGGER: - return 'dev' + return "dev" elif self == InvokeFrom.EXPLORE: - return 'explore_app' + return "explore_app" elif self == InvokeFrom.SERVICE_API: - return 'api' + return "api" - return 'dev' + return "dev" class ModelConfigWithCredentialsEntity(BaseModel): """ Model Config With Credentials Entity. """ + provider: str model: str model_schema: AIModelEntity @@ -72,6 +74,7 @@ class AppGenerateEntity(BaseModel): """ App Generate Entity. """ + task_id: str # app config @@ -102,6 +105,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity): """ Chat Application Generate Entity. """ + # app config app_config: EasyUIBasedAppConfig model_conf: ModelConfigWithCredentialsEntity @@ -116,6 +120,7 @@ class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity): """ Chat Application Generate Entity. """ + conversation_id: Optional[str] = None @@ -123,6 +128,7 @@ class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity): """ Completion Application Generate Entity. """ + pass @@ -130,6 +136,7 @@ class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity): """ Agent Chat Application Generate Entity. """ + conversation_id: Optional[str] = None @@ -137,6 +144,7 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity): """ Advanced Chat Application Generate Entity. """ + # app config app_config: WorkflowUIBasedAppConfig @@ -147,15 +155,18 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity): """ Single Iteration Run Entity. """ + node_id: str inputs: dict single_iteration_run: Optional[SingleIterationRunEntity] = None + class WorkflowAppGenerateEntity(AppGenerateEntity): """ Workflow Application Generate Entity. """ + # app config app_config: WorkflowUIBasedAppConfig @@ -163,6 +174,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): """ Single Iteration Run Entity. """ + node_id: str inputs: dict diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 15348251f2..4577e28535 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -1,3 +1,4 @@ +from datetime import datetime from enum import Enum from typing import Any, Optional @@ -5,13 +6,15 @@ from pydantic import BaseModel, field_validator from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState class QueueEvent(str, Enum): """ QueueEvent enum """ + LLM_CHUNK = "llm_chunk" TEXT_CHUNK = "text_chunk" AGENT_MESSAGE = "agent_message" @@ -31,6 +34,9 @@ class QueueEvent(str, Enum): ANNOTATION_REPLY = "annotation_reply" AGENT_THOUGHT = "agent_thought" MESSAGE_FILE = "message_file" + PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started" + PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded" + PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed" ERROR = "error" PING = "ping" STOP = "stop" @@ -38,46 +44,73 @@ class QueueEvent(str, Enum): class AppQueueEvent(BaseModel): """ - QueueEvent entity + QueueEvent abstract entity """ + event: QueueEvent class QueueLLMChunkEvent(AppQueueEvent): """ QueueLLMChunkEvent entity + Only for basic mode apps """ + event: QueueEvent = QueueEvent.LLM_CHUNK chunk: LLMResultChunk + class QueueIterationStartEvent(AppQueueEvent): """ QueueIterationStartEvent entity """ + event: QueueEvent = QueueEvent.ITERATION_START + node_execution_id: str node_id: str node_type: NodeType node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + start_at: datetime node_run_index: int - inputs: dict = None + inputs: Optional[dict[str, Any]] = None predecessor_node_id: Optional[str] = None - metadata: Optional[dict] = None + metadata: Optional[dict[str, Any]] = None + class QueueIterationNextEvent(AppQueueEvent): """ QueueIterationNextEvent entity """ + event: QueueEvent = QueueEvent.ITERATION_NEXT index: int + node_execution_id: str node_id: str node_type: NodeType + node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" node_run_index: int - output: Optional[Any] = None # output for the current iteration + output: Optional[Any] = None # output for the current iteration - @field_validator('output', mode='before') + @field_validator("output", mode="before") @classmethod def set_output(cls, v): """ @@ -87,41 +120,66 @@ class QueueIterationNextEvent(AppQueueEvent): return None if isinstance(v, int | float | str | bool | dict | list): return v - raise ValueError('output must be a valid type') + raise ValueError("output must be a valid type") + class QueueIterationCompletedEvent(AppQueueEvent): """ QueueIterationCompletedEvent entity """ - event:QueueEvent = QueueEvent.ITERATION_COMPLETED + event: QueueEvent = QueueEvent.ITERATION_COMPLETED + + node_execution_id: str node_id: str node_type: NodeType - + node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + start_at: datetime + node_run_index: int - outputs: dict + inputs: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + metadata: Optional[dict[str, Any]] = None + steps: int = 0 + + error: Optional[str] = None + class QueueTextChunkEvent(AppQueueEvent): """ QueueTextChunkEvent entity """ + event: QueueEvent = QueueEvent.TEXT_CHUNK text: str - metadata: Optional[dict] = None + from_variable_selector: Optional[list[str]] = None + """from variable selector""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" class QueueAgentMessageEvent(AppQueueEvent): """ QueueMessageEvent entity """ + event: QueueEvent = QueueEvent.AGENT_MESSAGE chunk: LLMResultChunk - + class QueueMessageReplaceEvent(AppQueueEvent): """ QueueMessageReplaceEvent entity """ + event: QueueEvent = QueueEvent.MESSAGE_REPLACE text: str @@ -130,14 +188,18 @@ class QueueRetrieverResourcesEvent(AppQueueEvent): """ QueueRetrieverResourcesEvent entity """ + event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES retriever_resources: list[dict] + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" class QueueAnnotationReplyEvent(AppQueueEvent): """ QueueAnnotationReplyEvent entity """ + event: QueueEvent = QueueEvent.ANNOTATION_REPLY message_annotation_id: str @@ -146,6 +208,7 @@ class QueueMessageEndEvent(AppQueueEvent): """ QueueMessageEndEvent entity """ + event: QueueEvent = QueueEvent.MESSAGE_END llm_result: Optional[LLMResult] = None @@ -154,6 +217,7 @@ class QueueAdvancedChatMessageEndEvent(AppQueueEvent): """ QueueAdvancedChatMessageEndEvent entity """ + event: QueueEvent = QueueEvent.ADVANCED_CHAT_MESSAGE_END @@ -161,20 +225,25 @@ class QueueWorkflowStartedEvent(AppQueueEvent): """ QueueWorkflowStartedEvent entity """ + event: QueueEvent = QueueEvent.WORKFLOW_STARTED + graph_runtime_state: GraphRuntimeState class QueueWorkflowSucceededEvent(AppQueueEvent): """ QueueWorkflowSucceededEvent entity """ + event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED + outputs: Optional[dict[str, Any]] = None class QueueWorkflowFailedEvent(AppQueueEvent): """ QueueWorkflowFailedEvent entity """ + event: QueueEvent = QueueEvent.WORKFLOW_FAILED error: str @@ -183,29 +252,55 @@ class QueueNodeStartedEvent(AppQueueEvent): """ QueueNodeStartedEvent entity """ + event: QueueEvent = QueueEvent.NODE_STARTED + node_execution_id: str node_id: str node_type: NodeType node_data: BaseNodeData node_run_index: int = 1 predecessor_node_id: Optional[str] = None + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime class QueueNodeSucceededEvent(AppQueueEvent): """ QueueNodeSucceededEvent entity """ + event: QueueEvent = QueueEvent.NODE_SUCCEEDED + node_execution_id: str node_id: str node_type: NodeType node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime - inputs: Optional[dict] = None - process_data: Optional[dict] = None - outputs: Optional[dict] = None - execution_metadata: Optional[dict] = None + inputs: Optional[dict[str, Any]] = None + process_data: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None error: Optional[str] = None @@ -214,15 +309,28 @@ class QueueNodeFailedEvent(AppQueueEvent): """ QueueNodeFailedEvent entity """ + event: QueueEvent = QueueEvent.NODE_FAILED + node_execution_id: str node_id: str node_type: NodeType node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime - inputs: Optional[dict] = None - outputs: Optional[dict] = None - process_data: Optional[dict] = None + inputs: Optional[dict[str, Any]] = None + process_data: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None error: str @@ -231,6 +339,7 @@ class QueueAgentThoughtEvent(AppQueueEvent): """ QueueAgentThoughtEvent entity """ + event: QueueEvent = QueueEvent.AGENT_THOUGHT agent_thought_id: str @@ -239,6 +348,7 @@ class QueueMessageFileEvent(AppQueueEvent): """ QueueAgentThoughtEvent entity """ + event: QueueEvent = QueueEvent.MESSAGE_FILE message_file_id: str @@ -247,6 +357,7 @@ class QueueErrorEvent(AppQueueEvent): """ QueueErrorEvent entity """ + event: QueueEvent = QueueEvent.ERROR error: Any = None @@ -255,6 +366,7 @@ class QueuePingEvent(AppQueueEvent): """ QueuePingEvent entity """ + event: QueueEvent = QueueEvent.PING @@ -262,10 +374,12 @@ class QueueStopEvent(AppQueueEvent): """ QueueStopEvent entity """ + class StopBy(Enum): """ Stop by enum """ + USER_MANUAL = "user-manual" ANNOTATION_REPLY = "annotation-reply" OUTPUT_MODERATION = "output-moderation" @@ -274,11 +388,25 @@ class QueueStopEvent(AppQueueEvent): event: QueueEvent = QueueEvent.STOP stopped_by: StopBy + def get_stop_reason(self) -> str: + """ + To stop reason + """ + reason_mapping = { + QueueStopEvent.StopBy.USER_MANUAL: "Stopped by user.", + QueueStopEvent.StopBy.ANNOTATION_REPLY: "Stopped by annotation reply.", + QueueStopEvent.StopBy.OUTPUT_MODERATION: "Stopped by output moderation.", + QueueStopEvent.StopBy.INPUT_MODERATION: "Stopped by input moderation.", + } + + return reason_mapping.get(self.stopped_by, "Stopped by unknown reason.") + class QueueMessage(BaseModel): """ - QueueMessage entity + QueueMessage abstract entity """ + task_id: str app_mode: str event: AppQueueEvent @@ -288,6 +416,7 @@ class MessageQueueMessage(QueueMessage): """ MessageQueueMessage entity """ + message_id: str conversation_id: str @@ -296,4 +425,57 @@ class WorkflowQueueMessage(QueueMessage): """ WorkflowQueueMessage entity """ + pass + + +class QueueParallelBranchRunStartedEvent(AppQueueEvent): + """ + QueueParallelBranchRunStartedEvent entity + """ + + event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED + + parallel_id: str + parallel_start_node_id: str + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + + +class QueueParallelBranchRunSucceededEvent(AppQueueEvent): + """ + QueueParallelBranchRunSucceededEvent entity + """ + + event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED + + parallel_id: str + parallel_start_node_id: str + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + + +class QueueParallelBranchRunFailedEvent(AppQueueEvent): + """ + QueueParallelBranchRunFailedEvent entity + """ + + event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED + + parallel_id: str + parallel_start_node_id: str + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + error: str diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 7bc5598984..49e5f55ebc 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -3,44 +3,16 @@ from typing import Any, Optional from pydantic import BaseModel, ConfigDict -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType -from core.workflow.nodes.answer.entities import GenerateRouteChunk from models.workflow import WorkflowNodeExecutionStatus -class WorkflowStreamGenerateNodes(BaseModel): - """ - WorkflowStreamGenerateNodes entity - """ - end_node_id: str - stream_node_ids: list[str] - - -class ChatflowStreamGenerateRoute(BaseModel): - """ - ChatflowStreamGenerateRoute entity - """ - answer_node_id: str - generate_route: list[GenerateRouteChunk] - current_route_position: int = 0 - - -class NodeExecutionInfo(BaseModel): - """ - NodeExecutionInfo entity - """ - workflow_node_execution_id: str - node_type: NodeType - start_at: float - - class TaskState(BaseModel): """ TaskState entity """ + metadata: dict = {} @@ -48,6 +20,7 @@ class EasyUITaskState(TaskState): """ EasyUITaskState entity """ + llm_result: LLMResult @@ -55,34 +28,15 @@ class WorkflowTaskState(TaskState): """ WorkflowTaskState entity """ + answer: str = "" - workflow_run_id: Optional[str] = None - start_at: Optional[float] = None - total_tokens: int = 0 - total_steps: int = 0 - - ran_node_execution_infos: dict[str, NodeExecutionInfo] = {} - latest_node_execution_info: Optional[NodeExecutionInfo] = None - - current_stream_generate_state: Optional[WorkflowStreamGenerateNodes] = None - - iteration_nested_node_ids: list[str] = None - - -class AdvancedChatTaskState(WorkflowTaskState): - """ - AdvancedChatTaskState entity - """ - usage: LLMUsage - - current_stream_generate_state: Optional[ChatflowStreamGenerateRoute] = None - class StreamEvent(Enum): """ Stream event """ + PING = "ping" ERROR = "error" MESSAGE = "message" @@ -97,6 +51,8 @@ class StreamEvent(Enum): WORKFLOW_FINISHED = "workflow_finished" NODE_STARTED = "node_started" NODE_FINISHED = "node_finished" + PARALLEL_BRANCH_STARTED = "parallel_branch_started" + PARALLEL_BRANCH_FINISHED = "parallel_branch_finished" ITERATION_STARTED = "iteration_started" ITERATION_NEXT = "iteration_next" ITERATION_COMPLETED = "iteration_completed" @@ -108,6 +64,7 @@ class StreamResponse(BaseModel): """ StreamResponse entity """ + event: StreamEvent task_id: str @@ -119,6 +76,7 @@ class ErrorStreamResponse(StreamResponse): """ ErrorStreamResponse entity """ + event: StreamEvent = StreamEvent.ERROR err: Exception model_config = ConfigDict(arbitrary_types_allowed=True) @@ -128,15 +86,18 @@ class MessageStreamResponse(StreamResponse): """ MessageStreamResponse entity """ + event: StreamEvent = StreamEvent.MESSAGE id: str answer: str + from_variable_selector: Optional[list[str]] = None class MessageAudioStreamResponse(StreamResponse): """ MessageStreamResponse entity """ + event: StreamEvent = StreamEvent.TTS_MESSAGE audio: str @@ -145,6 +106,7 @@ class MessageAudioEndStreamResponse(StreamResponse): """ MessageStreamResponse entity """ + event: StreamEvent = StreamEvent.TTS_MESSAGE_END audio: str @@ -153,6 +115,7 @@ class MessageEndStreamResponse(StreamResponse): """ MessageEndStreamResponse entity """ + event: StreamEvent = StreamEvent.MESSAGE_END id: str metadata: dict = {} @@ -162,6 +125,7 @@ class MessageFileStreamResponse(StreamResponse): """ MessageFileStreamResponse entity """ + event: StreamEvent = StreamEvent.MESSAGE_FILE id: str type: str @@ -173,6 +137,7 @@ class MessageReplaceStreamResponse(StreamResponse): """ MessageReplaceStreamResponse entity """ + event: StreamEvent = StreamEvent.MESSAGE_REPLACE answer: str @@ -181,6 +146,7 @@ class AgentThoughtStreamResponse(StreamResponse): """ AgentThoughtStreamResponse entity """ + event: StreamEvent = StreamEvent.AGENT_THOUGHT id: str position: int @@ -196,6 +162,7 @@ class AgentMessageStreamResponse(StreamResponse): """ AgentMessageStreamResponse entity """ + event: StreamEvent = StreamEvent.AGENT_MESSAGE id: str answer: str @@ -210,6 +177,7 @@ class WorkflowStartStreamResponse(StreamResponse): """ Data entity """ + id: str workflow_id: str sequence_number: int @@ -230,6 +198,7 @@ class WorkflowFinishStreamResponse(StreamResponse): """ Data entity """ + id: str workflow_id: str sequence_number: int @@ -258,6 +227,7 @@ class NodeStartStreamResponse(StreamResponse): """ Data entity """ + id: str node_id: str node_type: str @@ -267,6 +237,11 @@ class NodeStartStreamResponse(StreamResponse): inputs: Optional[dict] = None created_at: int extras: dict = {} + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None + iteration_id: Optional[str] = None event: StreamEvent = StreamEvent.NODE_STARTED workflow_run_id: str @@ -286,8 +261,13 @@ class NodeStartStreamResponse(StreamResponse): "predecessor_node_id": self.data.predecessor_node_id, "inputs": None, "created_at": self.data.created_at, - "extras": {} - } + "extras": {}, + "parallel_id": self.data.parallel_id, + "parallel_start_node_id": self.data.parallel_start_node_id, + "parent_parallel_id": self.data.parent_parallel_id, + "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, + "iteration_id": self.data.iteration_id, + }, } @@ -300,6 +280,7 @@ class NodeFinishStreamResponse(StreamResponse): """ Data entity """ + id: str node_id: str node_type: str @@ -316,6 +297,11 @@ class NodeFinishStreamResponse(StreamResponse): created_at: int finished_at: int files: Optional[list[dict]] = [] + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None + iteration_id: Optional[str] = None event: StreamEvent = StreamEvent.NODE_FINISHED workflow_run_id: str @@ -342,11 +328,62 @@ class NodeFinishStreamResponse(StreamResponse): "execution_metadata": None, "created_at": self.data.created_at, "finished_at": self.data.finished_at, - "files": [] - } + "files": [], + "parallel_id": self.data.parallel_id, + "parallel_start_node_id": self.data.parallel_start_node_id, + "parent_parallel_id": self.data.parent_parallel_id, + "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, + "iteration_id": self.data.iteration_id, + }, } +class ParallelBranchStartStreamResponse(StreamResponse): + """ + ParallelBranchStartStreamResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + parallel_id: str + parallel_branch_id: str + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None + iteration_id: Optional[str] = None + created_at: int + + event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED + workflow_run_id: str + data: Data + + +class ParallelBranchFinishedStreamResponse(StreamResponse): + """ + ParallelBranchFinishedStreamResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + parallel_id: str + parallel_branch_id: str + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None + iteration_id: Optional[str] = None + status: str + error: Optional[str] = None + created_at: int + + event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED + workflow_run_id: str + data: Data + + class IterationNodeStartStreamResponse(StreamResponse): """ NodeStartStreamResponse entity @@ -356,6 +393,7 @@ class IterationNodeStartStreamResponse(StreamResponse): """ Data entity """ + id: str node_id: str node_type: str @@ -364,6 +402,8 @@ class IterationNodeStartStreamResponse(StreamResponse): extras: dict = {} metadata: dict = {} inputs: dict = {} + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None event: StreamEvent = StreamEvent.ITERATION_STARTED workflow_run_id: str @@ -379,6 +419,7 @@ class IterationNodeNextStreamResponse(StreamResponse): """ Data entity """ + id: str node_id: str node_type: str @@ -387,6 +428,8 @@ class IterationNodeNextStreamResponse(StreamResponse): created_at: int pre_iteration_output: Optional[Any] = None extras: dict = {} + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None event: StreamEvent = StreamEvent.ITERATION_NEXT workflow_run_id: str @@ -402,14 +445,15 @@ class IterationNodeCompletedStreamResponse(StreamResponse): """ Data entity """ + id: str node_id: str node_type: str title: str outputs: Optional[dict] = None created_at: int - extras: dict = None - inputs: dict = None + extras: Optional[dict] = None + inputs: Optional[dict] = None status: WorkflowNodeExecutionStatus error: Optional[str] = None elapsed_time: float @@ -417,6 +461,8 @@ class IterationNodeCompletedStreamResponse(StreamResponse): execution_metadata: Optional[dict] = None finished_at: int steps: int + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None event: StreamEvent = StreamEvent.ITERATION_COMPLETED workflow_run_id: str @@ -432,7 +478,9 @@ class TextChunkStreamResponse(StreamResponse): """ Data entity """ + text: str + from_variable_selector: Optional[list[str]] = None event: StreamEvent = StreamEvent.TEXT_CHUNK data: Data @@ -447,6 +495,7 @@ class TextReplaceStreamResponse(StreamResponse): """ Data entity """ + text: str event: StreamEvent = StreamEvent.TEXT_REPLACE @@ -457,6 +506,7 @@ class PingStreamResponse(StreamResponse): """ PingStreamResponse entity """ + event: StreamEvent = StreamEvent.PING @@ -464,6 +514,7 @@ class AppStreamResponse(BaseModel): """ AppStreamResponse entity """ + stream_response: StreamResponse @@ -471,6 +522,7 @@ class ChatbotAppStreamResponse(AppStreamResponse): """ ChatbotAppStreamResponse entity """ + conversation_id: str message_id: str created_at: int @@ -480,6 +532,7 @@ class CompletionAppStreamResponse(AppStreamResponse): """ CompletionAppStreamResponse entity """ + message_id: str created_at: int @@ -488,13 +541,15 @@ class WorkflowAppStreamResponse(AppStreamResponse): """ WorkflowAppStreamResponse entity """ - workflow_run_id: str + + workflow_run_id: Optional[str] = None class AppBlockingResponse(BaseModel): """ AppBlockingResponse entity """ + task_id: str def to_dict(self) -> dict: @@ -510,6 +565,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse): """ Data entity """ + id: str mode: str conversation_id: str @@ -530,6 +586,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse): """ Data entity """ + id: str mode: str message_id: str @@ -549,6 +606,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): """ Data entity """ + id: str workflow_id: str status: str @@ -562,25 +620,3 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): workflow_run_id: str data: Data - - -class WorkflowIterationState(BaseModel): - """ - WorkflowIterationState entity - """ - - class Data(BaseModel): - """ - Data entity - """ - parent_iteration_id: Optional[str] = None - iteration_id: str - current_index: int - iteration_steps_boundary: list[int] = None - node_execution_id: str - started_at: float - inputs: dict = None - total_tokens: int = 0 - node_data: BaseNodeData - - current_iterations: dict[str, Data] = None diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 19ff94de5e..2e37a126c3 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -13,11 +13,9 @@ logger = logging.getLogger(__name__) class AnnotationReplyFeature: - def query(self, app_record: App, - message: Message, - query: str, - user_id: str, - invoke_from: InvokeFrom) -> Optional[MessageAnnotation]: + def query( + self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom + ) -> Optional[MessageAnnotation]: """ Query app annotations to reply :param app_record: app record @@ -27,8 +25,9 @@ class AnnotationReplyFeature: :param invoke_from: invoke from :return: """ - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_record.id).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_record.id).first() + ) if not annotation_setting: return None @@ -41,55 +40,50 @@ class AnnotationReplyFeature: embedding_model_name = collection_binding_detail.model_name dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_provider_name, - embedding_model_name, - 'annotation' + embedding_provider_name, embedding_model_name, "annotation" ) dataset = Dataset( id=app_record.id, tenant_id=app_record.tenant_id, - indexing_technique='high_quality', + indexing_technique="high_quality", embedding_model_provider=embedding_provider_name, embedding_model=embedding_model_name, - collection_binding_id=dataset_collection_binding.id + collection_binding_id=dataset_collection_binding.id, ) - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) documents = vector.search_by_vector( - query=query, - top_k=1, - score_threshold=score_threshold, - filter={ - 'group_id': [dataset.id] - } + query=query, top_k=1, score_threshold=score_threshold, filter={"group_id": [dataset.id]} ) if documents: - annotation_id = documents[0].metadata['annotation_id'] - score = documents[0].metadata['score'] + annotation_id = documents[0].metadata["annotation_id"] + score = documents[0].metadata["score"] annotation = AppAnnotationService.get_annotation_by_id(annotation_id) if annotation: if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]: - from_source = 'api' + from_source = "api" else: - from_source = 'console' + from_source = "console" # insert annotation history - AppAnnotationService.add_annotation_history(annotation.id, - app_record.id, - annotation.question, - annotation.content, - query, - user_id, - message.id, - from_source, - score) + AppAnnotationService.add_annotation_history( + annotation.id, + app_record.id, + annotation.question, + annotation.content, + query, + user_id, + message.id, + from_source, + score, + ) return annotation except Exception as e: - logger.warning(f'Query annotation failed, exception: {str(e)}.') + logger.warning(f"Query annotation failed, exception: {str(e)}.") return None return None diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py index b8f3e0e1f6..ba14b61201 100644 --- a/api/core/app/features/hosting_moderation/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -8,8 +8,9 @@ logger = logging.getLogger(__name__) class HostingModerationFeature: - def check(self, application_generate_entity: EasyUIBasedAppGenerateEntity, - prompt_messages: list[PromptMessage]) -> bool: + def check( + self, application_generate_entity: EasyUIBasedAppGenerateEntity, prompt_messages: list[PromptMessage] + ) -> bool: """ Check hosting moderation :param application_generate_entity: application generate entity @@ -23,9 +24,6 @@ class HostingModerationFeature: if isinstance(prompt_message.content, str): text += prompt_message.content + "\n" - moderation_result = moderation.check_moderation( - model_config, - text - ) + moderation_result = moderation.check_moderation(model_config, text) return moderation_result diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index f11e8021f0..227182f5ab 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -19,7 +19,7 @@ class RateLimit: _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes _instance_dict = {} - def __new__(cls: type['RateLimit'], client_id: str, max_active_requests: int): + def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int): if client_id not in cls._instance_dict: instance = super().__new__(cls) cls._instance_dict[client_id] = instance @@ -27,13 +27,13 @@ class RateLimit: def __init__(self, client_id: str, max_active_requests: int): self.max_active_requests = max_active_requests - if hasattr(self, 'initialized'): + if hasattr(self, "initialized"): return self.initialized = True self.client_id = client_id self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id) self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id) - self.last_recalculate_time = float('-inf') + self.last_recalculate_time = float("-inf") self.flush_cache(use_local_value=True) def flush_cache(self, use_local_value=False): @@ -46,7 +46,7 @@ class RateLimit: pipe.execute() else: with redis_client.pipeline() as pipe: - self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode('utf-8')) + self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode("utf-8")) redis_client.expire(self.max_active_requests_key, timedelta(days=1)) # flush max active requests (in-transit request list) @@ -54,8 +54,11 @@ class RateLimit: return request_details = redis_client.hgetall(self.active_requests_key) redis_client.expire(self.active_requests_key, timedelta(days=1)) - timeout_requests = [k for k, v in request_details.items() if - time.time() - float(v.decode('utf-8')) > RateLimit._REQUEST_MAX_ALIVE_TIME] + timeout_requests = [ + k + for k, v in request_details.items() + if time.time() - float(v.decode("utf-8")) > RateLimit._REQUEST_MAX_ALIVE_TIME + ] if timeout_requests: redis_client.hdel(self.active_requests_key, *timeout_requests) @@ -69,8 +72,10 @@ class RateLimit: active_requests_count = redis_client.hlen(self.active_requests_key) if active_requests_count >= self.max_active_requests: - raise AppInvokeQuotaExceededError("Too many requests. Please try again later. The current maximum " - "concurrent requests allowed is {}.".format(self.max_active_requests)) + raise AppInvokeQuotaExceededError( + "Too many requests. Please try again later. The current maximum " + "concurrent requests allowed is {}.".format(self.max_active_requests) + ) redis_client.hset(self.active_requests_key, request_id, str(time.time())) return request_id @@ -116,5 +121,5 @@ class RateLimitGenerator: if not self.closed: self.closed = True self.rate_limit.exit(self.request_id) - if self.generator is not None and hasattr(self.generator, 'close'): + if self.generator is not None and hasattr(self.generator, "close"): self.generator.close() diff --git a/api/core/app/segments/__init__.py b/api/core/app/segments/__init__.py index 7de06dfb96..652ef243b4 100644 --- a/api/core/app/segments/__init__.py +++ b/api/core/app/segments/__init__.py @@ -25,25 +25,25 @@ from .variables import ( ) __all__ = [ - 'IntegerVariable', - 'FloatVariable', - 'ObjectVariable', - 'SecretVariable', - 'StringVariable', - 'ArrayAnyVariable', - 'Variable', - 'SegmentType', - 'SegmentGroup', - 'Segment', - 'NoneSegment', - 'NoneVariable', - 'IntegerSegment', - 'FloatSegment', - 'ObjectSegment', - 'ArrayAnySegment', - 'StringSegment', - 'ArrayStringVariable', - 'ArrayNumberVariable', - 'ArrayObjectVariable', - 'ArraySegment', + "IntegerVariable", + "FloatVariable", + "ObjectVariable", + "SecretVariable", + "StringVariable", + "ArrayAnyVariable", + "Variable", + "SegmentType", + "SegmentGroup", + "Segment", + "NoneSegment", + "NoneVariable", + "IntegerSegment", + "FloatSegment", + "ObjectSegment", + "ArrayAnySegment", + "StringSegment", + "ArrayStringVariable", + "ArrayNumberVariable", + "ArrayObjectVariable", + "ArraySegment", ] diff --git a/api/core/app/segments/factory.py b/api/core/app/segments/factory.py index e6e9ce9774..40a69ed4eb 100644 --- a/api/core/app/segments/factory.py +++ b/api/core/app/segments/factory.py @@ -28,12 +28,12 @@ from .variables import ( def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: - if (value_type := mapping.get('value_type')) is None: - raise VariableError('missing value type') - if not mapping.get('name'): - raise VariableError('missing name') - if (value := mapping.get('value')) is None: - raise VariableError('missing value') + if (value_type := mapping.get("value_type")) is None: + raise VariableError("missing value type") + if not mapping.get("name"): + raise VariableError("missing name") + if (value := mapping.get("value")) is None: + raise VariableError("missing value") match value_type: case SegmentType.STRING: result = StringVariable.model_validate(mapping) @@ -44,7 +44,7 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: case SegmentType.NUMBER if isinstance(value, float): result = FloatVariable.model_validate(mapping) case SegmentType.NUMBER if not isinstance(value, float | int): - raise VariableError(f'invalid number value {value}') + raise VariableError(f"invalid number value {value}") case SegmentType.OBJECT if isinstance(value, dict): result = ObjectVariable.model_validate(mapping) case SegmentType.ARRAY_STRING if isinstance(value, list): @@ -54,9 +54,9 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: case SegmentType.ARRAY_OBJECT if isinstance(value, list): result = ArrayObjectVariable.model_validate(mapping) case _: - raise VariableError(f'not supported value type {value_type}') + raise VariableError(f"not supported value type {value_type}") if result.size > dify_config.MAX_VARIABLE_SIZE: - raise VariableError(f'variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}') + raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}") return result @@ -73,4 +73,4 @@ def build_segment(value: Any, /) -> Segment: return ObjectSegment(value=value) if isinstance(value, list): return ArrayAnySegment(value=value) - raise ValueError(f'not supported value {value}') + raise ValueError(f"not supported value {value}") diff --git a/api/core/app/segments/parser.py b/api/core/app/segments/parser.py index de6c796652..3c4d7046f4 100644 --- a/api/core/app/segments/parser.py +++ b/api/core/app/segments/parser.py @@ -4,14 +4,14 @@ from core.workflow.entities.variable_pool import VariablePool from . import SegmentGroup, factory -VARIABLE_PATTERN = re.compile(r'\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}') +VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") def convert_template(*, template: str, variable_pool: VariablePool): parts = re.split(VARIABLE_PATTERN, template) segments = [] for part in filter(lambda x: x, parts): - if '.' in part and (value := variable_pool.get(part.split('.'))): + if "." in part and (value := variable_pool.get(part.split("."))): segments.append(value) else: segments.append(factory.build_segment(part)) diff --git a/api/core/app/segments/segment_group.py b/api/core/app/segments/segment_group.py index b4ff09b6d3..b363255b2c 100644 --- a/api/core/app/segments/segment_group.py +++ b/api/core/app/segments/segment_group.py @@ -8,15 +8,15 @@ class SegmentGroup(Segment): @property def text(self): - return ''.join([segment.text for segment in self.value]) + return "".join([segment.text for segment in self.value]) @property def log(self): - return ''.join([segment.log for segment in self.value]) + return "".join([segment.log for segment in self.value]) @property def markdown(self): - return ''.join([segment.markdown for segment in self.value]) + return "".join([segment.markdown for segment in self.value]) def to_object(self): return [segment.to_object() for segment in self.value] diff --git a/api/core/app/segments/segments.py b/api/core/app/segments/segments.py index 5c713cac67..b26b3c8291 100644 --- a/api/core/app/segments/segments.py +++ b/api/core/app/segments/segments.py @@ -14,13 +14,14 @@ class Segment(BaseModel): value_type: SegmentType value: Any - @field_validator('value_type') + @field_validator("value_type") + @classmethod def validate_value_type(cls, value): """ This validator checks if the provided value is equal to the default value of the 'value_type' field. If the value is different, a ValueError is raised. """ - if value != cls.model_fields['value_type'].default: + if value != cls.model_fields["value_type"].default: raise ValueError("Cannot modify 'value_type'") return value @@ -50,15 +51,15 @@ class NoneSegment(Segment): @property def text(self) -> str: - return 'null' + return "null" @property def log(self) -> str: - return 'null' + return "null" @property def markdown(self) -> str: - return 'null' + return "null" class StringSegment(Segment): @@ -76,24 +77,21 @@ class IntegerSegment(Segment): value: int - - - class ObjectSegment(Segment): value_type: SegmentType = SegmentType.OBJECT value: Mapping[str, Any] @property def text(self) -> str: - return json.dumps(self.model_dump()['value'], ensure_ascii=False) + return json.dumps(self.model_dump()["value"], ensure_ascii=False) @property def log(self) -> str: - return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2) + return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2) @property def markdown(self) -> str: - return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2) + return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2) class ArraySegment(Segment): @@ -101,11 +99,11 @@ class ArraySegment(Segment): def markdown(self) -> str: items = [] for item in self.value: - if hasattr(item, 'to_markdown'): + if hasattr(item, "to_markdown"): items.append(item.to_markdown()) else: items.append(str(item)) - return '\n'.join(items) + return "\n".join(items) class ArrayAnySegment(ArraySegment): @@ -126,4 +124,3 @@ class ArrayNumberSegment(ArraySegment): class ArrayObjectSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_OBJECT value: Sequence[Mapping[str, Any]] - diff --git a/api/core/app/segments/types.py b/api/core/app/segments/types.py index cdd2b0b4b0..9cf0856df5 100644 --- a/api/core/app/segments/types.py +++ b/api/core/app/segments/types.py @@ -2,14 +2,14 @@ from enum import Enum class SegmentType(str, Enum): - NONE = 'none' - NUMBER = 'number' - STRING = 'string' - SECRET = 'secret' - ARRAY_ANY = 'array[any]' - ARRAY_STRING = 'array[string]' - ARRAY_NUMBER = 'array[number]' - ARRAY_OBJECT = 'array[object]' - OBJECT = 'object' + NONE = "none" + NUMBER = "number" + STRING = "string" + SECRET = "secret" + ARRAY_ANY = "array[any]" + ARRAY_STRING = "array[string]" + ARRAY_NUMBER = "array[number]" + ARRAY_OBJECT = "array[object]" + OBJECT = "object" - GROUP = 'group' + GROUP = "group" diff --git a/api/core/app/segments/variables.py b/api/core/app/segments/variables.py index 8fef707fcf..f0e403ab8d 100644 --- a/api/core/app/segments/variables.py +++ b/api/core/app/segments/variables.py @@ -23,11 +23,11 @@ class Variable(Segment): """ id: str = Field( - default='', + default="", description="Unique identity for variable. It's only used by environment variables now.", ) name: str - description: str = Field(default='', description='Description of the variable.') + description: str = Field(default="", description="Description of the variable.") class StringVariable(StringSegment, Variable): @@ -62,7 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable): pass - class SecretVariable(StringVariable): value_type: SegmentType = SegmentType.SECRET diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index a3c1fb5824..49f58af12c 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -32,10 +32,13 @@ class BasedGenerateTaskPipeline: _task_state: TaskState _application_generate_entity: AppGenerateEntity - def __init__(self, application_generate_entity: AppGenerateEntity, - queue_manager: AppQueueManager, - user: Union[Account, EndUser], - stream: bool) -> None: + def __init__( + self, + application_generate_entity: AppGenerateEntity, + queue_manager: AppQueueManager, + user: Union[Account, EndUser], + stream: bool, + ) -> None: """ Initialize GenerateTaskPipeline. :param application_generate_entity: application generate entity @@ -61,35 +64,39 @@ class BasedGenerateTaskPipeline: e = event.error if isinstance(e, InvokeAuthorizationError): - err = InvokeAuthorizationError('Incorrect API key provided') + err = InvokeAuthorizationError("Incorrect API key provided") elif isinstance(e, InvokeError) or isinstance(e, ValueError): err = e else: - err = Exception(e.description if getattr(e, 'description', None) is not None else str(e)) + err = Exception(e.description if getattr(e, "description", None) is not None else str(e)) if message: - message = db.session.query(Message).filter(Message.id == message.id).first() - err_desc = self._error_to_desc(err) - message.status = 'error' - message.error = err_desc + refetch_message = db.session.query(Message).filter(Message.id == message.id).first() - db.session.commit() + if refetch_message: + err_desc = self._error_to_desc(err) + refetch_message.status = "error" + refetch_message.error = err_desc + + db.session.commit() return err - def _error_to_desc(cls, e: Exception) -> str: + def _error_to_desc(self, e: Exception) -> str: """ Error to desc. :param e: exception :return: """ if isinstance(e, QuotaExceededError): - return ("Your quota for Dify Hosted Model Provider has been exhausted. " - "Please go to Settings -> Model Provider to complete your own provider credentials.") + return ( + "Your quota for Dify Hosted Model Provider has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials." + ) - message = getattr(e, 'description', str(e)) + message = getattr(e, "description", str(e)) if not message: - message = 'Internal Server Error, please contact support.' + message = "Internal Server Error, please contact support." return message @@ -99,10 +106,7 @@ class BasedGenerateTaskPipeline: :param e: exception :return: """ - return ErrorStreamResponse( - task_id=self._application_generate_entity.task_id, - err=e - ) + return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e) def _ping_stream_response(self) -> PingStreamResponse: """ @@ -123,11 +127,8 @@ class BasedGenerateTaskPipeline: return OutputModeration( tenant_id=app_config.tenant_id, app_id=app_config.app_id, - rule=ModerationRule( - type=sensitive_word_avoidance.type, - config=sensitive_word_avoidance.config - ), - queue_manager=self._queue_manager + rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config), + queue_manager=self._queue_manager, ) def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]: @@ -141,8 +142,7 @@ class BasedGenerateTaskPipeline: self._output_moderation_handler.stop_thread() completion = self._output_moderation_handler.moderation_completion( - completion=completion, - public_event=False + completion=completion, public_event=False ) self._output_moderation_handler = None diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 8d91a507a9..659503301e 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -64,23 +64,21 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan """ EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application. """ - _task_state: EasyUITaskState - _application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity - ] - def __init__(self, application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity - ], - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - user: Union[Account, EndUser], - stream: bool) -> None: + _task_state: EasyUITaskState + _application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity] + + def __init__( + self, + application_generate_entity: Union[ + ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity + ], + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool, + ) -> None: """ Initialize GenerateTaskPipeline. :param application_generate_entity: application generate entity @@ -101,18 +99,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan model=self._model_config.model, prompt_messages=[], message=AssistantPromptMessage(content=""), - usage=LLMUsage.empty_usage() + usage=LLMUsage.empty_usage(), ) ) self._conversation_name_generate_thread = None def process( - self, + self, ) -> Union[ ChatbotAppBlockingResponse, CompletionAppBlockingResponse, - Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None] + Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None], ]: """ Process generate task pipeline. @@ -125,22 +123,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: # start generate conversation name thread self._conversation_name_generate_thread = self._generate_conversation_name( - self._conversation, - self._application_generate_entity.query + self._conversation, self._application_generate_entity.query ) - generator = self._wrapper_process_stream_response( - trace_manager=self._application_generate_entity.trace_manager - ) + generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) if self._stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) - def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> Union[ - ChatbotAppBlockingResponse, - CompletionAppBlockingResponse - ]: + def _to_blocking_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]: """ Process blocking response. :return: @@ -149,11 +143,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan if isinstance(stream_response, ErrorStreamResponse): raise stream_response.err elif isinstance(stream_response, MessageEndStreamResponse): - extras = { - 'usage': jsonable_encoder(self._task_state.llm_result.usage) - } + extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)} if self._task_state.metadata: - extras['metadata'] = self._task_state.metadata + extras["metadata"] = self._task_state.metadata if self._conversation.mode == AppMode.COMPLETION.value: response = CompletionAppBlockingResponse( @@ -164,8 +156,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan message_id=self._message.id, answer=self._task_state.llm_result.message.content, created_at=int(self._message.created_at.timestamp()), - **extras - ) + **extras, + ), ) else: response = ChatbotAppBlockingResponse( @@ -177,18 +169,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan message_id=self._message.id, answer=self._task_state.llm_result.message.content, created_at=int(self._message.created_at.timestamp()), - **extras - ) + **extras, + ), ) return response else: continue - raise Exception('Queue listening stopped unexpectedly.') + raise Exception("Queue listening stopped unexpectedly.") - def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \ - -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]: + def _to_stream_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]: """ To stream response. :return: @@ -198,37 +191,41 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan yield CompletionAppStreamResponse( message_id=self._message.id, created_at=int(self._message.created_at.timestamp()), - stream_response=stream_response + stream_response=stream_response, ) else: yield ChatbotAppStreamResponse( conversation_id=self._conversation.id, message_id=self._message.id, created_at=int(self._message.created_at.timestamp()), - stream_response=stream_response + stream_response=stream_response, ) - def _listenAudioMsg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher, task_id: str): if publisher is None: return None - audio_msg: AudioTrunk = publisher.checkAndGetAudio() + audio_msg: AudioTrunk = publisher.check_and_get_audio() if audio_msg and audio_msg.status != "finish": # audio_str = audio_msg.audio.decode('utf-8', errors='ignore') return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None - def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ - Generator[StreamResponse, None, None]: - + def _wrapper_process_stream_response( + self, trace_manager: Optional[TraceQueueManager] = None + ) -> Generator[StreamResponse, None, None]: tenant_id = self._application_generate_entity.app_config.tenant_id task_id = self._application_generate_entity.task_id publisher = None - text_to_speech_dict = self._app_config.app_model_config_dict.get('text_to_speech') - if text_to_speech_dict and text_to_speech_dict.get('autoPlay') == 'enabled' and text_to_speech_dict.get('enabled'): - publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get('voice', None)) + text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech") + if ( + text_to_speech_dict + and text_to_speech_dict.get("autoPlay") == "enabled" + and text_to_speech_dict.get("enabled") + ): + publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None)) for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): while True: - audio_response = self._listenAudioMsg(publisher, task_id) + audio_response = self._listen_audio_msg(publisher, task_id) if audio_response: yield audio_response else: @@ -240,7 +237,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: if publisher is None: break - audio = publisher.checkAndGetAudio() + audio = publisher.check_and_get_audio() if audio is None: # release cpu # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) @@ -250,14 +247,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan break else: start_listener_time = time.time() - yield MessageAudioStreamResponse(audio=audio.audio, - task_id=task_id) - yield MessageAudioEndStreamResponse(audio='', task_id=task_id) + yield MessageAudioStreamResponse(audio=audio.audio, task_id=task_id) + yield MessageAudioEndStreamResponse(audio="", task_id=task_id) def _process_stream_response( - self, - publisher: AppGeneratorTTSPublisher, - trace_manager: Optional[TraceQueueManager] = None + self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None ) -> Generator[StreamResponse, None, None]: """ Process stream response. @@ -333,9 +327,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message( - self, trace_manager: Optional[TraceQueueManager] = None - ) -> None: + def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> None: """ Save message. :return: @@ -347,31 +339,32 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( - self._model_config.mode, - self._task_state.llm_result.prompt_messages + self._model_config.mode, self._task_state.llm_result.prompt_messages ) self._message.message_tokens = usage.prompt_tokens self._message.message_unit_price = usage.prompt_unit_price self._message.message_price_unit = usage.prompt_price_unit - self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \ - if llm_result.message.content else '' + self._message.answer = ( + PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) + if llm_result.message.content + else "" + ) self._message.answer_tokens = usage.completion_tokens self._message.answer_unit_price = usage.completion_unit_price self._message.answer_price_unit = usage.completion_price_unit self._message.provider_response_latency = time.perf_counter() - self._start_at self._message.total_price = usage.total_price self._message.currency = usage.currency - self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ - if self._task_state.metadata else None + self._message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) db.session.commit() if trace_manager: trace_manager.add_trace_task( TraceTask( - TraceTaskName.MESSAGE_TRACE, - conversation_id=self._conversation.id, - message_id=self._message.id + TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation.id, message_id=self._message.id ) ) @@ -379,11 +372,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan self._message, application_generate_entity=self._application_generate_entity, conversation=self._conversation, - is_first_message=self._application_generate_entity.app_config.app_mode in [ - AppMode.AGENT_CHAT, - AppMode.CHAT - ] and self._application_generate_entity.conversation_id is None, - extras=self._application_generate_entity.extras + is_first_message=self._application_generate_entity.app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT] + and self._application_generate_entity.conversation_id is None, + extras=self._application_generate_entity.extras, ) def _handle_stop(self, event: QueueStopEvent) -> None: @@ -395,22 +386,17 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan model = model_config.model model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, - model=model_config.model + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) # calculate num tokens prompt_tokens = 0 if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY: - prompt_tokens = model_instance.get_llm_num_tokens( - self._task_state.llm_result.prompt_messages - ) + prompt_tokens = model_instance.get_llm_num_tokens(self._task_state.llm_result.prompt_messages) completion_tokens = 0 if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL: - completion_tokens = model_instance.get_llm_num_tokens( - [self._task_state.llm_result.message] - ) + completion_tokens = model_instance.get_llm_num_tokens([self._task_state.llm_result.message]) credentials = model_config.credentials @@ -418,10 +404,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) self._task_state.llm_result.usage = model_type_instance._calc_response_usage( - model, - credentials, - prompt_tokens, - completion_tokens + model, credentials, prompt_tokens, completion_tokens ) def _message_end_to_stream_response(self) -> MessageEndStreamResponse: @@ -429,16 +412,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan Message end to stream response. :return: """ - self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage) + self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage) extras = {} if self._task_state.metadata: - extras['metadata'] = self._task_state.metadata + extras["metadata"] = self._task_state.metadata return MessageEndStreamResponse( - task_id=self._application_generate_entity.task_id, - id=self._message.id, - **extras + task_id=self._application_generate_entity.task_id, id=self._message.id, **extras ) def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: @@ -449,9 +430,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan :return: """ return AgentMessageStreamResponse( - task_id=self._application_generate_entity.task_id, - id=message_id, - answer=answer + task_id=self._application_generate_entity.task_id, id=message_id, answer=answer ) def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Optional[AgentThoughtStreamResponse]: @@ -461,9 +440,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan :return: """ agent_thought: MessageAgentThought = ( - db.session.query(MessageAgentThought) - .filter(MessageAgentThought.id == event.agent_thought_id) - .first() + db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first() ) db.session.refresh(agent_thought) db.session.close() @@ -478,7 +455,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan tool=agent_thought.tool, tool_labels=agent_thought.tool_labels, tool_input=agent_thought.tool_input, - message_files=agent_thought.files + message_files=agent_thought.files, ) return None @@ -500,15 +477,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan prompt_messages=self._task_state.llm_result.prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage(content=self._task_state.llm_result.message.content) - ) + message=AssistantPromptMessage(content=self._task_state.llm_result.message.content), + ), ) - ), PublishFrom.TASK_PIPELINE + ), + PublishFrom.TASK_PIPELINE, ) self._queue_manager.publish( - QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), - PublishFrom.TASK_PIPELINE + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE ) return True else: diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index 76c50809cf..5872e00740 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -8,7 +8,6 @@ from core.app.entities.app_invoke_entities import ( AgentChatAppGenerateEntity, ChatAppGenerateEntity, CompletionAppGenerateEntity, - InvokeFrom, ) from core.app.entities.queue_entities import ( QueueAnnotationReplyEvent, @@ -16,11 +15,11 @@ from core.app.entities.queue_entities import ( QueueRetrieverResourcesEvent, ) from core.app.entities.task_entities import ( - AdvancedChatTaskState, EasyUITaskState, MessageFileStreamResponse, MessageReplaceStreamResponse, MessageStreamResponse, + WorkflowTaskState, ) from core.llm_generator.llm_generator import LLMGenerator from core.tools.tool_file_manager import ToolFileManager @@ -31,12 +30,9 @@ from services.annotation_service import AppAnnotationService class MessageCycleManage: _application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity, - AdvancedChatAppGenerateEntity + ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity ] - _task_state: Union[EasyUITaskState, AdvancedChatTaskState] + _task_state: Union[EasyUITaskState, WorkflowTaskState] def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]: """ @@ -45,17 +41,23 @@ class MessageCycleManage: :param query: query :return: thread """ + if isinstance(self._application_generate_entity, CompletionAppGenerateEntity): + return None + is_first_message = self._application_generate_entity.conversation_id is None extras = self._application_generate_entity.extras - auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True) + auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True) if auto_generate_conversation_name and is_first_message: # start generate thread - thread = Thread(target=self._generate_conversation_name_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'conversation_id': conversation.id, - 'query': query - }) + thread = Thread( + target=self._generate_conversation_name_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "conversation_id": conversation.id, + "query": query, + }, + ) thread.start() @@ -63,17 +65,13 @@ class MessageCycleManage: return None - def _generate_conversation_name_worker(self, - flask_app: Flask, - conversation_id: str, - query: str): + def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str): with flask_app.app_context(): # get conversation and message - conversation = ( - db.session.query(Conversation) - .filter(Conversation.id == conversation_id) - .first() - ) + conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() + + if not conversation: + return if conversation.mode != AppMode.COMPLETION.value: app_model = conversation.app @@ -100,12 +98,9 @@ class MessageCycleManage: annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) if annotation: account = annotation.account - self._task_state.metadata['annotation_reply'] = { - 'id': annotation.id, - 'account': { - 'id': annotation.account_id, - 'name': account.name if account else 'Dify user' - } + self._task_state.metadata["annotation_reply"] = { + "id": annotation.id, + "account": {"id": annotation.account_id, "name": account.name if account else "Dify user"}, } return annotation @@ -119,28 +114,7 @@ class MessageCycleManage: :return: """ if self._application_generate_entity.app_config.additional_features.show_retrieve_source: - self._task_state.metadata['retriever_resources'] = event.retriever_resources - - def _get_response_metadata(self) -> dict: - """ - Get response metadata by invoke from. - :return: - """ - metadata = {} - - # show_retrieve_source - if 'retriever_resources' in self._task_state.metadata: - metadata['retriever_resources'] = self._task_state.metadata['retriever_resources'] - - # show annotation reply - if 'annotation_reply' in self._task_state.metadata: - metadata['annotation_reply'] = self._task_state.metadata['annotation_reply'] - - # show usage - if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: - metadata['usage'] = self._task_state.metadata['usage'] - - return metadata + self._task_state.metadata["retriever_resources"] = event.retriever_resources def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: """ @@ -148,27 +122,23 @@ class MessageCycleManage: :param event: event :return: """ - message_file: MessageFile = ( - db.session.query(MessageFile) - .filter(MessageFile.id == event.message_file_id) - .first() - ) + message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first() if message_file: # get tool file id - tool_file_id = message_file.url.split('/')[-1] + tool_file_id = message_file.url.split("/")[-1] # trim extension - tool_file_id = tool_file_id.split('.')[0] + tool_file_id = tool_file_id.split(".")[0] # get extension - if '.' in message_file.url: + if "." in message_file.url: extension = f'.{message_file.url.split(".")[-1]}' if len(extension) > 10: - extension = '.bin' + extension = ".bin" else: - extension = '.bin' + extension = ".bin" # add sign url to local file - if message_file.url.startswith('http'): + if message_file.url.startswith("http"): url = message_file.url else: url = ToolFileManager.sign_file(tool_file_id=tool_file_id, extension=extension) @@ -177,13 +147,15 @@ class MessageCycleManage: task_id=self._application_generate_entity.task_id, id=message_file.id, type=message_file.type, - belongs_to=message_file.belongs_to or 'user', - url=url + belongs_to=message_file.belongs_to or "user", + url=url, ) return None - def _message_to_stream_response(self, answer: str, message_id: str) -> MessageStreamResponse: + def _message_to_stream_response( + self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None + ) -> MessageStreamResponse: """ Message to stream response. :param answer: answer @@ -193,7 +165,8 @@ class MessageCycleManage: return MessageStreamResponse( task_id=self._application_generate_entity.task_id, id=message_id, - answer=answer + answer=answer, + from_variable_selector=from_variable_selector, ) def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse: @@ -202,7 +175,4 @@ class MessageCycleManage: :param answer: answer :return: """ - return MessageReplaceStreamResponse( - task_id=self._application_generate_entity.task_id, - answer=answer - ) + return MessageReplaceStreamResponse(task_id=self._application_generate_entity.task_id, answer=answer) diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 4935c43ac4..a030d5dcbf 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -1,33 +1,41 @@ import json import time from datetime import datetime, timezone -from typing import Optional, Union, cast +from typing import Any, Optional, Union, cast -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, QueueNodeFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, - QueueStopEvent, - QueueWorkflowFailedEvent, - QueueWorkflowSucceededEvent, + QueueParallelBranchRunFailedEvent, + QueueParallelBranchRunStartedEvent, + QueueParallelBranchRunSucceededEvent, ) from core.app.entities.task_entities import ( - NodeExecutionInfo, + IterationNodeCompletedStreamResponse, + IterationNodeNextStreamResponse, + IterationNodeStartStreamResponse, NodeFinishStreamResponse, NodeStartStreamResponse, + ParallelBranchFinishedStreamResponse, + ParallelBranchStartStreamResponse, WorkflowFinishStreamResponse, WorkflowStartStreamResponse, + WorkflowTaskState, ) -from core.app.task_pipeline.workflow_iteration_cycle_manage import WorkflowIterationCycleManage from core.file.file_obj import FileVar from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.tools.tool_manager import ToolManager -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType +from core.workflow.entities.node_entities import NodeType +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.tool.entities import ToolNodeData -from core.workflow.workflow_engine_manager import WorkflowEngineManager +from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.account import Account from models.model import EndUser @@ -41,54 +49,56 @@ from models.workflow import ( WorkflowRunStatus, WorkflowRunTriggeredFrom, ) -from services.workflow_service import WorkflowService -class WorkflowCycleManage(WorkflowIterationCycleManage): - def _init_workflow_run(self, workflow: Workflow, - triggered_from: WorkflowRunTriggeredFrom, - user: Union[Account, EndUser], - user_inputs: dict, - system_inputs: Optional[dict] = None) -> WorkflowRun: - """ - Init workflow run - :param workflow: Workflow instance - :param triggered_from: triggered from - :param user: account or end user - :param user_inputs: user variables inputs - :param system_inputs: system inputs, like: query, files - :return: - """ - max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \ - .filter(WorkflowRun.tenant_id == workflow.tenant_id) \ - .filter(WorkflowRun.app_id == workflow.app_id) \ - .scalar() or 0 +class WorkflowCycleManage: + _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] + _workflow: Workflow + _user: Union[Account, EndUser] + _task_state: WorkflowTaskState + _workflow_system_variables: dict[SystemVariableKey, Any] + + def _handle_workflow_run_start(self) -> WorkflowRun: + max_sequence = ( + db.session.query(db.func.max(WorkflowRun.sequence_number)) + .filter(WorkflowRun.tenant_id == self._workflow.tenant_id) + .filter(WorkflowRun.app_id == self._workflow.app_id) + .scalar() + or 0 + ) new_sequence_number = max_sequence + 1 - inputs = {**user_inputs} - for key, value in (system_inputs or {}).items(): - if key.value == 'conversation': + inputs = {**self._application_generate_entity.inputs} + for key, value in (self._workflow_system_variables or {}).items(): + if key.value == "conversation": continue - inputs[f'sys.{key.value}'] = value - inputs = WorkflowEngineManager.handle_special_values(inputs) + inputs[f"sys.{key.value}"] = value + + inputs = WorkflowEntry.handle_special_values(inputs) + + triggered_from = ( + WorkflowRunTriggeredFrom.DEBUGGING + if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER + else WorkflowRunTriggeredFrom.APP_RUN + ) # init workflow run - workflow_run = WorkflowRun( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - sequence_number=new_sequence_number, - workflow_id=workflow.id, - type=workflow.type, - triggered_from=triggered_from.value, - version=workflow.version, - graph=workflow.graph, - inputs=json.dumps(inputs), - status=WorkflowRunStatus.RUNNING.value, - created_by_role=(CreatedByRole.ACCOUNT.value - if isinstance(user, Account) else CreatedByRole.END_USER.value), - created_by=user.id + workflow_run = WorkflowRun() + workflow_run.tenant_id = self._workflow.tenant_id + workflow_run.app_id = self._workflow.app_id + workflow_run.sequence_number = new_sequence_number + workflow_run.workflow_id = self._workflow.id + workflow_run.type = self._workflow.type + workflow_run.triggered_from = triggered_from.value + workflow_run.version = self._workflow.version + workflow_run.graph = self._workflow.graph + workflow_run.inputs = json.dumps(inputs) + workflow_run.status = WorkflowRunStatus.RUNNING.value + workflow_run.created_by_role = ( + CreatedByRole.ACCOUNT.value if isinstance(self._user, Account) else CreatedByRole.END_USER.value ) + workflow_run.created_by = self._user.id db.session.add(workflow_run) db.session.commit() @@ -97,33 +107,37 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): return workflow_run - def _workflow_run_success( - self, workflow_run: WorkflowRun, + def _handle_workflow_run_success( + self, + workflow_run: WorkflowRun, + start_at: float, total_tokens: int, total_steps: int, outputs: Optional[str] = None, conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> WorkflowRun: """ Workflow run success :param workflow_run: workflow run + :param start_at: start time :param total_tokens: total tokens :param total_steps: total steps :param outputs: outputs :param conversation_id: conversation id :return: """ + workflow_run = self._refetch_workflow_run(workflow_run.id) + workflow_run.status = WorkflowRunStatus.SUCCEEDED.value workflow_run.outputs = outputs - workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id) + workflow_run.elapsed_time = time.perf_counter() - start_at workflow_run.total_tokens = total_tokens workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() db.session.refresh(workflow_run) - db.session.close() if trace_manager: trace_manager.add_trace_task( @@ -135,34 +149,64 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): ) ) + db.session.close() + return workflow_run - def _workflow_run_failed( - self, workflow_run: WorkflowRun, + def _handle_workflow_run_failed( + self, + workflow_run: WorkflowRun, + start_at: float, total_tokens: int, total_steps: int, status: WorkflowRunStatus, error: str, conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> WorkflowRun: """ Workflow run failed :param workflow_run: workflow run + :param start_at: start time :param total_tokens: total tokens :param total_steps: total steps :param status: status :param error: error message :return: """ + workflow_run = self._refetch_workflow_run(workflow_run.id) + workflow_run.status = status.value workflow_run.error = error - workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id) + workflow_run.elapsed_time = time.perf_counter() - start_at workflow_run.total_tokens = total_tokens workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() + + running_workflow_node_executions = ( + db.session.query(WorkflowNodeExecution) + .filter( + WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, + WorkflowNodeExecution.app_id == workflow_run.app_id, + WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.workflow_run_id == workflow_run.id, + WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, + ) + .all() + ) + + for workflow_node_execution in running_workflow_node_executions: + workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value + workflow_node_execution.error = error + workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_node_execution.elapsed_time = ( + workflow_node_execution.finished_at - workflow_node_execution.created_at + ).total_seconds() + db.session.commit() + db.session.refresh(workflow_run) db.session.close() @@ -178,39 +222,26 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): return workflow_run - def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun, - node_id: str, - node_type: NodeType, - node_title: str, - node_run_index: int = 1, - predecessor_node_id: Optional[str] = None) -> WorkflowNodeExecution: - """ - Init workflow node execution from workflow run - :param workflow_run: workflow run - :param node_id: node id - :param node_type: node type - :param node_title: node title - :param node_run_index: run index - :param predecessor_node_id: predecessor node id if exists - :return: - """ + def _handle_node_execution_start( + self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent + ) -> WorkflowNodeExecution: # init workflow node execution - workflow_node_execution = WorkflowNodeExecution( - tenant_id=workflow_run.tenant_id, - app_id=workflow_run.app_id, - workflow_id=workflow_run.workflow_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - workflow_run_id=workflow_run.id, - predecessor_node_id=predecessor_node_id, - index=node_run_index, - node_id=node_id, - node_type=node_type.value, - title=node_title, - status=WorkflowNodeExecutionStatus.RUNNING.value, - created_by_role=workflow_run.created_by_role, - created_by=workflow_run.created_by, - created_at=datetime.now(timezone.utc).replace(tzinfo=None) - ) + workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.tenant_id = workflow_run.tenant_id + workflow_node_execution.app_id = workflow_run.app_id + workflow_node_execution.workflow_id = workflow_run.workflow_id + workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value + workflow_node_execution.workflow_run_id = workflow_run.id + workflow_node_execution.predecessor_node_id = event.predecessor_node_id + workflow_node_execution.index = event.node_run_index + workflow_node_execution.node_execution_id = event.node_execution_id + workflow_node_execution.node_id = event.node_id + workflow_node_execution.node_type = event.node_type.value + workflow_node_execution.title = event.node_data.title + workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value + workflow_node_execution.created_by_role = workflow_run.created_by_role + workflow_node_execution.created_by = workflow_run.created_by + workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.add(workflow_node_execution) db.session.commit() @@ -219,33 +250,26 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): return workflow_node_execution - def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution, - start_at: float, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None) -> WorkflowNodeExecution: + def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: """ Workflow node execution success - :param workflow_node_execution: workflow node execution - :param start_at: start time - :param inputs: inputs - :param process_data: process data - :param outputs: outputs - :param execution_metadata: execution metadata + :param event: queue node succeeded event :return: """ - inputs = WorkflowEngineManager.handle_special_values(inputs) - outputs = WorkflowEngineManager.handle_special_values(outputs) + workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) + + inputs = WorkflowEntry.handle_special_values(event.inputs) + outputs = WorkflowEntry.handle_special_values(event.outputs) workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value - workflow_node_execution.elapsed_time = time.perf_counter() - start_at workflow_node_execution.inputs = json.dumps(inputs) if inputs else None - workflow_node_execution.process_data = json.dumps(process_data) if process_data else None + workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None workflow_node_execution.outputs = json.dumps(outputs) if outputs else None - workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \ - if execution_metadata else None + workflow_node_execution.execution_metadata = ( + json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None + ) workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds() db.session.commit() db.session.refresh(workflow_node_execution) @@ -253,33 +277,24 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): return workflow_node_execution - def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution, - start_at: float, - error: str, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None - ) -> WorkflowNodeExecution: + def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution: """ Workflow node execution failed - :param workflow_node_execution: workflow node execution - :param start_at: start time - :param error: error message + :param event: queue node failed event :return: """ - inputs = WorkflowEngineManager.handle_special_values(inputs) - outputs = WorkflowEngineManager.handle_special_values(outputs) + workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) + + inputs = WorkflowEntry.handle_special_values(event.inputs) + outputs = WorkflowEntry.handle_special_values(event.outputs) workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value - workflow_node_execution.error = error - workflow_node_execution.elapsed_time = time.perf_counter() - start_at + workflow_node_execution.error = event.error workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow_node_execution.inputs = json.dumps(inputs) if inputs else None - workflow_node_execution.process_data = json.dumps(process_data) if process_data else None + workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None workflow_node_execution.outputs = json.dumps(outputs) if outputs else None - workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \ - if execution_metadata else None + workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds() db.session.commit() db.session.refresh(workflow_node_execution) @@ -287,8 +302,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): return workflow_node_execution - def _workflow_start_to_stream_response(self, task_id: str, - workflow_run: WorkflowRun) -> WorkflowStartStreamResponse: + ################################################# + # to stream responses # + ################################################# + + def _workflow_start_to_stream_response( + self, task_id: str, workflow_run: WorkflowRun + ) -> WorkflowStartStreamResponse: """ Workflow start to stream response. :param task_id: task id @@ -302,13 +322,14 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): id=workflow_run.id, workflow_id=workflow_run.workflow_id, sequence_number=workflow_run.sequence_number, - inputs=workflow_run.inputs_dict, - created_at=int(workflow_run.created_at.timestamp()) - ) + inputs=workflow_run.inputs_dict or {}, + created_at=int(workflow_run.created_at.timestamp()), + ), ) - def _workflow_finish_to_stream_response(self, task_id: str, - workflow_run: WorkflowRun) -> WorkflowFinishStreamResponse: + def _workflow_finish_to_stream_response( + self, task_id: str, workflow_run: WorkflowRun + ) -> WorkflowFinishStreamResponse: """ Workflow finish to stream response. :param task_id: task id @@ -348,14 +369,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): created_by=created_by, created_at=int(workflow_run.created_at.timestamp()), finished_at=int(workflow_run.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict) - ) + files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict or {}), + ), ) - def _workflow_node_start_to_stream_response(self, event: QueueNodeStartedEvent, - task_id: str, - workflow_node_execution: WorkflowNodeExecution) \ - -> NodeStartStreamResponse: + def _workflow_node_start_to_stream_response( + self, event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution + ) -> Optional[NodeStartStreamResponse]: """ Workflow node start to stream response. :param event: queue node started event @@ -363,6 +383,9 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): :param workflow_node_execution: workflow node execution :return: """ + if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: + return None + response = NodeStartStreamResponse( task_id=task_id, workflow_run_id=workflow_node_execution.workflow_run_id, @@ -374,29 +397,42 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): index=workflow_node_execution.index, predecessor_node_id=workflow_node_execution.predecessor_node_id, inputs=workflow_node_execution.inputs_dict, - created_at=int(workflow_node_execution.created_at.timestamp()) - ) + created_at=int(workflow_node_execution.created_at.timestamp()), + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + ), ) # extras logic if event.node_type == NodeType.TOOL: node_data = cast(ToolNodeData, event.node_data) - response.data.extras['icon'] = ToolManager.get_tool_icon( + response.data.extras["icon"] = ToolManager.get_tool_icon( tenant_id=self._application_generate_entity.app_config.tenant_id, provider_type=node_data.provider_type, - provider_id=node_data.provider_id + provider_id=node_data.provider_id, ) return response - def _workflow_node_finish_to_stream_response(self, task_id: str, workflow_node_execution: WorkflowNodeExecution) \ - -> NodeFinishStreamResponse: + def _workflow_node_finish_to_stream_response( + self, + event: QueueNodeSucceededEvent | QueueNodeFailedEvent, + task_id: str, + workflow_node_execution: WorkflowNodeExecution, + ) -> Optional[NodeFinishStreamResponse]: """ Workflow node finish to stream response. + :param event: queue node succeeded or failed event :param task_id: task id :param workflow_node_execution: workflow node execution :return: """ + if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: + return None + return NodeFinishStreamResponse( task_id=task_id, workflow_run_id=workflow_node_execution.workflow_run_id, @@ -416,181 +452,153 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): execution_metadata=workflow_node_execution.execution_metadata_dict, created_at=int(workflow_node_execution.created_at.timestamp()), finished_at=int(workflow_node_execution.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict) - ) + files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + ), ) - def _handle_workflow_start(self) -> WorkflowRun: - self._task_state.start_at = time.perf_counter() - - workflow_run = self._init_workflow_run( - workflow=self._workflow, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING - if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER - else WorkflowRunTriggeredFrom.APP_RUN, - user=self._user, - user_inputs=self._application_generate_entity.inputs, - system_inputs=self._workflow_system_variables + def _workflow_parallel_branch_start_to_stream_response( + self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent + ) -> ParallelBranchStartStreamResponse: + """ + Workflow parallel branch start to stream response + :param task_id: task id + :param workflow_run: workflow run + :param event: parallel branch run started event + :return: + """ + return ParallelBranchStartStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=ParallelBranchStartStreamResponse.Data( + parallel_id=event.parallel_id, + parallel_branch_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + created_at=int(time.time()), + ), ) - self._task_state.workflow_run_id = workflow_run.id - - db.session.close() - - return workflow_run - - def _handle_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() - workflow_node_execution = self._init_node_execution_from_workflow_run( - workflow_run=workflow_run, - node_id=event.node_id, - node_type=event.node_type, - node_title=event.node_data.title, - node_run_index=event.node_run_index, - predecessor_node_id=event.predecessor_node_id + def _workflow_parallel_branch_finished_to_stream_response( + self, + task_id: str, + workflow_run: WorkflowRun, + event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent, + ) -> ParallelBranchFinishedStreamResponse: + """ + Workflow parallel branch finished to stream response + :param task_id: task id + :param workflow_run: workflow run + :param event: parallel branch run succeeded or failed event + :return: + """ + return ParallelBranchFinishedStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=ParallelBranchFinishedStreamResponse.Data( + parallel_id=event.parallel_id, + parallel_branch_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed", + error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None, + created_at=int(time.time()), + ), ) - latest_node_execution_info = NodeExecutionInfo( - workflow_node_execution_id=workflow_node_execution.id, - node_type=event.node_type, - start_at=time.perf_counter() + def _workflow_iteration_start_to_stream_response( + self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent + ) -> IterationNodeStartStreamResponse: + """ + Workflow iteration start to stream response + :param task_id: task id + :param workflow_run: workflow run + :param event: iteration start event + :return: + """ + return IterationNodeStartStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=IterationNodeStartStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, + created_at=int(time.time()), + extras={}, + inputs=event.inputs or {}, + metadata=event.metadata or {}, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + ), ) - self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info - self._task_state.latest_node_execution_info = latest_node_execution_info + def _workflow_iteration_next_to_stream_response( + self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent + ) -> IterationNodeNextStreamResponse: + """ + Workflow iteration next to stream response + :param task_id: task id + :param workflow_run: workflow run + :param event: iteration next event + :return: + """ + return IterationNodeNextStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=IterationNodeNextStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, + index=event.index, + pre_iteration_output=event.output, + created_at=int(time.time()), + extras={}, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + ), + ) - self._task_state.total_steps += 1 - - db.session.close() - - return workflow_node_execution - - def _handle_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution: - current_node_execution = self._task_state.ran_node_execution_infos[event.node_id] - workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first() - - execution_metadata = event.execution_metadata if isinstance(event, QueueNodeSucceededEvent) else None - - if self._iteration_state and self._iteration_state.current_iterations: - if not execution_metadata: - execution_metadata = {} - current_iteration_data = None - for iteration_node_id in self._iteration_state.current_iterations: - data = self._iteration_state.current_iterations[iteration_node_id] - if data.parent_iteration_id == None: - current_iteration_data = data - break - - if current_iteration_data: - execution_metadata[NodeRunMetadataKey.ITERATION_ID] = current_iteration_data.iteration_id - execution_metadata[NodeRunMetadataKey.ITERATION_INDEX] = current_iteration_data.current_index - - if isinstance(event, QueueNodeSucceededEvent): - workflow_node_execution = self._workflow_node_execution_success( - workflow_node_execution=workflow_node_execution, - start_at=current_node_execution.start_at, - inputs=event.inputs, - process_data=event.process_data, + def _workflow_iteration_completed_to_stream_response( + self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent + ) -> IterationNodeCompletedStreamResponse: + """ + Workflow iteration completed to stream response + :param task_id: task id + :param workflow_run: workflow run + :param event: iteration completed event + :return: + """ + return IterationNodeCompletedStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=IterationNodeCompletedStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, outputs=event.outputs, - execution_metadata=execution_metadata - ) - - if execution_metadata and execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): - self._task_state.total_tokens += ( - int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) - - if self._iteration_state: - for iteration_node_id in self._iteration_state.current_iterations: - data = self._iteration_state.current_iterations[iteration_node_id] - if execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): - data.total_tokens += int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)) - - if workflow_node_execution.node_type == NodeType.LLM.value: - outputs = workflow_node_execution.outputs_dict - usage_dict = outputs.get('usage', {}) - self._task_state.metadata['usage'] = usage_dict - else: - workflow_node_execution = self._workflow_node_execution_failed( - workflow_node_execution=workflow_node_execution, - start_at=current_node_execution.start_at, - error=event.error, - inputs=event.inputs, - process_data=event.process_data, - outputs=event.outputs, - execution_metadata=execution_metadata - ) - - db.session.close() - - return workflow_node_execution - - def _handle_workflow_finished( - self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent, - conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None - ) -> Optional[WorkflowRun]: - workflow_run = db.session.query(WorkflowRun).filter( - WorkflowRun.id == self._task_state.workflow_run_id).first() - if not workflow_run: - return None - - if conversation_id is None: - conversation_id = self._application_generate_entity.inputs.get('sys.conversation_id') - if isinstance(event, QueueStopEvent): - workflow_run = self._workflow_run_failed( - workflow_run=workflow_run, - total_tokens=self._task_state.total_tokens, - total_steps=self._task_state.total_steps, - status=WorkflowRunStatus.STOPPED, - error='Workflow stopped.', - conversation_id=conversation_id, - trace_manager=trace_manager - ) - - latest_node_execution_info = self._task_state.latest_node_execution_info - if latest_node_execution_info: - workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == latest_node_execution_info.workflow_node_execution_id).first() - if (workflow_node_execution - and workflow_node_execution.status == WorkflowNodeExecutionStatus.RUNNING.value): - self._workflow_node_execution_failed( - workflow_node_execution=workflow_node_execution, - start_at=latest_node_execution_info.start_at, - error='Workflow stopped.' - ) - elif isinstance(event, QueueWorkflowFailedEvent): - workflow_run = self._workflow_run_failed( - workflow_run=workflow_run, - total_tokens=self._task_state.total_tokens, - total_steps=self._task_state.total_steps, - status=WorkflowRunStatus.FAILED, - error=event.error, - conversation_id=conversation_id, - trace_manager=trace_manager - ) - else: - if self._task_state.latest_node_execution_info: - workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first() - outputs = workflow_node_execution.outputs - else: - outputs = None - - workflow_run = self._workflow_run_success( - workflow_run=workflow_run, - total_tokens=self._task_state.total_tokens, - total_steps=self._task_state.total_steps, - outputs=outputs, - conversation_id=conversation_id, - trace_manager=trace_manager - ) - - self._task_state.workflow_run_id = workflow_run.id - - db.session.close() - - return workflow_run + created_at=int(time.time()), + extras={}, + inputs=event.inputs or {}, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + error=None, + elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(), + total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, + execution_metadata=event.metadata, + finished_at=int(time.time()), + steps=event.steps, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + ), + ) def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]: """ @@ -641,9 +649,45 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): return None if isinstance(value, dict): - if '__variant' in value and value['__variant'] == FileVar.__name__: + if "__variant" in value and value["__variant"] == FileVar.__name__: return value elif isinstance(value, FileVar): return value.to_dict() return None + + def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun: + """ + Refetch workflow run + :param workflow_run_id: workflow run id + :return: + """ + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() + + if not workflow_run: + raise Exception(f"Workflow run not found: {workflow_run_id}") + + return workflow_run + + def _refetch_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution: + """ + Refetch workflow node execution + :param node_execution_id: workflow node execution id + :return: + """ + workflow_node_execution = ( + db.session.query(WorkflowNodeExecution) + .filter( + WorkflowNodeExecution.tenant_id == self._application_generate_entity.app_config.tenant_id, + WorkflowNodeExecution.app_id == self._application_generate_entity.app_config.app_id, + WorkflowNodeExecution.workflow_id == self._workflow.id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.node_execution_id == node_execution_id, + ) + .first() + ) + + if not workflow_node_execution: + raise Exception(f"Workflow node execution not found: {node_execution_id}") + + return workflow_node_execution diff --git a/api/core/app/task_pipeline/workflow_cycle_state_manager.py b/api/core/app/task_pipeline/workflow_cycle_state_manager.py index bd98c82720..e69de29bb2 100644 --- a/api/core/app/task_pipeline/workflow_cycle_state_manager.py +++ b/api/core/app/task_pipeline/workflow_cycle_state_manager.py @@ -1,16 +0,0 @@ -from typing import Any, Union - -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity -from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState -from core.workflow.enums import SystemVariableKey -from models.account import Account -from models.model import EndUser -from models.workflow import Workflow - - -class WorkflowCycleStateManager: - _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] - _workflow: Workflow - _user: Union[Account, EndUser] - _task_state: Union[AdvancedChatTaskState, WorkflowTaskState] - _workflow_system_variables: dict[SystemVariableKey, Any] diff --git a/api/core/app/task_pipeline/workflow_iteration_cycle_manage.py b/api/core/app/task_pipeline/workflow_iteration_cycle_manage.py deleted file mode 100644 index aff1870714..0000000000 --- a/api/core/app/task_pipeline/workflow_iteration_cycle_manage.py +++ /dev/null @@ -1,290 +0,0 @@ -import json -import time -from collections.abc import Generator -from datetime import datetime, timezone -from typing import Optional, Union - -from core.app.entities.queue_entities import ( - QueueIterationCompletedEvent, - QueueIterationNextEvent, - QueueIterationStartEvent, -) -from core.app.entities.task_entities import ( - IterationNodeCompletedStreamResponse, - IterationNodeNextStreamResponse, - IterationNodeStartStreamResponse, - NodeExecutionInfo, - WorkflowIterationState, -) -from core.app.task_pipeline.workflow_cycle_state_manager import WorkflowCycleStateManager -from core.workflow.entities.node_entities import NodeType -from core.workflow.workflow_engine_manager import WorkflowEngineManager -from extensions.ext_database import db -from models.workflow import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, - WorkflowNodeExecutionTriggeredFrom, - WorkflowRun, -) - - -class WorkflowIterationCycleManage(WorkflowCycleStateManager): - _iteration_state: WorkflowIterationState = None - - def _init_iteration_state(self) -> WorkflowIterationState: - if not self._iteration_state: - self._iteration_state = WorkflowIterationState( - current_iterations={} - ) - - def _handle_iteration_to_stream_response(self, task_id: str, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) \ - -> Union[IterationNodeStartStreamResponse, IterationNodeNextStreamResponse, IterationNodeCompletedStreamResponse]: - """ - Handle iteration to stream response - :param task_id: task id - :param event: iteration event - :return: - """ - if isinstance(event, QueueIterationStartEvent): - return IterationNodeStartStreamResponse( - task_id=task_id, - workflow_run_id=self._task_state.workflow_run_id, - data=IterationNodeStartStreamResponse.Data( - id=event.node_id, - node_id=event.node_id, - node_type=event.node_type.value, - title=event.node_data.title, - created_at=int(time.time()), - extras={}, - inputs=event.inputs, - metadata=event.metadata - ) - ) - elif isinstance(event, QueueIterationNextEvent): - current_iteration = self._iteration_state.current_iterations[event.node_id] - - return IterationNodeNextStreamResponse( - task_id=task_id, - workflow_run_id=self._task_state.workflow_run_id, - data=IterationNodeNextStreamResponse.Data( - id=event.node_id, - node_id=event.node_id, - node_type=event.node_type.value, - title=current_iteration.node_data.title, - index=event.index, - pre_iteration_output=event.output, - created_at=int(time.time()), - extras={} - ) - ) - elif isinstance(event, QueueIterationCompletedEvent): - current_iteration = self._iteration_state.current_iterations[event.node_id] - - return IterationNodeCompletedStreamResponse( - task_id=task_id, - workflow_run_id=self._task_state.workflow_run_id, - data=IterationNodeCompletedStreamResponse.Data( - id=event.node_id, - node_id=event.node_id, - node_type=event.node_type.value, - title=current_iteration.node_data.title, - outputs=event.outputs, - created_at=int(time.time()), - extras={}, - inputs=current_iteration.inputs, - status=WorkflowNodeExecutionStatus.SUCCEEDED, - error=None, - elapsed_time=time.perf_counter() - current_iteration.started_at, - total_tokens=current_iteration.total_tokens, - execution_metadata={ - 'total_tokens': current_iteration.total_tokens, - }, - finished_at=int(time.time()), - steps=current_iteration.current_index - ) - ) - - def _init_iteration_execution_from_workflow_run(self, - workflow_run: WorkflowRun, - node_id: str, - node_type: NodeType, - node_title: str, - node_run_index: int = 1, - inputs: Optional[dict] = None, - predecessor_node_id: Optional[str] = None - ) -> WorkflowNodeExecution: - workflow_node_execution = WorkflowNodeExecution( - tenant_id=workflow_run.tenant_id, - app_id=workflow_run.app_id, - workflow_id=workflow_run.workflow_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - workflow_run_id=workflow_run.id, - predecessor_node_id=predecessor_node_id, - index=node_run_index, - node_id=node_id, - node_type=node_type.value, - inputs=json.dumps(inputs) if inputs else None, - title=node_title, - status=WorkflowNodeExecutionStatus.RUNNING.value, - created_by_role=workflow_run.created_by_role, - created_by=workflow_run.created_by, - execution_metadata=json.dumps({ - 'started_run_index': node_run_index + 1, - 'current_index': 0, - 'steps_boundary': [], - }), - created_at=datetime.now(timezone.utc).replace(tzinfo=None) - ) - - db.session.add(workflow_node_execution) - db.session.commit() - db.session.refresh(workflow_node_execution) - db.session.close() - - return workflow_node_execution - - def _handle_iteration_operation(self, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) -> WorkflowNodeExecution: - if isinstance(event, QueueIterationStartEvent): - return self._handle_iteration_started(event) - elif isinstance(event, QueueIterationNextEvent): - return self._handle_iteration_next(event) - elif isinstance(event, QueueIterationCompletedEvent): - return self._handle_iteration_completed(event) - - def _handle_iteration_started(self, event: QueueIterationStartEvent) -> WorkflowNodeExecution: - self._init_iteration_state() - - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() - workflow_node_execution = self._init_iteration_execution_from_workflow_run( - workflow_run=workflow_run, - node_id=event.node_id, - node_type=NodeType.ITERATION, - node_title=event.node_data.title, - node_run_index=event.node_run_index, - inputs=event.inputs, - predecessor_node_id=event.predecessor_node_id - ) - - latest_node_execution_info = NodeExecutionInfo( - workflow_node_execution_id=workflow_node_execution.id, - node_type=NodeType.ITERATION, - start_at=time.perf_counter() - ) - - self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info - self._task_state.latest_node_execution_info = latest_node_execution_info - - self._iteration_state.current_iterations[event.node_id] = WorkflowIterationState.Data( - parent_iteration_id=None, - iteration_id=event.node_id, - current_index=0, - iteration_steps_boundary=[], - node_execution_id=workflow_node_execution.id, - started_at=time.perf_counter(), - inputs=event.inputs, - total_tokens=0, - node_data=event.node_data - ) - - db.session.close() - - return workflow_node_execution - - def _handle_iteration_next(self, event: QueueIterationNextEvent) -> WorkflowNodeExecution: - if event.node_id not in self._iteration_state.current_iterations: - return - current_iteration = self._iteration_state.current_iterations[event.node_id] - current_iteration.current_index = event.index - current_iteration.iteration_steps_boundary.append(event.node_run_index) - workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == current_iteration.node_execution_id - ).first() - - original_node_execution_metadata = workflow_node_execution.execution_metadata_dict - if original_node_execution_metadata: - original_node_execution_metadata['current_index'] = event.index - original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary - original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens - workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata) - - db.session.commit() - - db.session.close() - - def _handle_iteration_completed(self, event: QueueIterationCompletedEvent): - if event.node_id not in self._iteration_state.current_iterations: - return - - current_iteration = self._iteration_state.current_iterations[event.node_id] - workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == current_iteration.node_execution_id - ).first() - - workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value - workflow_node_execution.outputs = json.dumps(WorkflowEngineManager.handle_special_values(event.outputs)) if event.outputs else None - workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at - - original_node_execution_metadata = workflow_node_execution.execution_metadata_dict - if original_node_execution_metadata: - original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary - original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens - workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata) - - db.session.commit() - - # remove current iteration - self._iteration_state.current_iterations.pop(event.node_id, None) - - # set latest node execution info - latest_node_execution_info = NodeExecutionInfo( - workflow_node_execution_id=workflow_node_execution.id, - node_type=NodeType.ITERATION, - start_at=time.perf_counter() - ) - - self._task_state.latest_node_execution_info = latest_node_execution_info - - db.session.close() - - def _handle_iteration_exception(self, task_id: str, error: str) -> Generator[IterationNodeCompletedStreamResponse, None, None]: - """ - Handle iteration exception - """ - if not self._iteration_state or not self._iteration_state.current_iterations: - return - - for node_id, current_iteration in self._iteration_state.current_iterations.items(): - workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == current_iteration.node_execution_id - ).first() - - workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value - workflow_node_execution.error = error - workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at - - db.session.commit() - db.session.close() - - yield IterationNodeCompletedStreamResponse( - task_id=task_id, - workflow_run_id=self._task_state.workflow_run_id, - data=IterationNodeCompletedStreamResponse.Data( - id=node_id, - node_id=node_id, - node_type=NodeType.ITERATION.value, - title=current_iteration.node_data.title, - outputs={}, - created_at=int(time.time()), - extras={}, - inputs=current_iteration.inputs, - status=WorkflowNodeExecutionStatus.FAILED, - error=error, - elapsed_time=time.perf_counter() - current_iteration.started_at, - total_tokens=current_iteration.total_tokens, - execution_metadata={ - 'total_tokens': current_iteration.total_tokens, - }, - finished_at=int(time.time()), - steps=current_iteration.current_index - ) - ) diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index 5789965747..99e992fd89 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -16,31 +16,32 @@ _TEXT_COLOR_MAPPING = { "red": "31;1", } + def get_colored_text(text: str, color: str) -> str: """Get colored text.""" color_str = _TEXT_COLOR_MAPPING[color] return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" -def print_text( - text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None -) -> None: +def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None) -> None: """Print text with highlighting and no end characters.""" text_to_print = get_colored_text(text, color) if color else text print(text_to_print, end=end, file=file) if file: file.flush() # ensure all printed content are written to file + class DifyAgentCallbackHandler(BaseModel): """Callback Handler that prints to std out.""" - color: Optional[str] = '' + + color: Optional[str] = "" current_loop: int = 1 def __init__(self, color: Optional[str] = None) -> None: super().__init__() """Initialize callback handler.""" # use a specific color is not specified - self.color = color or 'green' + self.color = color or "green" self.current_loop = 1 def on_tool_start( @@ -58,7 +59,7 @@ class DifyAgentCallbackHandler(BaseModel): tool_outputs: Sequence[ToolInvokeMessage], message_id: Optional[str] = None, timer: Optional[Any] = None, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> None: """If not the final action, print out observation.""" print_text("\n[on_tool_end]\n", color=self.color) @@ -79,26 +80,21 @@ class DifyAgentCallbackHandler(BaseModel): ) ) - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None: """Do nothing.""" - print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='red') + print_text("\n[on_tool_error] Error: " + str(error) + "\n", color="red") - def on_agent_start( - self, thought: str - ) -> None: + def on_agent_start(self, thought: str) -> None: """Run on agent start.""" if thought: - print_text("\n[on_agent_start] \nCurrent Loop: " + \ - str(self.current_loop) + \ - "\nThought: " + thought + "\n", color=self.color) + print_text( + "\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\nThought: " + thought + "\n", + color=self.color, + ) else: print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color) - def on_agent_finish( - self, color: Optional[str] = None, **kwargs: Any - ) -> None: + def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any) -> None: """Run on agent end.""" print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color) @@ -107,9 +103,9 @@ class DifyAgentCallbackHandler(BaseModel): @property def ignore_agent(self) -> bool: """Whether to ignore agent callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' + return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != "true" @property def ignore_chat_model(self) -> bool: """Whether to ignore chat model callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' + return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != "true" diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 8e1f496b22..6d5393ce5c 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,4 +1,3 @@ - from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueRetrieverResourcesEvent @@ -11,11 +10,9 @@ from models.model import DatasetRetrieverResource class DatasetIndexToolCallbackHandler: """Callback handler for dataset tool.""" - def __init__(self, queue_manager: AppQueueManager, - app_id: str, - message_id: str, - user_id: str, - invoke_from: InvokeFrom) -> None: + def __init__( + self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, invoke_from: InvokeFrom + ) -> None: self._queue_manager = queue_manager self._app_id = app_id self._message_id = message_id @@ -29,11 +26,12 @@ class DatasetIndexToolCallbackHandler: dataset_query = DatasetQuery( dataset_id=dataset_id, content=query, - source='app', + source="app", source_app_id=self._app_id, - created_by_role=('account' - if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'), - created_by=self._user_id + created_by_role=( + "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user" + ), + created_by=self._user_id, ) db.session.add(dataset_query) @@ -43,18 +41,15 @@ class DatasetIndexToolCallbackHandler: """Handle tool end.""" for document in documents: query = db.session.query(DocumentSegment).filter( - DocumentSegment.index_node_id == document.metadata['doc_id'] + DocumentSegment.index_node_id == document.metadata["doc_id"] ) # if 'dataset_id' in document.metadata: - if 'dataset_id' in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id']) + if "dataset_id" in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) # add hit count to document segment - query.update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, - synchronize_session=False - ) + query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) db.session.commit() @@ -64,26 +59,25 @@ class DatasetIndexToolCallbackHandler: for item in resource: dataset_retriever_resource = DatasetRetrieverResource( message_id=self._message_id, - position=item.get('position'), - dataset_id=item.get('dataset_id'), - dataset_name=item.get('dataset_name'), - document_id=item.get('document_id'), - document_name=item.get('document_name'), - data_source_type=item.get('data_source_type'), - segment_id=item.get('segment_id'), - score=item.get('score') if 'score' in item else None, - hit_count=item.get('hit_count') if 'hit_count' else None, - word_count=item.get('word_count') if 'word_count' in item else None, - segment_position=item.get('segment_position') if 'segment_position' in item else None, - index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None, - content=item.get('content'), - retriever_from=item.get('retriever_from'), - created_by=self._user_id + position=item.get("position"), + dataset_id=item.get("dataset_id"), + dataset_name=item.get("dataset_name"), + document_id=item.get("document_id"), + document_name=item.get("document_name"), + data_source_type=item.get("data_source_type"), + segment_id=item.get("segment_id"), + score=item.get("score") if "score" in item else None, + hit_count=item.get("hit_count") if "hit_count" in item else None, + word_count=item.get("word_count") if "word_count" in item else None, + segment_position=item.get("segment_position") if "segment_position" in item else None, + index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None, + content=item.get("content"), + retriever_from=item.get("retriever_from"), + created_by=self._user_id, ) db.session.add(dataset_retriever_resource) db.session.commit() self._queue_manager.publish( - QueueRetrieverResourcesEvent(retriever_resources=resource), - PublishFrom.APPLICATION_MANAGER + QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER ) diff --git a/api/core/callback_handler/workflow_tool_callback_handler.py b/api/core/callback_handler/workflow_tool_callback_handler.py index 84bab7e1a3..8ac12f72f2 100644 --- a/api/core/callback_handler/workflow_tool_callback_handler.py +++ b/api/core/callback_handler/workflow_tool_callback_handler.py @@ -2,4 +2,4 @@ from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackH class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler): - """Callback Handler that prints to std out.""" \ No newline at end of file + """Callback Handler that prints to std out.""" diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py index b7e0cc0c2b..4cc793b0d7 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/embedding/cached_embedding.py @@ -29,9 +29,13 @@ class CacheEmbedding(Embeddings): embedding_queue_indices = [] for i, text in enumerate(texts): hash = helper.generate_text_hash(text) - embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, - hash=hash, - provider_name=self._model_instance.provider).first() + embedding = ( + db.session.query(Embedding) + .filter_by( + model_name=self._model_instance.model, hash=hash, provider_name=self._model_instance.provider + ) + .first() + ) if embedding: text_embeddings[i] = embedding.get_embedding() else: @@ -41,17 +45,18 @@ class CacheEmbedding(Embeddings): embedding_queue_embeddings = [] try: model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) - model_schema = model_type_instance.get_model_schema(self._model_instance.model, - self._model_instance.credentials) - max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \ - if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1 + model_schema = model_type_instance.get_model_schema( + self._model_instance.model, self._model_instance.credentials + ) + max_chunks = ( + model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties + else 1 + ) for i in range(0, len(embedding_queue_texts), max_chunks): - batch_texts = embedding_queue_texts[i:i + max_chunks] + batch_texts = embedding_queue_texts[i : i + max_chunks] - embedding_result = self._model_instance.invoke_text_embedding( - texts=batch_texts, - user=self._user - ) + embedding_result = self._model_instance.invoke_text_embedding(texts=batch_texts, user=self._user) for vector in embedding_result.embeddings: try: @@ -60,16 +65,18 @@ class CacheEmbedding(Embeddings): except IntegrityError: db.session.rollback() except Exception as e: - logging.exception('Failed transform embedding: ', e) + logging.exception("Failed transform embedding: ", e) cache_embeddings = [] try: for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings): text_embeddings[i] = embedding hash = helper.generate_text_hash(texts[i]) if hash not in cache_embeddings: - embedding_cache = Embedding(model_name=self._model_instance.model, - hash=hash, - provider_name=self._model_instance.provider) + embedding_cache = Embedding( + model_name=self._model_instance.model, + hash=hash, + provider_name=self._model_instance.provider, + ) embedding_cache.set_embedding(embedding) db.session.add(embedding_cache) cache_embeddings.append(hash) @@ -78,7 +85,7 @@ class CacheEmbedding(Embeddings): db.session.rollback() except Exception as ex: db.session.rollback() - logger.error('Failed to embed documents: ', ex) + logger.error("Failed to embed documents: ", ex) raise ex return text_embeddings @@ -87,16 +94,13 @@ class CacheEmbedding(Embeddings): """Embed query text.""" # use doc embedding cache or store if not exists hash = helper.generate_text_hash(text) - embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}' + embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}" embedding = redis_client.get(embedding_cache_key) if embedding: redis_client.expire(embedding_cache_key, 600) return list(np.frombuffer(base64.b64decode(embedding), dtype="float")) try: - embedding_result = self._model_instance.invoke_text_embedding( - texts=[text], - user=self._user - ) + embedding_result = self._model_instance.invoke_text_embedding(texts=[text], user=self._user) embedding_results = embedding_result.embeddings[0] embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() @@ -116,6 +120,6 @@ class CacheEmbedding(Embeddings): except IntegrityError: db.session.rollback() except: - logging.exception('Failed to add embedding to redis') + logging.exception("Failed to add embedding to redis") return embedding_results diff --git a/api/core/entities/agent_entities.py b/api/core/entities/agent_entities.py index 0cdf8670c4..656bf4aa72 100644 --- a/api/core/entities/agent_entities.py +++ b/api/core/entities/agent_entities.py @@ -2,7 +2,7 @@ from enum import Enum class PlanningStrategy(Enum): - ROUTER = 'router' - REACT_ROUTER = 'react_router' - REACT = 'react' - FUNCTION_CALL = 'function_call' + ROUTER = "router" + REACT_ROUTER = "react_router" + REACT = "react" + FUNCTION_CALL = "function_call" diff --git a/api/core/entities/message_entities.py b/api/core/entities/message_entities.py index 370aeee463..10bc9f6ed7 100644 --- a/api/core/entities/message_entities.py +++ b/api/core/entities/message_entities.py @@ -5,7 +5,7 @@ from pydantic import BaseModel class PromptMessageFileType(enum.Enum): - IMAGE = 'image' + IMAGE = "image" @staticmethod def value_of(value): @@ -22,8 +22,8 @@ class PromptMessageFile(BaseModel): class ImagePromptMessageFile(PromptMessageFile): class DETAIL(enum.Enum): - LOW = 'low' - HIGH = 'high' + LOW = "low" + HIGH = "high" type: PromptMessageFileType = PromptMessageFileType.IMAGE detail: DETAIL = DETAIL.LOW diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 22a21ecf93..9ed5528e43 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -12,6 +12,7 @@ class ModelStatus(Enum): """ Enum class for model status. """ + ACTIVE = "active" NO_CONFIGURE = "no-configure" QUOTA_EXCEEDED = "quota-exceeded" @@ -23,6 +24,7 @@ class SimpleModelProviderEntity(BaseModel): """ Simple provider. """ + provider: str label: I18nObject icon_small: Optional[I18nObject] = None @@ -40,7 +42,7 @@ class SimpleModelProviderEntity(BaseModel): label=provider_entity.label, icon_small=provider_entity.icon_small, icon_large=provider_entity.icon_large, - supported_model_types=provider_entity.supported_model_types + supported_model_types=provider_entity.supported_model_types, ) @@ -48,6 +50,7 @@ class ProviderModelWithStatusEntity(ProviderModel): """ Model class for model response. """ + status: ModelStatus load_balancing_enabled: bool = False @@ -56,6 +59,7 @@ class ModelWithProviderEntity(ProviderModelWithStatusEntity): """ Model with provider entity. """ + provider: SimpleModelProviderEntity @@ -63,6 +67,7 @@ class DefaultModelProviderEntity(BaseModel): """ Default model provider entity. """ + provider: str label: I18nObject icon_small: Optional[I18nObject] = None @@ -74,6 +79,7 @@ class DefaultModelEntity(BaseModel): """ Default model entity. """ + model: str model_type: ModelType provider: DefaultModelProviderEntity diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 778ef2e1ac..4797b69b85 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -47,6 +47,7 @@ class ProviderConfiguration(BaseModel): """ Model class for provider configuration. """ + tenant_id: str provider: ProviderEntity preferred_provider_type: ProviderType @@ -67,9 +68,13 @@ class ProviderConfiguration(BaseModel): original_provider_configurate_methods[self.provider.provider].append(configurate_method) if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: - if (any(len(quota_configuration.restrict_models) > 0 - for quota_configuration in self.system_configuration.quota_configurations) - and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods): + if ( + any( + len(quota_configuration.restrict_models) > 0 + for quota_configuration in self.system_configuration.quota_configurations + ) + and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods + ): self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: @@ -83,10 +88,9 @@ class ProviderConfiguration(BaseModel): if self.model_settings: # check if model is disabled by admin for model_setting in self.model_settings: - if (model_setting.model_type == model_type - and model_setting.model == model): + if model_setting.model_type == model_type and model_setting.model == model: if not model_setting.enabled: - raise ValueError(f'Model {model} is disabled.') + raise ValueError(f"Model {model} is disabled.") if self.using_provider_type == ProviderType.SYSTEM: restrict_models = [] @@ -99,10 +103,12 @@ class ProviderConfiguration(BaseModel): copy_credentials = self.system_configuration.credentials.copy() if restrict_models: for restrict_model in restrict_models: - if (restrict_model.model_type == model_type - and restrict_model.model == model - and restrict_model.base_model_name): - copy_credentials['base_model_name'] = restrict_model.base_model_name + if ( + restrict_model.model_type == model_type + and restrict_model.model == model + and restrict_model.base_model_name + ): + copy_credentials["base_model_name"] = restrict_model.base_model_name return copy_credentials else: @@ -128,20 +134,21 @@ class ProviderConfiguration(BaseModel): current_quota_type = self.system_configuration.current_quota_type current_quota_configuration = next( - (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), - None + (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None ) - return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \ - SystemConfigurationStatus.QUOTA_EXCEEDED + return ( + SystemConfigurationStatus.ACTIVE + if current_quota_configuration.is_valid + else SystemConfigurationStatus.QUOTA_EXCEEDED + ) def is_custom_configuration_available(self) -> bool: """ Check custom configuration available. :return: """ - return (self.custom_configuration.provider is not None - or len(self.custom_configuration.models) > 0) + return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0 def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: """ @@ -161,7 +168,8 @@ class ProviderConfiguration(BaseModel): return self.obfuscated_credentials( credentials=credentials, credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas - if self.provider.provider_credential_schema else [] + if self.provider.provider_credential_schema + else [], ) def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]: @@ -171,17 +179,21 @@ class ProviderConfiguration(BaseModel): :return: """ # get provider - provider_record = db.session.query(Provider) \ + provider_record = ( + db.session.query(Provider) .filter( - Provider.tenant_id == self.tenant_id, - Provider.provider_name == self.provider.provider, - Provider.provider_type == ProviderType.CUSTOM.value - ).first() + Provider.tenant_id == self.tenant_id, + Provider.provider_name == self.provider.provider, + Provider.provider_type == ProviderType.CUSTOM.value, + ) + .first() + ) # Get provider credential secret variables provider_credential_secret_variables = self.extract_secret_variables( self.provider.provider_credential_schema.credential_form_schemas - if self.provider.provider_credential_schema else [] + if self.provider.provider_credential_schema + else [] ) if provider_record: @@ -189,9 +201,7 @@ class ProviderConfiguration(BaseModel): # fix origin data if provider_record.encrypted_config: if not provider_record.encrypted_config.startswith("{"): - original_credentials = { - "openai_api_key": provider_record.encrypted_config - } + original_credentials = {"openai_api_key": provider_record.encrypted_config} else: original_credentials = json.loads(provider_record.encrypted_config) else: @@ -207,8 +217,7 @@ class ProviderConfiguration(BaseModel): credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) credentials = model_provider_factory.provider_credentials_validate( - provider=self.provider.provider, - credentials=credentials + provider=self.provider.provider, credentials=credentials ) for key, value in credentials.items(): @@ -239,15 +248,13 @@ class ProviderConfiguration(BaseModel): provider_name=self.provider.provider, provider_type=ProviderType.CUSTOM.value, encrypted_config=json.dumps(credentials), - is_valid=True + is_valid=True, ) db.session.add(provider_record) db.session.commit() provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER + tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER ) provider_model_credentials_cache.delete() @@ -260,12 +267,15 @@ class ProviderConfiguration(BaseModel): :return: """ # get provider - provider_record = db.session.query(Provider) \ + provider_record = ( + db.session.query(Provider) .filter( - Provider.tenant_id == self.tenant_id, - Provider.provider_name == self.provider.provider, - Provider.provider_type == ProviderType.CUSTOM.value - ).first() + Provider.tenant_id == self.tenant_id, + Provider.provider_name == self.provider.provider, + Provider.provider_type == ProviderType.CUSTOM.value, + ) + .first() + ) # delete provider if provider_record: @@ -277,13 +287,14 @@ class ProviderConfiguration(BaseModel): provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER + cache_type=ProviderCredentialsCacheType.PROVIDER, ) provider_model_credentials_cache.delete() - def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \ - -> Optional[dict]: + def get_custom_model_credentials( + self, model_type: ModelType, model: str, obfuscated: bool = False + ) -> Optional[dict]: """ Get custom model credentials. @@ -305,13 +316,15 @@ class ProviderConfiguration(BaseModel): return self.obfuscated_credentials( credentials=credentials, credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas - if self.provider.model_credential_schema else [] + if self.provider.model_credential_schema + else [], ) return None - def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \ - -> tuple[ProviderModel, dict]: + def custom_model_credentials_validate( + self, model_type: ModelType, model: str, credentials: dict + ) -> tuple[ProviderModel, dict]: """ Validate custom model credentials. @@ -321,24 +334,29 @@ class ProviderConfiguration(BaseModel): :return: """ # get provider model - provider_model_record = db.session.query(ProviderModel) \ + provider_model_record = ( + db.session.query(ProviderModel) .filter( - ProviderModel.tenant_id == self.tenant_id, - ProviderModel.provider_name == self.provider.provider, - ProviderModel.model_name == model, - ProviderModel.model_type == model_type.to_origin_model_type() - ).first() + ProviderModel.tenant_id == self.tenant_id, + ProviderModel.provider_name == self.provider.provider, + ProviderModel.model_name == model, + ProviderModel.model_type == model_type.to_origin_model_type(), + ) + .first() + ) # Get provider credential secret variables provider_credential_secret_variables = self.extract_secret_variables( self.provider.model_credential_schema.credential_form_schemas - if self.provider.model_credential_schema else [] + if self.provider.model_credential_schema + else [] ) if provider_model_record: try: - original_credentials = json.loads( - provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} + original_credentials = ( + json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} + ) except JSONDecodeError: original_credentials = {} @@ -350,10 +368,7 @@ class ProviderConfiguration(BaseModel): credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) credentials = model_provider_factory.model_credentials_validate( - provider=self.provider.provider, - model_type=model_type, - model=model, - credentials=credentials + provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) for key, value in credentials.items(): @@ -388,7 +403,7 @@ class ProviderConfiguration(BaseModel): model_name=model, model_type=model_type.to_origin_model_type(), encrypted_config=json.dumps(credentials), - is_valid=True + is_valid=True, ) db.session.add(provider_model_record) db.session.commit() @@ -396,7 +411,7 @@ class ProviderConfiguration(BaseModel): provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_model_record.id, - cache_type=ProviderCredentialsCacheType.MODEL + cache_type=ProviderCredentialsCacheType.MODEL, ) provider_model_credentials_cache.delete() @@ -409,13 +424,16 @@ class ProviderConfiguration(BaseModel): :return: """ # get provider model - provider_model_record = db.session.query(ProviderModel) \ + provider_model_record = ( + db.session.query(ProviderModel) .filter( - ProviderModel.tenant_id == self.tenant_id, - ProviderModel.provider_name == self.provider.provider, - ProviderModel.model_name == model, - ProviderModel.model_type == model_type.to_origin_model_type() - ).first() + ProviderModel.tenant_id == self.tenant_id, + ProviderModel.provider_name == self.provider.provider, + ProviderModel.model_name == model, + ProviderModel.model_type == model_type.to_origin_model_type(), + ) + .first() + ) # delete provider model if provider_model_record: @@ -425,7 +443,7 @@ class ProviderConfiguration(BaseModel): provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_model_record.id, - cache_type=ProviderCredentialsCacheType.MODEL + cache_type=ProviderCredentialsCacheType.MODEL, ) provider_model_credentials_cache.delete() @@ -437,13 +455,16 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_setting = db.session.query(ProviderModelSetting) \ + model_setting = ( + db.session.query(ProviderModelSetting) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model - ).first() + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) if model_setting: model_setting.enabled = True @@ -455,7 +476,7 @@ class ProviderConfiguration(BaseModel): provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, - enabled=True + enabled=True, ) db.session.add(model_setting) db.session.commit() @@ -469,13 +490,16 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_setting = db.session.query(ProviderModelSetting) \ + model_setting = ( + db.session.query(ProviderModelSetting) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model - ).first() + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) if model_setting: model_setting.enabled = False @@ -487,7 +511,7 @@ class ProviderConfiguration(BaseModel): provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, - enabled=False + enabled=False, ) db.session.add(model_setting) db.session.commit() @@ -501,13 +525,16 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - return db.session.query(ProviderModelSetting) \ + return ( + db.session.query(ProviderModelSetting) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model - ).first() + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: """ @@ -516,24 +543,30 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - load_balancing_config_count = db.session.query(LoadBalancingModelConfig) \ + load_balancing_config_count = ( + db.session.query(LoadBalancingModelConfig) .filter( - LoadBalancingModelConfig.tenant_id == self.tenant_id, - LoadBalancingModelConfig.provider_name == self.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model - ).count() + LoadBalancingModelConfig.tenant_id == self.tenant_id, + LoadBalancingModelConfig.provider_name == self.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + ) + .count() + ) if load_balancing_config_count <= 1: - raise ValueError('Model load balancing configuration must be more than 1.') + raise ValueError("Model load balancing configuration must be more than 1.") - model_setting = db.session.query(ProviderModelSetting) \ + model_setting = ( + db.session.query(ProviderModelSetting) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model - ).first() + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) if model_setting: model_setting.load_balancing_enabled = True @@ -545,7 +578,7 @@ class ProviderConfiguration(BaseModel): provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, - load_balancing_enabled=True + load_balancing_enabled=True, ) db.session.add(model_setting) db.session.commit() @@ -559,13 +592,16 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_setting = db.session.query(ProviderModelSetting) \ + model_setting = ( + db.session.query(ProviderModelSetting) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model - ).first() + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) if model_setting: model_setting.load_balancing_enabled = False @@ -577,7 +613,7 @@ class ProviderConfiguration(BaseModel): provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, - load_balancing_enabled=False + load_balancing_enabled=False, ) db.session.add(model_setting) db.session.commit() @@ -617,11 +653,14 @@ class ProviderConfiguration(BaseModel): return # get preferred provider - preferred_model_provider = db.session.query(TenantPreferredModelProvider) \ + preferred_model_provider = ( + db.session.query(TenantPreferredModelProvider) .filter( - TenantPreferredModelProvider.tenant_id == self.tenant_id, - TenantPreferredModelProvider.provider_name == self.provider.provider - ).first() + TenantPreferredModelProvider.tenant_id == self.tenant_id, + TenantPreferredModelProvider.provider_name == self.provider.provider, + ) + .first() + ) if preferred_model_provider: preferred_model_provider.preferred_provider_type = provider_type.value @@ -629,7 +668,7 @@ class ProviderConfiguration(BaseModel): preferred_model_provider = TenantPreferredModelProvider( tenant_id=self.tenant_id, provider_name=self.provider.provider, - preferred_provider_type=provider_type.value + preferred_provider_type=provider_type.value, ) db.session.add(preferred_model_provider) @@ -658,9 +697,7 @@ class ProviderConfiguration(BaseModel): :return: """ # Get provider credential secret variables - credential_secret_variables = self.extract_secret_variables( - credential_form_schemas - ) + credential_secret_variables = self.extract_secret_variables(credential_form_schemas) # Obfuscate provider credentials copy_credentials = credentials.copy() @@ -670,9 +707,9 @@ class ProviderConfiguration(BaseModel): return copy_credentials - def get_provider_model(self, model_type: ModelType, - model: str, - only_active: bool = False) -> Optional[ModelWithProviderEntity]: + def get_provider_model( + self, model_type: ModelType, model: str, only_active: bool = False + ) -> Optional[ModelWithProviderEntity]: """ Get provider model. :param model_type: model type @@ -688,8 +725,9 @@ class ProviderConfiguration(BaseModel): return None - def get_provider_models(self, model_type: Optional[ModelType] = None, - only_active: bool = False) -> list[ModelWithProviderEntity]: + def get_provider_models( + self, model_type: Optional[ModelType] = None, only_active: bool = False + ) -> list[ModelWithProviderEntity]: """ Get provider models. :param model_type: model type @@ -711,15 +749,11 @@ class ProviderConfiguration(BaseModel): if self.using_provider_type == ProviderType.SYSTEM: provider_models = self._get_system_provider_models( - model_types=model_types, - provider_instance=provider_instance, - model_setting_map=model_setting_map + model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map ) else: provider_models = self._get_custom_provider_models( - model_types=model_types, - provider_instance=provider_instance, - model_setting_map=model_setting_map + model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map ) if only_active: @@ -728,11 +762,12 @@ class ProviderConfiguration(BaseModel): # resort provider_models return sorted(provider_models, key=lambda x: x.model_type.value) - def _get_system_provider_models(self, - model_types: list[ModelType], - provider_instance: ModelProvider, - model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \ - -> list[ModelWithProviderEntity]: + def _get_system_provider_models( + self, + model_types: list[ModelType], + provider_instance: ModelProvider, + model_setting_map: dict[ModelType, dict[str, ModelSettings]], + ) -> list[ModelWithProviderEntity]: """ Get system provider models. @@ -760,7 +795,7 @@ class ProviderConfiguration(BaseModel): model_properties=m.model_properties, deprecated=m.deprecated, provider=SimpleModelProviderEntity(self.provider), - status=status + status=status, ) ) @@ -783,23 +818,20 @@ class ProviderConfiguration(BaseModel): if should_use_custom_model: if original_provider_configurate_methods[self.provider.provider] == [ - ConfigurateMethod.CUSTOMIZABLE_MODEL]: + ConfigurateMethod.CUSTOMIZABLE_MODEL + ]: # only customizable model for restrict_model in restrict_models: copy_credentials = self.system_configuration.credentials.copy() if restrict_model.base_model_name: - copy_credentials['base_model_name'] = restrict_model.base_model_name + copy_credentials["base_model_name"] = restrict_model.base_model_name try: - custom_model_schema = ( - provider_instance.get_model_instance(restrict_model.model_type) - .get_customizable_model_schema_from_credentials( - restrict_model.model, - copy_credentials - ) - ) + custom_model_schema = provider_instance.get_model_instance( + restrict_model.model_type + ).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials) except Exception as ex: - logger.warning(f'get custom model schema failed, {ex}') + logger.warning(f"get custom model schema failed, {ex}") continue if not custom_model_schema: @@ -809,8 +841,10 @@ class ProviderConfiguration(BaseModel): continue status = ModelStatus.ACTIVE - if (custom_model_schema.model_type in model_setting_map - and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]): + if ( + custom_model_schema.model_type in model_setting_map + and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] + ): model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] if model_setting.enabled is False: status = ModelStatus.DISABLED @@ -825,7 +859,7 @@ class ProviderConfiguration(BaseModel): model_properties=custom_model_schema.model_properties, deprecated=custom_model_schema.deprecated, provider=SimpleModelProviderEntity(self.provider), - status=status + status=status, ) ) @@ -839,11 +873,12 @@ class ProviderConfiguration(BaseModel): return provider_models - def _get_custom_provider_models(self, - model_types: list[ModelType], - provider_instance: ModelProvider, - model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \ - -> list[ModelWithProviderEntity]: + def _get_custom_provider_models( + self, + model_types: list[ModelType], + provider_instance: ModelProvider, + model_setting_map: dict[ModelType, dict[str, ModelSettings]], + ) -> list[ModelWithProviderEntity]: """ Get custom provider models. @@ -885,7 +920,7 @@ class ProviderConfiguration(BaseModel): deprecated=m.deprecated, provider=SimpleModelProviderEntity(self.provider), status=status, - load_balancing_enabled=load_balancing_enabled + load_balancing_enabled=load_balancing_enabled, ) ) @@ -895,15 +930,13 @@ class ProviderConfiguration(BaseModel): continue try: - custom_model_schema = ( - provider_instance.get_model_instance(model_configuration.model_type) - .get_customizable_model_schema_from_credentials( - model_configuration.model, - model_configuration.credentials - ) + custom_model_schema = provider_instance.get_model_instance( + model_configuration.model_type + ).get_customizable_model_schema_from_credentials( + model_configuration.model, model_configuration.credentials ) except Exception as ex: - logger.warning(f'get custom model schema failed, {ex}') + logger.warning(f"get custom model schema failed, {ex}") continue if not custom_model_schema: @@ -911,8 +944,10 @@ class ProviderConfiguration(BaseModel): status = ModelStatus.ACTIVE load_balancing_enabled = False - if (custom_model_schema.model_type in model_setting_map - and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]): + if ( + custom_model_schema.model_type in model_setting_map + and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] + ): model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] if model_setting.enabled is False: status = ModelStatus.DISABLED @@ -931,7 +966,7 @@ class ProviderConfiguration(BaseModel): deprecated=custom_model_schema.deprecated, provider=SimpleModelProviderEntity(self.provider), status=status, - load_balancing_enabled=load_balancing_enabled + load_balancing_enabled=load_balancing_enabled, ) ) @@ -942,17 +977,16 @@ class ProviderConfigurations(BaseModel): """ Model class for provider configuration dict. """ + tenant_id: str configurations: dict[str, ProviderConfiguration] = {} def __init__(self, tenant_id: str): super().__init__(tenant_id=tenant_id) - def get_models(self, - provider: Optional[str] = None, - model_type: Optional[ModelType] = None, - only_active: bool = False) \ - -> list[ModelWithProviderEntity]: + def get_models( + self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False + ) -> list[ModelWithProviderEntity]: """ Get available models. @@ -1019,10 +1053,10 @@ class ProviderModelBundle(BaseModel): """ Provider model bundle. """ + configuration: ProviderConfiguration provider_instance: ModelProvider model_type_instance: AIModel # pydantic configs - model_config = ConfigDict(arbitrary_types_allowed=True, - protected_namespaces=()) + model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 0d5b0a1b2c..44725623dc 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -8,18 +8,19 @@ from models.provider import ProviderQuotaType class QuotaUnit(Enum): - TIMES = 'times' - TOKENS = 'tokens' - CREDITS = 'credits' + TIMES = "times" + TOKENS = "tokens" + CREDITS = "credits" class SystemConfigurationStatus(Enum): """ Enum class for system configuration status. """ - ACTIVE = 'active' - QUOTA_EXCEEDED = 'quota-exceeded' - UNSUPPORTED = 'unsupported' + + ACTIVE = "active" + QUOTA_EXCEEDED = "quota-exceeded" + UNSUPPORTED = "unsupported" class RestrictModel(BaseModel): @@ -35,6 +36,7 @@ class QuotaConfiguration(BaseModel): """ Model class for provider quota configuration. """ + quota_type: ProviderQuotaType quota_unit: QuotaUnit quota_limit: int @@ -47,6 +49,7 @@ class SystemConfiguration(BaseModel): """ Model class for provider system configuration. """ + enabled: bool current_quota_type: Optional[ProviderQuotaType] = None quota_configurations: list[QuotaConfiguration] = [] @@ -57,6 +60,7 @@ class CustomProviderConfiguration(BaseModel): """ Model class for provider custom configuration. """ + credentials: dict @@ -64,6 +68,7 @@ class CustomModelConfiguration(BaseModel): """ Model class for provider custom model configuration. """ + model: str model_type: ModelType credentials: dict @@ -76,6 +81,7 @@ class CustomConfiguration(BaseModel): """ Model class for provider custom configuration. """ + provider: Optional[CustomProviderConfiguration] = None models: list[CustomModelConfiguration] = [] @@ -84,6 +90,7 @@ class ModelLoadBalancingConfiguration(BaseModel): """ Class for model load balancing configuration. """ + id: str name: str credentials: dict @@ -93,6 +100,7 @@ class ModelSettings(BaseModel): """ Model class for model settings. """ + model: str model_type: ModelType enabled: bool = True diff --git a/api/core/errors/error.py b/api/core/errors/error.py index 53323a2eeb..3b186476eb 100644 --- a/api/core/errors/error.py +++ b/api/core/errors/error.py @@ -3,6 +3,7 @@ from typing import Optional class LLMError(Exception): """Base class for all LLM exceptions.""" + description: Optional[str] = None def __init__(self, description: Optional[str] = None) -> None: @@ -11,6 +12,7 @@ class LLMError(Exception): class LLMBadRequestError(LLMError): """Raised when the LLM returns bad request.""" + description = "Bad Request" @@ -18,6 +20,7 @@ class ProviderTokenNotInitError(Exception): """ Custom exception raised when the provider token is not initialized. """ + description = "Provider Token Not Init" def __init__(self, *args, **kwargs): @@ -28,6 +31,7 @@ class QuotaExceededError(Exception): """ Custom exception raised when the quota for a provider has been exceeded. """ + description = "Quota Exceeded" @@ -35,6 +39,7 @@ class AppInvokeQuotaExceededError(Exception): """ Custom exception raised when the quota for an app has been exceeded. """ + description = "App Invoke Quota Exceeded" @@ -42,9 +47,11 @@ class ModelCurrentlyNotSupportError(Exception): """ Custom exception raised when the model not support """ + description = "Model Currently Not Support" class InvokeRateLimitError(Exception): """Raised when the Invoke returns rate limit error.""" + description = "Rate Limit Error" diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py index 4db7a99973..38cebb6b6b 100644 --- a/api/core/extension/api_based_extension_requestor.py +++ b/api/core/extension/api_based_extension_requestor.py @@ -20,10 +20,7 @@ class APIBasedExtensionRequestor: :param params: the request params :return: the response json """ - headers = { - "Content-Type": "application/json", - "Authorization": "Bearer {}".format(self.api_key) - } + headers = {"Content-Type": "application/json", "Authorization": "Bearer {}".format(self.api_key)} url = self.api_endpoint @@ -32,20 +29,17 @@ class APIBasedExtensionRequestor: proxies = None if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL: proxies = { - 'http': dify_config.SSRF_PROXY_HTTP_URL, - 'https': dify_config.SSRF_PROXY_HTTPS_URL, + "http": dify_config.SSRF_PROXY_HTTP_URL, + "https": dify_config.SSRF_PROXY_HTTPS_URL, } response = requests.request( - method='POST', + method="POST", url=url, - json={ - 'point': point.value, - 'params': params - }, + json={"point": point.value, "params": params}, headers=headers, timeout=self.timeout, - proxies=proxies + proxies=proxies, ) except requests.exceptions.Timeout: raise ValueError("request timeout") @@ -53,9 +47,8 @@ class APIBasedExtensionRequestor: raise ValueError("request connection error") if response.status_code != 200: - raise ValueError("request error, status_code: {}, content: {}".format( - response.status_code, - response.text[:100] - )) + raise ValueError( + "request error, status_code: {}, content: {}".format(response.status_code, response.text[:100]) + ) return response.json() diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index 0296126d8b..f1a49c4921 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -11,8 +11,8 @@ from core.helper.position_helper import sort_to_dict_by_position_map class ExtensionModule(enum.Enum): - MODERATION = 'moderation' - EXTERNAL_DATA_TOOL = 'external_data_tool' + MODERATION = "moderation" + EXTERNAL_DATA_TOOL = "external_data_tool" class ModuleExtension(BaseModel): @@ -41,12 +41,12 @@ class Extensible: position_map = {} # get the path of the current class - current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py') + current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py") current_dir_path = os.path.dirname(current_path) # traverse subdirectories for subdir_name in os.listdir(current_dir_path): - if subdir_name.startswith('__'): + if subdir_name.startswith("__"): continue subdir_path = os.path.join(current_dir_path, subdir_name) @@ -58,21 +58,21 @@ class Extensible: # in the front-end page and business logic, there are special treatments. builtin = False position = None - if '__builtin__' in file_names: + if "__builtin__" in file_names: builtin = True - builtin_file_path = os.path.join(subdir_path, '__builtin__') + builtin_file_path = os.path.join(subdir_path, "__builtin__") if os.path.exists(builtin_file_path): - with open(builtin_file_path, encoding='utf-8') as f: + with open(builtin_file_path, encoding="utf-8") as f: position = int(f.read().strip()) - position_map[extension_name] = position + position_map[extension_name] = position - if (extension_name + '.py') not in file_names: + if (extension_name + ".py") not in file_names: logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") continue # Dynamic loading {subdir_name}.py file and find the subclass of Extensible - py_path = os.path.join(subdir_path, extension_name + '.py') + py_path = os.path.join(subdir_path, extension_name + ".py") spec = importlib.util.spec_from_file_location(extension_name, py_path) if not spec or not spec.loader: raise Exception(f"Failed to load module {extension_name} from {py_path}") @@ -91,25 +91,29 @@ class Extensible: json_data = {} if not builtin: - if 'schema.json' not in file_names: + if "schema.json" not in file_names: logging.warning(f"Missing schema.json file in {subdir_path}, Skip.") continue - json_path = os.path.join(subdir_path, 'schema.json') + json_path = os.path.join(subdir_path, "schema.json") json_data = {} if os.path.exists(json_path): - with open(json_path, encoding='utf-8') as f: + with open(json_path, encoding="utf-8") as f: json_data = json.load(f) - extensions.append(ModuleExtension( - extension_class=extension_class, - name=extension_name, - label=json_data.get('label'), - form_schema=json_data.get('form_schema'), - builtin=builtin, - position=position - )) + extensions.append( + ModuleExtension( + extension_class=extension_class, + name=extension_name, + label=json_data.get("label"), + form_schema=json_data.get("form_schema"), + builtin=builtin, + position=position, + ) + ) - sorted_extensions = sort_to_dict_by_position_map(position_map=position_map, data=extensions, name_func=lambda x: x.name) + sorted_extensions = sort_to_dict_by_position_map( + position_map=position_map, data=extensions, name_func=lambda x: x.name + ) return sorted_extensions diff --git a/api/core/extension/extension.py b/api/core/extension/extension.py index 29e892c58a..3da170455e 100644 --- a/api/core/extension/extension.py +++ b/api/core/extension/extension.py @@ -6,10 +6,7 @@ from core.moderation.base import Moderation class Extension: __module_extensions: dict[str, dict[str, ModuleExtension]] = {} - module_classes = { - ExtensionModule.MODERATION: Moderation, - ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool - } + module_classes = {ExtensionModule.MODERATION: Moderation, ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool} def init(self): for module, module_class in self.module_classes.items(): diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index 58c82502ea..54ec97a493 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -30,10 +30,11 @@ class ApiExternalDataTool(ExternalDataTool): raise ValueError("api_based_extension_id is required") # get api_based_extension - api_based_extension = db.session.query(APIBasedExtension).filter( - APIBasedExtension.tenant_id == tenant_id, - APIBasedExtension.id == api_based_extension_id - ).first() + api_based_extension = ( + db.session.query(APIBasedExtension) + .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) + .first() + ) if not api_based_extension: raise ValueError("api_based_extension_id is invalid") @@ -50,47 +51,42 @@ class ApiExternalDataTool(ExternalDataTool): api_based_extension_id = self.config.get("api_based_extension_id") # get api_based_extension - api_based_extension = db.session.query(APIBasedExtension).filter( - APIBasedExtension.tenant_id == self.tenant_id, - APIBasedExtension.id == api_based_extension_id - ).first() + api_based_extension = ( + db.session.query(APIBasedExtension) + .filter(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id) + .first() + ) if not api_based_extension: - raise ValueError("[External data tool] API query failed, variable: {}, " - "error: api_based_extension_id is invalid" - .format(self.variable)) + raise ValueError( + "[External data tool] API query failed, variable: {}, " + "error: api_based_extension_id is invalid".format(self.variable) + ) # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=self.tenant_id, - token=api_based_extension.api_key - ) + api_key = encrypter.decrypt_token(tenant_id=self.tenant_id, token=api_based_extension.api_key) try: # request api - requestor = APIBasedExtensionRequestor( - api_endpoint=api_based_extension.api_endpoint, - api_key=api_key - ) + requestor = APIBasedExtensionRequestor(api_endpoint=api_based_extension.api_endpoint, api_key=api_key) except Exception as e: - raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format( - self.variable, - e - )) + raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(self.variable, e)) - response_json = requestor.request(point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, params={ - 'app_id': self.app_id, - 'tool_variable': self.variable, - 'inputs': inputs, - 'query': query - }) + response_json = requestor.request( + point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, + params={"app_id": self.app_id, "tool_variable": self.variable, "inputs": inputs, "query": query}, + ) - if 'result' not in response_json: - raise ValueError("[External data tool] API query failed, variable: {}, error: result not found in response" - .format(self.variable)) + if "result" not in response_json: + raise ValueError( + "[External data tool] API query failed, variable: {}, error: result not found in response".format( + self.variable + ) + ) - if not isinstance(response_json['result'], str): - raise ValueError("[External data tool] API query failed, variable: {}, error: result is not string" - .format(self.variable)) + if not isinstance(response_json["result"], str): + raise ValueError( + "[External data tool] API query failed, variable: {}, error: result is not string".format(self.variable) + ) - return response_json['result'] + return response_json["result"] diff --git a/api/core/external_data_tool/external_data_fetch.py b/api/core/external_data_tool/external_data_fetch.py index 8601cb34e7..84b94e117f 100644 --- a/api/core/external_data_tool/external_data_fetch.py +++ b/api/core/external_data_tool/external_data_fetch.py @@ -12,11 +12,14 @@ logger = logging.getLogger(__name__) class ExternalDataFetch: - def fetch(self, tenant_id: str, - app_id: str, - external_data_tools: list[ExternalDataVariableEntity], - inputs: dict, - query: str) -> dict: + def fetch( + self, + tenant_id: str, + app_id: str, + external_data_tools: list[ExternalDataVariableEntity], + inputs: dict, + query: str, + ) -> dict: """ Fill in variable inputs from external data tools if exists. @@ -38,7 +41,7 @@ class ExternalDataFetch: app_id, tool, inputs, - query + query, ) futures[future] = tool @@ -50,12 +53,15 @@ class ExternalDataFetch: inputs.update(results) return inputs - def _query_external_data_tool(self, flask_app: Flask, - tenant_id: str, - app_id: str, - external_data_tool: ExternalDataVariableEntity, - inputs: dict, - query: str) -> tuple[Optional[str], Optional[str]]: + def _query_external_data_tool( + self, + flask_app: Flask, + tenant_id: str, + app_id: str, + external_data_tool: ExternalDataVariableEntity, + inputs: dict, + query: str, + ) -> tuple[Optional[str], Optional[str]]: """ Query external data tool. :param flask_app: flask app @@ -72,17 +78,10 @@ class ExternalDataFetch: tool_config = external_data_tool.config external_data_tool_factory = ExternalDataToolFactory( - name=tool_type, - tenant_id=tenant_id, - app_id=app_id, - variable=tool_variable, - config=tool_config + name=tool_type, tenant_id=tenant_id, app_id=app_id, variable=tool_variable, config=tool_config ) # query external data tool - result = external_data_tool_factory.query( - inputs=inputs, - query=query - ) + result = external_data_tool_factory.query(inputs=inputs, query=query) return tool_variable, result diff --git a/api/core/external_data_tool/factory.py b/api/core/external_data_tool/factory.py index 979f243af6..2872109859 100644 --- a/api/core/external_data_tool/factory.py +++ b/api/core/external_data_tool/factory.py @@ -5,14 +5,10 @@ from extensions.ext_code_based_extension import code_based_extension class ExternalDataToolFactory: - def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None: extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) self.__extension_instance = extension_class( - tenant_id=tenant_id, - app_id=app_id, - variable=variable, - config=config + tenant_id=tenant_id, app_id=app_id, variable=variable, config=config ) @classmethod diff --git a/api/core/file/file_obj.py b/api/core/file/file_obj.py index 3959f4b4a0..5c4e694025 100644 --- a/api/core/file/file_obj.py +++ b/api/core/file/file_obj.py @@ -13,11 +13,12 @@ class FileExtraConfig(BaseModel): """ File Upload Entity. """ + image_config: Optional[dict[str, Any]] = None class FileType(enum.Enum): - IMAGE = 'image' + IMAGE = "image" @staticmethod def value_of(value): @@ -28,9 +29,9 @@ class FileType(enum.Enum): class FileTransferMethod(enum.Enum): - REMOTE_URL = 'remote_url' - LOCAL_FILE = 'local_file' - TOOL_FILE = 'tool_file' + REMOTE_URL = "remote_url" + LOCAL_FILE = "local_file" + TOOL_FILE = "tool_file" @staticmethod def value_of(value): @@ -39,9 +40,10 @@ class FileTransferMethod(enum.Enum): return member raise ValueError(f"No matching enum found for value '{value}'") + class FileBelongsTo(enum.Enum): - USER = 'user' - ASSISTANT = 'assistant' + USER = "user" + ASSISTANT = "assistant" @staticmethod def value_of(value): @@ -65,16 +67,16 @@ class FileVar(BaseModel): def to_dict(self) -> dict: return { - '__variant': self.__class__.__name__, - 'tenant_id': self.tenant_id, - 'type': self.type.value, - 'transfer_method': self.transfer_method.value, - 'url': self.preview_url, - 'remote_url': self.url, - 'related_id': self.related_id, - 'filename': self.filename, - 'extension': self.extension, - 'mime_type': self.mime_type, + "__variant": self.__class__.__name__, + "tenant_id": self.tenant_id, + "type": self.type.value, + "transfer_method": self.transfer_method.value, + "url": self.preview_url, + "remote_url": self.url, + "related_id": self.related_id, + "filename": self.filename, + "extension": self.extension, + "mime_type": self.mime_type, } def to_markdown(self) -> str: @@ -86,7 +88,7 @@ class FileVar(BaseModel): if self.type == FileType.IMAGE: text = f'![{self.filename or ""}]({preview_url})' else: - text = f'[{self.filename or preview_url}]({preview_url})' + text = f"[{self.filename or preview_url}]({preview_url})" return text @@ -115,28 +117,29 @@ class FileVar(BaseModel): return ImagePromptMessageContent( data=self.data, detail=ImagePromptMessageContent.DETAIL.HIGH - if image_config.get("detail") == "high" else ImagePromptMessageContent.DETAIL.LOW + if image_config.get("detail") == "high" + else ImagePromptMessageContent.DETAIL.LOW, ) def _get_data(self, force_url: bool = False) -> Optional[str]: from models.model import UploadFile + if self.type == FileType.IMAGE: if self.transfer_method == FileTransferMethod.REMOTE_URL: return self.url elif self.transfer_method == FileTransferMethod.LOCAL_FILE: - upload_file = (db.session.query(UploadFile) - .filter( - UploadFile.id == self.related_id, - UploadFile.tenant_id == self.tenant_id - ).first()) - - return UploadFileParser.get_image_data( - upload_file=upload_file, - force_url=force_url + upload_file = ( + db.session.query(UploadFile) + .filter(UploadFile.id == self.related_id, UploadFile.tenant_id == self.tenant_id) + .first() ) + + return UploadFileParser.get_image_data(upload_file=upload_file, force_url=force_url) elif self.transfer_method == FileTransferMethod.TOOL_FILE: extension = self.extension # add sign url - return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=self.related_id, extension=extension) + return ToolFileParser.get_tool_file_manager().sign_file( + tool_file_id=self.related_id, extension=extension + ) return None diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py index 085ff07cfd..8feaabedbb 100644 --- a/api/core/file/message_file_parser.py +++ b/api/core/file/message_file_parser.py @@ -13,13 +13,13 @@ from services.file_service import IMAGE_EXTENSIONS class MessageFileParser: - def __init__(self, tenant_id: str, app_id: str) -> None: self.tenant_id = tenant_id self.app_id = app_id - def validate_and_transform_files_arg(self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, - user: Union[Account, EndUser]) -> list[FileVar]: + def validate_and_transform_files_arg( + self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, user: Union[Account, EndUser] + ) -> list[FileVar]: """ validate and transform files arg @@ -30,22 +30,22 @@ class MessageFileParser: """ for file in files: if not isinstance(file, dict): - raise ValueError('Invalid file format, must be dict') - if not file.get('type'): - raise ValueError('Missing file type') - FileType.value_of(file.get('type')) - if not file.get('transfer_method'): - raise ValueError('Missing file transfer method') - FileTransferMethod.value_of(file.get('transfer_method')) - if file.get('transfer_method') == FileTransferMethod.REMOTE_URL.value: - if not file.get('url'): - raise ValueError('Missing file url') - if not file.get('url').startswith('http'): - raise ValueError('Invalid file url') - if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'): - raise ValueError('Missing file upload_file_id') - if file.get('transform_method') == FileTransferMethod.TOOL_FILE.value and not file.get('tool_file_id'): - raise ValueError('Missing file tool_file_id') + raise ValueError("Invalid file format, must be dict") + if not file.get("type"): + raise ValueError("Missing file type") + FileType.value_of(file.get("type")) + if not file.get("transfer_method"): + raise ValueError("Missing file transfer method") + FileTransferMethod.value_of(file.get("transfer_method")) + if file.get("transfer_method") == FileTransferMethod.REMOTE_URL.value: + if not file.get("url"): + raise ValueError("Missing file url") + if not file.get("url").startswith("http"): + raise ValueError("Invalid file url") + if file.get("transfer_method") == FileTransferMethod.LOCAL_FILE.value and not file.get("upload_file_id"): + raise ValueError("Missing file upload_file_id") + if file.get("transform_method") == FileTransferMethod.TOOL_FILE.value and not file.get("tool_file_id"): + raise ValueError("Missing file tool_file_id") # transform files to file objs type_file_objs = self._to_file_objs(files, file_extra_config) @@ -62,17 +62,17 @@ class MessageFileParser: continue # Validate number of files - if len(files) > image_config['number_limits']: + if len(files) > image_config["number_limits"]: raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}") for file_obj in file_objs: # Validate transfer method - if file_obj.transfer_method.value not in image_config['transfer_methods']: - raise ValueError(f'Invalid transfer method: {file_obj.transfer_method.value}') + if file_obj.transfer_method.value not in image_config["transfer_methods"]: + raise ValueError(f"Invalid transfer method: {file_obj.transfer_method.value}") # Validate file type if file_obj.type != FileType.IMAGE: - raise ValueError(f'Invalid file type: {file_obj.type}') + raise ValueError(f"Invalid file type: {file_obj.type}") if file_obj.transfer_method == FileTransferMethod.REMOTE_URL: # check remote url valid and is image @@ -81,18 +81,21 @@ class MessageFileParser: raise ValueError(error) elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE: # get upload file from upload_file_id - upload_file = (db.session.query(UploadFile) - .filter( - UploadFile.id == file_obj.related_id, - UploadFile.tenant_id == self.tenant_id, - UploadFile.created_by == user.id, - UploadFile.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - UploadFile.extension.in_(IMAGE_EXTENSIONS) - ).first()) + upload_file = ( + db.session.query(UploadFile) + .filter( + UploadFile.id == file_obj.related_id, + UploadFile.tenant_id == self.tenant_id, + UploadFile.created_by == user.id, + UploadFile.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + UploadFile.extension.in_(IMAGE_EXTENSIONS), + ) + .first() + ) # check upload file is belong to tenant and user if not upload_file: - raise ValueError('Invalid upload file') + raise ValueError("Invalid upload file") new_files.append(file_obj) @@ -113,8 +116,9 @@ class MessageFileParser: # return all file objs return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs] - def _to_file_objs(self, files: list[Union[dict, MessageFile]], - file_extra_config: FileExtraConfig) -> dict[FileType, list[FileVar]]: + def _to_file_objs( + self, files: list[Union[dict, MessageFile]], file_extra_config: FileExtraConfig + ) -> dict[FileType, list[FileVar]]: """ transform files to file objs @@ -152,23 +156,23 @@ class MessageFileParser: :return: """ if isinstance(file, dict): - transfer_method = FileTransferMethod.value_of(file.get('transfer_method')) + transfer_method = FileTransferMethod.value_of(file.get("transfer_method")) if transfer_method != FileTransferMethod.TOOL_FILE: return FileVar( tenant_id=self.tenant_id, - type=FileType.value_of(file.get('type')), + type=FileType.value_of(file.get("type")), transfer_method=transfer_method, - url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, - related_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, - extra_config=file_extra_config + url=file.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None, + related_id=file.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None, + extra_config=file_extra_config, ) return FileVar( tenant_id=self.tenant_id, - type=FileType.value_of(file.get('type')), + type=FileType.value_of(file.get("type")), transfer_method=transfer_method, url=None, - related_id=file.get('tool_file_id'), - extra_config=file_extra_config + related_id=file.get("tool_file_id"), + extra_config=file_extra_config, ) else: return FileVar( @@ -178,7 +182,7 @@ class MessageFileParser: transfer_method=FileTransferMethod.value_of(file.transfer_method), url=file.url, related_id=file.upload_file_id or None, - extra_config=file_extra_config + extra_config=file_extra_config, ) def _check_image_remote_url(self, url): @@ -190,17 +194,17 @@ class MessageFileParser: def is_s3_presigned_url(url): try: parsed_url = urlparse(url) - if 'amazonaws.com' not in parsed_url.netloc: + if "amazonaws.com" not in parsed_url.netloc: return False query_params = parse_qs(parsed_url.query) - required_params = ['Signature', 'Expires'] + required_params = ["Signature", "Expires"] for param in required_params: if param not in query_params: return False - if not query_params['Expires'][0].isdigit(): + if not query_params["Expires"][0].isdigit(): return False - signature = query_params['Signature'][0] - if not re.match(r'^[A-Za-z0-9+/]+={0,2}$', signature): + signature = query_params["Signature"][0] + if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature): return False return True except Exception: diff --git a/api/core/file/tool_file_parser.py b/api/core/file/tool_file_parser.py index ea8605ac57..1efaf5529d 100644 --- a/api/core/file/tool_file_parser.py +++ b/api/core/file/tool_file_parser.py @@ -1,8 +1,7 @@ -tool_file_manager = { - 'manager': None -} +tool_file_manager = {"manager": None} + class ToolFileParser: @staticmethod - def get_tool_file_manager() -> 'ToolFileManager': - return tool_file_manager['manager'] \ No newline at end of file + def get_tool_file_manager() -> "ToolFileManager": + return tool_file_manager["manager"] diff --git a/api/core/file/upload_file_parser.py b/api/core/file/upload_file_parser.py index 737a11e426..a8c1fd4d02 100644 --- a/api/core/file/upload_file_parser.py +++ b/api/core/file/upload_file_parser.py @@ -9,7 +9,7 @@ from typing import Optional from configs import dify_config from extensions.ext_storage import storage -IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg'] +IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) @@ -22,18 +22,18 @@ class UploadFileParser: if upload_file.extension not in IMAGE_EXTENSIONS: return None - if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == 'url' or force_url: + if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url: return cls.get_signed_temp_image_url(upload_file.id) else: # get image file base64 try: data = storage.load(upload_file.key) except FileNotFoundError: - logging.error(f'File not found: {upload_file.key}') + logging.error(f"File not found: {upload_file.key}") return None - encoded_string = base64.b64encode(data).decode('utf-8') - return f'data:{upload_file.mime_type};base64,{encoded_string}' + encoded_string = base64.b64encode(data).decode("utf-8") + return f"data:{upload_file.mime_type};base64,{encoded_string}" @classmethod def get_signed_temp_image_url(cls, upload_file_id) -> str: @@ -44,7 +44,7 @@ class UploadFileParser: :return: """ base_url = dify_config.FILES_URL - image_preview_url = f'{base_url}/files/{upload_file_id}/image-preview' + image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview" timestamp = str(int(time.time())) nonce = os.urandom(16).hex() diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 4662ebb47a..7ee6e63817 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -15,9 +15,11 @@ from core.helper.code_executor.template_transformer import TemplateTransformer logger = logging.getLogger(__name__) -class CodeExecutionException(Exception): + +class CodeExecutionError(Exception): pass + class CodeExecutionResponse(BaseModel): class Data(BaseModel): stdout: Optional[str] = None @@ -29,9 +31,9 @@ class CodeExecutionResponse(BaseModel): class CodeLanguage(str, Enum): - PYTHON3 = 'python3' - JINJA2 = 'jinja2' - JAVASCRIPT = 'javascript' + PYTHON3 = "python3" + JINJA2 = "jinja2" + JAVASCRIPT = "javascript" class CodeExecutor: @@ -45,71 +47,73 @@ class CodeExecutor: } code_language_to_running_language = { - CodeLanguage.JAVASCRIPT: 'nodejs', + CodeLanguage.JAVASCRIPT: "nodejs", CodeLanguage.JINJA2: CodeLanguage.PYTHON3, CodeLanguage.PYTHON3: CodeLanguage.PYTHON3, } - supported_dependencies_languages: set[CodeLanguage] = { - CodeLanguage.PYTHON3 - } + supported_dependencies_languages: set[CodeLanguage] = {CodeLanguage.PYTHON3} @classmethod - def execute_code(cls, - language: CodeLanguage, - preload: str, - code: str) -> str: + def execute_code(cls, language: CodeLanguage, preload: str, code: str) -> str: """ Execute code :param language: code language :param code: code :return: """ - url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / 'v1' / 'sandbox' / 'run' + url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / "v1" / "sandbox" / "run" - headers = { - 'X-Api-Key': dify_config.CODE_EXECUTION_API_KEY - } + headers = {"X-Api-Key": dify_config.CODE_EXECUTION_API_KEY} data = { - 'language': cls.code_language_to_running_language.get(language), - 'code': code, - 'preload': preload, - 'enable_network': True + "language": cls.code_language_to_running_language.get(language), + "code": code, + "preload": preload, + "enable_network": True, } try: - response = post(str(url), json=data, headers=headers, - timeout=Timeout( - connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT, - read=dify_config.CODE_EXECUTION_READ_TIMEOUT, - write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT, - pool=None)) + response = post( + str(url), + json=data, + headers=headers, + timeout=Timeout( + connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT, + read=dify_config.CODE_EXECUTION_READ_TIMEOUT, + write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT, + pool=None, + ), + ) if response.status_code == 503: - raise CodeExecutionException('Code execution service is unavailable') + raise CodeExecutionError("Code execution service is unavailable") elif response.status_code != 200: - raise Exception(f'Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running') - except CodeExecutionException as e: + raise Exception( + f"Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running" + ) + except CodeExecutionError as e: raise e except Exception as e: - raise CodeExecutionException('Failed to execute code, which is likely a network issue,' - ' please check if the sandbox service is running.' - f' ( Error: {str(e)} )') + raise CodeExecutionError( + "Failed to execute code, which is likely a network issue," + " please check if the sandbox service is running." + f" ( Error: {str(e)} )" + ) try: response = response.json() except: - raise CodeExecutionException('Failed to parse response') + raise CodeExecutionError("Failed to parse response") - if (code := response.get('code')) != 0: - raise CodeExecutionException(f"Got error code: {code}. Got error msg: {response.get('message')}") + if (code := response.get("code")) != 0: + raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response.get('message')}") response = CodeExecutionResponse(**response) if response.data.error: - raise CodeExecutionException(response.data.error) + raise CodeExecutionError(response.data.error) - return response.data.stdout or '' + return response.data.stdout or "" @classmethod def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict) -> dict: @@ -122,13 +126,13 @@ class CodeExecutor: """ template_transformer = cls.code_template_transformers.get(language) if not template_transformer: - raise CodeExecutionException(f'Unsupported language {language}') + raise CodeExecutionError(f"Unsupported language {language}") runner, preload = template_transformer.transform_caller(code, inputs) try: response = cls.execute_code(language, preload, runner) - except CodeExecutionException as e: + except CodeExecutionError as e: raise e return template_transformer.transform_response(response) diff --git a/api/core/helper/code_executor/code_node_provider.py b/api/core/helper/code_executor/code_node_provider.py index 3f099b7ac5..e233a596b9 100644 --- a/api/core/helper/code_executor/code_node_provider.py +++ b/api/core/helper/code_executor/code_node_provider.py @@ -26,23 +26,9 @@ class CodeNodeProvider(BaseModel): return { "type": "code", "config": { - "variables": [ - { - "variable": "arg1", - "value_selector": [] - }, - { - "variable": "arg2", - "value_selector": [] - } - ], + "variables": [{"variable": "arg1", "value_selector": []}, {"variable": "arg2", "value_selector": []}], "code_language": cls.get_language(), "code": cls.get_default_code(), - "outputs": { - "result": { - "type": "string", - "children": None - } - } - } + "outputs": {"result": {"type": "string", "children": None}}, + }, } diff --git a/api/core/helper/code_executor/javascript/javascript_code_provider.py b/api/core/helper/code_executor/javascript/javascript_code_provider.py index a157fcc6d1..ae324b83a9 100644 --- a/api/core/helper/code_executor/javascript/javascript_code_provider.py +++ b/api/core/helper/code_executor/javascript/javascript_code_provider.py @@ -18,4 +18,5 @@ class JavascriptCodeProvider(CodeNodeProvider): result: arg1 + arg2 } } - """) + """ + ) diff --git a/api/core/helper/code_executor/javascript/javascript_transformer.py b/api/core/helper/code_executor/javascript/javascript_transformer.py index a4d2551972..d67a0903aa 100644 --- a/api/core/helper/code_executor/javascript/javascript_transformer.py +++ b/api/core/helper/code_executor/javascript/javascript_transformer.py @@ -21,5 +21,6 @@ class NodeJsTemplateTransformer(TemplateTransformer): var output_json = JSON.stringify(output_obj) var result = `<>${{output_json}}<>` console.log(result) - """) + """ + ) return runner_script diff --git a/api/core/helper/code_executor/jinja2/jinja2_formatter.py b/api/core/helper/code_executor/jinja2/jinja2_formatter.py index f1e5da584c..db2eb5ebb6 100644 --- a/api/core/helper/code_executor/jinja2/jinja2_formatter.py +++ b/api/core/helper/code_executor/jinja2/jinja2_formatter.py @@ -10,8 +10,6 @@ class Jinja2Formatter: :param inputs: inputs :return: """ - result = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, code=template, inputs=inputs - ) + result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs) - return result['result'] + return result["result"] diff --git a/api/core/helper/code_executor/jinja2/jinja2_transformer.py b/api/core/helper/code_executor/jinja2/jinja2_transformer.py index b8cb29600e..63d58edbc7 100644 --- a/api/core/helper/code_executor/jinja2/jinja2_transformer.py +++ b/api/core/helper/code_executor/jinja2/jinja2_transformer.py @@ -11,9 +11,7 @@ class Jinja2TemplateTransformer(TemplateTransformer): :param response: response :return: """ - return { - 'result': cls.extract_result_str_from_response(response) - } + return {"result": cls.extract_result_str_from_response(response)} @classmethod def get_runner_script(cls) -> str: diff --git a/api/core/helper/code_executor/python3/python3_code_provider.py b/api/core/helper/code_executor/python3/python3_code_provider.py index 923724b49d..9cca8af7c6 100644 --- a/api/core/helper/code_executor/python3/python3_code_provider.py +++ b/api/core/helper/code_executor/python3/python3_code_provider.py @@ -17,4 +17,5 @@ class Python3CodeProvider(CodeNodeProvider): return { "result": arg1 + arg2, } - """) + """ + ) diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index cf66558b65..6f016f27bc 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -5,9 +5,9 @@ from base64 import b64encode class TemplateTransformer(ABC): - _code_placeholder: str = '{{code}}' - _inputs_placeholder: str = '{{inputs}}' - _result_tag: str = '<>' + _code_placeholder: str = "{{code}}" + _inputs_placeholder: str = "{{inputs}}" + _result_tag: str = "<>" @classmethod def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]: @@ -24,9 +24,9 @@ class TemplateTransformer(ABC): @classmethod def extract_result_str_from_response(cls, response: str) -> str: - result = re.search(rf'{cls._result_tag}(.*){cls._result_tag}', response, re.DOTALL) + result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL) if not result: - raise ValueError('Failed to parse result') + raise ValueError("Failed to parse result") result = result.group(1) return result @@ -50,7 +50,7 @@ class TemplateTransformer(ABC): @classmethod def serialize_inputs(cls, inputs: dict) -> str: inputs_json_str = json.dumps(inputs, ensure_ascii=False).encode() - input_base64_encoded = b64encode(inputs_json_str).decode('utf-8') + input_base64_encoded = b64encode(inputs_json_str).decode("utf-8") return input_base64_encoded @classmethod @@ -67,4 +67,4 @@ class TemplateTransformer(ABC): """ Get preload script """ - return '' + return "" diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py index 5e5deb86b4..96341a1b78 100644 --- a/api/core/helper/encrypter.py +++ b/api/core/helper/encrypter.py @@ -8,14 +8,15 @@ def obfuscated_token(token: str): if not token: return token if len(token) <= 8: - return '*' * 20 - return token[:6] + '*' * 12 + token[-2:] + return "*" * 20 + return token[:6] + "*" * 12 + token[-2:] def encrypt_token(tenant_id: str, token: str): from models.account import Tenant + if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()): - raise ValueError(f'Tenant with id {tenant_id} not found') + raise ValueError(f"Tenant with id {tenant_id} not found") encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) return base64.b64encode(encrypted_token).decode() diff --git a/api/core/helper/model_provider_cache.py b/api/core/helper/model_provider_cache.py index 29cb4acc7d..5e274f8916 100644 --- a/api/core/helper/model_provider_cache.py +++ b/api/core/helper/model_provider_cache.py @@ -25,7 +25,7 @@ class ProviderCredentialsCache: cached_provider_credentials = redis_client.get(self.cache_key) if cached_provider_credentials: try: - cached_provider_credentials = cached_provider_credentials.decode('utf-8') + cached_provider_credentials = cached_provider_credentials.decode("utf-8") cached_provider_credentials = json.loads(cached_provider_credentials) except JSONDecodeError: return None diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index 20feae8554..b880590de2 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -12,19 +12,20 @@ logger = logging.getLogger(__name__) def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) -> bool: moderation_config = hosting_configuration.moderation_config - if (moderation_config and moderation_config.enabled is True - and 'openai' in hosting_configuration.provider_map - and hosting_configuration.provider_map['openai'].enabled is True + if ( + moderation_config + and moderation_config.enabled is True + and "openai" in hosting_configuration.provider_map + and hosting_configuration.provider_map["openai"].enabled is True ): using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type provider_name = model_config.provider - if using_provider_type == ProviderType.SYSTEM \ - and provider_name in moderation_config.providers: - hosting_openai_config = hosting_configuration.provider_map['openai'] + if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers: + hosting_openai_config = hosting_configuration.provider_map["openai"] # 2000 text per chunk length = 2000 - text_chunks = [text[i:i + length] for i in range(0, len(text), length)] + text_chunks = [text[i : i + length] for i in range(0, len(text), length)] if len(text_chunks) == 0: return True @@ -34,15 +35,13 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) try: model_type_instance = OpenAIModerationModel() moderation_result = model_type_instance.invoke( - model='text-moderation-stable', - credentials=hosting_openai_config.credentials, - text=text_chunk + model="text-moderation-stable", credentials=hosting_openai_config.credentials, text=text_chunk ) if moderation_result is True: return True except Exception as ex: logger.exception(ex) - raise InvokeBadRequestError('Rate limit exceeded, please try again later.') + raise InvokeBadRequestError("Rate limit exceeded, please try again later.") return False diff --git a/api/core/helper/module_import_helper.py b/api/core/helper/module_import_helper.py index 2000577a40..e6e1491548 100644 --- a/api/core/helper/module_import_helper.py +++ b/api/core/helper/module_import_helper.py @@ -37,8 +37,9 @@ def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type] """ Get all the subclasses of the parent type from the module """ - classes = [x for _, x in vars(mod).items() - if isinstance(x, type) and x != parent_type and issubclass(x, parent_type)] + classes = [ + x for _, x in vars(mod).items() if isinstance(x, type) and x != parent_type and issubclass(x, parent_type) + ] return classes @@ -56,6 +57,6 @@ def load_single_subclass_from_source( case 1: return subclasses[0] case 0: - raise Exception(f'Missing subclass of {parent_type.__name__} in {script_path}') + raise Exception(f"Missing subclass of {parent_type.__name__} in {script_path}") case _: - raise Exception(f'Multiple subclasses of {parent_type.__name__} in {script_path}') \ No newline at end of file + raise Exception(f"Multiple subclasses of {parent_type.__name__} in {script_path}") diff --git a/api/core/helper/position_helper.py b/api/core/helper/position_helper.py index 8cf184ac44..3efdc8aa47 100644 --- a/api/core/helper/position_helper.py +++ b/api/core/helper/position_helper.py @@ -73,13 +73,13 @@ def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) def is_filtered( - include_set: set[str], - exclude_set: set[str], - data: Any, - name_func: Callable[[Any], str], + include_set: set[str], + exclude_set: set[str], + data: Any, + name_func: Callable[[Any], str], ) -> bool: """ - Chcek if the object should be filtered out. + Check if the object should be filtered out. Overall logic: exclude > include > pin :param include_set: the set of names to be included :param exclude_set: the set of names to be excluded @@ -102,9 +102,9 @@ def is_filtered( def sort_by_position_map( - position_map: dict[str, int], - data: list[Any], - name_func: Callable[[Any], str], + position_map: dict[str, int], + data: list[Any], + name_func: Callable[[Any], str], ) -> list[Any]: """ Sort the objects by the position map. @@ -117,13 +117,13 @@ def sort_by_position_map( if not position_map or not data: return data - return sorted(data, key=lambda x: position_map.get(name_func(x), float('inf'))) + return sorted(data, key=lambda x: position_map.get(name_func(x), float("inf"))) def sort_to_dict_by_position_map( - position_map: dict[str, int], - data: list[Any], - name_func: Callable[[Any], str], + position_map: dict[str, int], + data: list[Any], + name_func: Callable[[Any], str], ) -> OrderedDict[str, Any]: """ Sort the objects into a ordered dict by the position map. diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 14ca8e943c..4e6d58904e 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -1,31 +1,34 @@ """ Proxy requests to avoid SSRF """ + import logging import os import time import httpx -SSRF_PROXY_ALL_URL = os.getenv('SSRF_PROXY_ALL_URL', '') -SSRF_PROXY_HTTP_URL = os.getenv('SSRF_PROXY_HTTP_URL', '') -SSRF_PROXY_HTTPS_URL = os.getenv('SSRF_PROXY_HTTPS_URL', '') -SSRF_DEFAULT_MAX_RETRIES = int(os.getenv('SSRF_DEFAULT_MAX_RETRIES', '3')) +SSRF_PROXY_ALL_URL = os.getenv("SSRF_PROXY_ALL_URL", "") +SSRF_PROXY_HTTP_URL = os.getenv("SSRF_PROXY_HTTP_URL", "") +SSRF_PROXY_HTTPS_URL = os.getenv("SSRF_PROXY_HTTPS_URL", "") +SSRF_DEFAULT_MAX_RETRIES = int(os.getenv("SSRF_DEFAULT_MAX_RETRIES", "3")) -proxies = { - 'http://': SSRF_PROXY_HTTP_URL, - 'https://': SSRF_PROXY_HTTPS_URL -} if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None +proxies = ( + {"http://": SSRF_PROXY_HTTP_URL, "https://": SSRF_PROXY_HTTPS_URL} + if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL + else None +) BACKOFF_FACTOR = 0.5 STATUS_FORCELIST = [429, 500, 502, 503, 504] + def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): if "allow_redirects" in kwargs: allow_redirects = kwargs.pop("allow_redirects") if "follow_redirects" not in kwargs: kwargs["follow_redirects"] = allow_redirects - + retries = 0 while retries <= max_retries: try: @@ -52,24 +55,24 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('GET', url, max_retries=max_retries, **kwargs) + return make_request("GET", url, max_retries=max_retries, **kwargs) def post(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('POST', url, max_retries=max_retries, **kwargs) + return make_request("POST", url, max_retries=max_retries, **kwargs) def put(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('PUT', url, max_retries=max_retries, **kwargs) + return make_request("PUT", url, max_retries=max_retries, **kwargs) def patch(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('PATCH', url, max_retries=max_retries, **kwargs) + return make_request("PATCH", url, max_retries=max_retries, **kwargs) def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('DELETE', url, max_retries=max_retries, **kwargs) + return make_request("DELETE", url, max_retries=max_retries, **kwargs) def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('HEAD', url, max_retries=max_retries, **kwargs) + return make_request("HEAD", url, max_retries=max_retries, **kwargs) diff --git a/api/core/helper/tool_parameter_cache.py b/api/core/helper/tool_parameter_cache.py index a6f486e81d..4c3b736186 100644 --- a/api/core/helper/tool_parameter_cache.py +++ b/api/core/helper/tool_parameter_cache.py @@ -9,14 +9,11 @@ from extensions.ext_redis import redis_client class ToolParameterCacheType(Enum): PARAMETER = "tool_parameter" + class ToolParameterCache: - def __init__(self, - tenant_id: str, - provider: str, - tool_name: str, - cache_type: ToolParameterCacheType, - identity_id: str - ): + def __init__( + self, tenant_id: str, provider: str, tool_name: str, cache_type: ToolParameterCacheType, identity_id: str + ): self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}:identity_id:{identity_id}" def get(self) -> Optional[dict]: @@ -28,7 +25,7 @@ class ToolParameterCache: cached_tool_parameter = redis_client.get(self.cache_key) if cached_tool_parameter: try: - cached_tool_parameter = cached_tool_parameter.decode('utf-8') + cached_tool_parameter = cached_tool_parameter.decode("utf-8") cached_tool_parameter = json.loads(cached_tool_parameter) except JSONDecodeError: return None @@ -52,4 +49,4 @@ class ToolParameterCache: :return: """ - redis_client.delete(self.cache_key) \ No newline at end of file + redis_client.delete(self.cache_key) diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py index 6c5d3b8fb6..94b02cf985 100644 --- a/api/core/helper/tool_provider_cache.py +++ b/api/core/helper/tool_provider_cache.py @@ -9,6 +9,7 @@ from extensions.ext_redis import redis_client class ToolProviderCredentialsCacheType(Enum): PROVIDER = "tool_provider" + class ToolProviderCredentialsCache: def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType): self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}" @@ -22,7 +23,7 @@ class ToolProviderCredentialsCache: cached_provider_credentials = redis_client.get(self.cache_key) if cached_provider_credentials: try: - cached_provider_credentials = cached_provider_credentials.decode('utf-8') + cached_provider_credentials = cached_provider_credentials.decode("utf-8") cached_provider_credentials = json.loads(cached_provider_credentials) except JSONDecodeError: return None @@ -46,4 +47,4 @@ class ToolProviderCredentialsCache: :return: """ - redis_client.delete(self.cache_key) \ No newline at end of file + redis_client.delete(self.cache_key) diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index ddcd751286..eeeccc2349 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -46,7 +46,7 @@ class HostingConfiguration: def init_app(self, app: Flask) -> None: config = app.config - if config.get('EDITION') != 'CLOUD': + if config.get("EDITION") != "CLOUD": return self.provider_map["azure_openai"] = self.init_azure_openai(config) @@ -65,7 +65,7 @@ class HostingConfiguration: credentials = { "openai_api_key": app_config.get("HOSTED_AZURE_OPENAI_API_KEY"), "openai_api_base": app_config.get("HOSTED_AZURE_OPENAI_API_BASE"), - "base_model_name": "gpt-35-turbo" + "base_model_name": "gpt-35-turbo", } quotas = [] @@ -77,26 +77,45 @@ class HostingConfiguration: RestrictModel(model="gpt-4o", base_model_name="gpt-4o", model_type=ModelType.LLM), RestrictModel(model="gpt-4o-mini", base_model_name="gpt-4o-mini", model_type=ModelType.LLM), RestrictModel(model="gpt-4-32k", base_model_name="gpt-4-32k", model_type=ModelType.LLM), - RestrictModel(model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM), - RestrictModel(model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM), + RestrictModel( + model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM + ), + RestrictModel( + model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM + ), RestrictModel(model="gpt-35-turbo", base_model_name="gpt-35-turbo", model_type=ModelType.LLM), - RestrictModel(model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM), - RestrictModel(model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM), - RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM), - RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM), - RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING), - RestrictModel(model="text-embedding-3-small", base_model_name="text-embedding-3-small", model_type=ModelType.TEXT_EMBEDDING), - RestrictModel(model="text-embedding-3-large", base_model_name="text-embedding-3-large", model_type=ModelType.TEXT_EMBEDDING), - ] + RestrictModel( + model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM + ), + RestrictModel( + model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM + ), + RestrictModel( + model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM + ), + RestrictModel( + model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM + ), + RestrictModel( + model="text-embedding-ada-002", + base_model_name="text-embedding-ada-002", + model_type=ModelType.TEXT_EMBEDDING, + ), + RestrictModel( + model="text-embedding-3-small", + base_model_name="text-embedding-3-small", + model_type=ModelType.TEXT_EMBEDDING, + ), + RestrictModel( + model="text-embedding-3-large", + base_model_name="text-embedding-3-large", + model_type=ModelType.TEXT_EMBEDDING, + ), + ], ) quotas.append(trial_quota) - return HostingProvider( - enabled=True, - credentials=credentials, - quota_unit=quota_unit, - quotas=quotas - ) + return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) return HostingProvider( enabled=False, @@ -110,17 +129,12 @@ class HostingConfiguration: if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"): hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200")) trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS") - trial_quota = TrialHostingQuota( - quota_limit=hosted_quota_limit, - restrict_models=trial_models - ) + trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models) quotas.append(trial_quota) if app_config.get("HOSTED_OPENAI_PAID_ENABLED"): paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS") - paid_quota = PaidHostingQuota( - restrict_models=paid_models - ) + paid_quota = PaidHostingQuota(restrict_models=paid_models) quotas.append(paid_quota) if len(quotas) > 0: @@ -134,12 +148,7 @@ class HostingConfiguration: if app_config.get("HOSTED_OPENAI_API_ORGANIZATION"): credentials["openai_organization"] = app_config.get("HOSTED_OPENAI_API_ORGANIZATION") - return HostingProvider( - enabled=True, - credentials=credentials, - quota_unit=quota_unit, - quotas=quotas - ) + return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) return HostingProvider( enabled=False, @@ -153,9 +162,7 @@ class HostingConfiguration: if app_config.get("HOSTED_ANTHROPIC_TRIAL_ENABLED"): hosted_quota_limit = int(app_config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT", "0")) - trial_quota = TrialHostingQuota( - quota_limit=hosted_quota_limit - ) + trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit) quotas.append(trial_quota) if app_config.get("HOSTED_ANTHROPIC_PAID_ENABLED"): @@ -170,12 +177,7 @@ class HostingConfiguration: if app_config.get("HOSTED_ANTHROPIC_API_BASE"): credentials["anthropic_api_url"] = app_config.get("HOSTED_ANTHROPIC_API_BASE") - return HostingProvider( - enabled=True, - credentials=credentials, - quota_unit=quota_unit, - quotas=quotas - ) + return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) return HostingProvider( enabled=False, @@ -192,7 +194,7 @@ class HostingConfiguration: enabled=True, credentials=None, # use credentials from the provider quota_unit=quota_unit, - quotas=quotas + quotas=quotas, ) return HostingProvider( @@ -210,7 +212,7 @@ class HostingConfiguration: enabled=True, credentials=None, # use credentials from the provider quota_unit=quota_unit, - quotas=quotas + quotas=quotas, ) return HostingProvider( @@ -228,7 +230,7 @@ class HostingConfiguration: enabled=True, credentials=None, # use credentials from the provider quota_unit=quota_unit, - quotas=quotas + quotas=quotas, ) return HostingProvider( @@ -238,21 +240,19 @@ class HostingConfiguration: @staticmethod def init_moderation_config(app_config: Config) -> HostedModerationConfig: - if app_config.get("HOSTED_MODERATION_ENABLED") \ - and app_config.get("HOSTED_MODERATION_PROVIDERS"): + if app_config.get("HOSTED_MODERATION_ENABLED") and app_config.get("HOSTED_MODERATION_PROVIDERS"): return HostedModerationConfig( - enabled=True, - providers=app_config.get("HOSTED_MODERATION_PROVIDERS").split(',') + enabled=True, providers=app_config.get("HOSTED_MODERATION_PROVIDERS").split(",") ) - return HostedModerationConfig( - enabled=False - ) + return HostedModerationConfig(enabled=False) @staticmethod def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]: models_str = app_config.get(env_var) models_list = models_str.split(",") if models_str else [] - return [RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) for model_name in models_list if - model_name.strip()] - + return [ + RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) + for model_name in models_list + if model_name.strip() + ] diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 062666ac6a..eeb1dbfda0 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -16,9 +16,7 @@ from configs import dify_config from core.errors.error import ProviderTokenNotInitError from core.llm_generator.llm_generator import LLMGenerator from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.model_entities import ModelType, PriceType -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting @@ -41,7 +39,6 @@ from services.feature_service import FeatureService class IndexingRunner: - def __init__(self): self.storage = storage self.model_manager = ModelManager() @@ -51,25 +48,26 @@ class IndexingRunner: for dataset_document in dataset_documents: try: # get dataset - dataset = Dataset.query.filter_by( - id=dataset_document.dataset_id - ).first() + dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") # get the process rule - processing_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ - first() + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + .first() + ) index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() # extract text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) # transform - documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language, - processing_rule.to_dict()) + documents = self._transform( + index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict() + ) # save segment self._load_segments(dataset, dataset_document, documents) @@ -78,20 +76,20 @@ class IndexingRunner: index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, - documents=documents + documents=documents, ) - except DocumentIsPausedException: - raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) + except DocumentIsPausedError: + raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id)) except ProviderTokenNotInitError as e: - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e.description) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() except ObjectDeletedError: - logging.warning('Document deleted, document id: {}'.format(dataset_document.id)) + logging.warning("Document deleted, document id: {}".format(dataset_document.id)) except Exception as e: logging.exception("consume document failed") - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() @@ -100,26 +98,25 @@ class IndexingRunner: """Run the indexing process when the index_status is splitting.""" try: # get dataset - dataset = Dataset.query.filter_by( - id=dataset_document.dataset_id - ).first() + dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") # get exist document_segment list and delete document_segments = DocumentSegment.query.filter_by( - dataset_id=dataset.id, - document_id=dataset_document.id + dataset_id=dataset.id, document_id=dataset_document.id ).all() for document_segment in document_segments: db.session.delete(document_segment) db.session.commit() # get the process rule - processing_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ - first() + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + .first() + ) index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -127,28 +124,26 @@ class IndexingRunner: text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) # transform - documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language, - processing_rule.to_dict()) + documents = self._transform( + index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict() + ) # save segment self._load_segments(dataset, dataset_document, documents) # load self._load( - index_processor=index_processor, - dataset=dataset, - dataset_document=dataset_document, - documents=documents + index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents ) - except DocumentIsPausedException: - raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) + except DocumentIsPausedError: + raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id)) except ProviderTokenNotInitError as e: - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e.description) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() except Exception as e: logging.exception("consume document failed") - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() @@ -157,17 +152,14 @@ class IndexingRunner: """Run the indexing process when the index_status is indexing.""" try: # get dataset - dataset = Dataset.query.filter_by( - id=dataset_document.dataset_id - ).first() + dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") # get exist document_segment list and delete document_segments = DocumentSegment.query.filter_by( - dataset_id=dataset.id, - document_id=dataset_document.id + dataset_id=dataset.id, document_id=dataset_document.id ).all() documents = [] @@ -182,42 +174,48 @@ class IndexingRunner: "doc_hash": document_segment.index_node_hash, "document_id": document_segment.document_id, "dataset_id": document_segment.dataset_id, - } + }, ) documents.append(document) # build index # get the process rule - processing_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ - first() + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + .first() + ) index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() self._load( - index_processor=index_processor, - dataset=dataset, - dataset_document=dataset_document, - documents=documents + index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents ) - except DocumentIsPausedException: - raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) + except DocumentIsPausedError: + raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id)) except ProviderTokenNotInitError as e: - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e.description) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() except Exception as e: logging.exception("consume document failed") - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() - def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict, - doc_form: str = None, doc_language: str = 'English', dataset_id: str = None, - indexing_technique: str = 'economy') -> dict: + def indexing_estimate( + self, + tenant_id: str, + extract_settings: list[ExtractSetting], + tmp_processing_rule: dict, + doc_form: str = None, + doc_language: str = "English", + dataset_id: str = None, + indexing_technique: str = "economy", + ) -> dict: """ Estimate the indexing for the document. """ @@ -231,18 +229,16 @@ class IndexingRunner: embedding_model_instance = None if dataset_id: - dataset = Dataset.query.filter_by( - id=dataset_id - ).first() + dataset = Dataset.query.filter_by(id=dataset_id).first() if not dataset: - raise ValueError('Dataset not found.') - if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality': + raise ValueError("Dataset not found.") + if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality": if dataset.embedding_model_provider: embedding_model_instance = self.model_manager.get_model_instance( tenant_id=tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) else: embedding_model_instance = self.model_manager.get_default_model_instance( @@ -250,16 +246,13 @@ class IndexingRunner: model_type=ModelType.TEXT_EMBEDDING, ) else: - if indexing_technique == 'high_quality': + if indexing_technique == "high_quality": embedding_model_instance = self.model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) - tokens = 0 preview_texts = [] total_segments = 0 - total_price = 0 - currency = 'USD' index_type = doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() all_text_docs = [] @@ -268,8 +261,7 @@ class IndexingRunner: text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) all_text_docs.extend(text_docs) processing_rule = DatasetProcessRule( - mode=tmp_processing_rule["mode"], - rules=json.dumps(tmp_processing_rule["rules"]) + mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) ) # get splitter @@ -277,152 +269,118 @@ class IndexingRunner: # split to documents documents = self._split_to_documents_for_estimate( - text_docs=text_docs, - splitter=splitter, - processing_rule=processing_rule + text_docs=text_docs, splitter=splitter, processing_rule=processing_rule ) total_segments += len(documents) for document in documents: if len(preview_texts) < 5: preview_texts.append(document.page_content) - if indexing_technique == 'high_quality' or embedding_model_instance: - tokens += embedding_model_instance.get_text_embedding_num_tokens( - texts=[self.filter_string(document.page_content)] - ) - - if doc_form and doc_form == 'qa_model': - model_instance = self.model_manager.get_default_model_instance( - tenant_id=tenant_id, - model_type=ModelType.LLM - ) - - model_type_instance = model_instance.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) + if doc_form and doc_form == "qa_model": if len(preview_texts) > 0: # qa model document - response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], - doc_language) - document_qa_list = self.format_split_text(response) - price_info = model_type_instance.get_price( - model=model_instance.model, - credentials=model_instance.credentials, - price_type=PriceType.INPUT, - tokens=total_segments * 2000, + response = LLMGenerator.generate_qa_document( + current_user.current_tenant_id, preview_texts[0], doc_language ) - return { - "total_segments": total_segments * 20, - "tokens": total_segments * 2000, - "total_price": '{:f}'.format(price_info.total_amount), - "currency": price_info.currency, - "qa_preview": document_qa_list, - "preview": preview_texts - } - if embedding_model_instance: - embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance) - embedding_price_info = embedding_model_type_instance.get_price( - model=embedding_model_instance.model, - credentials=embedding_model_instance.credentials, - price_type=PriceType.INPUT, - tokens=tokens - ) - total_price = '{:f}'.format(embedding_price_info.total_amount) - currency = embedding_price_info.currency - return { - "total_segments": total_segments, - "tokens": tokens, - "total_price": total_price, - "currency": currency, - "preview": preview_texts - } + document_qa_list = self.format_split_text(response) - def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \ - -> list[Document]: + return {"total_segments": total_segments * 20, "qa_preview": document_qa_list, "preview": preview_texts} + return {"total_segments": total_segments, "preview": preview_texts} + + def _extract( + self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict + ) -> list[Document]: # load file if dataset_document.data_source_type not in ["upload_file", "notion_import", "website_crawl"]: return [] data_source_info = dataset_document.data_source_info_dict text_docs = [] - if dataset_document.data_source_type == 'upload_file': - if not data_source_info or 'upload_file_id' not in data_source_info: + if dataset_document.data_source_type == "upload_file": + if not data_source_info or "upload_file_id" not in data_source_info: raise ValueError("no upload file found") - file_detail = db.session.query(UploadFile). \ - filter(UploadFile.id == data_source_info['upload_file_id']). \ - one_or_none() + file_detail = ( + db.session.query(UploadFile).filter(UploadFile.id == data_source_info["upload_file_id"]).one_or_none() + ) if file_detail: extract_setting = ExtractSetting( - datasource_type="upload_file", - upload_file=file_detail, - document_model=dataset_document.doc_form + datasource_type="upload_file", upload_file=file_detail, document_model=dataset_document.doc_form ) - text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) - elif dataset_document.data_source_type == 'notion_import': - if (not data_source_info or 'notion_workspace_id' not in data_source_info - or 'notion_page_id' not in data_source_info): + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) + elif dataset_document.data_source_type == "notion_import": + if ( + not data_source_info + or "notion_workspace_id" not in data_source_info + or "notion_page_id" not in data_source_info + ): raise ValueError("no notion import info found") extract_setting = ExtractSetting( datasource_type="notion_import", notion_info={ - "notion_workspace_id": data_source_info['notion_workspace_id'], - "notion_obj_id": data_source_info['notion_page_id'], - "notion_page_type": data_source_info['type'], + "notion_workspace_id": data_source_info["notion_workspace_id"], + "notion_obj_id": data_source_info["notion_page_id"], + "notion_page_type": data_source_info["type"], "document": dataset_document, - "tenant_id": dataset_document.tenant_id + "tenant_id": dataset_document.tenant_id, }, - document_model=dataset_document.doc_form + document_model=dataset_document.doc_form, ) - text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) - elif dataset_document.data_source_type == 'website_crawl': - if (not data_source_info or 'provider' not in data_source_info - or 'url' not in data_source_info or 'job_id' not in data_source_info): + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) + elif dataset_document.data_source_type == "website_crawl": + if ( + not data_source_info + or "provider" not in data_source_info + or "url" not in data_source_info + or "job_id" not in data_source_info + ): raise ValueError("no website import info found") extract_setting = ExtractSetting( datasource_type="website_crawl", website_info={ - "provider": data_source_info['provider'], - "job_id": data_source_info['job_id'], + "provider": data_source_info["provider"], + "job_id": data_source_info["job_id"], "tenant_id": dataset_document.tenant_id, - "url": data_source_info['url'], - "mode": data_source_info['mode'], - "only_main_content": data_source_info['only_main_content'] + "url": data_source_info["url"], + "mode": data_source_info["mode"], + "only_main_content": data_source_info["only_main_content"], }, - document_model=dataset_document.doc_form + document_model=dataset_document.doc_form, ) - text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) # update document status to splitting self._update_document_index_status( document_id=dataset_document.id, after_indexing_status="splitting", extra_update_params={ DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs), - DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - } + DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + }, ) # replace doc id to document model id text_docs = cast(list[Document], text_docs) for text_doc in text_docs: - text_doc.metadata['document_id'] = dataset_document.id - text_doc.metadata['dataset_id'] = dataset_document.dataset_id + text_doc.metadata["document_id"] = dataset_document.id + text_doc.metadata["dataset_id"] = dataset_document.dataset_id return text_docs @staticmethod def filter_string(text): - text = re.sub(r'<\|', '<', text) - text = re.sub(r'\|>', '>', text) - text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text) + text = re.sub(r"<\|", "<", text) + text = re.sub(r"\|>", ">", text) + text = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]", "", text) # Unicode U+FFFE - text = re.sub('\uFFFE', '', text) + text = re.sub("\ufffe", "", text) return text @staticmethod - def _get_splitter(processing_rule: DatasetProcessRule, - embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: + def _get_splitter( + processing_rule: DatasetProcessRule, embedding_model_instance: Optional[ModelInstance] + ) -> TextSplitter: """ Get the NodeParser object according to the processing rule. """ @@ -436,10 +394,10 @@ class IndexingRunner: separator = segmentation["separator"] if separator: - separator = separator.replace('\\n', '\n') + separator = separator.replace("\\n", "\n") - if segmentation.get('chunk_overlap'): - chunk_overlap = segmentation['chunk_overlap'] + if segmentation.get("chunk_overlap"): + chunk_overlap = segmentation["chunk_overlap"] else: chunk_overlap = 0 @@ -448,22 +406,27 @@ class IndexingRunner: chunk_overlap=chunk_overlap, fixed_separator=separator, separators=["\n\n", "。", ". ", " ", ""], - embedding_model_instance=embedding_model_instance + embedding_model_instance=embedding_model_instance, ) else: # Automatic segmentation character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( - chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'], - chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'], + chunk_size=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["max_tokens"], + chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["chunk_overlap"], separators=["\n\n", "。", ". ", " ", ""], - embedding_model_instance=embedding_model_instance + embedding_model_instance=embedding_model_instance, ) return character_splitter - def _step_split(self, text_docs: list[Document], splitter: TextSplitter, - dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \ - -> list[Document]: + def _step_split( + self, + text_docs: list[Document], + splitter: TextSplitter, + dataset: Dataset, + dataset_document: DatasetDocument, + processing_rule: DatasetProcessRule, + ) -> list[Document]: """ Split the text documents into documents and save them to the document segment. """ @@ -473,14 +436,12 @@ class IndexingRunner: processing_rule=processing_rule, tenant_id=dataset.tenant_id, document_form=dataset_document.doc_form, - document_language=dataset_document.doc_language + document_language=dataset_document.doc_language, ) # save node to document segment doc_store = DatasetDocumentStore( - dataset=dataset, - user_id=dataset_document.created_by, - document_id=dataset_document.id + dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id ) # add document segments @@ -494,7 +455,7 @@ class IndexingRunner: extra_update_params={ DatasetDocument.cleaning_completed_at: cur_time, DatasetDocument.splitting_completed_at: cur_time, - } + }, ) # update segment status to indexing @@ -502,15 +463,21 @@ class IndexingRunner: dataset_document_id=dataset_document.id, update_params={ DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - } + DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + }, ) return documents - def _split_to_documents(self, text_docs: list[Document], splitter: TextSplitter, - processing_rule: DatasetProcessRule, tenant_id: str, - document_form: str, document_language: str) -> list[Document]: + def _split_to_documents( + self, + text_docs: list[Document], + splitter: TextSplitter, + processing_rule: DatasetProcessRule, + tenant_id: str, + document_form: str, + document_language: str, + ) -> list[Document]: """ Split the text documents into nodes. """ @@ -525,13 +492,12 @@ class IndexingRunner: documents = splitter.split_documents([text_doc]) split_documents = [] for document_node in documents: - if document_node.page_content.strip(): doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata['doc_id'] = doc_id - document_node.metadata['doc_hash'] = hash - # delete Spliter character + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash + # delete Splitter character page_content = document_node.page_content if page_content.startswith(".") or page_content.startswith("。"): page_content = page_content[1:] @@ -543,15 +509,21 @@ class IndexingRunner: split_documents.append(document_node) all_documents.extend(split_documents) # processing qa document - if document_form == 'qa_model': + if document_form == "qa_model": for i in range(0, len(all_documents), 10): threads = [] - sub_documents = all_documents[i:i + 10] + sub_documents = all_documents[i : i + 10] for doc in sub_documents: - document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={ - 'flask_app': current_app._get_current_object(), - 'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents, - 'document_language': document_language}) + document_format_thread = threading.Thread( + target=self.format_qa_document, + kwargs={ + "flask_app": current_app._get_current_object(), + "tenant_id": tenant_id, + "document_node": doc, + "all_qa_documents": all_qa_documents, + "document_language": document_language, + }, + ) threads.append(document_format_thread) document_format_thread.start() for thread in threads: @@ -570,12 +542,14 @@ class IndexingRunner: document_qa_list = self.format_split_text(response) qa_documents = [] for result in document_qa_list: - qa_document = Document(page_content=result['question'], metadata=document_node.metadata.model_copy()) + qa_document = Document( + page_content=result["question"], metadata=document_node.metadata.model_copy() + ) doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(result['question']) - qa_document.metadata['answer'] = result['answer'] - qa_document.metadata['doc_id'] = doc_id - qa_document.metadata['doc_hash'] = hash + hash = helper.generate_text_hash(result["question"]) + qa_document.metadata["answer"] = result["answer"] + qa_document.metadata["doc_id"] = doc_id + qa_document.metadata["doc_hash"] = hash qa_documents.append(qa_document) format_documents.extend(qa_documents) except Exception as e: @@ -583,8 +557,9 @@ class IndexingRunner: all_qa_documents.extend(format_documents) - def _split_to_documents_for_estimate(self, text_docs: list[Document], splitter: TextSplitter, - processing_rule: DatasetProcessRule) -> list[Document]: + def _split_to_documents_for_estimate( + self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule + ) -> list[Document]: """ Split the text documents into nodes. """ @@ -604,8 +579,8 @@ class IndexingRunner: doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document.page_content) - document.metadata['doc_id'] = doc_id - document.metadata['doc_hash'] = hash + document.metadata["doc_id"] = doc_id + document.metadata["doc_hash"] = hash split_documents.append(document) @@ -623,23 +598,23 @@ class IndexingRunner: else: rules = json.loads(processing_rule.rules) if processing_rule.rules else {} - if 'pre_processing_rules' in rules: + if "pre_processing_rules" in rules: pre_processing_rules = rules["pre_processing_rules"] for pre_processing_rule in pre_processing_rules: if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True: # Remove extra spaces - pattern = r'\n{3,}' - text = re.sub(pattern, '\n\n', text) - pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}' - text = re.sub(pattern, ' ', text) + pattern = r"\n{3,}" + text = re.sub(pattern, "\n\n", text) + pattern = r"[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}" + text = re.sub(pattern, " ", text) elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True: # Remove email - pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)' - text = re.sub(pattern, '', text) + pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)" + text = re.sub(pattern, "", text) # Remove URL - pattern = r'https?://[^\s]+' - text = re.sub(pattern, '', text) + pattern = r"https?://[^\s]+" + text = re.sub(pattern, "", text) return text @@ -648,27 +623,26 @@ class IndexingRunner: regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" matches = re.findall(regex, text, re.UNICODE) - return [ - { - "question": q, - "answer": re.sub(r"\n\s*", "\n", a.strip()) - } - for q, a in matches if q and a - ] + return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a] - def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset, - dataset_document: DatasetDocument, documents: list[Document]) -> None: + def _load( + self, + index_processor: BaseIndexProcessor, + dataset: Dataset, + dataset_document: DatasetDocument, + documents: list[Document], + ) -> None: """ insert index and update document/segment status to completed """ embedding_model_instance = None - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": embedding_model_instance = self.model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) # chunk nodes by chunk size @@ -677,18 +651,27 @@ class IndexingRunner: chunk_size = 10 # create keyword index - create_keyword_thread = threading.Thread(target=self._process_keyword_index, - args=(current_app._get_current_object(), - dataset.id, dataset_document.id, documents)) + create_keyword_thread = threading.Thread( + target=self._process_keyword_index, + args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), + ) create_keyword_thread.start() - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: futures = [] for i in range(0, len(documents), chunk_size): - chunk_documents = documents[i:i + chunk_size] - futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor, - chunk_documents, dataset, - dataset_document, embedding_model_instance)) + chunk_documents = documents[i : i + chunk_size] + futures.append( + executor.submit( + self._process_chunk, + current_app._get_current_object(), + index_processor, + chunk_documents, + dataset, + dataset_document, + embedding_model_instance, + ) + ) for future in futures: tokens += future.result() @@ -705,7 +688,7 @@ class IndexingRunner: DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at, DatasetDocument.error: None, - } + }, ) @staticmethod @@ -716,23 +699,26 @@ class IndexingRunner: raise ValueError("no dataset found") keyword = Keyword(dataset) keyword.create(documents) - if dataset.indexing_technique != 'high_quality': - document_ids = [document.metadata['doc_id'] for document in documents] + if dataset.indexing_technique != "high_quality": + document_ids = [document.metadata["doc_id"] for document in documents] db.session.query(DocumentSegment).filter( DocumentSegment.document_id == document_id, DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id.in_(document_ids), - DocumentSegment.status == "indexing" - ).update({ - DocumentSegment.status: "completed", - DocumentSegment.enabled: True, - DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - }) + DocumentSegment.status == "indexing", + ).update( + { + DocumentSegment.status: "completed", + DocumentSegment.enabled: True, + DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + } + ) db.session.commit() - def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document, - embedding_model_instance): + def _process_chunk( + self, flask_app, index_processor, chunk_documents, dataset, dataset_document, embedding_model_instance + ): with flask_app.app_context(): # check document is paused self._check_document_paused_status(dataset_document.id) @@ -740,26 +726,26 @@ class IndexingRunner: tokens = 0 if embedding_model_instance: tokens += sum( - embedding_model_instance.get_text_embedding_num_tokens( - [document.page_content] - ) + embedding_model_instance.get_text_embedding_num_tokens([document.page_content]) for document in chunk_documents ) # load index index_processor.load(dataset, chunk_documents, with_keywords=False) - document_ids = [document.metadata['doc_id'] for document in chunk_documents] + document_ids = [document.metadata["doc_id"] for document in chunk_documents] db.session.query(DocumentSegment).filter( DocumentSegment.document_id == dataset_document.id, DocumentSegment.dataset_id == dataset.id, DocumentSegment.index_node_id.in_(document_ids), - DocumentSegment.status == "indexing" - ).update({ - DocumentSegment.status: "completed", - DocumentSegment.enabled: True, - DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - }) + DocumentSegment.status == "indexing", + ).update( + { + DocumentSegment.status: "completed", + DocumentSegment.enabled: True, + DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + } + ) db.session.commit() @@ -767,27 +753,26 @@ class IndexingRunner: @staticmethod def _check_document_paused_status(document_id: str): - indexing_cache_key = 'document_{}_is_paused'.format(document_id) + indexing_cache_key = "document_{}_is_paused".format(document_id) result = redis_client.get(indexing_cache_key) if result: - raise DocumentIsPausedException() + raise DocumentIsPausedError() @staticmethod - def _update_document_index_status(document_id: str, after_indexing_status: str, - extra_update_params: Optional[dict] = None) -> None: + def _update_document_index_status( + document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None + ) -> None: """ Update the document indexing status. """ count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count() if count > 0: - raise DocumentIsPausedException() + raise DocumentIsPausedError() document = DatasetDocument.query.filter_by(id=document_id).first() if not document: - raise DocumentIsDeletedPausedException() + raise DocumentIsDeletedPausedError() - update_params = { - DatasetDocument.indexing_status: after_indexing_status - } + update_params = {DatasetDocument.indexing_status: after_indexing_status} if extra_update_params: update_params.update(extra_update_params) @@ -817,7 +802,7 @@ class IndexingRunner: "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) documents.append(document) # save vector index @@ -825,17 +810,23 @@ class IndexingRunner: index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor.load(dataset, documents) - def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset, - text_docs: list[Document], doc_language: str, process_rule: dict) -> list[Document]: + def _transform( + self, + index_processor: BaseIndexProcessor, + dataset: Dataset, + text_docs: list[Document], + doc_language: str, + process_rule: dict, + ) -> list[Document]: # get embedding model instance embedding_model_instance = None - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": if dataset.embedding_model_provider: embedding_model_instance = self.model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) else: embedding_model_instance = self.model_manager.get_default_model_instance( @@ -843,18 +834,20 @@ class IndexingRunner: model_type=ModelType.TEXT_EMBEDDING, ) - documents = index_processor.transform(text_docs, embedding_model_instance=embedding_model_instance, - process_rule=process_rule, tenant_id=dataset.tenant_id, - doc_language=doc_language) + documents = index_processor.transform( + text_docs, + embedding_model_instance=embedding_model_instance, + process_rule=process_rule, + tenant_id=dataset.tenant_id, + doc_language=doc_language, + ) return documents def _load_segments(self, dataset, dataset_document, documents): # save node to document segment doc_store = DatasetDocumentStore( - dataset=dataset, - user_id=dataset_document.created_by, - document_id=dataset_document.id + dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id ) # add document segments @@ -868,7 +861,7 @@ class IndexingRunner: extra_update_params={ DatasetDocument.cleaning_completed_at: cur_time, DatasetDocument.splitting_completed_at: cur_time, - } + }, ) # update segment status to indexing @@ -876,15 +869,15 @@ class IndexingRunner: dataset_document_id=dataset_document.id, update_params={ DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - } + DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + }, ) pass -class DocumentIsPausedException(Exception): +class DocumentIsPausedError(Exception): pass -class DocumentIsDeletedPausedException(Exception): +class DocumentIsDeletedPausedError(Exception): pass diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 8c13b4a45c..78a6d6e683 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -43,21 +43,16 @@ class LLMGenerator: with measure_time() as timer: response = model_instance.invoke_llm( - prompt_messages=prompts, - model_parameters={ - "max_tokens": 100, - "temperature": 1 - }, - stream=False + prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False ) answer = response.message.content - cleaned_answer = re.sub(r'^.*(\{.*\}).*$', r'\1', answer, flags=re.DOTALL) + cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL) result_dict = json.loads(cleaned_answer) - answer = result_dict['Your Output'] + answer = result_dict["Your Output"] name = answer.strip() if len(name) > 75: - name = name[:75] + '...' + name = name[:75] + "..." # get tracing instance trace_manager = TraceQueueManager(app_id=app_id) @@ -79,14 +74,9 @@ class LLMGenerator: output_parser = SuggestedQuestionsAfterAnswerOutputParser() format_instructions = output_parser.get_format_instructions() - prompt_template = PromptTemplateParser( - template="{{histories}}\n{{format_instructions}}\nquestions:\n" - ) + prompt_template = PromptTemplateParser(template="{{histories}}\n{{format_instructions}}\nquestions:\n") - prompt = prompt_template.format({ - "histories": histories, - "format_instructions": format_instructions - }) + prompt = prompt_template.format({"histories": histories, "format_instructions": format_instructions}) try: model_manager = ModelManager() @@ -101,12 +91,7 @@ class LLMGenerator: try: response = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters={ - "max_tokens": 256, - "temperature": 0 - }, - stream=False + prompt_messages=prompt_messages, model_parameters={"max_tokens": 256, "temperature": 0}, stream=False ) questions = output_parser.parse(response.message.content) @@ -119,32 +104,24 @@ class LLMGenerator: return questions @classmethod - def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool, rule_config_max_tokens: int = 512) -> dict: + def generate_rule_config( + cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool, rule_config_max_tokens: int = 512 + ) -> dict: output_parser = RuleConfigGeneratorOutputParser() error = "" error_step = "" - rule_config = { - "prompt": "", - "variables": [], - "opening_statement": "", - "error": "" - } - model_parameters = { - "max_tokens": rule_config_max_tokens, - "temperature": 0.01 - } + rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""} + model_parameters = {"max_tokens": rule_config_max_tokens, "temperature": 0.01} if no_variable: - prompt_template = PromptTemplateParser( - WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE - ) + prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE) prompt_generate = prompt_template.format( inputs={ "TASK_DESCRIPTION": instruction, }, - remove_template_variables=False + remove_template_variables=False, ) prompt_messages = [UserPromptMessage(content=prompt_generate)] @@ -158,13 +135,11 @@ class LLMGenerator: try: response = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=model_parameters, - stream=False + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False ) rule_config["prompt"] = response.message.content - + except InvokeError as e: error = str(e) error_step = "generate rule config" @@ -179,24 +154,18 @@ class LLMGenerator: # get rule config prompt, parameter and statement prompt_generate, parameter_generate, statement_generate = output_parser.get_format_instructions() - prompt_template = PromptTemplateParser( - prompt_generate - ) + prompt_template = PromptTemplateParser(prompt_generate) - parameter_template = PromptTemplateParser( - parameter_generate - ) + parameter_template = PromptTemplateParser(parameter_generate) - statement_template = PromptTemplateParser( - statement_generate - ) + statement_template = PromptTemplateParser(statement_generate) # format the prompt_generate_prompt prompt_generate_prompt = prompt_template.format( inputs={ "TASK_DESCRIPTION": instruction, }, - remove_template_variables=False + remove_template_variables=False, ) prompt_messages = [UserPromptMessage(content=prompt_generate_prompt)] @@ -213,9 +182,7 @@ class LLMGenerator: try: # the first step to generate the task prompt prompt_content = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=model_parameters, - stream=False + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False ) except InvokeError as e: error = str(e) @@ -230,7 +197,7 @@ class LLMGenerator: inputs={ "INPUT_TEXT": prompt_content.message.content, }, - remove_template_variables=False + remove_template_variables=False, ) parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)] @@ -240,15 +207,13 @@ class LLMGenerator: "TASK_DESCRIPTION": instruction, "INPUT_TEXT": prompt_content.message.content, }, - remove_template_variables=False + remove_template_variables=False, ) statement_messages = [UserPromptMessage(content=statement_generate_prompt)] try: parameter_content = model_instance.invoke_llm( - prompt_messages=parameter_messages, - model_parameters=model_parameters, - stream=False + prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False ) rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.content) except InvokeError as e: @@ -257,9 +222,7 @@ class LLMGenerator: try: statement_content = model_instance.invoke_llm( - prompt_messages=statement_messages, - model_parameters=model_parameters, - stream=False + prompt_messages=statement_messages, model_parameters=model_parameters, stream=False ) rule_config["opening_statement"] = statement_content.message.content except InvokeError as e: @@ -284,18 +247,10 @@ class LLMGenerator: model_type=ModelType.LLM, ) - prompt_messages = [ - SystemPromptMessage(content=prompt), - UserPromptMessage(content=query) - ] + prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)] response = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters={ - 'temperature': 0.01, - "max_tokens": 2000 - }, - stream=False + prompt_messages=prompt_messages, model_parameters={"temperature": 0.01, "max_tokens": 2000}, stream=False ) answer = response.message.content diff --git a/api/core/llm_generator/output_parser/errors.py b/api/core/llm_generator/output_parser/errors.py index 6a60f8de80..1e743f1757 100644 --- a/api/core/llm_generator/output_parser/errors.py +++ b/api/core/llm_generator/output_parser/errors.py @@ -1,2 +1,2 @@ -class OutputParserException(Exception): +class OutputParserError(Exception): pass diff --git a/api/core/llm_generator/output_parser/rule_config_generator.py b/api/core/llm_generator/output_parser/rule_config_generator.py index 8856f0c685..0c7683b16d 100644 --- a/api/core/llm_generator/output_parser/rule_config_generator.py +++ b/api/core/llm_generator/output_parser/rule_config_generator.py @@ -1,6 +1,6 @@ from typing import Any -from core.llm_generator.output_parser.errors import OutputParserException +from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.prompts import ( RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, @@ -10,9 +10,12 @@ from libs.json_in_md_parser import parse_and_check_json_markdown class RuleConfigGeneratorOutputParser: - def get_format_instructions(self) -> tuple[str, str, str]: - return RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE + return ( + RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, + RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, + RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE, + ) def parse(self, text: str) -> Any: try: @@ -21,16 +24,9 @@ class RuleConfigGeneratorOutputParser: if not isinstance(parsed["prompt"], str): raise ValueError("Expected 'prompt' to be a string.") if not isinstance(parsed["variables"], list): - raise ValueError( - "Expected 'variables' to be a list." - ) + raise ValueError("Expected 'variables' to be a list.") if not isinstance(parsed["opening_statement"], str): - raise ValueError( - "Expected 'opening_statement' to be a str." - ) + raise ValueError("Expected 'opening_statement' to be a str.") return parsed except Exception as e: - raise OutputParserException( - f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}" - ) - + raise OutputParserError(f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}") diff --git a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py index 3f046c68fc..182aeed98f 100644 --- a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py +++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py @@ -6,7 +6,6 @@ from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCT class SuggestedQuestionsAfterAnswerOutputParser: - def get_format_instructions(self) -> str: return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT @@ -15,7 +14,7 @@ class SuggestedQuestionsAfterAnswerOutputParser: if action_match is not None: json_obj = json.loads(action_match.group(0).strip()) else: - json_obj= [] + json_obj = [] print(f"Could not parse LLM output: {text}") return json_obj diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index 87361b385a..7ab257872f 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -66,19 +66,19 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( "and keeping each question under 20 characters.\n" "MAKE SURE your output is the SAME language as the Assistant's latest response(if the main response is written in Chinese, then the language of your output must be using Chinese.)!\n" "The output must be an array in JSON format following the specified schema:\n" - "[\"question1\",\"question2\",\"question3\"]\n" + '["question1","question2","question3"]\n' ) GENERATOR_QA_PROMPT = ( - ' The user will send a long text. Generate a Question and Answer pairs only using the knowledge in the long text. Please think step by step.' - 'Step 1: Understand and summarize the main content of this text.\n' - 'Step 2: What key information or concepts are mentioned in this text?\n' - 'Step 3: Decompose or combine multiple pieces of information and concepts.\n' - 'Step 4: Generate questions and answers based on these key information and concepts.\n' - ' The questions should be clear and detailed, and the answers should be detailed and complete. ' - 'You must answer in {language}, in a style that is clear and detailed in {language}. No language other than {language} should be used. \n' - ' Use the following format: Q1:\nA1:\nQ2:\nA2:...\n' - '' + " The user will send a long text. Generate a Question and Answer pairs only using the knowledge in the long text. Please think step by step." + "Step 1: Understand and summarize the main content of this text.\n" + "Step 2: What key information or concepts are mentioned in this text?\n" + "Step 3: Decompose or combine multiple pieces of information and concepts.\n" + "Step 4: Generate questions and answers based on these key information and concepts.\n" + " The questions should be clear and detailed, and the answers should be detailed and complete. " + "You must answer in {language}, in a style that is clear and detailed in {language}. No language other than {language} should be used. \n" + " Use the following format: Q1:\nA1:\nQ2:\nA2:...\n" + "" ) WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE = """ @@ -87,7 +87,7 @@ Here is a task description for which I would like you to create a high-quality p {{TASK_DESCRIPTION}} Based on task description, please create a well-structured prompt template that another AI could use to consistently complete the task. The prompt template should include: -- Do not inlcude or section and variables in the prompt, assume user will add them at their own will. +- Do not include or section and variables in the prompt, assume user will add them at their own will. - Clear instructions for the AI that will be using this prompt, demarcated with tags. The instructions should provide step-by-step directions on how to complete the task using the input variables. Also Specifies in the instructions that the output should not contain any xml tag. - Relevant examples if needed to clarify the task further, demarcated with tags. Do not include variables in the prompt. Give three pairs of input and output examples. - Include other relevant sections demarcated with appropriate XML tags like , . diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index b33d4dd7cb..54b1d8212b 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -21,8 +21,9 @@ class TokenBufferMemory: self.conversation = conversation self.model_instance = model_instance - def get_history_prompt_messages(self, max_token_limit: int = 2000, - message_limit: Optional[int] = None) -> list[PromptMessage]: + def get_history_prompt_messages( + self, max_token_limit: int = 2000, message_limit: Optional[int] = None + ) -> list[PromptMessage]: """ Get history prompt messages. :param max_token_limit: max token limit @@ -31,16 +32,11 @@ class TokenBufferMemory: app_record = self.conversation.app # fetch limited messages, and return reversed - query = db.session.query( - Message.id, - Message.query, - Message.answer, - Message.created_at, - Message.workflow_run_id - ).filter( - Message.conversation_id == self.conversation.id, - Message.answer != '' - ).order_by(Message.created_at.desc()) + query = ( + db.session.query(Message.id, Message.query, Message.answer, Message.created_at, Message.workflow_run_id) + .filter(Message.conversation_id == self.conversation.id, Message.answer != "") + .order_by(Message.created_at.desc()) + ) if message_limit and message_limit > 0: message_limit = message_limit if message_limit <= 500 else 500 @@ -50,10 +46,7 @@ class TokenBufferMemory: messages = query.limit(message_limit).all() messages = list(reversed(messages)) - message_file_parser = MessageFileParser( - tenant_id=app_record.tenant_id, - app_id=app_record.id - ) + message_file_parser = MessageFileParser(tenant_id=app_record.tenant_id, app_id=app_record.id) prompt_messages = [] for message in messages: files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() @@ -63,20 +56,17 @@ class TokenBufferMemory: file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) else: if message.workflow_run_id: - workflow_run = (db.session.query(WorkflowRun) - .filter(WorkflowRun.id == message.workflow_run_id).first()) + workflow_run = ( + db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first() + ) if workflow_run: file_extra_config = FileUploadConfigManager.convert( - workflow_run.workflow.features_dict, - is_vision=False + workflow_run.workflow.features_dict, is_vision=False ) if file_extra_config: - file_objs = message_file_parser.transform_message_files( - files, - file_extra_config - ) + file_objs = message_file_parser.transform_message_files(files, file_extra_config) else: file_objs = [] @@ -97,24 +87,23 @@ class TokenBufferMemory: return [] # prune the chat message if it exceeds the max token limit - curr_message_tokens = self.model_instance.get_llm_num_tokens( - prompt_messages - ) + curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) if curr_message_tokens > max_token_limit: pruned_memory = [] - while curr_message_tokens > max_token_limit and len(prompt_messages)>1: + while curr_message_tokens > max_token_limit and len(prompt_messages) > 1: pruned_memory.append(prompt_messages.pop(0)) - curr_message_tokens = self.model_instance.get_llm_num_tokens( - prompt_messages - ) + curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) return prompt_messages - def get_history_prompt_text(self, human_prefix: str = "Human", - ai_prefix: str = "Assistant", - max_token_limit: int = 2000, - message_limit: Optional[int] = None) -> str: + def get_history_prompt_text( + self, + human_prefix: str = "Human", + ai_prefix: str = "Assistant", + max_token_limit: int = 2000, + message_limit: Optional[int] = None, + ) -> str: """ Get history prompt text. :param human_prefix: human prefix @@ -123,10 +112,7 @@ class TokenBufferMemory: :param message_limit: message limit :return: """ - prompt_messages = self.get_history_prompt_messages( - max_token_limit=max_token_limit, - message_limit=message_limit - ) + prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit) string_messages = [] for m in prompt_messages: diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 7b1a7ada5b..990efd36c6 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,6 +1,6 @@ import logging import os -from collections.abc import Callable, Generator +from collections.abc import Callable, Generator, Sequence from typing import IO, Optional, Union, cast from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle @@ -41,7 +41,7 @@ class ModelInstance: configuration=provider_model_bundle.configuration, model_type=provider_model_bundle.model_type_instance.model_type, model=model, - credentials=self.credentials + credentials=self.credentials, ) @staticmethod @@ -54,10 +54,7 @@ class ModelInstance: """ configuration = provider_model_bundle.configuration model_type = provider_model_bundle.model_type_instance.model_type - credentials = configuration.get_current_credentials( - model_type=model_type, - model=model - ) + credentials = configuration.get_current_credentials(model_type=model_type, model=model) if credentials is None: raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.") @@ -65,10 +62,9 @@ class ModelInstance: return credentials @staticmethod - def _get_load_balancing_manager(configuration: ProviderConfiguration, - model_type: ModelType, - model: str, - credentials: dict) -> Optional["LBModelManager"]: + def _get_load_balancing_manager( + configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict + ) -> Optional["LBModelManager"]: """ Get load balancing model credentials :param configuration: provider configuration @@ -81,8 +77,7 @@ class ModelInstance: current_model_setting = None # check if model is disabled by admin for model_setting in configuration.model_settings: - if (model_setting.model_type == model_type - and model_setting.model == model): + if model_setting.model_type == model_type and model_setting.model == model: current_model_setting = model_setting break @@ -95,17 +90,23 @@ class ModelInstance: model_type=model_type, model=model, load_balancing_configs=current_model_setting.load_balancing_configs, - managed_credentials=credentials if configuration.custom_configuration.provider else None + managed_credentials=credentials if configuration.custom_configuration.provider else None, ) return lb_model_manager return None - def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \ - -> Union[LLMResult, Generator]: + def invoke_llm( + self, + prompt_messages: list[PromptMessage], + model_parameters: Optional[dict] = None, + tools: Sequence[PromptMessageTool] | None = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -132,11 +133,12 @@ class ModelInstance: stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) - def get_llm_num_tokens(self, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_llm_num_tokens( + self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """ Get number of tokens for llm @@ -153,11 +155,10 @@ class ModelInstance: model=self.model, credentials=self.credentials, prompt_messages=prompt_messages, - tools=tools + tools=tools, ) - def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) -> TextEmbeddingResult: """ Invoke large language model @@ -174,7 +175,7 @@ class ModelInstance: model=self.model, credentials=self.credentials, texts=texts, - user=user + user=user, ) def get_text_embedding_num_tokens(self, texts: list[str]) -> int: @@ -192,13 +193,17 @@ class ModelInstance: function=self.model_type_instance.get_num_tokens, model=self.model, credentials=self.credentials, - texts=texts + texts=texts, ) - def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None, - top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def invoke_rerank( + self, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -221,11 +226,10 @@ class ModelInstance: docs=docs, score_threshold=score_threshold, top_n=top_n, - user=user + user=user, ) - def invoke_moderation(self, text: str, user: Optional[str] = None) \ - -> bool: + def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool: """ Invoke moderation model @@ -242,11 +246,10 @@ class ModelInstance: model=self.model, credentials=self.credentials, text=text, - user=user + user=user, ) - def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \ - -> str: + def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke large language model @@ -263,11 +266,10 @@ class ModelInstance: model=self.model, credentials=self.credentials, file=file, - user=user + user=user, ) - def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) \ - -> str: + def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> str: """ Invoke large language tts model @@ -288,7 +290,7 @@ class ModelInstance: content_text=content_text, user=user, tenant_id=tenant_id, - voice=voice + voice=voice, ) def _round_robin_invoke(self, function: Callable, *args, **kwargs): @@ -312,8 +314,8 @@ class ModelInstance: raise last_exception try: - if 'credentials' in kwargs: - del kwargs['credentials'] + if "credentials" in kwargs: + del kwargs["credentials"] return function(*args, **kwargs, credentials=lb_config.credentials) except InvokeRateLimitError as e: # expire in 60 seconds @@ -340,9 +342,7 @@ class ModelInstance: self.model_type_instance = cast(TTSModel, self.model_type_instance) return self.model_type_instance.get_tts_model_voices( - model=self.model, - credentials=self.credentials, - language=language + model=self.model, credentials=self.credentials, language=language ) @@ -363,9 +363,7 @@ class ModelManager: return self.get_default_model_instance(tenant_id, model_type) provider_model_bundle = self._provider_manager.get_provider_model_bundle( - tenant_id=tenant_id, - provider=provider, - model_type=model_type + tenant_id=tenant_id, provider=provider, model_type=model_type ) return ModelInstance(provider_model_bundle, model) @@ -386,10 +384,7 @@ class ModelManager: :param model_type: model type :return: """ - default_model_entity = self._provider_manager.get_default_model( - tenant_id=tenant_id, - model_type=model_type - ) + default_model_entity = self._provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type) if not default_model_entity: raise ProviderTokenNotInitError(f"Default model not found for {model_type}") @@ -398,17 +393,20 @@ class ModelManager: tenant_id=tenant_id, provider=default_model_entity.provider.provider, model_type=model_type, - model=default_model_entity.model + model=default_model_entity.model, ) class LBModelManager: - def __init__(self, tenant_id: str, - provider: str, - model_type: ModelType, - model: str, - load_balancing_configs: list[ModelLoadBalancingConfiguration], - managed_credentials: Optional[dict] = None) -> None: + def __init__( + self, + tenant_id: str, + provider: str, + model_type: ModelType, + model: str, + load_balancing_configs: list[ModelLoadBalancingConfiguration], + managed_credentials: Optional[dict] = None, + ) -> None: """ Load balancing model manager :param tenant_id: tenant_id @@ -439,10 +437,7 @@ class LBModelManager: :return: """ cache_key = "model_lb_index:{}:{}:{}:{}".format( - self._tenant_id, - self._provider, - self._model_type.value, - self._model + self._tenant_id, self._provider, self._model_type.value, self._model ) cooldown_load_balancing_configs = [] @@ -473,10 +468,12 @@ class LBModelManager: continue - if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): - logger.info(f"Model LB\nid: {config.id}\nname:{config.name}\n" - f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n" - f"model_type: {self._model_type.value}\nmodel: {self._model}") + if bool(os.environ.get("DEBUG", "False").lower() == "true"): + logger.info( + f"Model LB\nid: {config.id}\nname:{config.name}\n" + f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n" + f"model_type: {self._model_type.value}\nmodel: {self._model}" + ) return config @@ -490,14 +487,10 @@ class LBModelManager: :return: """ cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( - self._tenant_id, - self._provider, - self._model_type.value, - self._model, - config.id + self._tenant_id, self._provider, self._model_type.value, self._model, config.id ) - redis_client.setex(cooldown_cache_key, expire, 'true') + redis_client.setex(cooldown_cache_key, expire, "true") def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool: """ @@ -506,11 +499,7 @@ class LBModelManager: :return: """ cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( - self._tenant_id, - self._provider, - self._model_type.value, - self._model, - config.id + self._tenant_id, self._provider, self._model_type.value, self._model, config.id ) res = redis_client.exists(cooldown_cache_key) @@ -518,11 +507,9 @@ class LBModelManager: return res @staticmethod - def get_config_in_cooldown_and_ttl(tenant_id: str, - provider: str, - model_type: ModelType, - model: str, - config_id: str) -> tuple[bool, int]: + def get_config_in_cooldown_and_ttl( + tenant_id: str, provider: str, model_type: ModelType, model: str, config_id: str + ) -> tuple[bool, int]: """ Get model load balancing config is in cooldown and ttl :param tenant_id: workspace id @@ -533,11 +520,7 @@ class LBModelManager: :return: """ cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( - tenant_id, - provider, - model_type.value, - model, - config_id + tenant_id, provider, model_type.value, model, config_id ) ttl = redis_client.ttl(cooldown_cache_key) diff --git a/api/core/model_runtime/callbacks/base_callback.py b/api/core/model_runtime/callbacks/base_callback.py index bba004a32a..92da53c9a4 100644 --- a/api/core/model_runtime/callbacks/base_callback.py +++ b/api/core/model_runtime/callbacks/base_callback.py @@ -18,12 +18,21 @@ class Callback: Base class for callbacks. Only for LLM. """ + raise_error: bool = False - def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_before_invoke( + self, + llm_instance: AIModel, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ Before invoke callback @@ -39,10 +48,19 @@ class Callback: """ raise NotImplementedError() - def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None): + def on_new_chunk( + self, + llm_instance: AIModel, + chunk: LLMResultChunk, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ): """ On new chunk callback @@ -59,10 +77,19 @@ class Callback: """ raise NotImplementedError() - def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_after_invoke( + self, + llm_instance: AIModel, + result: LLMResult, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ After invoke callback @@ -79,10 +106,19 @@ class Callback: """ raise NotImplementedError() - def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_invoke_error( + self, + llm_instance: AIModel, + ex: Exception, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ Invoke error callback @@ -99,9 +135,7 @@ class Callback: """ raise NotImplementedError() - def print_text( - self, text: str, color: Optional[str] = None, end: str = "" - ) -> None: + def print_text(self, text: str, color: Optional[str] = None, end: str = "") -> None: """Print text with highlighting and no end characters.""" text_to_print = self._get_colored_text(text, color) if color else text print(text_to_print, end=end) diff --git a/api/core/model_runtime/callbacks/logging_callback.py b/api/core/model_runtime/callbacks/logging_callback.py index 0406853b88..3b6b825244 100644 --- a/api/core/model_runtime/callbacks/logging_callback.py +++ b/api/core/model_runtime/callbacks/logging_callback.py @@ -10,11 +10,20 @@ from core.model_runtime.model_providers.__base.ai_model import AIModel logger = logging.getLogger(__name__) + class LoggingCallback(Callback): - def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_before_invoke( + self, + llm_instance: AIModel, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ Before invoke callback @@ -28,40 +37,49 @@ class LoggingCallback(Callback): :param stream: is stream response :param user: unique user id """ - self.print_text("\n[on_llm_before_invoke]\n", color='blue') - self.print_text(f"Model: {model}\n", color='blue') - self.print_text("Parameters:\n", color='blue') + self.print_text("\n[on_llm_before_invoke]\n", color="blue") + self.print_text(f"Model: {model}\n", color="blue") + self.print_text("Parameters:\n", color="blue") for key, value in model_parameters.items(): - self.print_text(f"\t{key}: {value}\n", color='blue') + self.print_text(f"\t{key}: {value}\n", color="blue") if stop: - self.print_text(f"\tstop: {stop}\n", color='blue') + self.print_text(f"\tstop: {stop}\n", color="blue") if tools: - self.print_text("\tTools:\n", color='blue') + self.print_text("\tTools:\n", color="blue") for tool in tools: - self.print_text(f"\t\t{tool.name}\n", color='blue') + self.print_text(f"\t\t{tool.name}\n", color="blue") - self.print_text(f"Stream: {stream}\n", color='blue') + self.print_text(f"Stream: {stream}\n", color="blue") if user: - self.print_text(f"User: {user}\n", color='blue') + self.print_text(f"User: {user}\n", color="blue") - self.print_text("Prompt messages:\n", color='blue') + self.print_text("Prompt messages:\n", color="blue") for prompt_message in prompt_messages: if prompt_message.name: - self.print_text(f"\tname: {prompt_message.name}\n", color='blue') + self.print_text(f"\tname: {prompt_message.name}\n", color="blue") - self.print_text(f"\trole: {prompt_message.role.value}\n", color='blue') - self.print_text(f"\tcontent: {prompt_message.content}\n", color='blue') + self.print_text(f"\trole: {prompt_message.role.value}\n", color="blue") + self.print_text(f"\tcontent: {prompt_message.content}\n", color="blue") if stream: self.print_text("\n[on_llm_new_chunk]") - def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None): + def on_new_chunk( + self, + llm_instance: AIModel, + chunk: LLMResultChunk, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ): """ On new chunk callback @@ -79,10 +97,19 @@ class LoggingCallback(Callback): sys.stdout.write(chunk.delta.message.content) sys.stdout.flush() - def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_after_invoke( + self, + llm_instance: AIModel, + result: LLMResult, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ After invoke callback @@ -97,24 +124,33 @@ class LoggingCallback(Callback): :param stream: is stream response :param user: unique user id """ - self.print_text("\n[on_llm_after_invoke]\n", color='yellow') - self.print_text(f"Content: {result.message.content}\n", color='yellow') + self.print_text("\n[on_llm_after_invoke]\n", color="yellow") + self.print_text(f"Content: {result.message.content}\n", color="yellow") if result.message.tool_calls: - self.print_text("Tool calls:\n", color='yellow') + self.print_text("Tool calls:\n", color="yellow") for tool_call in result.message.tool_calls: - self.print_text(f"\t{tool_call.id}\n", color='yellow') - self.print_text(f"\t{tool_call.function.name}\n", color='yellow') - self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color='yellow') + self.print_text(f"\t{tool_call.id}\n", color="yellow") + self.print_text(f"\t{tool_call.function.name}\n", color="yellow") + self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color="yellow") - self.print_text(f"Model: {result.model}\n", color='yellow') - self.print_text(f"Usage: {result.usage}\n", color='yellow') - self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color='yellow') + self.print_text(f"Model: {result.model}\n", color="yellow") + self.print_text(f"Usage: {result.usage}\n", color="yellow") + self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color="yellow") - def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_invoke_error( + self, + llm_instance: AIModel, + ex: Exception, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ Invoke error callback @@ -129,5 +165,5 @@ class LoggingCallback(Callback): :param stream: is stream response :param user: unique user id """ - self.print_text("\n[on_llm_invoke_error]\n", color='red') + self.print_text("\n[on_llm_invoke_error]\n", color="red") logger.exception(ex) diff --git a/api/core/model_runtime/docs/en_US/schema.md b/api/core/model_runtime/docs/en_US/schema.md index 67f4e0879d..f819a4dbdc 100644 --- a/api/core/model_runtime/docs/en_US/schema.md +++ b/api/core/model_runtime/docs/en_US/schema.md @@ -52,7 +52,7 @@ - `mode` (string) voice model.(available for model type `tts`) - `name` (string) voice model display name.(available for model type `tts`) - `language` (string) the voice model supports languages.(available for model type `tts`) - - `word_limit` (int) Single conversion word limit, paragraphwise by default(available for model type `tts`) + - `word_limit` (int) Single conversion word limit, paragraph-wise by default(available for model type `tts`) - `audio_type` (string) Support audio file extension format, e.g.:mp3,wav(available for model type `tts`) - `max_workers` (int) Number of concurrent workers supporting text and audio conversion(available for model type`tts`) - `max_characters_per_chunk` (int) Maximum characters per chunk (available for model type `moderation`) @@ -150,7 +150,7 @@ - `input` (float) Input price, i.e., Prompt price - `output` (float) Output price, i.e., returned content price -- `unit` (float) Pricing unit, e.g., if the price is meausred in 1M tokens, the corresponding token amount for the unit price is `0.000001`. +- `unit` (float) Pricing unit, e.g., if the price is measured in 1M tokens, the corresponding token amount for the unit price is `0.000001`. - `currency` (string) Currency unit ### ProviderCredentialSchema diff --git a/api/core/model_runtime/entities/common_entities.py b/api/core/model_runtime/entities/common_entities.py index 175c13cfdc..659ad59bd6 100644 --- a/api/core/model_runtime/entities/common_entities.py +++ b/api/core/model_runtime/entities/common_entities.py @@ -7,6 +7,7 @@ class I18nObject(BaseModel): """ Model class for i18n object. """ + zh_Hans: Optional[str] = None en_US: str diff --git a/api/core/model_runtime/entities/defaults.py b/api/core/model_runtime/entities/defaults.py index d2076bf74a..e94be6f918 100644 --- a/api/core/model_runtime/entities/defaults.py +++ b/api/core/model_runtime/entities/defaults.py @@ -2,107 +2,123 @@ from core.model_runtime.entities.model_entities import DefaultParameterName PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { DefaultParameterName.TEMPERATURE: { - 'label': { - 'en_US': 'Temperature', - 'zh_Hans': '温度', + "label": { + "en_US": "Temperature", + "zh_Hans": "温度", }, - 'type': 'float', - 'help': { - 'en_US': 'Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.', - 'zh_Hans': '温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。', + "type": "float", + "help": { + "en_US": "Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.", + "zh_Hans": "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。", }, - 'required': False, - 'default': 0.0, - 'min': 0.0, - 'max': 1.0, - 'precision': 2, + "required": False, + "default": 0.0, + "min": 0.0, + "max": 1.0, + "precision": 2, }, DefaultParameterName.TOP_P: { - 'label': { - 'en_US': 'Top P', - 'zh_Hans': 'Top P', + "label": { + "en_US": "Top P", + "zh_Hans": "Top P", }, - 'type': 'float', - 'help': { - 'en_US': 'Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.', - 'zh_Hans': '通过核心采样控制多样性:0.5表示考虑了一半的所有可能性加权选项。', + "type": "float", + "help": { + "en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.", + "zh_Hans": "通过核心采样控制多样性:0.5表示考虑了一半的所有可能性加权选项。", }, - 'required': False, - 'default': 1.0, - 'min': 0.0, - 'max': 1.0, - 'precision': 2, + "required": False, + "default": 1.0, + "min": 0.0, + "max": 1.0, + "precision": 2, + }, + DefaultParameterName.TOP_K: { + "label": { + "en_US": "Top K", + "zh_Hans": "Top K", + }, + "type": "int", + "help": { + "en_US": "Limits the number of tokens to consider for each step by keeping only the k most likely tokens.", + "zh_Hans": "通过只保留每一步中最可能的 k 个标记来限制要考虑的标记数量。", + }, + "required": False, + "default": 50, + "min": 1, + "max": 100, + "precision": 0, }, DefaultParameterName.PRESENCE_PENALTY: { - 'label': { - 'en_US': 'Presence Penalty', - 'zh_Hans': '存在惩罚', + "label": { + "en_US": "Presence Penalty", + "zh_Hans": "存在惩罚", }, - 'type': 'float', - 'help': { - 'en_US': 'Applies a penalty to the log-probability of tokens already in the text.', - 'zh_Hans': '对文本中已有的标记的对数概率施加惩罚。', + "type": "float", + "help": { + "en_US": "Applies a penalty to the log-probability of tokens already in the text.", + "zh_Hans": "对文本中已有的标记的对数概率施加惩罚。", }, - 'required': False, - 'default': 0.0, - 'min': 0.0, - 'max': 1.0, - 'precision': 2, + "required": False, + "default": 0.0, + "min": 0.0, + "max": 1.0, + "precision": 2, }, DefaultParameterName.FREQUENCY_PENALTY: { - 'label': { - 'en_US': 'Frequency Penalty', - 'zh_Hans': '频率惩罚', + "label": { + "en_US": "Frequency Penalty", + "zh_Hans": "频率惩罚", }, - 'type': 'float', - 'help': { - 'en_US': 'Applies a penalty to the log-probability of tokens that appear in the text.', - 'zh_Hans': '对文本中出现的标记的对数概率施加惩罚。', + "type": "float", + "help": { + "en_US": "Applies a penalty to the log-probability of tokens that appear in the text.", + "zh_Hans": "对文本中出现的标记的对数概率施加惩罚。", }, - 'required': False, - 'default': 0.0, - 'min': 0.0, - 'max': 1.0, - 'precision': 2, + "required": False, + "default": 0.0, + "min": 0.0, + "max": 1.0, + "precision": 2, }, DefaultParameterName.MAX_TOKENS: { - 'label': { - 'en_US': 'Max Tokens', - 'zh_Hans': '最大标记', + "label": { + "en_US": "Max Tokens", + "zh_Hans": "最大标记", }, - 'type': 'int', - 'help': { - 'en_US': 'Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.', - 'zh_Hans': '指定生成结果长度的上限。如果生成结果截断,可以调大该参数。', + "type": "int", + "help": { + "en_US": "Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.", + "zh_Hans": "指定生成结果长度的上限。如果生成结果截断,可以调大该参数。", }, - 'required': False, - 'default': 64, - 'min': 1, - 'max': 2048, - 'precision': 0, + "required": False, + "default": 64, + "min": 1, + "max": 2048, + "precision": 0, }, DefaultParameterName.RESPONSE_FORMAT: { - 'label': { - 'en_US': 'Response Format', - 'zh_Hans': '回复格式', + "label": { + "en_US": "Response Format", + "zh_Hans": "回复格式", }, - 'type': 'string', - 'help': { - 'en_US': 'Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.', - 'zh_Hans': '设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等', + "type": "string", + "help": { + "en_US": "Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.", + "zh_Hans": "设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等", }, - 'required': False, - 'options': ['JSON', 'XML'], + "required": False, + "options": ["JSON", "XML"], }, DefaultParameterName.JSON_SCHEMA: { - 'label': { - 'en_US': 'JSON Schema', + "label": { + "en_US": "JSON Schema", }, - 'type': 'text', - 'help': { - 'en_US': 'Set a response json schema will ensure LLM to adhere it.', - 'zh_Hans': '设置返回的json schema,llm将按照它返回', + "type": "text", + "help": { + "en_US": "Set a response json schema will ensure LLM to adhere it.", + "zh_Hans": "设置返回的json schema,llm将按照它返回", }, - 'required': False, + "required": False, }, } diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py index b5bd9e267a..52b590f66a 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/core/model_runtime/entities/llm_entities.py @@ -12,11 +12,12 @@ class LLMMode(Enum): """ Enum class for large language model mode. """ + COMPLETION = "completion" CHAT = "chat" @classmethod - def value_of(cls, value: str) -> 'LLMMode': + def value_of(cls, value: str) -> "LLMMode": """ Get value of given mode. @@ -26,13 +27,14 @@ class LLMMode(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") class LLMUsage(ModelUsage): """ Model class for llm usage. """ + prompt_tokens: int prompt_unit_price: Decimal prompt_price_unit: Decimal @@ -50,24 +52,59 @@ class LLMUsage(ModelUsage): def empty_usage(cls): return cls( prompt_tokens=0, - prompt_unit_price=Decimal('0.0'), - prompt_price_unit=Decimal('0.0'), - prompt_price=Decimal('0.0'), + prompt_unit_price=Decimal("0.0"), + prompt_price_unit=Decimal("0.0"), + prompt_price=Decimal("0.0"), completion_tokens=0, - completion_unit_price=Decimal('0.0'), - completion_price_unit=Decimal('0.0'), - completion_price=Decimal('0.0'), + completion_unit_price=Decimal("0.0"), + completion_price_unit=Decimal("0.0"), + completion_price=Decimal("0.0"), total_tokens=0, - total_price=Decimal('0.0'), - currency='USD', - latency=0.0 + total_price=Decimal("0.0"), + currency="USD", + latency=0.0, ) + def plus(self, other: "LLMUsage") -> "LLMUsage": + """ + Add two LLMUsage instances together. + + :param other: Another LLMUsage instance to add + :return: A new LLMUsage instance with summed values + """ + if self.total_tokens == 0: + return other + else: + return LLMUsage( + prompt_tokens=self.prompt_tokens + other.prompt_tokens, + prompt_unit_price=other.prompt_unit_price, + prompt_price_unit=other.prompt_price_unit, + prompt_price=self.prompt_price + other.prompt_price, + completion_tokens=self.completion_tokens + other.completion_tokens, + completion_unit_price=other.completion_unit_price, + completion_price_unit=other.completion_price_unit, + completion_price=self.completion_price + other.completion_price, + total_tokens=self.total_tokens + other.total_tokens, + total_price=self.total_price + other.total_price, + currency=other.currency, + latency=self.latency + other.latency, + ) + + def __add__(self, other: "LLMUsage") -> "LLMUsage": + """ + Overload the + operator to add two LLMUsage instances. + + :param other: Another LLMUsage instance to add + :return: A new LLMUsage instance with summed values + """ + return self.plus(other) + class LLMResult(BaseModel): """ Model class for llm result. """ + model: str prompt_messages: list[PromptMessage] message: AssistantPromptMessage @@ -79,6 +116,7 @@ class LLMResultChunkDelta(BaseModel): """ Model class for llm result chunk delta. """ + index: int message: AssistantPromptMessage usage: Optional[LLMUsage] = None @@ -89,6 +127,7 @@ class LLMResultChunk(BaseModel): """ Model class for llm result chunk. """ + model: str prompt_messages: list[PromptMessage] system_fingerprint: Optional[str] = None @@ -99,4 +138,5 @@ class NumTokensResult(PriceInfo): """ Model class for number of tokens result. """ + tokens: int diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index e8e6963b56..e51bb18deb 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -9,13 +9,14 @@ class PromptMessageRole(Enum): """ Enum class for prompt message. """ + SYSTEM = "system" USER = "user" ASSISTANT = "assistant" TOOL = "tool" @classmethod - def value_of(cls, value: str) -> 'PromptMessageRole': + def value_of(cls, value: str) -> "PromptMessageRole": """ Get value of given mode. @@ -25,13 +26,14 @@ class PromptMessageRole(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid prompt message type value {value}') + raise ValueError(f"invalid prompt message type value {value}") class PromptMessageTool(BaseModel): """ Model class for prompt message tool. """ + name: str description: str parameters: dict @@ -41,7 +43,8 @@ class PromptMessageFunction(BaseModel): """ Model class for prompt message function. """ - type: str = 'function' + + type: str = "function" function: PromptMessageTool @@ -49,14 +52,16 @@ class PromptMessageContentType(Enum): """ Enum class for prompt message content type. """ - TEXT = 'text' - IMAGE = 'image' + + TEXT = "text" + IMAGE = "image" class PromptMessageContent(BaseModel): """ Model class for prompt message content. """ + type: PromptMessageContentType data: str @@ -65,6 +70,7 @@ class TextPromptMessageContent(PromptMessageContent): """ Model class for text prompt message content. """ + type: PromptMessageContentType = PromptMessageContentType.TEXT @@ -72,9 +78,10 @@ class ImagePromptMessageContent(PromptMessageContent): """ Model class for image prompt message content. """ + class DETAIL(Enum): - LOW = 'low' - HIGH = 'high' + LOW = "low" + HIGH = "high" type: PromptMessageContentType = PromptMessageContentType.IMAGE detail: DETAIL = DETAIL.LOW @@ -84,6 +91,7 @@ class PromptMessage(ABC, BaseModel): """ Model class for prompt message. """ + role: PromptMessageRole content: Optional[str | list[PromptMessageContent]] = None name: Optional[str] = None @@ -101,6 +109,7 @@ class UserPromptMessage(PromptMessage): """ Model class for user prompt message. """ + role: PromptMessageRole = PromptMessageRole.USER @@ -108,14 +117,17 @@ class AssistantPromptMessage(PromptMessage): """ Model class for assistant prompt message. """ + class ToolCall(BaseModel): """ Model class for assistant prompt message tool call. """ + class ToolCallFunction(BaseModel): """ Model class for assistant prompt message tool call function. """ + name: str arguments: str @@ -123,7 +135,7 @@ class AssistantPromptMessage(PromptMessage): type: str function: ToolCallFunction - @field_validator('id', mode='before') + @field_validator("id", mode="before") @classmethod def transform_id_to_str(cls, value) -> str: if not isinstance(value, str): @@ -145,10 +157,12 @@ class AssistantPromptMessage(PromptMessage): return True + class SystemPromptMessage(PromptMessage): """ Model class for system prompt message. """ + role: PromptMessageRole = PromptMessageRole.SYSTEM @@ -156,6 +170,7 @@ class ToolPromptMessage(PromptMessage): """ Model class for tool prompt message. """ + role: PromptMessageRole = PromptMessageRole.TOOL tool_call_id: str diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index c257ce63d2..d898ef1490 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -11,6 +11,7 @@ class ModelType(Enum): """ Enum class for model type. """ + LLM = "llm" TEXT_EMBEDDING = "text-embedding" RERANK = "rerank" @@ -26,22 +27,22 @@ class ModelType(Enum): :return: model type """ - if origin_model_type == 'text-generation' or origin_model_type == cls.LLM.value: + if origin_model_type == "text-generation" or origin_model_type == cls.LLM.value: return cls.LLM - elif origin_model_type == 'embeddings' or origin_model_type == cls.TEXT_EMBEDDING.value: + elif origin_model_type == "embeddings" or origin_model_type == cls.TEXT_EMBEDDING.value: return cls.TEXT_EMBEDDING - elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value: + elif origin_model_type == "reranking" or origin_model_type == cls.RERANK.value: return cls.RERANK - elif origin_model_type == 'speech2text' or origin_model_type == cls.SPEECH2TEXT.value: + elif origin_model_type == "speech2text" or origin_model_type == cls.SPEECH2TEXT.value: return cls.SPEECH2TEXT - elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value: + elif origin_model_type == "tts" or origin_model_type == cls.TTS.value: return cls.TTS - elif origin_model_type == 'text2img' or origin_model_type == cls.TEXT2IMG.value: + elif origin_model_type == "text2img" or origin_model_type == cls.TEXT2IMG.value: return cls.TEXT2IMG elif origin_model_type == cls.MODERATION.value: return cls.MODERATION else: - raise ValueError(f'invalid origin model type {origin_model_type}') + raise ValueError(f"invalid origin model type {origin_model_type}") def to_origin_model_type(self) -> str: """ @@ -50,26 +51,28 @@ class ModelType(Enum): :return: origin model type """ if self == self.LLM: - return 'text-generation' + return "text-generation" elif self == self.TEXT_EMBEDDING: - return 'embeddings' + return "embeddings" elif self == self.RERANK: - return 'reranking' + return "reranking" elif self == self.SPEECH2TEXT: - return 'speech2text' + return "speech2text" elif self == self.TTS: - return 'tts' + return "tts" elif self == self.MODERATION: - return 'moderation' + return "moderation" elif self == self.TEXT2IMG: - return 'text2img' + return "text2img" else: - raise ValueError(f'invalid model type {self}') + raise ValueError(f"invalid model type {self}") + class FetchFrom(Enum): """ Enum class for fetch from. """ + PREDEFINED_MODEL = "predefined-model" CUSTOMIZABLE_MODEL = "customizable-model" @@ -78,6 +81,7 @@ class ModelFeature(Enum): """ Enum class for llm feature. """ + TOOL_CALL = "tool-call" MULTI_TOOL_CALL = "multi-tool-call" AGENT_THOUGHT = "agent-thought" @@ -85,12 +89,14 @@ class ModelFeature(Enum): STREAM_TOOL_CALL = "stream-tool-call" -class DefaultParameterName(Enum): +class DefaultParameterName(str, Enum): """ Enum class for parameter template variable. """ + TEMPERATURE = "temperature" TOP_P = "top_p" + TOP_K = "top_k" PRESENCE_PENALTY = "presence_penalty" FREQUENCY_PENALTY = "frequency_penalty" MAX_TOKENS = "max_tokens" @@ -98,7 +104,7 @@ class DefaultParameterName(Enum): JSON_SCHEMA = "json_schema" @classmethod - def value_of(cls, value: Any) -> 'DefaultParameterName': + def value_of(cls, value: Any) -> "DefaultParameterName": """ Get parameter name from value. @@ -108,13 +114,14 @@ class DefaultParameterName(Enum): for name in cls: if name.value == value: return name - raise ValueError(f'invalid parameter name {value}') + raise ValueError(f"invalid parameter name {value}") class ParameterType(Enum): """ Enum class for parameter type. """ + FLOAT = "float" INT = "int" STRING = "string" @@ -126,6 +133,7 @@ class ModelPropertyKey(Enum): """ Enum class for model property key. """ + MODE = "mode" CONTEXT_SIZE = "context_size" MAX_CHUNKS = "max_chunks" @@ -143,6 +151,7 @@ class ProviderModel(BaseModel): """ Model class for provider model. """ + model: str label: I18nObject model_type: ModelType @@ -157,6 +166,7 @@ class ParameterRule(BaseModel): """ Model class for parameter rule. """ + name: str use_template: Optional[str] = None label: I18nObject @@ -174,6 +184,7 @@ class PriceConfig(BaseModel): """ Model class for pricing info. """ + input: Decimal output: Optional[Decimal] = None unit: Decimal @@ -184,6 +195,7 @@ class AIModelEntity(ProviderModel): """ Model class for AI model. """ + parameter_rules: list[ParameterRule] = [] pricing: Optional[PriceConfig] = None @@ -196,6 +208,7 @@ class PriceType(Enum): """ Enum class for price type. """ + INPUT = "input" OUTPUT = "output" @@ -204,6 +217,7 @@ class PriceInfo(BaseModel): """ Model class for price info. """ + unit_price: Decimal unit: Decimal total_amount: Decimal diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py index f88f89d588..bfe861a97f 100644 --- a/api/core/model_runtime/entities/provider_entities.py +++ b/api/core/model_runtime/entities/provider_entities.py @@ -12,6 +12,7 @@ class ConfigurateMethod(Enum): """ Enum class for configurate method of provider model. """ + PREDEFINED_MODEL = "predefined-model" CUSTOMIZABLE_MODEL = "customizable-model" @@ -20,6 +21,7 @@ class FormType(Enum): """ Enum class for form type. """ + TEXT_INPUT = "text-input" SECRET_INPUT = "secret-input" SELECT = "select" @@ -31,6 +33,7 @@ class FormShowOnObject(BaseModel): """ Model class for form show on. """ + variable: str value: str @@ -39,6 +42,7 @@ class FormOption(BaseModel): """ Model class for form option. """ + label: I18nObject value: str show_on: list[FormShowOnObject] = [] @@ -46,15 +50,14 @@ class FormOption(BaseModel): def __init__(self, **data): super().__init__(**data) if not self.label: - self.label = I18nObject( - en_US=self.value - ) + self.label = I18nObject(en_US=self.value) class CredentialFormSchema(BaseModel): """ Model class for credential form schema. """ + variable: str label: I18nObject type: FormType @@ -70,6 +73,7 @@ class ProviderCredentialSchema(BaseModel): """ Model class for provider credential schema. """ + credential_form_schemas: list[CredentialFormSchema] @@ -82,6 +86,7 @@ class ModelCredentialSchema(BaseModel): """ Model class for model credential schema. """ + model: FieldModelSchema credential_form_schemas: list[CredentialFormSchema] @@ -90,6 +95,7 @@ class SimpleProviderEntity(BaseModel): """ Simple model class for provider. """ + provider: str label: I18nObject icon_small: Optional[I18nObject] = None @@ -102,6 +108,7 @@ class ProviderHelpEntity(BaseModel): """ Model class for provider help. """ + title: I18nObject url: I18nObject @@ -110,6 +117,7 @@ class ProviderEntity(BaseModel): """ Model class for provider. """ + provider: str label: I18nObject description: Optional[I18nObject] = None @@ -138,7 +146,7 @@ class ProviderEntity(BaseModel): icon_small=self.icon_small, icon_large=self.icon_large, supported_model_types=self.supported_model_types, - models=self.models + models=self.models, ) @@ -146,5 +154,6 @@ class ProviderConfig(BaseModel): """ Model class for provider config. """ + provider: str credentials: dict diff --git a/api/core/model_runtime/entities/rerank_entities.py b/api/core/model_runtime/entities/rerank_entities.py index d51efd2b3b..99709e1bcd 100644 --- a/api/core/model_runtime/entities/rerank_entities.py +++ b/api/core/model_runtime/entities/rerank_entities.py @@ -5,6 +5,7 @@ class RerankDocument(BaseModel): """ Model class for rerank document. """ + index: int text: str score: float @@ -14,5 +15,6 @@ class RerankResult(BaseModel): """ Model class for rerank result. """ + model: str docs: list[RerankDocument] diff --git a/api/core/model_runtime/entities/text_embedding_entities.py b/api/core/model_runtime/entities/text_embedding_entities.py index 7be3def379..846b89d658 100644 --- a/api/core/model_runtime/entities/text_embedding_entities.py +++ b/api/core/model_runtime/entities/text_embedding_entities.py @@ -9,6 +9,7 @@ class EmbeddingUsage(ModelUsage): """ Model class for embedding usage. """ + tokens: int total_tokens: int unit_price: Decimal @@ -22,7 +23,7 @@ class TextEmbeddingResult(BaseModel): """ Model class for text embedding result. """ + model: str embeddings: list[list[float]] usage: EmbeddingUsage - diff --git a/api/core/model_runtime/errors/invoke.py b/api/core/model_runtime/errors/invoke.py index 0513cfaf67..edfb19c7d0 100644 --- a/api/core/model_runtime/errors/invoke.py +++ b/api/core/model_runtime/errors/invoke.py @@ -3,6 +3,7 @@ from typing import Optional class InvokeError(Exception): """Base class for all LLM exceptions.""" + description: Optional[str] = None def __init__(self, description: Optional[str] = None) -> None: @@ -14,24 +15,29 @@ class InvokeError(Exception): class InvokeConnectionError(InvokeError): """Raised when the Invoke returns connection error.""" + description = "Connection Error" class InvokeServerUnavailableError(InvokeError): """Raised when the Invoke returns server unavailable error.""" + description = "Server Unavailable Error" class InvokeRateLimitError(InvokeError): """Raised when the Invoke returns rate limit error.""" + description = "Rate Limit Error" class InvokeAuthorizationError(InvokeError): """Raised when the Invoke returns authorization error.""" + description = "Incorrect model credentials provided, please check and try again. " class InvokeBadRequestError(InvokeError): """Raised when the Invoke returns bad request.""" + description = "Bad Request Error" diff --git a/api/core/model_runtime/errors/validate.py b/api/core/model_runtime/errors/validate.py index 8db79a52bb..7fcd2133f9 100644 --- a/api/core/model_runtime/errors/validate.py +++ b/api/core/model_runtime/errors/validate.py @@ -2,4 +2,5 @@ class CredentialsValidateFailedError(Exception): """ Credentials validate failed error """ + pass diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 716bb63566..09d2d7e54d 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -66,12 +66,14 @@ class AIModel(ABC): :param error: model invoke error :return: unified error """ - provider_name = self.__class__.__module__.split('.')[-3] + provider_name = self.__class__.__module__.split(".")[-3] for invoke_error, model_errors in self._invoke_error_mapping.items(): if isinstance(error, tuple(model_errors)): if invoke_error == InvokeAuthorizationError: - return invoke_error(description=f"[{provider_name}] Incorrect model credentials provided, please check and try again. ") + return invoke_error( + description=f"[{provider_name}] Incorrect model credentials provided, please check and try again. " + ) return invoke_error(description=f"[{provider_name}] {invoke_error.description}, {str(error)}") @@ -115,7 +117,7 @@ class AIModel(ABC): if not price_config: raise ValueError(f"Price config not found for model {model}") total_amount = tokens * unit_price * price_config.unit - total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) + total_amount = total_amount.quantize(decimal.Decimal("0.0000001"), rounding=decimal.ROUND_HALF_UP) return PriceInfo( unit_price=unit_price, @@ -136,24 +138,26 @@ class AIModel(ABC): model_schemas = [] # get module name - model_type = self.__class__.__module__.split('.')[-1] + model_type = self.__class__.__module__.split(".")[-1] # get provider name - provider_name = self.__class__.__module__.split('.')[-3] + provider_name = self.__class__.__module__.split(".")[-3] # get the path of current classes current_path = os.path.abspath(__file__) # get parent path of the current path - provider_model_type_path = os.path.join(os.path.dirname(os.path.dirname(current_path)), provider_name, model_type) + provider_model_type_path = os.path.join( + os.path.dirname(os.path.dirname(current_path)), provider_name, model_type + ) # get all yaml files path under provider_model_type_path that do not start with __ model_schema_yaml_paths = [ os.path.join(provider_model_type_path, model_schema_yaml) for model_schema_yaml in os.listdir(provider_model_type_path) - if not model_schema_yaml.startswith('__') - and not model_schema_yaml.startswith('_') + if not model_schema_yaml.startswith("__") + and not model_schema_yaml.startswith("_") and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml)) - and model_schema_yaml.endswith('.yaml') + and model_schema_yaml.endswith(".yaml") ] # get _position.yaml file path @@ -165,10 +169,10 @@ class AIModel(ABC): yaml_data = load_yaml_file(model_schema_yaml_path) new_parameter_rules = [] - for parameter_rule in yaml_data.get('parameter_rules', []): - if 'use_template' in parameter_rule: + for parameter_rule in yaml_data.get("parameter_rules", []): + if "use_template" in parameter_rule: try: - default_parameter_name = DefaultParameterName.value_of(parameter_rule['use_template']) + default_parameter_name = DefaultParameterName.value_of(parameter_rule["use_template"]) default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name) copy_default_parameter_rule = default_parameter_rule.copy() copy_default_parameter_rule.update(parameter_rule) @@ -176,31 +180,26 @@ class AIModel(ABC): except ValueError: pass - if 'label' not in parameter_rule: - parameter_rule['label'] = { - 'zh_Hans': parameter_rule['name'], - 'en_US': parameter_rule['name'] - } + if "label" not in parameter_rule: + parameter_rule["label"] = {"zh_Hans": parameter_rule["name"], "en_US": parameter_rule["name"]} new_parameter_rules.append(parameter_rule) - yaml_data['parameter_rules'] = new_parameter_rules + yaml_data["parameter_rules"] = new_parameter_rules - if 'label' not in yaml_data: - yaml_data['label'] = { - 'zh_Hans': yaml_data['model'], - 'en_US': yaml_data['model'] - } + if "label" not in yaml_data: + yaml_data["label"] = {"zh_Hans": yaml_data["model"], "en_US": yaml_data["model"]} - yaml_data['fetch_from'] = FetchFrom.PREDEFINED_MODEL.value + yaml_data["fetch_from"] = FetchFrom.PREDEFINED_MODEL.value try: # yaml_data to entity model_schema = AIModelEntity(**yaml_data) except Exception as e: model_schema_yaml_file_name = os.path.basename(model_schema_yaml_path).rstrip(".yaml") - raise Exception(f'Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}:' - f' {str(e)}') + raise Exception( + f"Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}:" f" {str(e)}" + ) # cache model schema model_schemas.append(model_schema) @@ -235,7 +234,9 @@ class AIModel(ABC): return None - def get_customizable_model_schema_from_credentials(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]: + def get_customizable_model_schema_from_credentials( + self, model: str, credentials: Mapping + ) -> Optional[AIModelEntity]: """ Get customizable model schema from credentials @@ -261,19 +262,19 @@ class AIModel(ABC): try: default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template) default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name) - if not parameter_rule.max and 'max' in default_parameter_rule: - parameter_rule.max = default_parameter_rule['max'] - if not parameter_rule.min and 'min' in default_parameter_rule: - parameter_rule.min = default_parameter_rule['min'] - if not parameter_rule.default and 'default' in default_parameter_rule: - parameter_rule.default = default_parameter_rule['default'] - if not parameter_rule.precision and 'precision' in default_parameter_rule: - parameter_rule.precision = default_parameter_rule['precision'] - if not parameter_rule.required and 'required' in default_parameter_rule: - parameter_rule.required = default_parameter_rule['required'] - if not parameter_rule.help and 'help' in default_parameter_rule: + if not parameter_rule.max and "max" in default_parameter_rule: + parameter_rule.max = default_parameter_rule["max"] + if not parameter_rule.min and "min" in default_parameter_rule: + parameter_rule.min = default_parameter_rule["min"] + if not parameter_rule.default and "default" in default_parameter_rule: + parameter_rule.default = default_parameter_rule["default"] + if not parameter_rule.precision and "precision" in default_parameter_rule: + parameter_rule.precision = default_parameter_rule["precision"] + if not parameter_rule.required and "required" in default_parameter_rule: + parameter_rule.required = default_parameter_rule["required"] + if not parameter_rule.help and "help" in default_parameter_rule: parameter_rule.help = I18nObject( - en_US=default_parameter_rule['help']['en_US'], + en_US=default_parameter_rule["help"]["en_US"], ) if ( parameter_rule.help diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index cfc8942c79..5c39186e65 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -35,16 +35,24 @@ class LargeLanguageModel(AIModel): """ Model class for large language model. """ + model_type: ModelType = ModelType.LLM # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \ - -> Union[LLMResult, Generator]: + def invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: Optional[dict] = None, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -69,7 +77,7 @@ class LargeLanguageModel(AIModel): callbacks = callbacks or [] - if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): + if bool(os.environ.get("DEBUG", "False").lower() == "true"): callbacks.append(LoggingCallback()) # trigger before invoke callbacks @@ -82,7 +90,7 @@ class LargeLanguageModel(AIModel): stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) try: @@ -96,7 +104,7 @@ class LargeLanguageModel(AIModel): stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) else: result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -111,7 +119,7 @@ class LargeLanguageModel(AIModel): stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) raise self._transform_invoke_error(e) @@ -127,7 +135,7 @@ class LargeLanguageModel(AIModel): stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) elif isinstance(result, LLMResult): self._trigger_after_invoke_callbacks( @@ -140,15 +148,23 @@ class LargeLanguageModel(AIModel): stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) return result - def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None) -> Union[LLMResult, Generator]: + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper, ensure the response is a code block with output markdown quote @@ -183,7 +199,7 @@ if you are not sure about the structure. tools=tools, stop=stop, stream=stream, - user=user + user=user, ) model_parameters.pop("response_format") @@ -195,15 +211,16 @@ if you are not sure about the structure. if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): # override the system message prompt_messages[0] = SystemPromptMessage( - content=block_prompts - .replace("{{instructions}}", str(prompt_messages[0].content)) + content=block_prompts.replace("{{instructions}}", str(prompt_messages[0].content)) ) else: # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=block_prompts - .replace("{{instructions}}", f"Please output a valid {code_block} object.") - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=block_prompts.replace("{{instructions}}", f"Please output a valid {code_block} object.") + ), + ) if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage): # add ```JSON\n to the last text message @@ -216,9 +233,7 @@ if you are not sure about the structure. break else: # append a user message - prompt_messages.append(UserPromptMessage( - content=f"```{code_block}\n" - )) + prompt_messages.append(UserPromptMessage(content=f"```{code_block}\n")) response = self._invoke( model=model, @@ -228,33 +243,30 @@ if you are not sure about the structure. tools=tools, stop=stop, stream=stream, - user=user + user=user, ) if isinstance(response, Generator): first_chunk = next(response) + def new_generator(): yield first_chunk yield from response if first_chunk.delta.message.content and first_chunk.delta.message.content.startswith("`"): return self._code_block_mode_stream_processor_with_backtick( - model=model, - prompt_messages=prompt_messages, - input_generator=new_generator() + model=model, prompt_messages=prompt_messages, input_generator=new_generator() ) else: return self._code_block_mode_stream_processor( - model=model, - prompt_messages=prompt_messages, - input_generator=new_generator() + model=model, prompt_messages=prompt_messages, input_generator=new_generator() ) return response - def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage], - input_generator: Generator[LLMResultChunk, None, None] - ) -> Generator[LLMResultChunk, None, None]: + def _code_block_mode_stream_processor( + self, model: str, prompt_messages: list[PromptMessage], input_generator: Generator[LLMResultChunk, None, None] + ) -> Generator[LLMResultChunk, None, None]: """ Code block mode stream processor, ensure the response is a code block with output markdown quote @@ -303,16 +315,13 @@ if you are not sure about the structure. prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=new_piece, - tool_calls=[] - ), - ) + message=AssistantPromptMessage(content=new_piece, tool_calls=[]), + ), ) - def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list, - input_generator: Generator[LLMResultChunk, None, None]) \ - -> Generator[LLMResultChunk, None, None]: + def _code_block_mode_stream_processor_with_backtick( + self, model: str, prompt_messages: list, input_generator: Generator[LLMResultChunk, None, None] + ) -> Generator[LLMResultChunk, None, None]: """ Code block mode stream processor, ensure the response is a code block with output markdown quote. This version skips the language identifier that follows the opening triple backticks. @@ -378,18 +387,23 @@ if you are not sure about the structure. prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=new_piece, - tool_calls=[] - ), - ) + message=AssistantPromptMessage(content=new_piece, tool_calls=[]), + ), ) - def _invoke_result_generator(self, model: str, result: Generator, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> Generator: + def _invoke_result_generator( + self, + model: str, + result: Generator, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Generator: """ Invoke result generator @@ -397,9 +411,7 @@ if you are not sure about the structure. :return: result generator """ callbacks = callbacks or [] - prompt_message = AssistantPromptMessage( - content="" - ) + prompt_message = AssistantPromptMessage(content="") usage = None system_fingerprint = None real_model = model @@ -418,7 +430,7 @@ if you are not sure about the structure. stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) prompt_message.content += chunk.delta.message.content @@ -438,7 +450,7 @@ if you are not sure about the structure. prompt_messages=prompt_messages, message=prompt_message, usage=usage if usage else LLMUsage.empty_usage(), - system_fingerprint=system_fingerprint + system_fingerprint=system_fingerprint, ), credentials=credentials, prompt_messages=prompt_messages, @@ -447,15 +459,21 @@ if you are not sure about the structure. stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) @abstractmethod - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -472,8 +490,13 @@ if you are not sure about the structure. raise NotImplementedError @abstractmethod - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -519,7 +542,9 @@ if you are not sure about the structure. return mode - def _calc_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage: + def _calc_response_usage( + self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int + ) -> LLMUsage: """ Calculate response usage @@ -539,10 +564,7 @@ if you are not sure about the structure. # get completion price info completion_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.OUTPUT, - tokens=completion_tokens + model=model, credentials=credentials, price_type=PriceType.OUTPUT, tokens=completion_tokens ) # transform usage @@ -558,16 +580,23 @@ if you are not sure about the structure. total_tokens=prompt_tokens + completion_tokens, total_price=prompt_price_info.total_amount + completion_price_info.total_amount, currency=prompt_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage - def _trigger_before_invoke_callbacks(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: + def _trigger_before_invoke_callbacks( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> None: """ Trigger before invoke callbacks @@ -593,7 +622,7 @@ if you are not sure about the structure. tools=tools, stop=stop, stream=stream, - user=user + user=user, ) except Exception as e: if callback.raise_error: @@ -601,11 +630,19 @@ if you are not sure about the structure. else: logger.warning(f"Callback {callback.__class__.__name__} on_before_invoke failed with error {e}") - def _trigger_new_chunk_callbacks(self, chunk: LLMResultChunk, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: + def _trigger_new_chunk_callbacks( + self, + chunk: LLMResultChunk, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> None: """ Trigger new chunk callbacks @@ -632,7 +669,7 @@ if you are not sure about the structure. tools=tools, stop=stop, stream=stream, - user=user + user=user, ) except Exception as e: if callback.raise_error: @@ -640,11 +677,19 @@ if you are not sure about the structure. else: logger.warning(f"Callback {callback.__class__.__name__} on_new_chunk failed with error {e}") - def _trigger_after_invoke_callbacks(self, model: str, result: LLMResult, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: + def _trigger_after_invoke_callbacks( + self, + model: str, + result: LLMResult, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> None: """ Trigger after invoke callbacks @@ -672,7 +717,7 @@ if you are not sure about the structure. tools=tools, stop=stop, stream=stream, - user=user + user=user, ) except Exception as e: if callback.raise_error: @@ -680,11 +725,19 @@ if you are not sure about the structure. else: logger.warning(f"Callback {callback.__class__.__name__} on_after_invoke failed with error {e}") - def _trigger_invoke_error_callbacks(self, model: str, ex: Exception, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: + def _trigger_invoke_error_callbacks( + self, + model: str, + ex: Exception, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> None: """ Trigger invoke error callbacks @@ -712,7 +765,7 @@ if you are not sure about the structure. tools=tools, stop=stop, stream=stream, - user=user + user=user, ) except Exception as e: if callback.raise_error: @@ -758,11 +811,13 @@ if you are not sure about the structure. # validate parameter value range if parameter_rule.min is not None and parameter_value < parameter_rule.min: raise ValueError( - f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}.") + f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}." + ) if parameter_rule.max is not None and parameter_value > parameter_rule.max: raise ValueError( - f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.") + f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}." + ) elif parameter_rule.type == ParameterType.FLOAT: if not isinstance(parameter_value, float | int): raise ValueError(f"Model Parameter {parameter_name} should be float.") @@ -775,16 +830,19 @@ if you are not sure about the structure. else: if parameter_value != round(parameter_value, parameter_rule.precision): raise ValueError( - f"Model Parameter {parameter_name} should be round to {parameter_rule.precision} decimal places.") + f"Model Parameter {parameter_name} should be round to {parameter_rule.precision} decimal places." + ) # validate parameter value range if parameter_rule.min is not None and parameter_value < parameter_rule.min: raise ValueError( - f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}.") + f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}." + ) if parameter_rule.max is not None and parameter_value > parameter_rule.max: raise ValueError( - f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.") + f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}." + ) elif parameter_rule.type == ParameterType.BOOLEAN: if not isinstance(parameter_value, bool): raise ValueError(f"Model Parameter {parameter_name} should be bool.") diff --git a/api/core/model_runtime/model_providers/__base/model_provider.py b/api/core/model_runtime/model_providers/__base/model_provider.py index 780460a3f7..4374093de4 100644 --- a/api/core/model_runtime/model_providers/__base/model_provider.py +++ b/api/core/model_runtime/model_providers/__base/model_provider.py @@ -29,32 +29,32 @@ class ModelProvider(ABC): def get_provider_schema(self) -> ProviderEntity: """ Get provider schema - + :return: provider schema """ if self.provider_schema: return self.provider_schema - + # get dirname of the current path - provider_name = self.__class__.__module__.split('.')[-1] + provider_name = self.__class__.__module__.split(".")[-1] # get the path of the model_provider classes base_path = os.path.abspath(__file__) current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name) - + # read provider schema from yaml file - yaml_path = os.path.join(current_path, f'{provider_name}.yaml') + yaml_path = os.path.join(current_path, f"{provider_name}.yaml") yaml_data = load_yaml_file(yaml_path) - + try: # yaml_data to entity provider_schema = ProviderEntity(**yaml_data) except Exception as e: - raise Exception(f'Invalid provider schema for {provider_name}: {str(e)}') + raise Exception(f"Invalid provider schema for {provider_name}: {str(e)}") # cache schema self.provider_schema = provider_schema - + return provider_schema def models(self, model_type: ModelType) -> list[AIModelEntity]: @@ -92,15 +92,15 @@ class ModelProvider(ABC): # get the path of the model type classes base_path = os.path.abspath(__file__) - model_type_name = model_type.value.replace('-', '_') + model_type_name = model_type.value.replace("-", "_") model_type_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name, model_type_name) - model_type_py_path = os.path.join(model_type_path, f'{model_type_name}.py') + model_type_py_path = os.path.join(model_type_path, f"{model_type_name}.py") if not os.path.isdir(model_type_path) or not os.path.exists(model_type_py_path): - raise Exception(f'Invalid model type {model_type} for provider {provider_name}') + raise Exception(f"Invalid model type {model_type} for provider {provider_name}") # Dynamic loading {model_type_name}.py file and find the subclass of AIModel - parent_module = '.'.join(self.__class__.__module__.split('.')[:-1]) + parent_module = ".".join(self.__class__.__module__.split(".")[:-1]) mod = import_module_from_source( module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path ) diff --git a/api/core/model_runtime/model_providers/__base/moderation_model.py b/api/core/model_runtime/model_providers/__base/moderation_model.py index 2b17f292c5..d04414ccb8 100644 --- a/api/core/model_runtime/model_providers/__base/moderation_model.py +++ b/api/core/model_runtime/model_providers/__base/moderation_model.py @@ -12,14 +12,13 @@ class ModerationModel(AIModel): """ Model class for moderation model. """ + model_type: ModelType = ModelType.MODERATION # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, - text: str, user: Optional[str] = None) \ - -> bool: + def invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool: """ Invoke moderation model @@ -37,9 +36,7 @@ class ModerationModel(AIModel): raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, - text: str, user: Optional[str] = None) \ - -> bool: + def _invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool: """ Invoke large language model @@ -50,4 +47,3 @@ class ModerationModel(AIModel): :return: false if text is safe, true otherwise """ raise NotImplementedError - diff --git a/api/core/model_runtime/model_providers/__base/rerank_model.py b/api/core/model_runtime/model_providers/__base/rerank_model.py index 2c86f25180..5fb9604742 100644 --- a/api/core/model_runtime/model_providers/__base/rerank_model.py +++ b/api/core/model_runtime/model_providers/__base/rerank_model.py @@ -11,12 +11,19 @@ class RerankModel(AIModel): """ Base Model class for rerank model. """ + model_type: ModelType = ModelType.RERANK - def invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -37,10 +44,16 @@ class RerankModel(AIModel): raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model diff --git a/api/core/model_runtime/model_providers/__base/speech2text_model.py b/api/core/model_runtime/model_providers/__base/speech2text_model.py index 4fb11025fe..b6b0b73743 100644 --- a/api/core/model_runtime/model_providers/__base/speech2text_model.py +++ b/api/core/model_runtime/model_providers/__base/speech2text_model.py @@ -12,14 +12,13 @@ class Speech2TextModel(AIModel): """ Model class for speech2text model. """ + model_type: ModelType = ModelType.SPEECH2TEXT # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke large language model @@ -35,9 +34,7 @@ class Speech2TextModel(AIModel): raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke large language model @@ -59,4 +56,4 @@ class Speech2TextModel(AIModel): current_dir = os.path.dirname(os.path.abspath(__file__)) # Construct the path to the audio file - return os.path.join(current_dir, 'audio.mp3') + return os.path.join(current_dir, "audio.mp3") diff --git a/api/core/model_runtime/model_providers/__base/text2img_model.py b/api/core/model_runtime/model_providers/__base/text2img_model.py index e0f1adb1c4..a5810e2f0e 100644 --- a/api/core/model_runtime/model_providers/__base/text2img_model.py +++ b/api/core/model_runtime/model_providers/__base/text2img_model.py @@ -11,14 +11,15 @@ class Text2ImageModel(AIModel): """ Model class for text2img model. """ + model_type: ModelType = ModelType.TEXT2IMG # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, prompt: str, - model_parameters: dict, user: Optional[str] = None) \ - -> list[IO[bytes]]: + def invoke( + self, model: str, credentials: dict, prompt: str, model_parameters: dict, user: Optional[str] = None + ) -> list[IO[bytes]]: """ Invoke Text2Image model @@ -36,9 +37,9 @@ class Text2ImageModel(AIModel): raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, prompt: str, - model_parameters: dict, user: Optional[str] = None) \ - -> list[IO[bytes]]: + def _invoke( + self, model: str, credentials: dict, prompt: str, model_parameters: dict, user: Optional[str] = None + ) -> list[IO[bytes]]: """ Invoke Text2Image model diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index 381d2f6cd1..54a4486023 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -13,14 +13,15 @@ class TextEmbeddingModel(AIModel): """ Model class for text embedding model. """ + model_type: ModelType = ModelType.TEXT_EMBEDDING # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke large language model @@ -38,9 +39,9 @@ class TextEmbeddingModel(AIModel): raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke large language model diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py index 6059b3f561..5fe6dda6ad 100644 --- a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py +++ b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py @@ -7,27 +7,28 @@ from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer _tokenizer = None _lock = Lock() + class GPT2Tokenizer: @staticmethod def _get_num_tokens_by_gpt2(text: str) -> int: """ - use gpt2 tokenizer to get num tokens + use gpt2 tokenizer to get num tokens """ _tokenizer = GPT2Tokenizer.get_encoder() tokens = _tokenizer.encode(text, verbose=False) return len(tokens) - + @staticmethod def get_num_tokens(text: str) -> int: return GPT2Tokenizer._get_num_tokens_by_gpt2(text) - + @staticmethod def get_encoder() -> Any: global _tokenizer, _lock with _lock: if _tokenizer is None: base_path = abspath(__file__) - gpt2_tokenizer_path = join(dirname(base_path), 'gpt2') + gpt2_tokenizer_path = join(dirname(base_path), "gpt2") _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path) - return _tokenizer \ No newline at end of file + return _tokenizer diff --git a/api/core/model_runtime/model_providers/__base/tts_model.py b/api/core/model_runtime/model_providers/__base/tts_model.py index 64e85d2c11..70be9322a7 100644 --- a/api/core/model_runtime/model_providers/__base/tts_model.py +++ b/api/core/model_runtime/model_providers/__base/tts_model.py @@ -13,15 +13,17 @@ logger = logging.getLogger(__name__) class TTSModel(AIModel): """ - Model class for ttstext model. + Model class for TTS model. """ + model_type: ModelType = ModelType.TTS # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, - user: Optional[str] = None): + def invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ): """ Invoke large language model @@ -35,14 +37,21 @@ class TTSModel(AIModel): :return: translated audio file """ try: - return self._invoke(model=model, credentials=credentials, user=user, - content_text=content_text, voice=voice, tenant_id=tenant_id) + return self._invoke( + model=model, + credentials=credentials, + user=user, + content_text=content_text, + voice=voice, + tenant_id=tenant_id, + ) except Exception as e: raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, - user: Optional[str] = None): + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ): """ Invoke large language model @@ -71,10 +80,13 @@ class TTSModel(AIModel): if model_schema and ModelPropertyKey.VOICES in model_schema.model_properties: voices = model_schema.model_properties[ModelPropertyKey.VOICES] if language: - return [{'name': d['name'], 'value': d['mode']} for d in voices if - language and language in d.get('language')] + return [ + {"name": d["name"], "value": d["mode"]} + for d in voices + if language and language in d.get("language") + ] else: - return [{'name': d['name'], 'value': d['mode']} for d in voices] + return [{"name": d["name"], "value": d["mode"]} for d in voices] def _get_model_default_voice(self, model: str, credentials: dict) -> any: """ @@ -123,23 +135,23 @@ class TTSModel(AIModel): return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS] @staticmethod - def _split_text_into_sentences(org_text, max_length=2000, pattern=r'[。.!?]'): + def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"): match = re.compile(pattern) tx = match.finditer(org_text) start = 0 result = [] - one_sentence = '' + one_sentence = "" for i in tx: end = i.regs[0][1] tmp = org_text[start:end] if len(one_sentence + tmp) > max_length: result.append(one_sentence) - one_sentence = '' + one_sentence = "" one_sentence += tmp start = end last_sens = org_text[start:] if last_sens: one_sentence += last_sens - if one_sentence != '': + if one_sentence != "": result.append(one_sentence) return result diff --git a/api/core/model_runtime/model_providers/anthropic/anthropic.py b/api/core/model_runtime/model_providers/anthropic/anthropic.py index 00a6bbce3b..5b12f04a3e 100644 --- a/api/core/model_runtime/model_providers/anthropic/anthropic.py +++ b/api/core/model_runtime/model_providers/anthropic/anthropic.py @@ -19,13 +19,10 @@ class AnthropicProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - # Use `claude-instant-1` model for validate, - model_instance.validate_credentials( - model='claude-instant-1.2', - credentials=credentials - ) + # Use `claude-3-opus-20240229` model for validate, + model_instance.validate_credentials(model="claude-3-opus-20240229", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.2.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.2.yaml index 929a7f8725..ac69bbf4d2 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.2.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.2.yaml @@ -33,3 +33,4 @@ pricing: output: '5.51' unit: '0.000001' currency: USD +deprecated: true diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 81be1a06a7..30e9d2e9f2 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -55,11 +55,17 @@ if you are not sure about the structure. class AnthropicLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -76,10 +82,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): # invoke model return self._chat_generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm chat model @@ -96,41 +109,39 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): credentials_kwargs = self._to_credential_kwargs(credentials) # transform model parameters from completion api of anthropic to chat api - if 'max_tokens_to_sample' in model_parameters: - model_parameters['max_tokens'] = model_parameters.pop('max_tokens_to_sample') + if "max_tokens_to_sample" in model_parameters: + model_parameters["max_tokens"] = model_parameters.pop("max_tokens_to_sample") # init model client client = Anthropic(**credentials_kwargs) extra_model_kwargs = {} if stop: - extra_model_kwargs['stop_sequences'] = stop + extra_model_kwargs["stop_sequences"] = stop if user: - extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user) + extra_model_kwargs["metadata"] = completion_create_params.Metadata(user_id=user) system, prompt_message_dicts = self._convert_prompt_messages(prompt_messages) if system: - extra_model_kwargs['system'] = system + extra_model_kwargs["system"] = system # Add the new header for claude-3-5-sonnet-20240620 model extra_headers = {} if model == "claude-3-5-sonnet-20240620": - if model_parameters.get('max_tokens') > 4096: + if model_parameters.get("max_tokens") > 4096: extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15" if tools: - extra_model_kwargs['tools'] = [ - self._transform_tool_prompt(tool) for tool in tools - ] + extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools] response = client.beta.tools.messages.create( model=model, messages=prompt_message_dicts, stream=stream, extra_headers=extra_headers, **model_parameters, - **extra_model_kwargs + **extra_model_kwargs, ) else: # chat model @@ -140,22 +151,30 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): stream=stream, extra_headers=extra_headers, **model_parameters, - **extra_model_kwargs + **extra_model_kwargs, ) if stream: return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages) return self._handle_chat_generate_response(model, credentials, response, prompt_messages) - - def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: list[Callback] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper for invoking large language model """ - if model_parameters.get('response_format'): + if model_parameters.get("response_format"): stop = stop or [] # chat model self._transform_chat_json_prompts( @@ -167,24 +186,27 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): stop=stop, stream=stream, user=user, - response_format=model_parameters['response_format'] + response_format=model_parameters["response_format"], ) - model_parameters.pop('response_format') + model_parameters.pop("response_format") return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def _transform_tool_prompt(self, tool: PromptMessageTool) -> dict: - return { - 'name': tool.name, - 'description': tool.description, - 'input_schema': tool.parameters - } + return {"name": tool.name, "description": tool.description, "input_schema": tool.parameters} - def _transform_chat_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ - -> None: + def _transform_chat_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ Transform json prompts """ @@ -197,22 +219,30 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): # override the system message prompt_messages[0] = SystemPromptMessage( - content=ANTHROPIC_BLOCK_MODE_PROMPT - .replace("{{instructions}}", prompt_messages[0].content) - .replace("{{block}}", response_format) + content=ANTHROPIC_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace( + "{{block}}", response_format + ) ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) else: # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=ANTHROPIC_BLOCK_MODE_PROMPT - .replace("{{instructions}}", f"Please output a valid {response_format} object.") - .replace("{{block}}", response_format) - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=ANTHROPIC_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", f"Please output a valid {response_format} object." + ).replace("{{block}}", response_format) + ), + ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -228,9 +258,9 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): tokens = client.count_tokens(prompt) tool_call_inner_prompts_tokens_map = { - 'claude-3-opus-20240229': 395, - 'claude-3-haiku-20240307': 264, - 'claude-3-sonnet-20240229': 159 + "claude-3-opus-20240229": 395, + "claude-3-haiku-20240307": 264, + "claude-3-sonnet-20240229": 159, } if model in tool_call_inner_prompts_tokens_map and tools: @@ -257,13 +287,18 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): "temperature": 0, "max_tokens": 20, }, - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _handle_chat_generate_response(self, model: str, credentials: dict, response: Union[Message, ToolsBetaMessage], - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: Union[Message, ToolsBetaMessage], + prompt_messages: list[PromptMessage], + ) -> LLMResult: """ Handle llm chat response @@ -274,22 +309,18 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): :return: llm response """ # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content='', - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content="", tool_calls=[]) for content in response.content: - if content.type == 'text': + if content.type == "text": assistant_prompt_message.content += content.text - elif content.type == 'tool_use': + elif content.type == "tool_use": tool_call = AssistantPromptMessage.ToolCall( id=content.id, - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=content.name, - arguments=json.dumps(content.input) - ) + name=content.name, arguments=json.dumps(content.input) + ), ) assistant_prompt_message.tool_calls.append(tool_call) @@ -308,17 +339,14 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): # transform response response = LLMResult( - model=response.model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage + model=response.model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage ) return response - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, - response: Stream[MessageStreamEvent], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_chat_generate_stream_response( + self, model: str, credentials: dict, response: Stream[MessageStreamEvent], prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm chat stream response @@ -327,7 +355,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" return_model = None input_tokens = 0 output_tokens = 0 @@ -338,24 +366,23 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): for chunk in response: if isinstance(chunk, MessageStartEvent): - if hasattr(chunk, 'content_block'): + if hasattr(chunk, "content_block"): content_block = chunk.content_block if isinstance(content_block, dict): - if content_block.get('type') == 'tool_use': + if content_block.get("type") == "tool_use": tool_call = AssistantPromptMessage.ToolCall( - id=content_block.get('id'), - type='function', + id=content_block.get("id"), + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=content_block.get('name'), - arguments='' - ) + name=content_block.get("name"), arguments="" + ), ) tool_calls.append(tool_call) - elif hasattr(chunk, 'delta'): + elif hasattr(chunk, "delta"): delta = chunk.delta if isinstance(delta, dict) and len(tool_calls) > 0: - if delta.get('type') == 'input_json_delta': - tool_calls[-1].function.arguments += delta.get('partial_json', '') + if delta.get("type") == "input_json_delta": + tool_calls[-1].function.arguments += delta.get("partial_json", "") elif chunk.message: return_model = chunk.message.model input_tokens = chunk.message.usage.input_tokens @@ -369,29 +396,24 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): # transform empty tool call arguments to {} for tool_call in tool_calls: if not tool_call.function.arguments: - tool_call.function.arguments = '{}' + tool_call.function.arguments = "{}" yield LLMResultChunk( model=return_model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index + 1, - message=AssistantPromptMessage( - content='', - tool_calls=tool_calls - ), + message=AssistantPromptMessage(content="", tool_calls=tool_calls), finish_reason=finish_reason, - usage=usage - ) + usage=usage, + ), ) elif isinstance(chunk, ContentBlockDeltaEvent): - chunk_text = chunk.delta.text if chunk.delta.text else '' + chunk_text = chunk.delta.text if chunk.delta.text else "" full_assistant_content += chunk_text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=chunk_text - ) + assistant_prompt_message = AssistantPromptMessage(content=chunk_text) index = chunk.index @@ -401,7 +423,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=chunk.index, message=assistant_prompt_message, - ) + ), ) def _to_credential_kwargs(self, credentials: dict) -> dict: @@ -412,14 +434,14 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): :return: """ credentials_kwargs = { - "api_key": credentials['anthropic_api_key'], + "api_key": credentials["anthropic_api_key"], "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "max_retries": 1, } - if credentials.get('anthropic_api_url'): - credentials['anthropic_api_url'] = credentials['anthropic_api_url'].rstrip('/') - credentials_kwargs['base_url'] = credentials['anthropic_api_url'] + if credentials.get("anthropic_api_url"): + credentials["anthropic_api_url"] = credentials["anthropic_api_url"].rstrip("/") + credentials_kwargs["base_url"] = credentials["anthropic_api_url"] return credentials_kwargs @@ -452,10 +474,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) @@ -465,25 +484,25 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): image_content = requests.get(message_content.data).content with Image.open(io.BytesIO(image_content)) as img: mime_type = f"image/{img.format.lower()}" - base64_data = base64.b64encode(image_content).decode('utf-8') + base64_data = base64.b64encode(image_content).decode("utf-8") except Exception as ex: - raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") + raise ValueError( + f"Failed to fetch image data from url {message_content.data}, {ex}" + ) else: data_split = message_content.data.split(";base64,") mime_type = data_split[0].replace("data:", "") base64_data = data_split[1] if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: - raise ValueError(f"Unsupported image type {mime_type}, " - f"only support image/jpeg, image/png, image/gif, and image/webp") + raise ValueError( + f"Unsupported image type {mime_type}, " + f"only support image/jpeg, image/png, image/gif, and image/webp" + ) sub_message_dict = { "type": "image", - "source": { - "type": "base64", - "media_type": mime_type, - "data": base64_data - } + "source": {"type": "base64", "media_type": mime_type, "data": base64_data}, } sub_messages.append(sub_message_dict) prompt_message_dicts.append({"role": "user", "content": sub_messages}) @@ -492,34 +511,28 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): content = [] if message.tool_calls: for tool_call in message.tool_calls: - content.append({ - "type": "tool_use", - "id": tool_call.id, - "name": tool_call.function.name, - "input": json.loads(tool_call.function.arguments) - }) + content.append( + { + "type": "tool_use", + "id": tool_call.id, + "name": tool_call.function.name, + "input": json.loads(tool_call.function.arguments), + } + ) if message.content: - content.append({ - "type": "text", - "text": message.content - }) - + content.append({"type": "text", "text": message.content}) + if prompt_message_dicts[-1]["role"] == "assistant": prompt_message_dicts[-1]["content"].extend(content) else: - prompt_message_dicts.append({ - "role": "assistant", - "content": content - }) + prompt_message_dicts.append({"role": "assistant", "content": content}) elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) message_dict = { "role": "user", - "content": [{ - "type": "tool_result", - "tool_use_id": message.tool_call_id, - "content": message.content - }] + "content": [ + {"type": "tool_result", "tool_use_id": message.tool_call_id, "content": message.content} + ], } prompt_message_dicts.append(message_dict) else: @@ -576,16 +589,13 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): :return: Combined string with necessary human_prompt and ai_prompt tags. """ if not messages: - return '' + return "" messages = messages.copy() # don't mutate the original list if not isinstance(messages[-1], AssistantPromptMessage): messages.append(AssistantPromptMessage(content="")) - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() @@ -601,24 +611,14 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - anthropic.APIConnectionError, - anthropic.APITimeoutError - ], - InvokeServerUnavailableError: [ - anthropic.InternalServerError - ], - InvokeRateLimitError: [ - anthropic.RateLimitError - ], - InvokeAuthorizationError: [ - anthropic.AuthenticationError, - anthropic.PermissionDeniedError - ], + InvokeConnectionError: [anthropic.APIConnectionError, anthropic.APITimeoutError], + InvokeServerUnavailableError: [anthropic.InternalServerError], + InvokeRateLimitError: [anthropic.RateLimitError], + InvokeAuthorizationError: [anthropic.AuthenticationError, anthropic.PermissionDeniedError], InvokeBadRequestError: [ anthropic.BadRequestError, anthropic.NotFoundError, anthropic.UnprocessableEntityError, - anthropic.APIError - ] + anthropic.APIError, + ], } diff --git a/api/core/model_runtime/model_providers/azure_openai/_common.py b/api/core/model_runtime/model_providers/azure_openai/_common.py index 31c788d226..32a0269af4 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_common.py +++ b/api/core/model_runtime/model_providers/azure_openai/_common.py @@ -15,10 +15,10 @@ from core.model_runtime.model_providers.azure_openai._constant import AZURE_OPEN class _CommonAzureOpenAI: @staticmethod def _to_credential_kwargs(credentials: dict) -> dict: - api_version = credentials.get('openai_api_version', AZURE_OPENAI_API_VERSION) + api_version = credentials.get("openai_api_version", AZURE_OPENAI_API_VERSION) credentials_kwargs = { - "api_key": credentials['openai_api_key'], - "azure_endpoint": credentials['openai_api_base'], + "api_key": credentials["openai_api_key"], + "azure_endpoint": credentials["openai_api_base"], "api_version": api_version, "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "max_retries": 1, @@ -29,24 +29,14 @@ class _CommonAzureOpenAI: @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - openai.APIConnectionError, - openai.APITimeoutError - ], - InvokeServerUnavailableError: [ - openai.InternalServerError - ], - InvokeRateLimitError: [ - openai.RateLimitError - ], - InvokeAuthorizationError: [ - openai.AuthenticationError, - openai.PermissionDeniedError - ], + InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError], + InvokeServerUnavailableError: [openai.InternalServerError], + InvokeRateLimitError: [openai.RateLimitError], + InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError], InvokeBadRequestError: [ openai.BadRequestError, openai.NotFoundError, openai.UnprocessableEntityError, - openai.APIError - ] + openai.APIError, + ], } diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py index 984cca3744..c2744691c3 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py @@ -14,11 +14,12 @@ from core.model_runtime.entities.model_entities import ( PriceConfig, ) -AZURE_OPENAI_API_VERSION = '2024-02-15-preview' +AZURE_OPENAI_API_VERSION = "2024-02-15-preview" + def _get_max_tokens(default: int, min_val: int, max_val: int) -> ParameterRule: rule = ParameterRule( - name='max_tokens', + name="max_tokens", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.MAX_TOKENS], ) rule.default = default @@ -34,11 +35,11 @@ class AzureBaseModel(BaseModel): LLM_BASE_MODELS = [ AzureBaseModel( - base_model_name='gpt-35-turbo', + base_model_name="gpt-35-turbo", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -53,51 +54,47 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.0005, output=0.0015, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-35-turbo-16k', + base_model_name="gpt-35-turbo-16k", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -112,37 +109,37 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), - _get_max_tokens(default=512, min_val=1, max_val=16385) + _get_max_tokens(default=512, min_val=1, max_val=16385), ], pricing=PriceConfig( input=0.003, output=0.004, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-35-turbo-0125', + base_model_name="gpt-35-turbo-0125", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -157,51 +154,47 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.0005, output=0.0015, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4', + base_model_name="gpt-4", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -216,32 +209,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=8192), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -249,34 +239,30 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.03, output=0.06, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-32k', + base_model_name="gpt-4-32k", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -291,32 +277,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=32768), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -324,34 +307,30 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.06, output=0.12, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-0125-preview', + base_model_name="gpt-4-0125-preview", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -366,32 +345,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -399,34 +375,30 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.01, output=0.03, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-1106-preview', + base_model_name="gpt-4-1106-preview", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -441,32 +413,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -474,34 +443,30 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.01, output=0.03, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4o-mini', + base_model_name="gpt-4o-mini", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -517,32 +482,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=16384), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -550,34 +512,30 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.150, output=0.600, unit=0.000001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4o-mini-2024-07-18', + base_model_name="gpt-4o-mini-2024-07-18", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -593,32 +551,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=16384), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -626,34 +581,40 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" + ), + required=False, + options=["text", "json_object", "json_schema"], + ), + ParameterRule( + name="json_schema", + label=I18nObject(en_US="JSON Schema"), + type="text", + help=I18nObject( + zh_Hans="设置返回的json schema,llm将按照它返回", + en_US="Set a response json schema will ensure LLM to adhere it.", ), required=False, - options=['text', 'json_object'] ), ], pricing=PriceConfig( input=0.150, output=0.600, unit=0.000001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4o', + base_model_name="gpt-4o", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -669,32 +630,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -702,34 +660,30 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=5.00, output=15.00, unit=0.000001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4o-2024-05-13', + base_model_name="gpt-4o-2024-05-13", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -745,32 +699,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -778,34 +729,30 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=5.00, output=15.00, unit=0.000001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-turbo', + base_model_name="gpt-4o-2024-08-06", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -821,32 +768,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -854,34 +798,40 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" + ), + required=False, + options=["text", "json_object", "json_schema"], + ), + ParameterRule( + name="json_schema", + label=I18nObject(en_US="JSON Schema"), + type="text", + help=I18nObject( + zh_Hans="设置返回的json schema,llm将按照它返回", + en_US="Set a response json schema will ensure LLM to adhere it.", ), required=False, - options=['text', 'json_object'] ), ], pricing=PriceConfig( - input=0.01, - output=0.03, - unit=0.001, - currency='USD', - ) - ) + input=5.00, + output=15.00, + unit=0.000001, + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-turbo-2024-04-09', + base_model_name="gpt-4-turbo", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -897,32 +847,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -930,38 +877,37 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.01, output=0.03, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-vision-preview', + base_model_name="gpt-4-turbo-2024-04-09", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ - ModelFeature.VISION + ModelFeature.AGENT_THOUGHT, + ModelFeature.VISION, + ModelFeature.MULTI_TOOL_CALL, + ModelFeature.STREAM_TOOL_CALL, ], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ @@ -970,32 +916,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -1003,34 +946,94 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.01, output=0.03, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-35-turbo-instruct', + base_model_name="gpt-4-vision-preview", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", + ), + model_type=ModelType.LLM, + features=[ModelFeature.VISION], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.MODE: LLMMode.CHAT.value, + ModelPropertyKey.CONTEXT_SIZE: 128000, + }, + parameter_rules=[ + ParameterRule( + name="temperature", + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], + ), + ParameterRule( + name="top_p", + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], + ), + ParameterRule( + name="presence_penalty", + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], + ), + ParameterRule( + name="frequency_penalty", + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], + ), + _get_max_tokens(default=512, min_val=1, max_val=4096), + ParameterRule( + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=I18nObject( + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", + ), + required=False, + precision=2, + min=0, + max=1, + ), + ParameterRule( + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", + help=I18nObject( + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" + ), + required=False, + options=["text", "json_object"], + ), + ], + pricing=PriceConfig( + input=0.01, + output=0.03, + unit=0.001, + currency="USD", + ), + ), + ), + AzureBaseModel( + base_model_name="gpt-35-turbo-instruct", + entity=AIModelEntity( + model="fake-deployment-name", + label=I18nObject( + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, @@ -1040,19 +1043,19 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), @@ -1061,16 +1064,16 @@ LLM_BASE_MODELS = [ input=0.0015, output=0.002, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='text-davinci-003', + base_model_name="text-davinci-003", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, @@ -1080,19 +1083,19 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), @@ -1101,20 +1104,18 @@ LLM_BASE_MODELS = [ input=0.02, output=0.02, unit=0.001, - currency='USD', - ) - ) - ) + currency="USD", + ), + ), + ), ] EMBEDDING_BASE_MODELS = [ AzureBaseModel( - base_model_name='text-embedding-ada-002', + base_model_name="text-embedding-ada-002", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ @@ -1124,17 +1125,15 @@ EMBEDDING_BASE_MODELS = [ pricing=PriceConfig( input=0.0001, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='text-embedding-3-small', + base_model_name="text-embedding-3-small", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ @@ -1144,17 +1143,15 @@ EMBEDDING_BASE_MODELS = [ pricing=PriceConfig( input=0.00002, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='text-embedding-3-large', + base_model_name="text-embedding-3-large", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ @@ -1164,135 +1161,129 @@ EMBEDDING_BASE_MODELS = [ pricing=PriceConfig( input=0.00013, unit=0.001, - currency='USD', - ) - ) - ) + currency="USD", + ), + ), + ), ] SPEECH2TEXT_BASE_MODELS = [ AzureBaseModel( - base_model_name='whisper-1', + base_model_name="whisper-1", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.SPEECH2TEXT, model_properties={ ModelPropertyKey.FILE_UPLOAD_LIMIT: 25, - ModelPropertyKey.SUPPORTED_FILE_EXTENSIONS: 'flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm' - } - ) + ModelPropertyKey.SUPPORTED_FILE_EXTENSIONS: "flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm", + }, + ), ) ] TTS_BASE_MODELS = [ AzureBaseModel( - base_model_name='tts-1', + base_model_name="tts-1", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TTS, model_properties={ - ModelPropertyKey.DEFAULT_VOICE: 'alloy', + ModelPropertyKey.DEFAULT_VOICE: "alloy", ModelPropertyKey.VOICES: [ { - 'mode': 'alloy', - 'name': 'Alloy', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "alloy", + "name": "Alloy", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'echo', - 'name': 'Echo', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "echo", + "name": "Echo", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'fable', - 'name': 'Fable', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "fable", + "name": "Fable", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'onyx', - 'name': 'Onyx', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "onyx", + "name": "Onyx", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'nova', - 'name': 'Nova', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "nova", + "name": "Nova", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'shimmer', - 'name': 'Shimmer', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "shimmer", + "name": "Shimmer", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, ], ModelPropertyKey.WORD_LIMIT: 120, - ModelPropertyKey.AUDIO_TYPE: 'mp3', - ModelPropertyKey.MAX_WORKERS: 5 + ModelPropertyKey.AUDIO_TYPE: "mp3", + ModelPropertyKey.MAX_WORKERS: 5, }, pricing=PriceConfig( input=0.015, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='tts-1-hd', + base_model_name="tts-1-hd", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TTS, model_properties={ - ModelPropertyKey.DEFAULT_VOICE: 'alloy', + ModelPropertyKey.DEFAULT_VOICE: "alloy", ModelPropertyKey.VOICES: [ { - 'mode': 'alloy', - 'name': 'Alloy', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "alloy", + "name": "Alloy", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'echo', - 'name': 'Echo', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "echo", + "name": "Echo", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'fable', - 'name': 'Fable', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "fable", + "name": "Fable", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'onyx', - 'name': 'Onyx', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "onyx", + "name": "Onyx", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'nova', - 'name': 'Nova', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "nova", + "name": "Nova", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'shimmer', - 'name': 'Shimmer', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "shimmer", + "name": "Shimmer", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, ], ModelPropertyKey.WORD_LIMIT: 120, - ModelPropertyKey.AUDIO_TYPE: 'mp3', - ModelPropertyKey.MAX_WORKERS: 5 + ModelPropertyKey.AUDIO_TYPE: "mp3", + ModelPropertyKey.MAX_WORKERS: 5, }, pricing=PriceConfig( input=0.03, unit=0.001, - currency='USD', - ) - ) - ) + currency="USD", + ), + ), + ), ] diff --git a/api/core/model_runtime/model_providers/azure_openai/azure_openai.py b/api/core/model_runtime/model_providers/azure_openai/azure_openai.py index 68977b2266..2e3c6aab05 100644 --- a/api/core/model_runtime/model_providers/azure_openai/azure_openai.py +++ b/api/core/model_runtime/model_providers/azure_openai/azure_openai.py @@ -6,6 +6,5 @@ logger = logging.getLogger(__name__) class AzureOpenAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml index be4d4651d7..700935b07b 100644 --- a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml +++ b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml @@ -138,6 +138,12 @@ model_credential_schema: show_on: - variable: __model_type value: llm + - label: + en_US: gpt-4o-2024-08-06 + value: gpt-4o-2024-08-06 + show_on: + - variable: __model_type + value: llm - label: en_US: gpt-4-turbo value: gpt-4-turbo diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index 1911caa952..3b9fb52e24 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -1,4 +1,5 @@ import copy +import json import logging from collections.abc import Generator, Sequence from typing import Optional, Union, cast @@ -33,16 +34,20 @@ logger = logging.getLogger(__name__) class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: - - base_model_name = credentials.get('base_model_name') + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + base_model_name = credentials.get("base_model_name") if not base_model_name: - raise ValueError('Base Model Name is required') + raise ValueError("Base Model Name is required") ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: @@ -55,7 +60,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): tools=tools, stop=stop, stream=stream, - user=user + user=user, ) else: # text completion model @@ -66,7 +71,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) def get_num_tokens( @@ -74,14 +79,14 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None + tools: Optional[list[PromptMessageTool]] = None, ) -> int: - base_model_name = credentials.get('base_model_name') + base_model_name = credentials.get("base_model_name") if not base_model_name: - raise ValueError('Base Model Name is required') + raise ValueError("Base Model Name is required") model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) if not model_entity: - raise ValueError(f'Base Model Name {base_model_name} is invalid') + raise ValueError(f"Base Model Name {base_model_name} is invalid") model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE) if model_mode == LLMMode.CHAT.value: @@ -91,21 +96,21 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): # text completion model, do not support tool calling content = prompt_messages[0].content assert isinstance(content, str) - return self._num_tokens_from_string(credentials,content) + return self._num_tokens_from_string(credentials, content) def validate_credentials(self, model: str, credentials: dict) -> None: - if 'openai_api_base' not in credentials: - raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required') + if "openai_api_base" not in credentials: + raise CredentialsValidateFailedError("Azure OpenAI API Base Endpoint is required") - if 'openai_api_key' not in credentials: - raise CredentialsValidateFailedError('Azure OpenAI API key is required') + if "openai_api_key" not in credentials: + raise CredentialsValidateFailedError("Azure OpenAI API key is required") - if 'base_model_name' not in credentials: - raise CredentialsValidateFailedError('Base Model Name is required') + if "base_model_name" not in credentials: + raise CredentialsValidateFailedError("Base Model Name is required") - base_model_name = credentials.get('base_model_name') + base_model_name = credentials.get("base_model_name") if not base_model_name: - raise CredentialsValidateFailedError('Base Model Name is required') + raise CredentialsValidateFailedError("Base Model Name is required") ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) if not ai_model_entity: @@ -117,7 +122,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: # chat model client.chat.completions.create( - messages=[{"role": "user", "content": 'ping'}], + messages=[{"role": "user", "content": "ping"}], model=model, temperature=0, max_tokens=20, @@ -126,7 +131,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): else: # text completion model client.completions.create( - prompt='ping', + prompt="ping", model=model, temperature=0, max_tokens=20, @@ -136,33 +141,35 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): raise CredentialsValidateFailedError(str(ex)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - base_model_name = credentials.get('base_model_name') + base_model_name = credentials.get("base_model_name") if not base_model_name: - raise ValueError('Base Model Name is required') + raise ValueError("Base Model Name is required") ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) return ai_model_entity.entity if ai_model_entity else None - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: - + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) extra_model_kwargs = {} if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user # text completion model response = client.completions.create( - prompt=prompt_messages[0].content, - model=model, - stream=stream, - **model_parameters, - **extra_model_kwargs + prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs ) if stream: @@ -171,15 +178,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return self._handle_generate_response(model, credentials, response, prompt_messages) def _handle_generate_response( - self, model: str, credentials: dict, response: Completion, - prompt_messages: list[PromptMessage] + self, model: str, credentials: dict, response: Completion, prompt_messages: list[PromptMessage] ): assistant_text = response.choices[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_text) # calculate num tokens if response.usage: @@ -208,24 +212,21 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return result def _handle_generate_stream_response( - self, model: str, credentials: dict, response: Stream[Completion], - prompt_messages: list[PromptMessage] + self, model: str, credentials: dict, response: Stream[Completion], prompt_messages: list[PromptMessage] ) -> Generator: - full_text = '' + full_text = "" for chunk in response: if len(chunk.choices) == 0: continue delta = chunk.choices[0] - if delta.finish_reason is None and (delta.text is None or delta.text == ''): + if delta.finish_reason is None and (delta.text is None or delta.text == ""): continue # transform assistant message to prompt message - text = delta.text if delta.text else '' - assistant_prompt_message = AssistantPromptMessage( - content=text - ) + text = delta.text if delta.text else "" + assistant_prompt_message = AssistantPromptMessage(content=text) full_text += text @@ -253,8 +254,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage - ) + usage=usage, + ), ) else: yield LLMResultChunk( @@ -264,29 +265,41 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, - ) + ), ) - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: - + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) response_format = model_parameters.get("response_format") if response_format: - if response_format == "json_object": - response_format = {"type": "json_object"} + if response_format == "json_schema": + json_schema = model_parameters.get("json_schema") + if not json_schema: + raise ValueError("Must define JSON Schema when the response format is json_schema") + try: + schema = json.loads(json_schema) + except: + raise ValueError(f"not correct json_schema format: {json_schema}") + model_parameters.pop("json_schema") + model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema} else: - response_format = {"type": "text"} - - model_parameters["response_format"] = response_format + model_parameters["response_format"] = {"type": response_format} extra_model_kwargs = {} if tools: - extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] + extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] # extra_model_kwargs['functions'] = [{ # "name": tool.name, # "description": tool.description, @@ -294,10 +307,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): # } for tool in tools] if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user # chat model messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] @@ -315,9 +328,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) def _handle_chat_generate_response( - self, model: str, credentials: dict, response: ChatCompletion, + self, + model: str, + credentials: dict, + response: ChatCompletion, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None + tools: Optional[list[PromptMessageTool]] = None, ): assistant_message = response.choices[0].message assistant_message_tool_calls = assistant_message.tool_calls @@ -327,10 +343,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) # calculate num tokens if response.usage: @@ -362,13 +375,13 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): credentials: dict, response: Stream[ChatCompletionChunk], prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None + tools: Optional[list[PromptMessageTool]] = None, ): index = 0 - full_assistant_content = '' + full_assistant_content = "" real_model = model system_fingerprint = None - completion = '' + completion = "" tool_calls = [] for chunk in response: if len(chunk.choices) == 0: @@ -379,7 +392,6 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): if delta.delta is None: continue - # extract tool calls from response self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls) @@ -389,15 +401,14 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=tool_calls + content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls ) - full_assistant_content += delta.delta.content if delta.delta.content else '' + full_assistant_content += delta.delta.content if delta.delta.content else "" real_model = chunk.model system_fingerprint = chunk.system_fingerprint - completion += delta.delta.content if delta.delta.content else '' + completion += delta.delta.content if delta.delta.content else "" yield LLMResultChunk( model=real_model, @@ -406,17 +417,15 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): delta=LLMResultChunkDelta( index=index, message=assistant_prompt_message, - ) + ), ) - index += 0 + index += 1 # calculate num tokens prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools) - full_assistant_prompt_message = AssistantPromptMessage( - content=completion - ) + full_assistant_prompt_message = AssistantPromptMessage(content=completion) completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message]) # transform usage @@ -427,27 +436,24 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): prompt_messages=prompt_messages, system_fingerprint=system_fingerprint, delta=LLMResultChunkDelta( - index=index, - message=AssistantPromptMessage(content=''), - finish_reason='stop', - usage=usage - ) + index=index, message=AssistantPromptMessage(content=""), finish_reason="stop", usage=usage + ), ) @staticmethod - def _update_tool_calls(tool_calls: list[AssistantPromptMessage.ToolCall], tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]]) -> None: + def _update_tool_calls( + tool_calls: list[AssistantPromptMessage.ToolCall], + tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]], + ) -> None: if tool_calls_response: for response_tool_call in tool_calls_response: if isinstance(response_tool_call, ChatCompletionMessageToolCall): function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function + id=response_tool_call.id, type=response_tool_call.type, function=function ) tool_calls.append(tool_call) elif isinstance(response_tool_call, ChoiceDeltaToolCall): @@ -456,8 +462,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): tool_calls[index].id = response_tool_call.id or tool_calls[index].id tool_calls[index].type = response_tool_call.type or tool_calls[index].type if response_tool_call.function: - tool_calls[index].function.name = response_tool_call.function.name or tool_calls[index].function.name - tool_calls[index].function.arguments += response_tool_call.function.arguments or '' + tool_calls[index].function.name = ( + response_tool_call.function.name or tool_calls[index].function.name + ) + tool_calls[index].function.arguments += response_tool_call.function.arguments or "" else: assert response_tool_call.id is not None assert response_tool_call.type is not None @@ -466,13 +474,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): assert response_tool_call.function.arguments is not None function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function + id=response_tool_call.id, type=response_tool_call.type, function=function ) tool_calls.append(tool_call) @@ -488,19 +493,13 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) message_dict = {"role": "user", "content": sub_messages} @@ -518,7 +517,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): "role": "tool", "name": message.name, "content": message.content, - "tool_call_id": message.tool_call_id + "tool_call_id": message.tool_call_id, } else: raise ValueError(f"Got unknown type {message}") @@ -528,10 +527,11 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return message_dict - def _num_tokens_from_string(self, credentials: dict, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string( + self, credentials: dict, text: str, tools: Optional[list[PromptMessageTool]] = None + ) -> int: try: - encoding = tiktoken.encoding_for_model(credentials['base_model_name']) + encoding = tiktoken.encoding_for_model(credentials["base_model_name"]) except KeyError: encoding = tiktoken.get_encoding("cl100k_base") @@ -543,14 +543,13 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return num_tokens def _num_tokens_from_messages( - self, credentials: dict, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None + self, credentials: dict, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/ main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" - model = credentials['base_model_name'] + model = credentials["base_model_name"] try: encoding = tiktoken.encoding_for_model(model) except KeyError: @@ -584,10 +583,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -619,40 +618,39 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): @staticmethod def _num_tokens_for_tools(encoding: tiktoken.Encoding, tools: list[PromptMessageTool]) -> int: - num_tokens = 0 for tool in tools: - num_tokens += len(encoding.encode('type')) - num_tokens += len(encoding.encode('function')) + num_tokens += len(encoding.encode("type")) + num_tokens += len(encoding.encode("function")) # calculate num tokens for function object - num_tokens += len(encoding.encode('name')) + num_tokens += len(encoding.encode("name")) num_tokens += len(encoding.encode(tool.name)) - num_tokens += len(encoding.encode('description')) + num_tokens += len(encoding.encode("description")) num_tokens += len(encoding.encode(tool.description)) parameters = tool.parameters - num_tokens += len(encoding.encode('parameters')) - if 'title' in parameters: - num_tokens += len(encoding.encode('title')) - num_tokens += len(encoding.encode(parameters['title'])) - num_tokens += len(encoding.encode('type')) - num_tokens += len(encoding.encode(parameters['type'])) - if 'properties' in parameters: - num_tokens += len(encoding.encode('properties')) - for key, value in parameters['properties'].items(): + num_tokens += len(encoding.encode("parameters")) + if "title" in parameters: + num_tokens += len(encoding.encode("title")) + num_tokens += len(encoding.encode(parameters["title"])) + num_tokens += len(encoding.encode("type")) + num_tokens += len(encoding.encode(parameters["type"])) + if "properties" in parameters: + num_tokens += len(encoding.encode("properties")) + for key, value in parameters["properties"].items(): num_tokens += len(encoding.encode(key)) for field_key, field_value in value.items(): num_tokens += len(encoding.encode(field_key)) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += len(encoding.encode(enum_field)) else: num_tokens += len(encoding.encode(field_key)) num_tokens += len(encoding.encode(str(field_value))) - if 'required' in parameters: - num_tokens += len(encoding.encode('required')) - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += len(encoding.encode("required")) + for required_field in parameters["required"]: num_tokens += 3 num_tokens += len(encoding.encode(required_field)) diff --git a/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py index 8aebcb90e4..a2b14cf3db 100644 --- a/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py @@ -15,9 +15,7 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel): Model class for OpenAI Speech to text model. """ - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model @@ -40,7 +38,7 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel): try: audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self._speech2text_invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -65,10 +63,9 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel): return response.text def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) + ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model) return ai_model_entity.entity - @staticmethod def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel: for ai_model_entity in SPEECH2TEXT_BASE_MODELS: diff --git a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py index e073bef014..d9cff8ecbb 100644 --- a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py @@ -16,19 +16,18 @@ from core.model_runtime.model_providers.azure_openai._constant import EMBEDDING_ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): - - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: - base_model_name = credentials['base_model_name'] + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: + base_model_name = credentials["base_model_name"] credentials_kwargs = self._to_credential_kwargs(credentials) client = AzureOpenAI(**credentials_kwargs) extra_model_kwargs = {} if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user - extra_model_kwargs['encoding_format'] = 'base64' + extra_model_kwargs["encoding_format"] = "base64" context_size = self._get_context_size(model, credentials) max_chunks = self._get_max_chunks(model, credentials) @@ -44,11 +43,9 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): enc = tiktoken.get_encoding("cl100k_base") for i, text in enumerate(texts): - token = enc.encode( - text - ) + token = enc.encode(text) for j in range(0, len(token), context_size): - tokens += [token[j: j + context_size]] + tokens += [token[j : j + context_size]] indices += [i] batched_embeddings = [] @@ -56,10 +53,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): for i in _iter: embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - client=client, - texts=tokens[i: i + max_chunks], - extra_model_kwargs=extra_model_kwargs + model=model, client=client, texts=tokens[i : i + max_chunks], extra_model_kwargs=extra_model_kwargs ) used_tokens += embedding_used_tokens @@ -75,10 +69,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): _result = results[i] if len(_result) == 0: embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - client=client, - texts="", - extra_model_kwargs=extra_model_kwargs + model=model, client=client, texts="", extra_model_kwargs=extra_model_kwargs ) used_tokens += embedding_used_tokens @@ -88,24 +79,16 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): embeddings[i] = (average / np.linalg.norm(average)).tolist() # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) - return TextEmbeddingResult( - embeddings=embeddings, - usage=usage, - model=base_model_name - ) + return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=base_model_name) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: if len(texts) == 0: return 0 try: - enc = tiktoken.encoding_for_model(credentials['base_model_name']) + enc = tiktoken.encoding_for_model(credentials["base_model_name"]) except KeyError: enc = tiktoken.get_encoding("cl100k_base") @@ -118,57 +101,52 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): return total_num_tokens def validate_credentials(self, model: str, credentials: dict) -> None: - if 'openai_api_base' not in credentials: - raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required') + if "openai_api_base" not in credentials: + raise CredentialsValidateFailedError("Azure OpenAI API Base Endpoint is required") - if 'openai_api_key' not in credentials: - raise CredentialsValidateFailedError('Azure OpenAI API key is required') + if "openai_api_key" not in credentials: + raise CredentialsValidateFailedError("Azure OpenAI API key is required") - if 'base_model_name' not in credentials: - raise CredentialsValidateFailedError('Base Model Name is required') + if "base_model_name" not in credentials: + raise CredentialsValidateFailedError("Base Model Name is required") - if not self._get_ai_model_entity(credentials['base_model_name'], model): + if not self._get_ai_model_entity(credentials["base_model_name"], model): raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid') try: credentials_kwargs = self._to_credential_kwargs(credentials) client = AzureOpenAI(**credentials_kwargs) - self._embedding_invoke( - model=model, - client=client, - texts=['ping'], - extra_model_kwargs={} - ) + self._embedding_invoke(model=model, client=client, texts=["ping"], extra_model_kwargs={}) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) + ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model) return ai_model_entity.entity @staticmethod - def _embedding_invoke(model: str, client: AzureOpenAI, texts: Union[list[str], str], - extra_model_kwargs: dict) -> tuple[list[list[float]], int]: + def _embedding_invoke( + model: str, client: AzureOpenAI, texts: Union[list[str], str], extra_model_kwargs: dict + ) -> tuple[list[list[float]], int]: response = client.embeddings.create( input=texts, model=model, **extra_model_kwargs, ) - if 'encoding_format' in extra_model_kwargs and extra_model_kwargs['encoding_format'] == 'base64': + if "encoding_format" in extra_model_kwargs and extra_model_kwargs["encoding_format"] == "base64": # decode base64 embedding - return ([list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data], - response.usage.total_tokens) + return ( + [list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data], + response.usage.total_tokens, + ) return [data.embedding for data in response.data], response.usage.total_tokens def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -179,7 +157,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py index f9ddd86f68..bbad726467 100644 --- a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py @@ -17,8 +17,9 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): Model class for OpenAI Speech to text model. """ - def _invoke(self, model: str, tenant_id: str, credentials: dict, - content_text: str, voice: str, user: Optional[str] = None) -> any: + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ) -> any: """ _invoke text2speech model @@ -30,13 +31,12 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): :param user: unique user id :return: text translated to audio file """ - if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]: + if not voice or voice not in [ + d["value"] for d in self.get_tts_model_voices(model=model, credentials=credentials) + ]: voice = self._get_model_default_voice(model, credentials) - return self._tts_invoke_streaming(model=model, - credentials=credentials, - content_text=content_text, - voice=voice) + return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -50,14 +50,13 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): self._tts_invoke_streaming( model=model, credentials=credentials, - content_text='Hello Dify!', + content_text="Hello Dify!", voice=self._get_model_default_voice(model, credentials), ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, - voice: str) -> any: + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: """ _tts_invoke_streaming text2speech model :param model: model name @@ -75,23 +74,29 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): if len(content_text) > max_length: sentences = self._split_text_into_sentences(content_text, max_length=max_length) executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences))) - futures = [executor.submit(client.audio.speech.with_streaming_response.create, model=model, - response_format="mp3", - input=sentences[i], voice=voice) for i in range(len(sentences))] + futures = [ + executor.submit( + client.audio.speech.with_streaming_response.create, + model=model, + response_format="mp3", + input=sentences[i], + voice=voice, + ) + for i in range(len(sentences)) + ] for index, future in enumerate(futures): yield from future.result().__enter__().iter_bytes(1024) else: - response = client.audio.speech.with_streaming_response.create(model=model, voice=voice, - response_format="mp3", - input=content_text.strip()) + response = client.audio.speech.with_streaming_response.create( + model=model, voice=voice, response_format="mp3", input=content_text.strip() + ) yield from response.__enter__().iter_bytes(1024) except Exception as ex: raise InvokeBadRequestError(str(ex)) - def _process_sentence(self, sentence: str, model: str, - voice, credentials: dict): + def _process_sentence(self, sentence: str, model: str, voice, credentials: dict): """ _tts_invoke openai text2speech model api @@ -108,10 +113,9 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): return response.read() def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) + ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model) return ai_model_entity.entity - @staticmethod def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel | None: for ai_model_entity in TTS_BASE_MODELS: diff --git a/api/core/model_runtime/model_providers/baichuan/baichuan.py b/api/core/model_runtime/model_providers/baichuan/baichuan.py index 71bd6b5d92..626fc811cf 100644 --- a/api/core/model_runtime/model_providers/baichuan/baichuan.py +++ b/api/core/model_runtime/model_providers/baichuan/baichuan.py @@ -6,6 +6,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) + class BaichuanProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -19,12 +20,9 @@ class BaichuanProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `baichuan2-turbo` model for validate, - model_instance.validate_credentials( - model='baichuan2-turbo', - credentials=credentials - ) + model_instance.validate_credentials(model="baichuan2-turbo", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/baichuan/baichuan.yaml b/api/core/model_runtime/model_providers/baichuan/baichuan.yaml index 792126af7f..81e6e36215 100644 --- a/api/core/model_runtime/model_providers/baichuan/baichuan.yaml +++ b/api/core/model_runtime/model_providers/baichuan/baichuan.yaml @@ -27,11 +27,3 @@ provider_credential_schema: placeholder: zh_Hans: 在此输入您的 API Key en_US: Enter your API Key - - variable: secret_key - label: - en_US: Secret Key - type: secret-input - required: false - placeholder: - zh_Hans: 在此输入您的 Secret Key - en_US: Enter your Secret Key diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-53b.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-53b.yaml index 04849500dc..8360dd5faf 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-53b.yaml +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-53b.yaml @@ -43,3 +43,4 @@ parameter_rules: zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。 en_US: Allow the model to perform external search to enhance the generation results. required: false +deprecated: true diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo-192k.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo-192k.yaml index c8156c152b..0ce0265cfe 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo-192k.yaml +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo-192k.yaml @@ -43,3 +43,4 @@ parameter_rules: zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。 en_US: Allow the model to perform external search to enhance the generation results. required: false +deprecated: true diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo.yaml index f91329c77a..ccb4ee8b92 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo.yaml +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo.yaml @@ -4,36 +4,32 @@ label: model_type: llm features: - agent-thought + - multi-tool-call model_properties: mode: chat context_size: 32000 parameter_rules: - name: temperature use_template: temperature + default: 0.3 - name: top_p use_template: top_p + default: 0.85 - name: top_k label: zh_Hans: 取样数量 en_US: Top k type: int + min: 0 + max: 20 + default: 5 help: zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 en_US: Only sample from the top K options for each subsequent token. required: false - name: max_tokens use_template: max_tokens - required: true - default: 8000 - min: 1 - max: 192000 - - name: presence_penalty - use_template: presence_penalty - - name: frequency_penalty - use_template: frequency_penalty - default: 1 - min: 1 - max: 2 + default: 2048 - name: with_search_enhance label: zh_Hans: 搜索增强 diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo-128k.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo-128k.yaml index bf72e82296..c6c6c7e9e9 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo-128k.yaml +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo-128k.yaml @@ -4,36 +4,44 @@ label: model_type: llm features: - agent-thought + - multi-tool-call model_properties: mode: chat context_size: 128000 parameter_rules: - name: temperature use_template: temperature + default: 0.3 - name: top_p use_template: top_p + default: 0.85 - name: top_k label: zh_Hans: 取样数量 en_US: Top k type: int + min: 0 + max: 20 + default: 5 help: zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 en_US: Only sample from the top K options for each subsequent token. required: false - name: max_tokens use_template: max_tokens - required: true - default: 8000 - min: 1 - max: 128000 - - name: presence_penalty - use_template: presence_penalty - - name: frequency_penalty - use_template: frequency_penalty - default: 1 - min: 1 - max: 2 + default: 2048 + - name: res_format + label: + zh_Hans: 回复格式 + en_US: response format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object - name: with_search_enhance label: zh_Hans: 搜索增强 diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo.yaml index 85882519b8..ee8a9ff0d5 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo.yaml +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo.yaml @@ -4,36 +4,44 @@ label: model_type: llm features: - agent-thought + - multi-tool-call model_properties: mode: chat context_size: 32000 parameter_rules: - name: temperature use_template: temperature + default: 0.3 - name: top_p use_template: top_p + default: 0.85 - name: top_k label: zh_Hans: 取样数量 en_US: Top k type: int + min: 0 + max: 20 + default: 5 help: zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 en_US: Only sample from the top K options for each subsequent token. required: false - name: max_tokens use_template: max_tokens - required: true - default: 8000 - min: 1 - max: 32000 - - name: presence_penalty - use_template: presence_penalty - - name: frequency_penalty - use_template: frequency_penalty - default: 1 - min: 1 - max: 2 + default: 2048 + - name: res_format + label: + zh_Hans: 回复格式 + en_US: response format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object - name: with_search_enhance label: zh_Hans: 搜索增强 diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan4.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan4.yaml index f8c6566081..e5e6aeb491 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan4.yaml +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan4.yaml @@ -4,36 +4,44 @@ label: model_type: llm features: - agent-thought + - multi-tool-call model_properties: mode: chat context_size: 32000 parameter_rules: - name: temperature use_template: temperature + default: 0.3 - name: top_p use_template: top_p + default: 0.85 - name: top_k label: zh_Hans: 取样数量 en_US: Top k type: int + min: 0 + max: 20 + default: 5 help: zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 en_US: Only sample from the top K options for each subsequent token. required: false - name: max_tokens use_template: max_tokens - required: true - default: 8000 - min: 1 - max: 32000 - - name: presence_penalty - use_template: presence_penalty - - name: frequency_penalty - use_template: frequency_penalty - default: 1 - min: 1 - max: 2 + default: 2048 + - name: res_format + label: + zh_Hans: 回复格式 + en_US: response format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object - name: with_search_enhance label: zh_Hans: 搜索增强 diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py index 7549b2fb60..bea6777f83 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py @@ -4,17 +4,17 @@ import re class BaichuanTokenizer: @classmethod def count_chinese_characters(cls, text: str) -> int: - return len(re.findall(r'[\u4e00-\u9fa5]', text)) + return len(re.findall(r"[\u4e00-\u9fa5]", text)) @classmethod def count_english_vocabularies(cls, text: str) -> int: # remove all non-alphanumeric characters but keep spaces and other symbols like !, ., etc. - text = re.sub(r'[^a-zA-Z0-9\s]', '', text) + text = re.sub(r"[^a-zA-Z0-9\s]", "", text) # count the number of words not characters return len(text.split()) - + @classmethod def _get_num_tokens(cls, text: str) -> int: # tokens = number of Chinese characters + number of English words * 1.3 (for estimation only, subject to actual return) # https://platform.baichuan-ai.com/docs/text-Embedding - return int(cls.count_chinese_characters(text) + cls.count_english_vocabularies(text) * 1.3) \ No newline at end of file + return int(cls.count_chinese_characters(text) + cls.count_english_vocabularies(text) * 1.3) diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py index d7d8b7c91b..39f867118b 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py @@ -1,14 +1,13 @@ -from collections.abc import Generator -from enum import Enum -from hashlib import md5 -from json import dumps, loads -from typing import Any, Union +import json +from collections.abc import Iterator +from typing import Any, Optional, Union from requests import post +from core.model_runtime.entities.message_entities import PromptMessageTool from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( BadRequestError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InternalServerError, InvalidAPIKeyError, InvalidAuthenticationError, @@ -16,203 +15,130 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor ) -class BaichuanMessage: - class Role(Enum): - USER = 'user' - ASSISTANT = 'assistant' - # Baichuan does not have system message - _SYSTEM = 'system' - - role: str = Role.USER.value - content: str - usage: dict[str, int] = None - stop_reason: str = '' - - def to_dict(self) -> dict[str, Any]: - return { - 'role': self.role, - 'content': self.content, - } - - def __init__(self, content: str, role: str = 'user') -> None: - self.content = content - self.role = role - class BaichuanModel: api_key: str - secret_key: str - def __init__(self, api_key: str, secret_key: str = '') -> None: + def __init__(self, api_key: str) -> None: self.api_key = api_key - self.secret_key = secret_key - def _model_mapping(self, model: str) -> str: + @property + def _model_mapping(self) -> dict: return { - 'baichuan2-turbo': 'Baichuan2-Turbo', - 'baichuan2-turbo-192k': 'Baichuan2-Turbo-192k', - 'baichuan2-53b': 'Baichuan2-53B', - 'baichuan3-turbo': 'Baichuan3-Turbo', - 'baichuan3-turbo-128k': 'Baichuan3-Turbo-128k', - 'baichuan4': 'Baichuan4', - }[model] + "baichuan2-turbo": "Baichuan2-Turbo", + "baichuan3-turbo": "Baichuan3-Turbo", + "baichuan3-turbo-128k": "Baichuan3-Turbo-128k", + "baichuan4": "Baichuan4", + } - def _handle_chat_generate_response(self, response) -> BaichuanMessage: - resp = response.json() - choices = resp.get('choices', []) - message = BaichuanMessage(content='', role='assistant') - for choice in choices: - message.content += choice['message']['content'] - message.role = choice['message']['role'] - if choice['finish_reason']: - message.stop_reason = choice['finish_reason'] + @property + def request_headers(self) -> dict[str, Any]: + return { + "Content-Type": "application/json", + "Authorization": "Bearer " + self.api_key, + } - if 'usage' in resp: - message.usage = { - 'prompt_tokens': resp['usage']['prompt_tokens'], - 'completion_tokens': resp['usage']['completion_tokens'], - 'total_tokens': resp['usage']['total_tokens'], - } + def _build_parameters( + self, + model: str, + stream: bool, + messages: list[dict], + parameters: dict[str, Any], + tools: Optional[list[PromptMessageTool]] = None, + ) -> dict[str, Any]: + if model in self._model_mapping.keys(): + # the LargeLanguageModel._code_block_mode_wrapper() method will remove the response_format of parameters. + # we need to rename it to res_format to get its value + if parameters.get("res_format") == "json_object": + parameters["response_format"] = {"type": "json_object"} - return message - - def _handle_chat_stream_generate_response(self, response) -> Generator: - for line in response.iter_lines(): - if not line: - continue - line = line.decode('utf-8') - # remove the first `data: ` prefix - if line.startswith('data:'): - line = line[5:].strip() - try: - data = loads(line) - except Exception as e: - if line.strip() == '[DONE]': - return - choices = data.get('choices', []) - # save stop reason temporarily - stop_reason = '' - for choice in choices: - if choice.get('finish_reason'): - stop_reason = choice['finish_reason'] + if tools or parameters.get("with_search_enhance") is True: + parameters["tools"] = [] - if len(choice['delta']['content']) == 0: - continue - yield BaichuanMessage(**choice['delta']) - - # if there is usage, the response is the last one, yield it and return - if 'usage' in data: - message = BaichuanMessage(content='', role='assistant') - message.usage = { - 'prompt_tokens': data['usage']['prompt_tokens'], - 'completion_tokens': data['usage']['completion_tokens'], - 'total_tokens': data['usage']['total_tokens'], - } - message.stop_reason = stop_reason - yield message - - def _build_parameters(self, model: str, stream: bool, messages: list[BaichuanMessage], - parameters: dict[str, Any]) \ - -> dict[str, Any]: - if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b' - or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'): - prompt_messages = [] - for message in messages: - if message.role == BaichuanMessage.Role.USER.value or message.role == BaichuanMessage.Role._SYSTEM.value: - # check if the latest message is a user message - if len(prompt_messages) > 0 and prompt_messages[-1]['role'] == BaichuanMessage.Role.USER.value: - prompt_messages[-1]['content'] += message.content - else: - prompt_messages.append({ - 'content': message.content, - 'role': BaichuanMessage.Role.USER.value, - }) - elif message.role == BaichuanMessage.Role.ASSISTANT.value: - prompt_messages.append({ - 'content': message.content, - 'role': message.role, - }) - # [baichuan] frequency_penalty must be between 1 and 2 - if 'frequency_penalty' in parameters: - if parameters['frequency_penalty'] < 1 or parameters['frequency_penalty'] > 2: - parameters['frequency_penalty'] = 1 + # with_search_enhance is deprecated, use web_search instead + if parameters.get("with_search_enhance") is True: + parameters["tools"].append( + { + "type": "web_search", + "web_search": {"enable": True}, + } + ) + if tools: + for tool in tools: + parameters["tools"].append( + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + }, + } + ) # turbo api accepts flat parameters return { - 'model': self._model_mapping(model), - 'stream': stream, - 'messages': prompt_messages, + "model": self._model_mapping.get(model), + "stream": stream, + "messages": messages, **parameters, } else: raise BadRequestError(f"Unknown model: {model}") - - def _build_headers(self, model: str, data: dict[str, Any]) -> dict[str, Any]: - if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b' - or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'): - # there is no secret key for turbo api - return { - 'Content-Type': 'application/json', - 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) ', - 'Authorization': 'Bearer ' + self.api_key, - } - else: - raise BadRequestError(f"Unknown model: {model}") - - def _calculate_md5(self, input_string): - return md5(input_string.encode('utf-8')).hexdigest() - def generate(self, model: str, stream: bool, messages: list[BaichuanMessage], - parameters: dict[str, Any], timeout: int) \ - -> Union[Generator, BaichuanMessage]: - - if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b' - or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'): - api_base = 'https://api.baichuan-ai.com/v1/chat/completions' + def generate( + self, + model: str, + stream: bool, + messages: list[dict], + parameters: dict[str, Any], + timeout: int, + tools: Optional[list[PromptMessageTool]] = None, + ) -> Union[Iterator, dict]: + if model in self._model_mapping.keys(): + api_base = "https://api.baichuan-ai.com/v1/chat/completions" else: raise BadRequestError(f"Unknown model: {model}") - - try: - data = self._build_parameters(model, stream, messages, parameters) - headers = self._build_headers(model, data) - except KeyError: - raise InternalServerError(f"Failed to build parameters for model: {model}") + + data = self._build_parameters(model, stream, messages, parameters, tools) try: response = post( url=api_base, - headers=headers, - data=dumps(data), + headers=self.request_headers, + data=json.dumps(data), timeout=timeout, - stream=stream + stream=stream, ) except Exception as e: raise InternalServerError(f"Failed to invoke model: {e}") - + if response.status_code != 200: try: resp = response.json() # try to parse error message - err = resp['error']['code'] - msg = resp['error']['message'] + err = resp["error"]["type"] + msg = resp["error"]["message"] except Exception as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") - if err == 'invalid_api_key': + if err == "invalid_api_key": raise InvalidAPIKeyError(msg) - elif err == 'insufficient_quota': - raise InsufficientAccountBalance(msg) - elif err == 'invalid_authentication': + elif err == "insufficient_quota": + raise InsufficientAccountBalanceError(msg) + elif err == "invalid_authentication": raise InvalidAuthenticationError(msg) - elif 'rate' in err: + elif err == "invalid_request_error": + raise BadRequestError(msg) + elif "rate" in err: raise RateLimitReachedError(msg) - elif 'internal' in err: + elif "internal" in err: raise InternalServerError(msg) - elif err == 'api_key_empty': + elif err == "api_key_empty": raise InvalidAPIKeyError(msg) else: raise InternalServerError(f"Unknown error: {err} with message: {msg}") - + if stream: - return self._handle_chat_stream_generate_response(response) + return response.iter_lines() else: - return self._handle_chat_generate_response(response) \ No newline at end of file + return response.json() diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py index 67d76b4a29..309b5cf413 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py @@ -1,17 +1,22 @@ class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass -class InsufficientAccountBalance(Exception): + +class InsufficientAccountBalanceError(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/baichuan/llm/llm.py b/api/core/model_runtime/model_providers/baichuan/llm/llm.py index edcd3af420..91a14bf100 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/llm.py @@ -1,7 +1,12 @@ -from collections.abc import Generator +import json +from collections.abc import Generator, Iterator from typing import cast -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -21,10 +26,10 @@ from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer -from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanMessage, BaichuanModel +from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( BadRequestError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InternalServerError, InvalidAPIKeyError, InvalidAuthenticationError, @@ -32,20 +37,40 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor ) -class BaichuanLarguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, - model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) +class BaichuanLanguageModel(LargeLanguageModel): + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: + return self._generate( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stream=stream, + ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: return self._num_tokens_from_messages(prompt_messages) - def _num_tokens_from_messages(self, messages: list[PromptMessage], ) -> int: + def _num_tokens_from_messages( + self, + messages: list[PromptMessage], + ) -> int: """Calculate num tokens for baichuan model""" def tokens(text: str): @@ -59,10 +84,10 @@ class BaichuanLarguageModel(LargeLanguageModel): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -84,20 +109,14 @@ class BaichuanLarguageModel(LargeLanguageModel): elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) message_dict = {"role": "assistant", "content": message.content} + if message.tool_calls: + message_dict["tool_calls"] = [tool_call.dict() for tool_call in message.tool_calls] elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - message_dict = {"role": "user", "content": message.content} + message_dict = {"role": "system", "content": message.content} elif isinstance(message, ToolPromptMessage): - # copy from core/model_runtime/model_providers/anthropic/llm/llm.py message = cast(ToolPromptMessage, message) - message_dict = { - "role": "user", - "content": [{ - "type": "tool_result", - "tool_use_id": message.tool_call_id, - "content": message.content - }] - } + message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id} else: raise ValueError(f"Unknown message type {type(message)}") @@ -105,102 +124,152 @@ class BaichuanLarguageModel(LargeLanguageModel): def validate_credentials(self, model: str, credentials: dict) -> None: # ping - instance = BaichuanModel( - api_key=credentials['api_key'], - secret_key=credentials.get('secret_key', '') - ) + instance = BaichuanModel(api_key=credentials["api_key"]) try: - instance.generate(model=model, stream=False, messages=[ - BaichuanMessage(content='ping', role='user') - ], parameters={ - 'max_tokens': 1, - }, timeout=60) + instance.generate( + model=model, + stream=False, + messages=[{"content": "ping", "role": "user"}], + parameters={ + "max_tokens": 1, + }, + timeout=60, + ) except Exception as e: raise CredentialsValidateFailedError(f"Invalid API key: {e}") - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - if tools is not None and len(tools) > 0: - raise InvokeBadRequestError("Baichuan model doesn't support tools") - - instance = BaichuanModel( - api_key=credentials['api_key'], - secret_key=credentials.get('secret_key', '') - ) - - # convert prompt messages to baichuan messages - messages = [ - BaichuanMessage( - content=message.content if isinstance(message.content, str) else ''.join([ - content.data for content in message.content - ]), - role=message.role.value - ) for message in prompt_messages - ] + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stream: bool = True, + ) -> LLMResult | Generator: + instance = BaichuanModel(api_key=credentials["api_key"]) + messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] # invoke model - response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters, - timeout=60) + response = instance.generate( + model=model, + stream=stream, + messages=messages, + parameters=model_parameters, + timeout=60, + tools=tools, + ) if stream: return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response) return self._handle_chat_generate_response(model, prompt_messages, credentials, response) - def _handle_chat_generate_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: BaichuanMessage) -> LLMResult: - # convert baichuan message to llm result - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=response.usage['prompt_tokens'], - completion_tokens=response.usage['completion_tokens']) + def _handle_chat_generate_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: dict, + ) -> LLMResult: + choices = response.get("choices", []) + assistant_message = AssistantPromptMessage(content="", tool_calls=[]) + if choices and choices[0]["finish_reason"] == "tool_calls": + for choice in choices: + for tool_call in choice["message"]["tool_calls"]: + tool = AssistantPromptMessage.ToolCall( + id=tool_call.get("id", ""), + type=tool_call.get("type", ""), + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=tool_call.get("function", {}).get("name", ""), + arguments=tool_call.get("function", {}).get("arguments", ""), + ), + ) + assistant_message.tool_calls.append(tool) + else: + for choice in choices: + assistant_message.content += choice["message"]["content"] + assistant_message.role = choice["message"]["role"] + + usage = response.get("usage") + if usage: + # transform usage + prompt_tokens = usage["prompt_tokens"] + completion_tokens = usage["completion_tokens"] + else: + # calculate num tokens + prompt_tokens = self._num_tokens_from_messages(prompt_messages) + completion_tokens = self._num_tokens_from_messages([assistant_message]) + + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + return LLMResult( model=model, prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=response.content, - tool_calls=[] - ), + message=assistant_message, usage=usage, ) - def _handle_chat_generate_stream_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Generator[BaichuanMessage, None, None]) -> Generator: - for message in response: - if message.usage: - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=message.usage['prompt_tokens'], - completion_tokens=message.usage['completion_tokens']) + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Iterator, + ) -> Generator: + for line in response: + if not line: + continue + line = line.decode("utf-8") + # remove the first `data: ` prefix + if line.startswith("data:"): + line = line[5:].strip() + try: + data = json.loads(line) + except Exception as e: + if line.strip() == "[DONE]": + return + choices = data.get("choices", []) + + stop_reason = "" + for choice in choices: + if choice.get("finish_reason"): + stop_reason = choice["finish_reason"] + + if len(choice["delta"]["content"]) == 0: + continue yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), - usage=usage, - finish_reason=message.stop_reason if message.stop_reason else None, + message=AssistantPromptMessage(content=choice["delta"]["content"], tool_calls=[]), + finish_reason=stop_reason, ), ) - else: + + # if there is usage, the response is the last one, yield it and return + if "usage" in data: + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=data["usage"]["prompt_tokens"], + completion_tokens=data["usage"]["completion_tokens"], + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), - finish_reason=message.stop_reason if message.stop_reason else None, + message=AssistantPromptMessage(content="", tool_calls=[]), + usage=usage, + finish_reason=stop_reason, ), ) @@ -215,21 +284,13 @@ class BaichuanLarguageModel(LargeLanguageModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } diff --git a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py index 5ae90d54b5..779dfbb608 100644 --- a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py @@ -19,7 +19,7 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( BadRequestError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InternalServerError, InvalidAPIKeyError, InvalidAuthenticationError, @@ -31,11 +31,12 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): """ Model class for BaiChuan text embedding model. """ - api_base: str = 'http://api.baichuan-ai.com/v1/embeddings' - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + api_base: str = "http://api.baichuan-ai.com/v1/embeddings" + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -45,28 +46,23 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - api_key = credentials['api_key'] - if model != 'baichuan-text-embedding': - raise ValueError('Invalid model name') + api_key = credentials["api_key"] + if model != "baichuan-text-embedding": + raise ValueError("Invalid model name") if not api_key: - raise CredentialsValidateFailedError('api_key is required') - + raise CredentialsValidateFailedError("api_key is required") + # split into chunks of batch size 16 chunks = [] for i in range(0, len(texts), 16): - chunks.append(texts[i:i + 16]) + chunks.append(texts[i : i + 16]) embeddings = [] token_usage = 0 for chunk in chunks: - # embeding chunk - chunk_embeddings, chunk_usage = self.embedding( - model=model, - api_key=api_key, - texts=chunk, - user=user - ) + # embedding chunk + chunk_embeddings, chunk_usage = self.embedding(model=model, api_key=api_key, texts=chunk, user=user) embeddings.extend(chunk_embeddings) token_usage += chunk_usage @@ -74,17 +70,14 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): result = TextEmbeddingResult( model=model, embeddings=embeddings, - usage=self._calc_response_usage( - model=model, - credentials=credentials, - tokens=token_usage - ) + usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage), ) return result - - def embedding(self, model: str, api_key, texts: list[str], user: Optional[str] = None) \ - -> tuple[list[list[float]], int]: + + def embedding( + self, model: str, api_key, texts: list[str], user: Optional[str] = None + ) -> tuple[list[list[float]], int]: """ Embed given texts @@ -95,56 +88,47 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): :return: embeddings result """ url = self.api_base - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} - data = { - 'model': 'Baichuan-Text-Embedding', - 'input': texts - } + data = {"model": "Baichuan-Text-Embedding", "input": texts} try: response = post(url, headers=headers, data=dumps(data)) except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: try: resp = response.json() # try to parse error message - err = resp['error']['code'] - msg = resp['error']['message'] + err = resp["error"]["code"] + msg = resp["error"]["message"] except Exception as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") - if err == 'invalid_api_key': + if err == "invalid_api_key": raise InvalidAPIKeyError(msg) - elif err == 'insufficient_quota': - raise InsufficientAccountBalance(msg) - elif err == 'invalid_authentication': - raise InvalidAuthenticationError(msg) - elif err and 'rate' in err: + elif err == "insufficient_quota": + raise InsufficientAccountBalanceError(msg) + elif err == "invalid_authentication": + raise InvalidAuthenticationError(msg) + elif err and "rate" in err: raise RateLimitReachedError(msg) - elif err and 'internal' in err: + elif err and "internal" in err: raise InternalServerError(msg) - elif err == 'api_key_empty': + elif err == "api_key_empty": raise InvalidAPIKeyError(msg) else: raise InternalServerError(f"Unknown error: {err} with message: {msg}") - + try: resp = response.json() - embeddings = resp['data'] - usage = resp['usage'] + embeddings = resp["data"] + usage = resp["usage"] except Exception as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") - return [ - data['embedding'] for data in embeddings - ], usage['total_tokens'] - + return [data["embedding"] for data in embeddings], usage["total_tokens"] def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -170,32 +154,24 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvalidAPIKeyError: - raise CredentialsValidateFailedError('Invalid api key') + raise CredentialsValidateFailedError("Invalid api key") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -207,10 +183,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -221,7 +194,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/bedrock/bedrock.py b/api/core/model_runtime/model_providers/bedrock/bedrock.py index e99bc52ff8..1cfc1d199c 100644 --- a/api/core/model_runtime/model_providers/bedrock/bedrock.py +++ b/api/core/model_runtime/model_providers/bedrock/bedrock.py @@ -6,6 +6,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) + class BedrockProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -19,13 +20,10 @@ class BedrockProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `amazon.titan-text-lite-v1` model by default for validating credentials - model_for_validation = credentials.get('model_for_validation', 'amazon.titan-text-lite-v1') - model_instance.validate_credentials( - model=model_for_validation, - credentials=credentials - ) + model_for_validation = credentials.get("model_for_validation", "amazon.titan-text-lite-v1") + model_instance.validate_credentials(model=model_for_validation, credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml index 53657c08a9..c2d5eb6471 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml @@ -52,6 +52,8 @@ parameter_rules: help: zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format pricing: input: '0.00025' output: '0.00125' diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-opus-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-opus-v1.yaml index d083d31e30..f90fa04266 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-opus-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-opus-v1.yaml @@ -52,6 +52,8 @@ parameter_rules: help: zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format pricing: input: '0.015' output: '0.075' diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.5.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.5.yaml index 5302231086..dad0d6b6b6 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.5.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.5.yaml @@ -51,6 +51,8 @@ parameter_rules: help: zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format pricing: input: '0.003' output: '0.015' diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml index 6995d2bf56..962def8011 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml @@ -51,6 +51,8 @@ parameter_rules: help: zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format pricing: input: '0.003' output: '0.015' diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.1.yaml index 1a3239c85e..70294e4ad3 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.1.yaml @@ -45,6 +45,8 @@ parameter_rules: help: zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format pricing: input: '0.008' output: '0.024' diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.yaml index 0343e3bbec..0a8ea61b6d 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.yaml @@ -45,6 +45,8 @@ parameter_rules: help: zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format pricing: input: '0.008' output: '0.024' diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index 3f7266f600..e07f2a419a 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -20,6 +20,7 @@ from botocore.exceptions import ( from PIL.Image import Image # local import +from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -44,37 +45,85 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel logger = logging.getLogger(__name__) +ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. +The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure +if you are not sure about the structure. + + +{{instructions}} + +""" + class BedrockLargeLanguageModel(LargeLanguageModel): - # please refer to the documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html # TODO There is invoke issue: context limit on Cohere Model, will add them after fixed. - CONVERSE_API_ENABLED_MODEL_INFO=[ - {'prefix': 'anthropic.claude-v2', 'support_system_prompts': True, 'support_tool_use': False}, - {'prefix': 'anthropic.claude-v1', 'support_system_prompts': True, 'support_tool_use': False}, - {'prefix': 'anthropic.claude-3', 'support_system_prompts': True, 'support_tool_use': True}, - {'prefix': 'meta.llama', 'support_system_prompts': True, 'support_tool_use': False}, - {'prefix': 'mistral.mistral-7b-instruct', 'support_system_prompts': False, 'support_tool_use': False}, - {'prefix': 'mistral.mixtral-8x7b-instruct', 'support_system_prompts': False, 'support_tool_use': False}, - {'prefix': 'mistral.mistral-large', 'support_system_prompts': True, 'support_tool_use': True}, - {'prefix': 'mistral.mistral-small', 'support_system_prompts': True, 'support_tool_use': True}, - {'prefix': 'cohere.command-r', 'support_system_prompts': True, 'support_tool_use': True}, - {'prefix': 'amazon.titan', 'support_system_prompts': False, 'support_tool_use': False} + CONVERSE_API_ENABLED_MODEL_INFO = [ + {"prefix": "anthropic.claude-v2", "support_system_prompts": True, "support_tool_use": False}, + {"prefix": "anthropic.claude-v1", "support_system_prompts": True, "support_tool_use": False}, + {"prefix": "anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True}, + {"prefix": "meta.llama", "support_system_prompts": True, "support_tool_use": False}, + {"prefix": "mistral.mistral-7b-instruct", "support_system_prompts": False, "support_tool_use": False}, + {"prefix": "mistral.mixtral-8x7b-instruct", "support_system_prompts": False, "support_tool_use": False}, + {"prefix": "mistral.mistral-large", "support_system_prompts": True, "support_tool_use": True}, + {"prefix": "mistral.mistral-small", "support_system_prompts": True, "support_tool_use": True}, + {"prefix": "cohere.command-r", "support_system_prompts": True, "support_tool_use": True}, + {"prefix": "amazon.titan", "support_system_prompts": False, "support_tool_use": False}, ] @staticmethod def _find_model_info(model_id): for model in BedrockLargeLanguageModel.CONVERSE_API_ENABLED_MODEL_INFO: - if model_id.startswith(model['prefix']): + if model_id.startswith(model["prefix"]): return model logger.info(f"current model id: {model_id} did not support by Converse API") return None - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: list[Callback] = None, + ) -> Union[LLMResult, Generator]: + """ + Code block mode wrapper for invoking large language model + """ + if model_parameters.get("response_format"): + stop = stop or [] + if "```\n" not in stop: + stop.append("```\n") + if "\n```" not in stop: + stop.append("\n```") + response_format = model_parameters.pop("response_format") + format_prompt = SystemPromptMessage( + content=ANTHROPIC_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace( + "{{block}}", response_format + ) + ) + if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): + prompt_messages[0] = format_prompt + else: + prompt_messages.insert(0, format_prompt) + prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) + return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -88,17 +137,28 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param user: unique user id :return: full response or stream response chunk generator result """ - - model_info= BedrockLargeLanguageModel._find_model_info(model) + + model_info = BedrockLargeLanguageModel._find_model_info(model) if model_info: - model_info['model'] = model + model_info["model"] = model # invoke models via boto3 converse API - return self._generate_with_converse(model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools) + return self._generate_with_converse( + model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools + ) # invoke other models via boto3 client return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) - def _generate_with_converse(self, model_info: dict, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]: + def _generate_with_converse( + self, + model_info: dict, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + tools: Optional[list[PromptMessageTool]] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model with converse API @@ -110,35 +170,39 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param stream: is stream response :return: full response or stream response chunk generator result """ - bedrock_client = boto3.client(service_name='bedrock-runtime', - aws_access_key_id=credentials.get("aws_access_key_id"), - aws_secret_access_key=credentials.get("aws_secret_access_key"), - region_name=credentials["aws_region"]) + bedrock_client = boto3.client( + service_name="bedrock-runtime", + aws_access_key_id=credentials.get("aws_access_key_id"), + aws_secret_access_key=credentials.get("aws_secret_access_key"), + region_name=credentials["aws_region"], + ) system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages) inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop) parameters = { - 'modelId': model_info['model'], - 'messages': prompt_message_dicts, - 'inferenceConfig': inference_config, - 'additionalModelRequestFields': additional_model_fields, + "modelId": model_info["model"], + "messages": prompt_message_dicts, + "inferenceConfig": inference_config, + "additionalModelRequestFields": additional_model_fields, } - if model_info['support_system_prompts'] and system and len(system) > 0: - parameters['system'] = system + if model_info["support_system_prompts"] and system and len(system) > 0: + parameters["system"] = system - if model_info['support_tool_use'] and tools: - parameters['toolConfig'] = self._convert_converse_tool_config(tools=tools) + if model_info["support_tool_use"] and tools: + parameters["toolConfig"] = self._convert_converse_tool_config(tools=tools) try: if stream: response = bedrock_client.converse_stream(**parameters) - return self._handle_converse_stream_response(model_info['model'], credentials, response, prompt_messages) + return self._handle_converse_stream_response( + model_info["model"], credentials, response, prompt_messages + ) else: response = bedrock_client.converse(**parameters) - return self._handle_converse_response(model_info['model'], credentials, response, prompt_messages) + return self._handle_converse_response(model_info["model"], credentials, response, prompt_messages) except ClientError as ex: - error_code = ex.response['Error']['Code'] + error_code = ex.response["Error"]["Code"] full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" raise self._map_client_to_invoke_error(error_code, full_error_msg) except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: @@ -149,8 +213,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel): except Exception as ex: raise InvokeError(str(ex)) - def _handle_converse_response(self, model: str, credentials: dict, response: dict, - prompt_messages: list[PromptMessage]) -> LLMResult: + + def _handle_converse_response( + self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm chat response @@ -160,36 +226,30 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param prompt_messages: prompt messages :return: full response chunk generator result """ - response_content = response['output']['message']['content'] + response_content = response["output"]["message"]["content"] # transform assistant message to prompt message - if response['stopReason'] == 'tool_use': + if response["stopReason"] == "tool_use": tool_calls = [] text, tool_use = self._extract_tool_use(response_content) tool_call = AssistantPromptMessage.ToolCall( - id=tool_use['toolUseId'], - type='function', + id=tool_use["toolUseId"], + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=tool_use['name'], - arguments=json.dumps(tool_use['input']) - ) + name=tool_use["name"], arguments=json.dumps(tool_use["input"]) + ), ) tool_calls.append(tool_call) - assistant_prompt_message = AssistantPromptMessage( - content=text, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=text, tool_calls=tool_calls) else: - assistant_prompt_message = AssistantPromptMessage( - content=response_content[0]['text'] - ) + assistant_prompt_message = AssistantPromptMessage(content=response_content[0]["text"]) # calculate num tokens - if response['usage']: + if response["usage"]: # transform usage - prompt_tokens = response['usage']['inputTokens'] - completion_tokens = response['usage']['outputTokens'] + prompt_tokens = response["usage"]["inputTokens"] + completion_tokens = response["usage"]["outputTokens"] else: # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -206,20 +266,25 @@ class BedrockLargeLanguageModel(LargeLanguageModel): ) return result - def _extract_tool_use(self, content:dict)-> tuple[str, dict]: + def _extract_tool_use(self, content: dict) -> tuple[str, dict]: tool_use = {} - text = '' + text = "" for item in content: - if 'toolUse' in item: - tool_use = item['toolUse'] - elif 'text' in item: - text = item['text'] + if "toolUse" in item: + tool_use = item["toolUse"] + elif "text" in item: + text = item["text"] else: raise ValueError(f"Got unknown item: {item}") return text, tool_use - def _handle_converse_stream_response(self, model: str, credentials: dict, response: dict, - prompt_messages: list[PromptMessage], ) -> Generator: + def _handle_converse_stream_response( + self, + model: str, + credentials: dict, + response: dict, + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm chat stream response @@ -231,7 +296,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): """ try: - full_assistant_content = '' + full_assistant_content = "" return_model = None input_tokens = 0 output_tokens = 0 @@ -240,87 +305,85 @@ class BedrockLargeLanguageModel(LargeLanguageModel): tool_calls: list[AssistantPromptMessage.ToolCall] = [] tool_use = {} - for chunk in response['stream']: - if 'messageStart' in chunk: + for chunk in response["stream"]: + if "messageStart" in chunk: return_model = model - elif 'messageStop' in chunk: - finish_reason = chunk['messageStop']['stopReason'] - elif 'contentBlockStart' in chunk: - tool = chunk['contentBlockStart']['start']['toolUse'] - tool_use['toolUseId'] = tool['toolUseId'] - tool_use['name'] = tool['name'] - elif 'metadata' in chunk: - input_tokens = chunk['metadata']['usage']['inputTokens'] - output_tokens = chunk['metadata']['usage']['outputTokens'] + elif "messageStop" in chunk: + finish_reason = chunk["messageStop"]["stopReason"] + elif "contentBlockStart" in chunk: + tool = chunk["contentBlockStart"]["start"]["toolUse"] + tool_use["toolUseId"] = tool["toolUseId"] + tool_use["name"] = tool["name"] + elif "metadata" in chunk: + input_tokens = chunk["metadata"]["usage"]["inputTokens"] + output_tokens = chunk["metadata"]["usage"]["outputTokens"] usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens) yield LLMResultChunk( model=return_model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, - message=AssistantPromptMessage( - content='', - tool_calls=tool_calls - ), + message=AssistantPromptMessage(content="", tool_calls=tool_calls), finish_reason=finish_reason, - usage=usage - ) + usage=usage, + ), ) - elif 'contentBlockDelta' in chunk: - delta = chunk['contentBlockDelta']['delta'] - if 'text' in delta: - chunk_text = delta['text'] if delta['text'] else '' + elif "contentBlockDelta" in chunk: + delta = chunk["contentBlockDelta"]["delta"] + if "text" in delta: + chunk_text = delta["text"] if delta["text"] else "" full_assistant_content += chunk_text assistant_prompt_message = AssistantPromptMessage( - content=chunk_text if chunk_text else '', + content=chunk_text if chunk_text else "", ) - index = chunk['contentBlockDelta']['contentBlockIndex'] + index = chunk["contentBlockDelta"]["contentBlockIndex"] yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=index+1, + index=index + 1, message=assistant_prompt_message, - ) + ), ) - elif 'toolUse' in delta: - if 'input' not in tool_use: - tool_use['input'] = '' - tool_use['input'] += delta['toolUse']['input'] - elif 'contentBlockStop' in chunk: - if 'input' in tool_use: + elif "toolUse" in delta: + if "input" not in tool_use: + tool_use["input"] = "" + tool_use["input"] += delta["toolUse"]["input"] + elif "contentBlockStop" in chunk: + if "input" in tool_use: tool_call = AssistantPromptMessage.ToolCall( - id=tool_use['toolUseId'], - type='function', + id=tool_use["toolUseId"], + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=tool_use['name'], - arguments=tool_use['input'] - ) + name=tool_use["name"], arguments=tool_use["input"] + ), ) tool_calls.append(tool_call) tool_use = {} except Exception as ex: raise InvokeError(str(ex)) - - def _convert_converse_api_model_parameters(self, model_parameters: dict, stop: Optional[list[str]] = None) -> tuple[dict, dict]: + + def _convert_converse_api_model_parameters( + self, model_parameters: dict, stop: Optional[list[str]] = None + ) -> tuple[dict, dict]: inference_config = {} additional_model_fields = {} - if 'max_tokens' in model_parameters: - inference_config['maxTokens'] = model_parameters['max_tokens'] + if "max_tokens" in model_parameters: + inference_config["maxTokens"] = model_parameters["max_tokens"] - if 'temperature' in model_parameters: - inference_config['temperature'] = model_parameters['temperature'] - - if 'top_p' in model_parameters: - inference_config['topP'] = model_parameters['temperature'] + if "temperature" in model_parameters: + inference_config["temperature"] = model_parameters["temperature"] + + if "top_p" in model_parameters: + inference_config["topP"] = model_parameters["temperature"] if stop: - inference_config['stopSequences'] = stop - - if 'top_k' in model_parameters: - additional_model_fields['top_k'] = model_parameters['top_k'] - + inference_config["stopSequences"] = stop + + if "top_k" in model_parameters: + additional_model_fields["top_k"] = model_parameters["top_k"] + return inference_config, additional_model_fields def _convert_converse_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]: @@ -332,7 +395,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): prompt_message_dicts = [] for message in prompt_messages: if isinstance(message, SystemPromptMessage): - message.content=message.content.strip() + message.content = message.content.strip() system.append({"text": message.content}) else: prompt_message_dicts.append(self._convert_prompt_message_to_dict(message)) @@ -349,15 +412,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel): "toolSpec": { "name": tool.name, "description": tool.description, - "inputSchema": { - "json": tool.parameters - } + "inputSchema": {"json": tool.parameters}, } } ) tool_config["tools"] = configs return tool_config - + def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: """ Convert PromptMessage to dict @@ -365,15 +426,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel): if isinstance(message, UserPromptMessage): message = cast(UserPromptMessage, message) if isinstance(message.content, str): - message_dict = {"role": "user", "content": [{'text': message.content}]} + message_dict = {"role": "user", "content": [{"text": message.content}]} else: sub_messages = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "text": message_content.data - } + sub_message_dict = {"text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) @@ -384,7 +443,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): image_content = requests.get(url).content with Image.open(io.BytesIO(image_content)) as img: mime_type = f"image/{img.format.lower()}" - base64_data = base64.b64encode(image_content).decode('utf-8') + base64_data = base64.b64encode(image_content).decode("utf-8") except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") else: @@ -394,16 +453,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel): image_content = base64.b64decode(base64_data) if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: - raise ValueError(f"Unsupported image type {mime_type}, " - f"only support image/jpeg, image/png, image/gif, and image/webp") + raise ValueError( + f"Unsupported image type {mime_type}, " + f"only support image/jpeg, image/png, image/gif, and image/webp" + ) sub_message_dict = { - "image": { - "format": mime_type.replace('image/', ''), - "source": { - "bytes": image_content - } - } + "image": {"format": mime_type.replace("image/", ""), "source": {"bytes": image_content}} } sub_messages.append(sub_message_dict) @@ -412,36 +468,46 @@ class BedrockLargeLanguageModel(LargeLanguageModel): message = cast(AssistantPromptMessage, message) if message.tool_calls: message_dict = { - "role": "assistant", "content":[{ - "toolUse": { - "toolUseId": message.tool_calls[0].id, - "name": message.tool_calls[0].function.name, - "input": json.loads(message.tool_calls[0].function.arguments) + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": message.tool_calls[0].id, + "name": message.tool_calls[0].function.name, + "input": json.loads(message.tool_calls[0].function.arguments), + } } - }] + ], } else: - message_dict = {"role": "assistant", "content": [{'text': message.content}]} + message_dict = {"role": "assistant", "content": [{"text": message.content}]} elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - message_dict = [{'text': message.content}] + message_dict = [{"text": message.content}] elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) message_dict = { "role": "user", - "content": [{ - "toolResult": { - "toolUseId": message.tool_call_id, - "content": [{"json": {"text": message.content}}] - } - }] + "content": [ + { + "toolResult": { + "toolUseId": message.tool_call_id, + "content": [{"json": {"text": message.content}}], + } + } + ], } else: raise ValueError(f"Got unknown type {message}") return message_dict - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage] | str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage] | str, + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -451,15 +517,14 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param tools: tools for tool calling :return:md = genai.GenerativeModel(model) """ - prefix = model.split('.')[0] - model_name = model.split('.')[1] - + prefix = model.split(".")[0] + model_name = model.split(".")[1] + if isinstance(prompt_messages, str): prompt = prompt_messages else: prompt = self._convert_messages_to_prompt(prompt_messages, prefix, model_name) - return self._get_num_tokens_by_gpt2(prompt) def validate_credentials(self, model: str, credentials: dict) -> None: @@ -482,24 +547,28 @@ class BedrockLargeLanguageModel(LargeLanguageModel): "topP": 0.9, "maxTokens": 32, } - + try: ping_message = UserPromptMessage(content="ping") - self._invoke(model=model, - credentials=credentials, - prompt_messages=[ping_message], - model_parameters=required_params, - stream=False) - + self._invoke( + model=model, + credentials=credentials, + prompt_messages=[ping_message], + model_parameters=required_params, + stream=False, + ) + except ClientError as ex: - error_code = ex.response['Error']['Code'] + error_code = ex.response["Error"]["Code"] full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" raise CredentialsValidateFailedError(str(self._map_client_to_invoke_error(error_code, full_error_msg))) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str, model_name: Optional[str] = None) -> str: + def _convert_one_message_to_text( + self, message: PromptMessage, model_prefix: str, model_name: Optional[str] = None + ) -> str: """ Convert a single message to a string. @@ -514,7 +583,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): if isinstance(message, UserPromptMessage): body = content - if (isinstance(content, list)): + if isinstance(content, list): body = "".join([c.data for c in content if c.type == PromptMessageContentType.TEXT]) message_text = f"{human_prompt_prefix} {body} {human_prompt_postfix}" elif isinstance(message, AssistantPromptMessage): @@ -528,7 +597,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel): return message_text - def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str, model_name: Optional[str] = None) -> str: + def _convert_messages_to_prompt( + self, messages: list[PromptMessage], model_prefix: str, model_name: Optional[str] = None + ) -> str: """ Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models @@ -537,27 +608,31 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :return: Combined string with necessary human_prompt and ai_prompt tags. """ if not messages: - return '' + return "" messages = messages.copy() # don't mutate the original list if not isinstance(messages[-1], AssistantPromptMessage): messages.append(AssistantPromptMessage(content="")) - text = "".join( - self._convert_one_message_to_text(message, model_prefix, model_name) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message, model_prefix, model_name) for message in messages) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() - def _create_payload(self, model: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True): + def _create_payload( + self, + model: str, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + ): """ Create payload for bedrock api call depending on model provider """ payload = {} - model_prefix = model.split('.')[0] - model_name = model.split('.')[1] + model_prefix = model.split(".")[0] + model_name = model.split(".")[1] if model_prefix == "ai21": payload["temperature"] = model_parameters.get("temperature") @@ -571,21 +646,27 @@ class BedrockLargeLanguageModel(LargeLanguageModel): payload["frequencyPenalty"] = {model_parameters.get("frequencyPenalty")} if model_parameters.get("countPenalty"): payload["countPenalty"] = {model_parameters.get("countPenalty")} - + elif model_prefix == "cohere": - payload = { **model_parameters } + payload = {**model_parameters} payload["prompt"] = prompt_messages[0].content payload["stream"] = stream - + else: raise ValueError(f"Got unknown model prefix {model_prefix}") - + return payload - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -598,18 +679,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param user: unique user id :return: full response or stream response chunk generator result """ - client_config = Config( - region_name=credentials["aws_region"] - ) + client_config = Config(region_name=credentials["aws_region"]) runtime_client = boto3.client( - service_name='bedrock-runtime', + service_name="bedrock-runtime", config=client_config, aws_access_key_id=credentials.get("aws_access_key_id"), - aws_secret_access_key=credentials.get("aws_secret_access_key") + aws_secret_access_key=credentials.get("aws_secret_access_key"), ) - model_prefix = model.split('.')[0] + model_prefix = model.split(".")[0] payload = self._create_payload(model, prompt_messages, model_parameters, stop, stream) # need workaround for ai21 models which doesn't support streaming @@ -619,18 +698,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel): invoke = runtime_client.invoke_model try: - body_jsonstr=json.dumps(payload) - response = invoke( - modelId=model, - contentType="application/json", - accept= "*/*", - body=body_jsonstr - ) + body_jsonstr = json.dumps(payload) + response = invoke(modelId=model, contentType="application/json", accept="*/*", body=body_jsonstr) except ClientError as ex: - error_code = ex.response['Error']['Code'] + error_code = ex.response["Error"]["Code"] full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" raise self._map_client_to_invoke_error(error_code, full_error_msg) - + except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: raise InvokeConnectionError(str(ex)) @@ -639,15 +713,15 @@ class BedrockLargeLanguageModel(LargeLanguageModel): except Exception as ex: raise InvokeError(str(ex)) - if stream: return self._handle_generate_stream_response(model, credentials, response, prompt_messages) return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: dict, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -657,7 +731,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response """ - response_body = json.loads(response.get('body').read().decode('utf-8')) + response_body = json.loads(response.get("body").read().decode("utf-8")) finish_reason = response_body.get("error") @@ -665,25 +739,23 @@ class BedrockLargeLanguageModel(LargeLanguageModel): raise InvokeError(finish_reason) # get output text and calculate num tokens based on model / provider - model_prefix = model.split('.')[0] + model_prefix = model.split(".")[0] if model_prefix == "ai21": - output = response_body.get('completions')[0].get('data').get('text') + output = response_body.get("completions")[0].get("data").get("text") prompt_tokens = len(response_body.get("prompt").get("tokens")) - completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens')) - + completion_tokens = len(response_body.get("completions")[0].get("data").get("tokens")) + elif model_prefix == "cohere": output = response_body.get("generations")[0].get("text") prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, output if output else '') - + completion_tokens = self.get_num_tokens(model, credentials, output if output else "") + else: raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") # construct assistant message from output - assistant_prompt_message = AssistantPromptMessage( - content=output - ) + assistant_prompt_message = AssistantPromptMessage(content=output) # calculate usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) @@ -698,8 +770,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel): return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: dict, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -709,65 +782,59 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator result """ - model_prefix = model.split('.')[0] + model_prefix = model.split(".")[0] if model_prefix == "ai21": - response_body = json.loads(response.get('body').read().decode('utf-8')) + response_body = json.loads(response.get("body").read().decode("utf-8")) - content = response_body.get('completions')[0].get('data').get('text') - finish_reason = response_body.get('completions')[0].get('finish_reason') + content = response_body.get("completions")[0].get("data").get("text") + finish_reason = response_body.get("completions")[0].get("finish_reason") prompt_tokens = len(response_body.get("prompt").get("tokens")) - completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens')) + completion_tokens = len(response_body.get("completions")[0].get("data").get("tokens")) usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content=content), - finish_reason=finish_reason, - usage=usage - ) - ) + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, message=AssistantPromptMessage(content=content), finish_reason=finish_reason, usage=usage + ), + ) return - - stream = response.get('body') + + stream = response.get("body") if not stream: - raise InvokeError('No response body') - + raise InvokeError("No response body") + index = -1 for event in stream: - chunk = event.get('chunk') - + chunk = event.get("chunk") + if not chunk: exception_name = next(iter(event)) full_ex_msg = f"{exception_name}: {event[exception_name]['message']}" raise self._map_client_to_invoke_error(exception_name, full_ex_msg) - payload = json.loads(chunk.get('bytes').decode()) + payload = json.loads(chunk.get("bytes").decode()) - model_prefix = model.split('.')[0] + model_prefix = model.split(".")[0] if model_prefix == "cohere": content_delta = payload.get("text") finish_reason = payload.get("finish_reason") - + else: raise ValueError(f"Got unknown model prefix {model_prefix} when handling stream response") # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content = content_delta if content_delta else '', + content=content_delta if content_delta else "", ) index += 1 - + if not finish_reason: yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) else: @@ -777,36 +844,33 @@ class BedrockLargeLanguageModel(LargeLanguageModel): # transform usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - finish_reason=finish_reason, - usage=usage - ) + index=index, message=assistant_prompt_message, finish_reason=finish_reason, usage=usage + ), ) - + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ Map model invoke error to unified error - The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller - The value is the md = genai.GenerativeModel(model)error type thrown by the model, + The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller + The value is the md = genai.GenerativeModel(model) error type thrown by the model, which needs to be converted into a unified error type for the caller. - :return: Invoke emd = genai.GenerativeModel(model)rror mapping + :return: Invoke emd = genai.GenerativeModel(model) error mapping """ return { InvokeConnectionError: [], InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], - InvokeBadRequestError: [] + InvokeBadRequestError: [], } - + def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]: """ Map client error to invoke error @@ -822,7 +886,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel): return InvokeBadRequestError(error_msg) elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: return InvokeRateLimitError(error_msg) - elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]: + elif error_code in [ + "ModelTimeoutException", + "ModelErrorException", + "InternalServerException", + "ModelNotReadyException", + ]: return InvokeServerUnavailableError(error_msg) elif error_code == "ModelStreamErrorException": return InvokeConnectionError(error_msg) diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py index 993416cdc8..2d898e3aaa 100644 --- a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py @@ -27,12 +27,11 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE logger = logging.getLogger(__name__) + class BedrockTextEmbeddingModel(TextEmbeddingModel): - - - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -42,67 +41,56 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - client_config = Config( - region_name=credentials["aws_region"] - ) + client_config = Config(region_name=credentials["aws_region"]) bedrock_runtime = boto3.client( - service_name='bedrock-runtime', + service_name="bedrock-runtime", config=client_config, aws_access_key_id=credentials.get("aws_access_key_id"), - aws_secret_access_key=credentials.get("aws_secret_access_key") + aws_secret_access_key=credentials.get("aws_secret_access_key"), ) embeddings = [] token_usage = 0 - - model_prefix = model.split('.')[0] - - if model_prefix == "amazon" : + + model_prefix = model.split(".")[0] + + if model_prefix == "amazon": for text in texts: body = { - "inputText": text, + "inputText": text, } response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) - embeddings.extend([response_body.get('embedding')]) - token_usage += response_body.get('inputTextTokenCount') - logger.warning(f'Total Tokens: {token_usage}') + embeddings.extend([response_body.get("embedding")]) + token_usage += response_body.get("inputTextTokenCount") + logger.warning(f"Total Tokens: {token_usage}") result = TextEmbeddingResult( model=model, embeddings=embeddings, - usage=self._calc_response_usage( - model=model, - credentials=credentials, - tokens=token_usage - ) + usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage), ) return result - if model_prefix == "cohere" : - input_type = 'search_document' if len(texts) > 1 else 'search_query' + if model_prefix == "cohere": + input_type = "search_document" if len(texts) > 1 else "search_query" for text in texts: body = { - "texts": [text], - "input_type": input_type, + "texts": [text], + "input_type": input_type, } response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) - embeddings.extend(response_body.get('embeddings')) + embeddings.extend(response_body.get("embeddings")) token_usage += len(text) result = TextEmbeddingResult( model=model, embeddings=embeddings, - usage=self._calc_response_usage( - model=model, - credentials=credentials, - tokens=token_usage - ) + usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage), ) return result - #others + # others raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") - def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ Get number of tokens for given prompt messages @@ -125,35 +113,41 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): :param credentials: model credentials :return: """ - + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ Map model invoke error to unified error - The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller - The value is the md = genai.GenerativeModel(model)error type thrown by the model, + The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller + The value is the md = genai.GenerativeModel(model) error type thrown by the model, which needs to be converted into a unified error type for the caller. - :return: Invoke emd = genai.GenerativeModel(model)rror mapping + :return: Invoke emd = genai.GenerativeModel(model) error mapping """ return { InvokeConnectionError: [], InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], - InvokeBadRequestError: [] + InvokeBadRequestError: [], } - - def _create_payload(self, model_prefix: str, texts: list[str], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True): + + def _create_payload( + self, + model_prefix: str, + texts: list[str], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + ): """ Create payload for bedrock api call depending on model provider """ payload = {} if model_prefix == "amazon": - payload['inputText'] = texts + payload["inputText"] = texts - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -165,10 +159,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -179,7 +170,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -199,31 +190,37 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): return InvokeBadRequestError(error_msg) elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: return InvokeRateLimitError(error_msg) - elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]: + elif error_code in [ + "ModelTimeoutException", + "ModelErrorException", + "InternalServerException", + "ModelNotReadyException", + ]: return InvokeServerUnavailableError(error_msg) elif error_code == "ModelStreamErrorException": return InvokeConnectionError(error_msg) return InvokeError(error_msg) - - def _invoke_bedrock_embedding(self, model: str, bedrock_runtime, body: dict, ): - accept = 'application/json' - content_type = 'application/json' + def _invoke_bedrock_embedding( + self, + model: str, + bedrock_runtime, + body: dict, + ): + accept = "application/json" + content_type = "application/json" try: response = bedrock_runtime.invoke_model( - body=json.dumps(body), - modelId=model, - accept=accept, - contentType=content_type + body=json.dumps(body), modelId=model, accept=accept, contentType=content_type ) - response_body = json.loads(response.get('body').read().decode('utf-8')) + response_body = json.loads(response.get("body").read().decode("utf-8")) return response_body except ClientError as ex: - error_code = ex.response['Error']['Code'] + error_code = ex.response["Error"]["Code"] full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" raise self._map_client_to_invoke_error(error_code, full_error_msg) - + except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: raise InvokeConnectionError(str(ex)) diff --git a/api/core/model_runtime/model_providers/chatglm/chatglm.py b/api/core/model_runtime/model_providers/chatglm/chatglm.py index e9dd5794f3..71d9a15322 100644 --- a/api/core/model_runtime/model_providers/chatglm/chatglm.py +++ b/api/core/model_runtime/model_providers/chatglm/chatglm.py @@ -20,12 +20,9 @@ class ChatGLMProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `chatglm3-6b` model for validate, - model_instance.validate_credentials( - model='chatglm3-6b', - credentials=credentials - ) + model_instance.validate_credentials(model="chatglm3-6b", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/chatglm/llm/llm.py b/api/core/model_runtime/model_providers/chatglm/llm/llm.py index e83d08af71..114acc1ec3 100644 --- a/api/core/model_runtime/model_providers/chatglm/llm/llm.py +++ b/api/core/model_runtime/model_providers/chatglm/llm/llm.py @@ -43,12 +43,19 @@ from core.model_runtime.utils import helper logger = logging.getLogger(__name__) + class ChatGLMLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ Invoke large language model @@ -71,11 +78,16 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): tools=tools, stop=stop, stream=stream, - user=user + user=user, ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -96,11 +108,16 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): :return: """ try: - self._invoke(model=model, credentials=credentials, prompt_messages=[ - UserPromptMessage(content="ping"), - ], model_parameters={ - "max_tokens": 16, - }) + self._invoke( + model=model, + credentials=credentials, + prompt_messages=[ + UserPromptMessage(content="ping"), + ], + model_parameters={ + "max_tokens": 16, + }, + ) except Exception as e: raise CredentialsValidateFailedError(str(e)) @@ -124,24 +141,24 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): ConflictError, NotFoundError, UnprocessableEntityError, - PermissionDeniedError + PermissionDeniedError, ], - InvokeRateLimitError: [ - RateLimitError - ], - InvokeAuthorizationError: [ - AuthenticationError - ], - InvokeBadRequestError: [ - ValueError - ] + InvokeRateLimitError: [RateLimitError], + InvokeAuthorizationError: [AuthenticationError], + InvokeBadRequestError: [ValueError], } - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ Invoke large language model @@ -163,35 +180,31 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): extra_model_kwargs = {} if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user if tools and len(tools) > 0: - extra_model_kwargs['functions'] = [ - helper.dump_model(tool) for tool in tools - ] + extra_model_kwargs["functions"] = [helper.dump_model(tool) for tool in tools] result = client.chat.completions.create( messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], model=model, stream=stream, **model_parameters, - **extra_model_kwargs + **extra_model_kwargs, ) if stream: return self._handle_chat_generate_stream_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) - + return self._handle_chat_generate_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) - + def _check_chatglm_parameters(self, model: str, model_parameters: dict, tools: list[PromptMessageTool]) -> None: if model.find("chatglm2") != -1 and tools is not None and len(tools) > 0: raise InvokeBadRequestError("ChatGLM2 does not support function calling") @@ -212,7 +225,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): if message.tool_calls and len(message.tool_calls) > 0: message_dict["function_call"] = { "name": message.tool_calls[0].function.name, - "arguments": message.tool_calls[0].function.arguments + "arguments": message.tool_calls[0].function.arguments, } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) @@ -223,12 +236,12 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): message_dict = {"role": "function", "content": message.content} else: raise ValueError(f"Unknown message type {type(message)}") - + return message_dict - - def _extract_response_tool_calls(self, - response_function_calls: list[FunctionCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + + def _extract_response_tool_calls( + self, response_function_calls: list[FunctionCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -239,19 +252,14 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): if response_function_calls: for response_tool_call in response_function_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.name, - arguments=response_tool_call.arguments + name=response_tool_call.name, arguments=response_tool_call.arguments ) - tool_call = AssistantPromptMessage.ToolCall( - id=0, - type='function', - function=function - ) + tool_call = AssistantPromptMessage.ToolCall(id=0, type="function", function=function) tool_calls.append(tool_call) return tool_calls - + def _to_client_kwargs(self, credentials: dict) -> dict: """ Convert invoke kwargs to client kwargs @@ -265,17 +273,20 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): client_kwargs = { "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "api_key": "1", - "base_url": str(URL(credentials['api_base']) / 'v1') + "base_url": str(URL(credentials["api_base"]) / "v1"), } return client_kwargs - - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: Stream[ChatCompletionChunk], - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) \ - -> Generator: - - full_response = '' + + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Stream[ChatCompletionChunk], + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> Generator: + full_response = "" for chunk in response: if len(chunk.choices) == 0: @@ -283,9 +294,9 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ""): continue - + # check if there is a tool call in the response function_calls = None if delta.delta.function_call: @@ -295,23 +306,25 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=assistant_message_tool_calls + content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls ) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=assistant_message_tool_calls + content=full_response, tool_calls=assistant_message_tool_calls ) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) - + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, @@ -320,7 +333,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage + usage=usage, ), ) else: @@ -335,11 +348,15 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): ) full_response += delta.delta.content - - def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) \ - -> LLMResult: + + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: ChatCompletion, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> LLMResult: """ Handle llm chat response @@ -359,15 +376,14 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): tool_calls = self._extract_response_tool_calls([function_calls] if function_calls else []) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) response = LLMResult( model=model, @@ -378,7 +394,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): ) return response - + def _num_tokens_from_string(self, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -395,17 +411,19 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): return num_tokens - def _num_tokens_from_messages(self, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """Calculate num tokens for chatglm2 and chatglm3 with GPT2 tokenizer. it's too complex to calculate num tokens for chatglm2 and chatglm3 with ChatGLM tokenizer, As a temporary solution we use GPT2 tokenizer instead. """ + def tokens(text: str): return self._get_num_tokens_by_gpt2(text) - + tokens_per_message = 3 tokens_per_name = 1 num_tokens = 0 @@ -414,10 +432,10 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text if key == "function_call": @@ -452,36 +470,37 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): :param tools: tools for tool calling :return: number of tokens """ + def tokens(text: str): return self._get_num_tokens_by_gpt2(text) num_tokens = 0 for tool in tools: # calculate num tokens for function object - num_tokens += tokens('name') + num_tokens += tokens("name") num_tokens += tokens(tool.name) - num_tokens += tokens('description') + num_tokens += tokens("description") num_tokens += tokens(tool.description) parameters = tool.parameters - num_tokens += tokens('parameters') - num_tokens += tokens('type') + num_tokens += tokens("parameters") + num_tokens += tokens("type") num_tokens += tokens(parameters.get("type")) - if 'properties' in parameters: - num_tokens += tokens('properties') - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += tokens("properties") + for key, value in parameters.get("properties").items(): num_tokens += tokens(key) for field_key, field_value in value.items(): num_tokens += tokens(field_key) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += tokens(enum_field) else: num_tokens += tokens(field_key) num_tokens += tokens(str(field_value)) - if 'required' in parameters: - num_tokens += tokens('required') - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += tokens("required") + for required_field in parameters["required"]: num_tokens += 3 num_tokens += tokens(required_field) diff --git a/api/core/model_runtime/model_providers/cohere/cohere.py b/api/core/model_runtime/model_providers/cohere/cohere.py index cfbcb94d26..8394a45fcf 100644 --- a/api/core/model_runtime/model_providers/cohere/cohere.py +++ b/api/core/model_runtime/model_providers/cohere/cohere.py @@ -20,12 +20,9 @@ class CohereProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.RERANK) # Use `rerank-english-v2.0` model for validate, - model_instance.validate_credentials( - model='rerank-english-v2.0', - credentials=credentials - ) + model_instance.validate_credentials(model="rerank-english-v2.0", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/cohere/llm/llm.py b/api/core/model_runtime/model_providers/cohere/llm/llm.py index 89b04c0279..203ca9c4a0 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/llm.py +++ b/api/core/model_runtime/model_providers/cohere/llm/llm.py @@ -55,11 +55,17 @@ class CohereLargeLanguageModel(LargeLanguageModel): Model class for Cohere large language model. """ - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -85,7 +91,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): tools=tools, stop=stop, stream=stream, - user=user + user=user, ) else: return self._generate( @@ -95,11 +101,16 @@ class CohereLargeLanguageModel(LargeLanguageModel): model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -136,30 +147,37 @@ class CohereLargeLanguageModel(LargeLanguageModel): self._chat_generate( model=model, credentials=credentials, - prompt_messages=[UserPromptMessage(content='ping')], + prompt_messages=[UserPromptMessage(content="ping")], model_parameters={ - 'max_tokens': 20, - 'temperature': 0, + "max_tokens": 20, + "temperature": 0, }, - stream=False + stream=False, ) else: self._generate( model=model, credentials=credentials, - prompt_messages=[UserPromptMessage(content='ping')], + prompt_messages=[UserPromptMessage(content="ping")], model_parameters={ - 'max_tokens': 20, - 'temperature': 0, + "max_tokens": 20, + "temperature": 0, }, - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm model @@ -173,17 +191,17 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) if stop: - model_parameters['end_sequences'] = stop + model_parameters["end_sequences"] = stop if stream: response = client.generate_stream( prompt=prompt_messages[0].content, model=model, **model_parameters, - request_options=RequestOptions(max_retries=0) + request_options=RequestOptions(max_retries=0), ) return self._handle_generate_stream_response(model, credentials, response, prompt_messages) @@ -192,14 +210,14 @@ class CohereLargeLanguageModel(LargeLanguageModel): prompt=prompt_messages[0].content, model=model, **model_parameters, - request_options=RequestOptions(max_retries=0) + request_options=RequestOptions(max_retries=0), ) return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: Generation, - prompt_messages: list[PromptMessage]) \ - -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: Generation, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -212,9 +230,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): assistant_text = response.generations[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_text) # calculate num tokens prompt_tokens = int(response.meta.billed_units.input_tokens) @@ -225,17 +241,18 @@ class CohereLargeLanguageModel(LargeLanguageModel): # transform response response = LLMResult( - model=model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage + model=model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage ) return response - def _handle_generate_stream_response(self, model: str, credentials: dict, - response: Iterator[GenerateStreamedResponse], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + response: Iterator[GenerateStreamedResponse], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -245,7 +262,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: llm response chunk generator """ index = 1 - full_assistant_content = '' + full_assistant_content = "" for chunk in response: if isinstance(chunk, GenerateStreamedResponse_TextGeneration): chunk = cast(GenerateStreamedResponse_TextGeneration, chunk) @@ -255,9 +272,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): continue # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=text - ) + assistant_prompt_message = AssistantPromptMessage(content=text) full_assistant_content += text @@ -267,7 +282,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=index, message=assistant_prompt_message, - ) + ), ) index += 1 @@ -277,9 +292,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): # calculate num tokens prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) completion_tokens = self._num_tokens_from_messages( - model, - credentials, - [AssistantPromptMessage(content=full_assistant_content)] + model, credentials, [AssistantPromptMessage(content=full_assistant_content)] ) # transform usage @@ -290,20 +303,27 @@ class CohereLargeLanguageModel(LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, - message=AssistantPromptMessage(content=''), + message=AssistantPromptMessage(content=""), finish_reason=chunk.finish_reason, - usage=usage - ) + usage=usage, + ), ) break elif isinstance(chunk, GenerateStreamedResponse_StreamError): chunk = cast(GenerateStreamedResponse_StreamError, chunk) raise InvokeBadRequestError(chunk.err) - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm chat model @@ -318,27 +338,28 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) if stop: - model_parameters['stop_sequences'] = stop + model_parameters["stop_sequences"] = stop if tools: if len(tools) == 1: raise ValueError("Cohere tool call requires at least two tools to be specified.") - model_parameters['tools'] = self._convert_tools(tools) + model_parameters["tools"] = self._convert_tools(tools) - message, chat_histories, tool_results \ - = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages) + message, chat_histories, tool_results = self._convert_prompt_messages_to_message_and_chat_histories( + prompt_messages + ) if tool_results: - model_parameters['tool_results'] = tool_results + model_parameters["tool_results"] = tool_results # chat model real_model = model if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL: - real_model = model.removesuffix('-chat') + real_model = model.removesuffix("-chat") if stream: response = client.chat_stream( @@ -346,7 +367,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): chat_history=chat_histories, model=real_model, **model_parameters, - request_options=RequestOptions(max_retries=0) + request_options=RequestOptions(max_retries=0), ) return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages) @@ -356,14 +377,14 @@ class CohereLargeLanguageModel(LargeLanguageModel): chat_history=chat_histories, model=real_model, **model_parameters, - request_options=RequestOptions(max_retries=0) + request_options=RequestOptions(max_retries=0), ) return self._handle_chat_generate_response(model, credentials, response, prompt_messages) - def _handle_chat_generate_response(self, model: str, credentials: dict, response: NonStreamedChatResponse, - prompt_messages: list[PromptMessage]) \ - -> LLMResult: + def _handle_chat_generate_response( + self, model: str, credentials: dict, response: NonStreamedChatResponse, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm chat response @@ -380,19 +401,15 @@ class CohereLargeLanguageModel(LargeLanguageModel): for cohere_tool_call in response.tool_calls: tool_call = AssistantPromptMessage.ToolCall( id=cohere_tool_call.name, - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=cohere_tool_call.name, - arguments=json.dumps(cohere_tool_call.parameters) - ) + name=cohere_tool_call.name, arguments=json.dumps(cohere_tool_call.parameters) + ), ) tool_calls.append(tool_call) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_text, tool_calls=tool_calls) # calculate num tokens prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) @@ -403,17 +420,18 @@ class CohereLargeLanguageModel(LargeLanguageModel): # transform response response = LLMResult( - model=model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage + model=model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage ) return response - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, - response: Iterator[StreamedChatResponse], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Iterator[StreamedChatResponse], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm chat stream response @@ -423,17 +441,16 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: llm response chunk generator """ - def final_response(full_text: str, - tool_calls: list[AssistantPromptMessage.ToolCall], - index: int, - finish_reason: Optional[str] = None) -> LLMResultChunk: + def final_response( + full_text: str, + tool_calls: list[AssistantPromptMessage.ToolCall], + index: int, + finish_reason: Optional[str] = None, + ) -> LLMResultChunk: # calculate num tokens prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) - full_assistant_prompt_message = AssistantPromptMessage( - content=full_text, - tool_calls=tool_calls - ) + full_assistant_prompt_message = AssistantPromptMessage(content=full_text, tool_calls=tool_calls) completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message]) # transform usage @@ -444,14 +461,14 @@ class CohereLargeLanguageModel(LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, - message=AssistantPromptMessage(content='', tool_calls=tool_calls), + message=AssistantPromptMessage(content="", tool_calls=tool_calls), finish_reason=finish_reason, - usage=usage - ) + usage=usage, + ), ) index = 1 - full_assistant_content = '' + full_assistant_content = "" tool_calls = [] for chunk in response: if isinstance(chunk, StreamedChatResponse_TextGeneration): @@ -462,9 +479,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): continue # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=text - ) + assistant_prompt_message = AssistantPromptMessage(content=text) full_assistant_content += text @@ -474,7 +489,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=index, message=assistant_prompt_message, - ) + ), ) index += 1 @@ -484,11 +499,10 @@ class CohereLargeLanguageModel(LargeLanguageModel): for cohere_tool_call in chunk.tool_calls: tool_call = AssistantPromptMessage.ToolCall( id=cohere_tool_call.name, - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=cohere_tool_call.name, - arguments=json.dumps(cohere_tool_call.parameters) - ) + name=cohere_tool_call.name, arguments=json.dumps(cohere_tool_call.parameters) + ), ) tool_calls.append(tool_call) elif isinstance(chunk, StreamedChatResponse_StreamEnd): @@ -496,8 +510,9 @@ class CohereLargeLanguageModel(LargeLanguageModel): yield final_response(full_assistant_content, tool_calls, index, chunk.finish_reason) index += 1 - def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \ - -> tuple[str, list[ChatMessage], list[ChatStreamRequestToolResultsItem]]: + def _convert_prompt_messages_to_message_and_chat_histories( + self, prompt_messages: list[PromptMessage] + ) -> tuple[str, list[ChatMessage], list[ChatStreamRequestToolResultsItem]]: """ Convert prompt messages to message and chat histories :param prompt_messages: prompt messages @@ -510,13 +525,14 @@ class CohereLargeLanguageModel(LargeLanguageModel): prompt_message = cast(AssistantPromptMessage, prompt_message) if prompt_message.tool_calls: for tool_call in prompt_message.tool_calls: - latest_tool_call_n_outputs.append(ChatStreamRequestToolResultsItem( - call=ToolCall( - name=tool_call.function.name, - parameters=json.loads(tool_call.function.arguments) - ), - outputs=[] - )) + latest_tool_call_n_outputs.append( + ChatStreamRequestToolResultsItem( + call=ToolCall( + name=tool_call.function.name, parameters=json.loads(tool_call.function.arguments) + ), + outputs=[], + ) + ) else: cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message) if cohere_prompt_message: @@ -529,12 +545,9 @@ class CohereLargeLanguageModel(LargeLanguageModel): if tool_call_n_outputs.call.name == prompt_message.tool_call_id: latest_tool_call_n_outputs[i] = ChatStreamRequestToolResultsItem( call=ToolCall( - name=tool_call_n_outputs.call.name, - parameters=tool_call_n_outputs.call.parameters + name=tool_call_n_outputs.call.name, parameters=tool_call_n_outputs.call.parameters ), - outputs=[{ - "result": prompt_message.content - }] + outputs=[{"result": prompt_message.content}], ) break i += 1 @@ -556,7 +569,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): latest_message = chat_histories.pop() message = latest_message.message else: - raise ValueError('Prompt messages is empty') + raise ValueError("Prompt messages is empty") return message, chat_histories, latest_tool_call_n_outputs @@ -569,7 +582,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): if isinstance(message.content, str): chat_message = ChatMessage(role="USER", message=message.content) else: - sub_message_text = '' + sub_message_text = "" for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) @@ -597,8 +610,8 @@ class CohereLargeLanguageModel(LargeLanguageModel): """ cohere_tools = [] for tool in tools: - properties = tool.parameters['properties'] - required_properties = tool.parameters['required'] + properties = tool.parameters["properties"] + required_properties = tool.parameters["required"] parameter_definitions = {} for p_key, p_val in properties.items(): @@ -606,21 +619,16 @@ class CohereLargeLanguageModel(LargeLanguageModel): if p_key in required_properties: required = True - desc = p_val['description'] - if 'enum' in p_val: - desc += (f"; Only accepts one of the following predefined options: " - f"[{', '.join(p_val['enum'])}]") + desc = p_val["description"] + if "enum" in p_val: + desc += f"; Only accepts one of the following predefined options: " f"[{', '.join(p_val['enum'])}]" parameter_definitions[p_key] = ToolParameterDefinitionsValue( - description=desc, - type=p_val['type'], - required=required + description=desc, type=p_val["type"], required=required ) cohere_tool = Tool( - name=tool.name, - description=tool.description, - parameter_definitions=parameter_definitions + name=tool.name, description=tool.description, parameter_definitions=parameter_definitions ) cohere_tools.append(cohere_tool) @@ -637,12 +645,9 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: number of tokens """ # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) - response = client.tokenize( - text=text, - model=model - ) + response = client.tokenize(text=text, model=model) return len(response.tokens) @@ -658,30 +663,30 @@ class CohereLargeLanguageModel(LargeLanguageModel): real_model = model if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL: - real_model = model.removesuffix('-chat') + real_model = model.removesuffix("-chat") return self._num_tokens_from_string(real_model, credentials, message_str) def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - Cohere supports fine-tuning of their models. This method returns the schema of the base model - but renamed to the fine-tuned model name. + Cohere supports fine-tuning of their models. This method returns the schema of the base model + but renamed to the fine-tuned model name. - :param model: model name - :param credentials: credentials + :param model: model name + :param credentials: credentials - :return: model schema + :return: model schema """ # get model schema models = self.predefined_models() model_map = {model.model: model for model in models} - mode = credentials.get('mode') + mode = credentials.get("mode") - if mode == 'chat': - base_model_schema = model_map['command-light-chat'] + if mode == "chat": + base_model_schema = model_map["command-light-chat"] else: - base_model_schema = model_map['command-light'] + base_model_schema = model_map["command-light"] base_model_schema = cast(AIModelEntity, base_model_schema) @@ -691,16 +696,13 @@ class CohereLargeLanguageModel(LargeLanguageModel): entity = AIModelEntity( model=model, - label=I18nObject( - zh_Hans=model, - en_US=model - ), + label=I18nObject(zh_Hans=model, en_US=model), model_type=ModelType.LLM, features=list(base_model_schema_features), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties=dict(base_model_schema_model_properties.items()), parameter_rules=list(base_model_schema_parameters_rules), - pricing=base_model_schema.pricing + pricing=base_model_schema.pricing, ) return entity @@ -716,22 +718,16 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - cohere.errors.service_unavailable_error.ServiceUnavailableError - ], - InvokeServerUnavailableError: [ - cohere.errors.internal_server_error.InternalServerError - ], - InvokeRateLimitError: [ - cohere.errors.too_many_requests_error.TooManyRequestsError - ], + InvokeConnectionError: [cohere.errors.service_unavailable_error.ServiceUnavailableError], + InvokeServerUnavailableError: [cohere.errors.internal_server_error.InternalServerError], + InvokeRateLimitError: [cohere.errors.too_many_requests_error.TooManyRequestsError], InvokeAuthorizationError: [ cohere.errors.unauthorized_error.UnauthorizedError, - cohere.errors.forbidden_error.ForbiddenError + cohere.errors.forbidden_error.ForbiddenError, ], InvokeBadRequestError: [ cohere.core.api_error.ApiError, cohere.errors.bad_request_error.BadRequestError, cohere.errors.not_found_error.NotFoundError, - ] + ], } diff --git a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py index d2fdb30c6f..aba8fedbc0 100644 --- a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py @@ -21,10 +21,16 @@ class CohereRerankModel(RerankModel): Model class for Cohere rerank model. """ - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -38,20 +44,17 @@ class CohereRerankModel(RerankModel): :return: rerank result """ if len(docs) == 0: - return RerankResult( - model=model, - docs=docs - ) + return RerankResult(model=model, docs=docs) # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) response = client.rerank( query=query, documents=docs, model=model, top_n=top_n, return_documents=True, - request_options=RequestOptions(max_retries=0) + request_options=RequestOptions(max_retries=0), ) rerank_documents = [] @@ -70,10 +73,7 @@ class CohereRerankModel(RerankModel): else: rerank_documents.append(rerank_document) - return RerankResult( - model=model, - docs=rerank_documents - ) + return RerankResult(model=model, docs=rerank_documents) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -94,7 +94,7 @@ class CohereRerankModel(RerankModel): "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -110,22 +110,16 @@ class CohereRerankModel(RerankModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - cohere.errors.service_unavailable_error.ServiceUnavailableError - ], - InvokeServerUnavailableError: [ - cohere.errors.internal_server_error.InternalServerError - ], - InvokeRateLimitError: [ - cohere.errors.too_many_requests_error.TooManyRequestsError - ], + InvokeConnectionError: [cohere.errors.service_unavailable_error.ServiceUnavailableError], + InvokeServerUnavailableError: [cohere.errors.internal_server_error.InternalServerError], + InvokeRateLimitError: [cohere.errors.too_many_requests_error.TooManyRequestsError], InvokeAuthorizationError: [ cohere.errors.unauthorized_error.UnauthorizedError, - cohere.errors.forbidden_error.ForbiddenError + cohere.errors.forbidden_error.ForbiddenError, ], InvokeBadRequestError: [ cohere.core.api_error.ApiError, cohere.errors.bad_request_error.BadRequestError, cohere.errors.not_found_error.NotFoundError, - ] + ], } diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py index 0540fb740f..a1c5e98118 100644 --- a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py @@ -24,9 +24,9 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): Model class for Cohere text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -46,14 +46,10 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): used_tokens = 0 for i, text in enumerate(texts): - tokenize_response = self._tokenize( - model=model, - credentials=credentials, - text=text - ) + tokenize_response = self._tokenize(model=model, credentials=credentials, text=text) for j in range(0, len(tokenize_response), context_size): - tokens += [tokenize_response[j: j + context_size]] + tokens += [tokenize_response[j : j + context_size]] indices += [i] batched_embeddings = [] @@ -62,9 +58,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): for i in _iter: # call embedding model embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - credentials=credentials, - texts=["".join(token) for token in tokens[i: i + max_chunks]] + model=model, credentials=credentials, texts=["".join(token) for token in tokens[i : i + max_chunks]] ) used_tokens += embedding_used_tokens @@ -80,9 +74,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): _result = results[i] if len(_result) == 0: embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - credentials=credentials, - texts=[" "] + model=model, credentials=credentials, texts=[" "] ) used_tokens += embedding_used_tokens @@ -92,17 +84,9 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): embeddings[i] = (average / np.linalg.norm(average)).tolist() # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) - return TextEmbeddingResult( - embeddings=embeddings, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -116,14 +100,10 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): if len(texts) == 0: return 0 - full_text = ' '.join(texts) + full_text = " ".join(texts) try: - response = self._tokenize( - model=model, - credentials=credentials, - text=full_text - ) + response = self._tokenize(model=model, credentials=credentials, text=full_text) except Exception as e: raise self._transform_invoke_error(e) @@ -141,14 +121,9 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): return [] # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) - response = client.tokenize( - text=text, - model=model, - offline=False, - request_options=RequestOptions(max_retries=0) - ) + response = client.tokenize(text=text, model=model, offline=False, request_options=RequestOptions(max_retries=0)) return response.token_strings @@ -162,11 +137,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): """ try: # call embedding model - self._embedding_invoke( - model=model, - credentials=credentials, - texts=['ping'] - ) + self._embedding_invoke(model=model, credentials=credentials, texts=["ping"]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -180,14 +151,14 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): :return: embeddings and used tokens """ # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) # call embedding model response = client.embed( texts=texts, model=model, - input_type='search_document' if len(texts) > 1 else 'search_query', - request_options=RequestOptions(max_retries=1) + input_type="search_document" if len(texts) > 1 else "search_query", + request_options=RequestOptions(max_retries=1), ) return response.embeddings, int(response.meta.billed_units.input_tokens) @@ -203,10 +174,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -217,7 +185,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -233,22 +201,16 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - cohere.errors.service_unavailable_error.ServiceUnavailableError - ], - InvokeServerUnavailableError: [ - cohere.errors.internal_server_error.InternalServerError - ], - InvokeRateLimitError: [ - cohere.errors.too_many_requests_error.TooManyRequestsError - ], + InvokeConnectionError: [cohere.errors.service_unavailable_error.ServiceUnavailableError], + InvokeServerUnavailableError: [cohere.errors.internal_server_error.InternalServerError], + InvokeRateLimitError: [cohere.errors.too_many_requests_error.TooManyRequestsError], InvokeAuthorizationError: [ cohere.errors.unauthorized_error.UnauthorizedError, - cohere.errors.forbidden_error.ForbiddenError + cohere.errors.forbidden_error.ForbiddenError, ], InvokeBadRequestError: [ cohere.core.api_error.ApiError, cohere.errors.bad_request_error.BadRequestError, cohere.errors.not_found_error.NotFoundError, - ] + ], } diff --git a/api/core/model_runtime/model_providers/deepseek/deepseek.py b/api/core/model_runtime/model_providers/deepseek/deepseek.py index d61fd4ddc8..10feef8972 100644 --- a/api/core/model_runtime/model_providers/deepseek/deepseek.py +++ b/api/core/model_runtime/model_providers/deepseek/deepseek.py @@ -7,9 +7,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) - class DeepSeekProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -22,12 +20,9 @@ class DeepSeekProvider(ModelProvider): # Use `deepseek-chat` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='deepseek-chat', - credentials=credentials - ) + model_instance.validate_credentials(model="deepseek-chat", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/deepseek/llm/llm.py b/api/core/model_runtime/model_providers/deepseek/llm/llm.py index bdb3823b60..6d0a3ee262 100644 --- a/api/core/model_runtime/model_providers/deepseek/llm/llm.py +++ b/api/core/model_runtime/model_providers/deepseek/llm/llm.py @@ -13,12 +13,17 @@ from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguag class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -27,10 +32,8 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel): self._add_custom_parameters(credentials) super().validate_credentials(model, credentials) - # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_string(self, model: str, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -48,8 +51,9 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel): return num_tokens # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/ @@ -69,10 +73,10 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel): # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -103,11 +107,10 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel): @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['openai_api_key']=credentials['api_key'] - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - credentials['openai_api_base']='https://api.deepseek.com' + credentials["mode"] = "chat" + credentials["openai_api_key"] = credentials["api_key"] + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + credentials["openai_api_base"] = "https://api.deepseek.com" else: - parsed_url = urlparse(credentials['endpoint_url']) - credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}" - + parsed_url = urlparse(credentials["endpoint_url"]) + credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}" diff --git a/api/core/model_runtime/model_providers/fishaudio/__init__.py b/api/core/model_runtime/model_providers/fishaudio/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/fishaudio/_assets/fishaudio_l_en.svg b/api/core/model_runtime/model_providers/fishaudio/_assets/fishaudio_l_en.svg new file mode 100644 index 0000000000..d6f7723bd5 --- /dev/null +++ b/api/core/model_runtime/model_providers/fishaudio/_assets/fishaudio_l_en.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/fishaudio/_assets/fishaudio_s_en.svg b/api/core/model_runtime/model_providers/fishaudio/_assets/fishaudio_s_en.svg new file mode 100644 index 0000000000..d6f7723bd5 --- /dev/null +++ b/api/core/model_runtime/model_providers/fishaudio/_assets/fishaudio_s_en.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/fishaudio/fishaudio.py b/api/core/model_runtime/model_providers/fishaudio/fishaudio.py new file mode 100644 index 0000000000..3bc4b533e0 --- /dev/null +++ b/api/core/model_runtime/model_providers/fishaudio/fishaudio.py @@ -0,0 +1,26 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class FishAudioProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + + For debugging purposes, this method now always passes validation. + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.TTS) + model_instance.validate_credentials(credentials=credentials) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise ex diff --git a/api/core/model_runtime/model_providers/fishaudio/fishaudio.yaml b/api/core/model_runtime/model_providers/fishaudio/fishaudio.yaml new file mode 100644 index 0000000000..479eb7fb85 --- /dev/null +++ b/api/core/model_runtime/model_providers/fishaudio/fishaudio.yaml @@ -0,0 +1,76 @@ +provider: fishaudio +label: + en_US: Fish Audio +description: + en_US: Models provided by Fish Audio, currently only support TTS. + zh_Hans: Fish Audio 提供的模型,目前仅支持 TTS。 +icon_small: + en_US: fishaudio_s_en.svg +icon_large: + en_US: fishaudio_l_en.svg +background: "#E5E7EB" +help: + title: + en_US: Get your API key from Fish Audio + zh_Hans: 从 Fish Audio 获取你的 API Key + url: + en_US: https://fish.audio/go-api/ +supported_model_types: + - tts +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: api_base + label: + en_US: API URL + type: text-input + required: false + default: https://api.fish.audio + placeholder: + en_US: Enter your API URL + zh_Hans: 在此输入您的 API URL + - variable: use_public_models + label: + en_US: Use Public Models + type: select + required: false + default: "false" + placeholder: + en_US: Toggle to use public models + zh_Hans: 切换以使用公共模型 + options: + - value: "true" + label: + en_US: Allow Public Models + zh_Hans: 使用公共模型 + - value: "false" + label: + en_US: Private Models Only + zh_Hans: 仅使用私有模型 + - variable: latency + label: + en_US: Latency + type: select + required: false + default: "normal" + placeholder: + en_US: Toggle to choice latency + zh_Hans: 切换以调整延迟 + options: + - value: "balanced" + label: + en_US: Low (may affect quality) + zh_Hans: 低延迟 (可能降低质量) + - value: "normal" + label: + en_US: Normal + zh_Hans: 标准 diff --git a/api/core/model_runtime/model_providers/fishaudio/tts/__init__.py b/api/core/model_runtime/model_providers/fishaudio/tts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/fishaudio/tts/tts.py b/api/core/model_runtime/model_providers/fishaudio/tts/tts.py new file mode 100644 index 0000000000..895a7a914c --- /dev/null +++ b/api/core/model_runtime/model_providers/fishaudio/tts/tts.py @@ -0,0 +1,158 @@ +from typing import Optional + +import httpx + +from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.tts_model import TTSModel + + +class FishAudioText2SpeechModel(TTSModel): + """ + Model class for Fish.audio Text to Speech model. + """ + + def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: + api_base = credentials.get("api_base", "https://api.fish.audio") + api_key = credentials.get("api_key") + use_public_models = credentials.get("use_public_models", "false") == "true" + + params = { + "self": str(not use_public_models).lower(), + "page_size": "100", + } + + if language is not None: + if "-" in language: + language = language.split("-")[0] + params["language"] = language + + results = httpx.get( + f"{api_base}/model", + headers={"Authorization": f"Bearer {api_key}"}, + params=params, + ) + + results.raise_for_status() + data = results.json() + + return [{"name": i["title"], "value": i["_id"]} for i in data["items"]] + + def _invoke( + self, + model: str, + tenant_id: str, + credentials: dict, + content_text: str, + voice: str, + user: Optional[str] = None, + ) -> any: + """ + Invoke text2speech model + + :param model: model name + :param tenant_id: user tenant id + :param credentials: model credentials + :param voice: model timbre + :param content_text: text content to be translated + :param user: unique user id + :return: generator yielding audio chunks + """ + + return self._tts_invoke_streaming( + model=model, + credentials=credentials, + content_text=content_text, + voice=voice, + ) + + def validate_credentials(self, credentials: dict, user: Optional[str] = None) -> None: + """ + Validate credentials for text2speech model + + :param credentials: model credentials + :param user: unique user id + """ + + try: + self.get_tts_model_voices( + None, + credentials={ + "api_key": credentials["api_key"], + "api_base": credentials["api_base"], + # Disable public models will trigger a 403 error if user is not logged in + "use_public_models": "false", + }, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: + """ + Invoke streaming text2speech model + :param model: model name + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: ID of the reference audio (if any) + :return: generator yielding audio chunks + """ + + try: + word_limit = self._get_model_word_limit(model, credentials) + if len(content_text) > word_limit: + sentences = self._split_text_into_sentences(content_text, max_length=word_limit) + else: + sentences = [content_text.strip()] + + for i in range(len(sentences)): + yield from self._tts_invoke_streaming_sentence( + credentials=credentials, content_text=sentences[i], voice=voice + ) + + except Exception as ex: + raise InvokeBadRequestError(str(ex)) + + def _tts_invoke_streaming_sentence(self, credentials: dict, content_text: str, voice: Optional[str] = None) -> any: + """ + Invoke streaming text2speech model + + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: ID of the reference audio (if any) + :return: generator yielding audio chunks + """ + api_key = credentials.get("api_key") + api_url = credentials.get("api_base", "https://api.fish.audio") + latency = credentials.get("latency") + + if not api_key: + raise InvokeBadRequestError("API key is required") + + with httpx.stream( + "POST", + api_url + "/v1/tts", + json={"text": content_text, "reference_id": voice, "latency": latency}, + headers={ + "Authorization": f"Bearer {api_key}", + }, + timeout=None, + ) as response: + if response.status_code != 200: + raise InvokeBadRequestError(f"Error: {response.status_code} - {response.text}") + yield from response.iter_bytes() + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeBadRequestError: [ + httpx.HTTPStatusError, + ], + } diff --git a/api/core/model_runtime/model_providers/fishaudio/tts/tts.yaml b/api/core/model_runtime/model_providers/fishaudio/tts/tts.yaml new file mode 100644 index 0000000000..b4a446a957 --- /dev/null +++ b/api/core/model_runtime/model_providers/fishaudio/tts/tts.yaml @@ -0,0 +1,5 @@ +model: tts-default +model_type: tts +model_properties: + word_limit: 1000 + audio_type: 'mp3' diff --git a/api/core/model_runtime/model_providers/google/google.py b/api/core/model_runtime/model_providers/google/google.py index ba25c74e71..70f56a8337 100644 --- a/api/core/model_runtime/model_providers/google/google.py +++ b/api/core/model_runtime/model_providers/google/google.py @@ -20,12 +20,9 @@ class GoogleProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `gemini-pro` model for validate, - model_instance.validate_credentials( - model='gemini-pro', - credentials=credentials - ) + model_instance.validate_credentials(model="gemini-pro", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index 84241fb6c8..274ff02095 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -49,12 +49,17 @@ if you are not sure about the structure. class GoogleLargeLanguageModel(LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -70,9 +75,14 @@ class GoogleLargeLanguageModel(LargeLanguageModel): """ # invoke model return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -85,7 +95,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): prompt = self._convert_messages_to_prompt(prompt_messages) return self._get_num_tokens_by_gpt2(prompt) - + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ Format a list of messages into a full prompt for the Google model @@ -95,13 +105,10 @@ class GoogleLargeLanguageModel(LargeLanguageModel): """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) return text.rstrip() - + def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool: """ Convert tool messages to glm tools @@ -117,14 +124,16 @@ class GoogleLargeLanguageModel(LargeLanguageModel): type=glm.Type.OBJECT, properties={ key: { - 'type_': value.get('type', 'string').upper(), - 'description': value.get('description', ''), - 'enum': value.get('enum', []) - } for key, value in tool.parameters.get('properties', {}).items() + "type_": value.get("type", "string").upper(), + "description": value.get("description", ""), + "enum": value.get("enum", []), + } + for key, value in tool.parameters.get("properties", {}).items() }, - required=tool.parameters.get('required', []) + required=tool.parameters.get("required", []), ), - ) for tool in tools + ) + for tool in tools ] ) @@ -136,20 +145,25 @@ class GoogleLargeLanguageModel(LargeLanguageModel): :param credentials: model credentials :return: """ - + try: ping_message = SystemPromptMessage(content="ping") self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5}) - + except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None - ) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -163,14 +177,12 @@ class GoogleLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ config_kwargs = model_parameters.copy() - config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None) + config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None) if stop: config_kwargs["stop_sequences"] = stop - google_model = genai.GenerativeModel( - model_name=model - ) + google_model = genai.GenerativeModel(model_name=model) history = [] @@ -180,7 +192,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): content = self._format_message_to_glm_content(last_msg) history.append(content) else: - for msg in prompt_messages: # makes message roles strictly alternating + for msg in prompt_messages: # makes message roles strictly alternating content = self._format_message_to_glm_content(msg) if history and history[-1]["role"] == content["role"]: history[-1]["parts"].extend(content["parts"]) @@ -194,7 +206,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): google_model._client = new_custom_client - safety_settings={ + safety_settings = { HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, @@ -203,13 +215,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel): response = google_model.generate_content( contents=history, - generation_config=genai.types.GenerationConfig( - **config_kwargs - ), + generation_config=genai.types.GenerationConfig(**config_kwargs), stream=stream, safety_settings=safety_settings, tools=self._convert_tools_to_glm_tool(tools) if tools else None, - request_options={"timeout": 600} + request_options={"timeout": 600}, ) if stream: @@ -217,8 +227,9 @@ class GoogleLargeLanguageModel(LargeLanguageModel): return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: GenerateContentResponse, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -229,9 +240,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): :return: llm response """ # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=response.text - ) + assistant_prompt_message = AssistantPromptMessage(content=response.text) # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -250,8 +259,9 @@ class GoogleLargeLanguageModel(LargeLanguageModel): return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: GenerateContentResponse, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -264,9 +274,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): index = -1 for chunk in response: for part in chunk.parts: - assistant_prompt_message = AssistantPromptMessage( - content='' - ) + assistant_prompt_message = AssistantPromptMessage(content="") if part.text: assistant_prompt_message.content += part.text @@ -275,36 +283,31 @@ class GoogleLargeLanguageModel(LargeLanguageModel): assistant_prompt_message.tool_calls = [ AssistantPromptMessage.ToolCall( id=part.function_call.name, - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=part.function_call.name, - arguments=json.dumps(dict(part.function_call.args.items())) - ) + arguments=json.dumps(dict(part.function_call.args.items())), + ), ) ] index += 1 - + if not response._done: - # transform assistant message to prompt message yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) else: - # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) # transform usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, @@ -312,8 +315,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel): index=index, message=assistant_prompt_message, finish_reason=str(chunk.candidates[0].finish_reason), - usage=usage - ) + usage=usage, + ), ) def _convert_one_message_to_text(self, message: PromptMessage) -> str: @@ -328,9 +331,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): content = message.content if isinstance(content, list): - content = "".join( - c.data for c in content if c.type != PromptMessageContentType.IMAGE - ) + content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE) if isinstance(message, UserPromptMessage): message_text = f"{human_prompt} {content}" @@ -353,95 +354,86 @@ class GoogleLargeLanguageModel(LargeLanguageModel): :return: glm Content representation of message """ if isinstance(message, UserPromptMessage): - glm_content = { - "role": "user", - "parts": [] - } - if (isinstance(message.content, str)): - glm_content['parts'].append(to_part(message.content)) + glm_content = {"role": "user", "parts": []} + if isinstance(message.content, str): + glm_content["parts"].append(to_part(message.content)) else: for c in message.content: if c.type == PromptMessageContentType.TEXT: - glm_content['parts'].append(to_part(c.data)) + glm_content["parts"].append(to_part(c.data)) elif c.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, c) if message_content.data.startswith("data:"): - metadata, base64_data = c.data.split(',', 1) - mime_type = metadata.split(';', 1)[0].split(':')[1] + metadata, base64_data = c.data.split(",", 1) + mime_type = metadata.split(";", 1)[0].split(":")[1] else: # fetch image data from url try: image_content = requests.get(message_content.data).content with Image.open(io.BytesIO(image_content)) as img: mime_type = f"image/{img.format.lower()}" - base64_data = base64.b64encode(image_content).decode('utf-8') + base64_data = base64.b64encode(image_content).decode("utf-8") except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") - blob = {"inline_data":{"mime_type":mime_type,"data":base64_data}} - glm_content['parts'].append(blob) + blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}} + glm_content["parts"].append(blob) return glm_content elif isinstance(message, AssistantPromptMessage): - glm_content = { - "role": "model", - "parts": [] - } + glm_content = {"role": "model", "parts": []} if message.content: - glm_content['parts'].append(to_part(message.content)) + glm_content["parts"].append(to_part(message.content)) if message.tool_calls: - glm_content["parts"].append(to_part(glm.FunctionCall( - name=message.tool_calls[0].function.name, - args=json.loads(message.tool_calls[0].function.arguments), - ))) + glm_content["parts"].append( + to_part( + glm.FunctionCall( + name=message.tool_calls[0].function.name, + args=json.loads(message.tool_calls[0].function.arguments), + ) + ) + ) return glm_content elif isinstance(message, SystemPromptMessage): - return { - "role": "user", - "parts": [to_part(message.content)] - } + return {"role": "user", "parts": [to_part(message.content)]} elif isinstance(message, ToolPromptMessage): return { "role": "function", - "parts": [glm.Part(function_response=glm.FunctionResponse( - name=message.name, - response={ - "response": message.content - } - ))] + "parts": [ + glm.Part( + function_response=glm.FunctionResponse( + name=message.name, response={"response": message.content} + ) + ) + ], } else: raise ValueError(f"Got unknown type {message}") - + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ Map model invoke error to unified error - The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller - The value is the md = genai.GenerativeModel(model)error type thrown by the model, + The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller + The value is the md = genai.GenerativeModel(model) error type thrown by the model, which needs to be converted into a unified error type for the caller. - :return: Invoke emd = genai.GenerativeModel(model)rror mapping + :return: Invoke emd = genai.GenerativeModel(model) error mapping """ return { - InvokeConnectionError: [ - exceptions.RetryError - ], + InvokeConnectionError: [exceptions.RetryError], InvokeServerUnavailableError: [ exceptions.ServiceUnavailable, exceptions.InternalServerError, exceptions.BadGateway, exceptions.GatewayTimeout, - exceptions.DeadlineExceeded - ], - InvokeRateLimitError: [ - exceptions.ResourceExhausted, - exceptions.TooManyRequests + exceptions.DeadlineExceeded, ], + InvokeRateLimitError: [exceptions.ResourceExhausted, exceptions.TooManyRequests], InvokeAuthorizationError: [ exceptions.Unauthenticated, exceptions.PermissionDenied, exceptions.Unauthenticated, - exceptions.Forbidden + exceptions.Forbidden, ], InvokeBadRequestError: [ exceptions.BadRequest, @@ -457,5 +449,5 @@ class GoogleLargeLanguageModel(LargeLanguageModel): exceptions.PreconditionFailed, exceptions.RequestRangeNotSatisfiable, exceptions.Cancelled, - ] + ], } diff --git a/api/core/model_runtime/model_providers/groq/groq.py b/api/core/model_runtime/model_providers/groq/groq.py index b3f37b3967..d0d5ff68f8 100644 --- a/api/core/model_runtime/model_providers/groq/groq.py +++ b/api/core/model_runtime/model_providers/groq/groq.py @@ -6,8 +6,8 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) -class GroqProvider(ModelProvider): +class GroqProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -18,12 +18,9 @@ class GroqProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='llama3-8b-8192', - credentials=credentials - ) + model_instance.validate_credentials(model="llama3-8b-8192", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/groq/llm/llm.py b/api/core/model_runtime/model_providers/groq/llm/llm.py index 915f7a4e1a..352a7b519e 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llm.py +++ b/api/core/model_runtime/model_providers/groq/llm/llm.py @@ -7,11 +7,17 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -21,6 +27,5 @@ class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel): @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.groq.com/openai/v1' - + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.groq.com/openai/v1" diff --git a/api/core/model_runtime/model_providers/huggingface_hub/_common.py b/api/core/model_runtime/model_providers/huggingface_hub/_common.py index dd8ae526e6..3c4020b6ee 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/_common.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/_common.py @@ -4,12 +4,6 @@ from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError class _CommonHuggingfaceHub: - @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: - return { - InvokeBadRequestError: [ - HfHubHTTPError, - BadRequestError - ] - } + return {InvokeBadRequestError: [HfHubHTTPError, BadRequestError]} diff --git a/api/core/model_runtime/model_providers/huggingface_hub/huggingface_hub.py b/api/core/model_runtime/model_providers/huggingface_hub/huggingface_hub.py index 15e2a4fed4..54d2a2bf39 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/huggingface_hub.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/huggingface_hub.py @@ -6,6 +6,5 @@ logger = logging.getLogger(__name__) class HuggingfaceHubProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py index f43a8aedaf..10c6d553f3 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py @@ -29,16 +29,23 @@ from core.model_runtime.model_providers.huggingface_hub._common import _CommonHu class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + client = InferenceClient(token=credentials["huggingfacehub_api_token"]) - client = InferenceClient(token=credentials['huggingfacehub_api_token']) + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + model = credentials["huggingfacehub_endpoint_url"] - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - model = credentials['huggingfacehub_endpoint_url'] - - if 'baichuan' in model.lower(): + if "baichuan" in model.lower(): stream = False response = client.text_generation( @@ -47,98 +54,100 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel stream=stream, model=model, stop_sequences=stop, - **model_parameters) + **model_parameters, + ) if stream: return self._handle_generate_stream_response(model, credentials, prompt_messages, response) return self._handle_generate_response(model, credentials, prompt_messages, response) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: prompt = self._convert_messages_to_prompt(prompt_messages) return self._get_num_tokens_by_gpt2(prompt) def validate_credentials(self, model: str, credentials: dict) -> None: try: - if 'huggingfacehub_api_type' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type must be provided.') + if "huggingfacehub_api_type" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.") - if credentials['huggingfacehub_api_type'] not in ('inference_endpoints', 'hosted_inference_api'): - raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type is invalid.') + if credentials["huggingfacehub_api_type"] not in ("inference_endpoints", "hosted_inference_api"): + raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.") - if 'huggingfacehub_api_token' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Access Token must be provided.') + if "huggingfacehub_api_token" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Access Token must be provided.") - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - if 'huggingfacehub_endpoint_url' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint URL must be provided.') + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + if "huggingfacehub_endpoint_url" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Endpoint URL must be provided.") - if 'task_type' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Task Type must be provided.') - elif credentials['huggingfacehub_api_type'] == 'hosted_inference_api': - credentials['task_type'] = self._get_hosted_model_task_type(credentials['huggingfacehub_api_token'], - model) + if "task_type" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Task Type must be provided.") + elif credentials["huggingfacehub_api_type"] == "hosted_inference_api": + credentials["task_type"] = self._get_hosted_model_task_type( + credentials["huggingfacehub_api_token"], model + ) - if credentials['task_type'] not in ("text2text-generation", "text-generation"): - raise CredentialsValidateFailedError('Huggingface Hub Task Type must be one of text2text-generation, ' - 'text-generation.') + if credentials["task_type"] not in ("text2text-generation", "text-generation"): + raise CredentialsValidateFailedError( + "Huggingface Hub Task Type must be one of text2text-generation, " "text-generation." + ) - client = InferenceClient(token=credentials['huggingfacehub_api_token']) + client = InferenceClient(token=credentials["huggingfacehub_api_token"]) - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - model = credentials['huggingfacehub_endpoint_url'] + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + model = credentials["huggingfacehub_endpoint_url"] try: - client.text_generation( - prompt='Who are you?', - stream=True, - model=model) + client.text_generation(prompt="Who are you?", stream=True, model=model) except BadRequestError as e: - raise CredentialsValidateFailedError('Only available for models running on with the `text-generation-inference`. ' - 'To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.') + raise CredentialsValidateFailedError( + "Only available for models running on with the `text-generation-inference`. " + "To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference." + ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, - model_properties={ - ModelPropertyKey.MODE: LLMMode.COMPLETION.value - }, - parameter_rules=self._get_customizable_model_parameter_rules() + model_properties={ModelPropertyKey.MODE: LLMMode.COMPLETION.value}, + parameter_rules=self._get_customizable_model_parameter_rules(), ) return entity @staticmethod def _get_customizable_model_parameter_rules() -> list[ParameterRule]: - temperature_rule_dict = PARAMETER_RULE_TEMPLATE.get( - DefaultParameterName.TEMPERATURE).copy() - temperature_rule_dict['name'] = 'temperature' + temperature_rule_dict = PARAMETER_RULE_TEMPLATE.get(DefaultParameterName.TEMPERATURE).copy() + temperature_rule_dict["name"] = "temperature" temperature_rule = ParameterRule(**temperature_rule_dict) temperature_rule.default = 0.5 top_p_rule_dict = PARAMETER_RULE_TEMPLATE.get(DefaultParameterName.TOP_P).copy() - top_p_rule_dict['name'] = 'top_p' + top_p_rule_dict["name"] = "top_p" top_p_rule = ParameterRule(**top_p_rule_dict) top_p_rule.default = 0.5 top_k_rule = ParameterRule( - name='top_k', + name="top_k", label={ - 'en_US': 'Top K', - 'zh_Hans': 'Top K', + "en_US": "Top K", + "zh_Hans": "Top K", }, - type='int', + type="int", help={ - 'en_US': 'The number of highest probability vocabulary tokens to keep for top-k-filtering.', - 'zh_Hans': '保留的最高概率词汇标记的数量。', + "en_US": "The number of highest probability vocabulary tokens to keep for top-k-filtering.", + "zh_Hans": "保留的最高概率词汇标记的数量。", }, required=False, default=2, @@ -148,15 +157,15 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel ) max_new_tokens = ParameterRule( - name='max_new_tokens', + name="max_new_tokens", label={ - 'en_US': 'Max New Tokens', - 'zh_Hans': '最大新标记', + "en_US": "Max New Tokens", + "zh_Hans": "最大新标记", }, - type='int', + type="int", help={ - 'en_US': 'Maximum number of generated tokens.', - 'zh_Hans': '生成的标记的最大数量。', + "en_US": "Maximum number of generated tokens.", + "zh_Hans": "生成的标记的最大数量。", }, required=False, default=20, @@ -166,30 +175,30 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel ) seed = ParameterRule( - name='seed', + name="seed", label={ - 'en_US': 'Random sampling seed', - 'zh_Hans': '随机采样种子', + "en_US": "Random sampling seed", + "zh_Hans": "随机采样种子", }, - type='int', + type="int", help={ - 'en_US': 'Random sampling seed.', - 'zh_Hans': '随机采样种子。', + "en_US": "Random sampling seed.", + "zh_Hans": "随机采样种子。", }, required=False, precision=0, ) repetition_penalty = ParameterRule( - name='repetition_penalty', + name="repetition_penalty", label={ - 'en_US': 'Repetition Penalty', - 'zh_Hans': '重复惩罚', + "en_US": "Repetition Penalty", + "zh_Hans": "重复惩罚", }, - type='float', + type="float", help={ - 'en_US': 'The parameter for repetition penalty. 1.0 means no penalty.', - 'zh_Hans': '重复惩罚的参数。1.0 表示没有惩罚。', + "en_US": "The parameter for repetition penalty. 1.0 means no penalty.", + "zh_Hans": "重复惩罚的参数。1.0 表示没有惩罚。", }, required=False, precision=1, @@ -197,11 +206,9 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel return [temperature_rule, top_k_rule, top_p_rule, max_new_tokens, seed, repetition_penalty] - def _handle_generate_stream_response(self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - response: Generator) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: Generator + ) -> Generator: index = -1 for chunk in response: # skip special tokens @@ -210,9 +217,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel index += 1 - assistant_prompt_message = AssistantPromptMessage( - content=chunk.token.text - ) + assistant_prompt_message = AssistantPromptMessage(content=chunk.token.text) if chunk.details: prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -240,15 +245,15 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel ), ) - def _handle_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: any) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: any + ) -> LLMResult: if isinstance(response, str): content = response else: content = response.generated_text - assistant_prompt_message = AssistantPromptMessage( - content=content - ) + assistant_prompt_message = AssistantPromptMessage(content=content) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) @@ -270,15 +275,14 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel try: if not model_info: - raise ValueError(f'Model {model_name} not found.') + raise ValueError(f"Model {model_name} not found.") - if 'inference' in model_info.cardData and not model_info.cardData['inference']: - raise ValueError(f'Inference API has been turned off for this model {model_name}.') + if "inference" in model_info.cardData and not model_info.cardData["inference"]: + raise ValueError(f"Inference API has been turned off for this model {model_name}.") valid_tasks = ("text2text-generation", "text-generation") if model_info.pipeline_tag not in valid_tasks: - raise ValueError(f"Model {model_name} is not a valid task, " - f"must be one of {valid_tasks}.") + raise ValueError(f"Model {model_name} is not a valid task, " f"must be one of {valid_tasks}.") except Exception as e: raise CredentialsValidateFailedError(f"{str(e)}") @@ -287,10 +291,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) return text.rstrip() diff --git a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py index 0f0c166f3e..cb7a30bbe5 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py @@ -13,40 +13,30 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.huggingface_hub._common import _CommonHuggingfaceHub -HUGGINGFACE_ENDPOINT_API = 'https://api.endpoints.huggingface.cloud/v2/endpoint/' +HUGGINGFACE_ENDPOINT_API = "https://api.endpoints.huggingface.cloud/v2/endpoint/" class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel): - - def _invoke(self, model: str, credentials: dict, texts: list[str], - user: Optional[str] = None) -> TextEmbeddingResult: - client = InferenceClient(token=credentials['huggingfacehub_api_token']) + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: + client = InferenceClient(token=credentials["huggingfacehub_api_token"]) execute_model = model - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - execute_model = credentials['huggingfacehub_endpoint_url'] + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + execute_model = credentials["huggingfacehub_endpoint_url"] output = client.post( - json={ - "inputs": texts, - "options": { - "wait_for_model": False, - "use_cache": False - } - }, - model=execute_model) + json={"inputs": texts, "options": {"wait_for_model": False, "use_cache": False}}, model=execute_model + ) embeddings = json.loads(output.decode()) tokens = self.get_num_tokens(model, credentials, texts) usage = self._calc_response_usage(model, credentials, tokens) - return TextEmbeddingResult( - embeddings=self._mean_pooling(embeddings), - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=self._mean_pooling(embeddings), usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: num_tokens = 0 @@ -56,52 +46,48 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel def validate_credentials(self, model: str, credentials: dict) -> None: try: - if 'huggingfacehub_api_type' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type must be provided.') + if "huggingfacehub_api_type" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.") - if 'huggingfacehub_api_token' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub API Token must be provided.') + if "huggingfacehub_api_token" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub API Token must be provided.") - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - if 'huggingface_namespace' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub User Name / Organization Name must be provided.') + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + if "huggingface_namespace" not in credentials: + raise CredentialsValidateFailedError( + "Huggingface Hub User Name / Organization Name must be provided." + ) - if 'huggingfacehub_endpoint_url' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint URL must be provided.') + if "huggingfacehub_endpoint_url" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Endpoint URL must be provided.") - if 'task_type' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Task Type must be provided.') + if "task_type" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Task Type must be provided.") - if credentials['task_type'] != 'feature-extraction': - raise CredentialsValidateFailedError('Huggingface Hub Task Type is invalid.') + if credentials["task_type"] != "feature-extraction": + raise CredentialsValidateFailedError("Huggingface Hub Task Type is invalid.") self._check_endpoint_url_model_repository_name(credentials, model) - model = credentials['huggingfacehub_endpoint_url'] + model = credentials["huggingfacehub_endpoint_url"] - elif credentials['huggingfacehub_api_type'] == 'hosted_inference_api': - self._check_hosted_model_task_type(credentials['huggingfacehub_api_token'], - model) + elif credentials["huggingfacehub_api_type"] == "hosted_inference_api": + self._check_hosted_model_task_type(credentials["huggingfacehub_api_token"], model) else: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type is invalid.') + raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.") - client = InferenceClient(token=credentials['huggingfacehub_api_token']) - client.feature_extraction(text='hello world', model=model) + client = InferenceClient(token=credentials["huggingfacehub_api_token"]) + client.feature_extraction(text="hello world", model=model) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, - model_properties={ - 'context_size': 10000, - 'max_chunks': 1 - } + model_properties={"context_size": 10000, "max_chunks": 1}, ) return entity @@ -128,24 +114,20 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel try: if not model_info: - raise ValueError(f'Model {model_name} not found.') + raise ValueError(f"Model {model_name} not found.") - if 'inference' in model_info.cardData and not model_info.cardData['inference']: - raise ValueError(f'Inference API has been turned off for this model {model_name}.') + if "inference" in model_info.cardData and not model_info.cardData["inference"]: + raise ValueError(f"Inference API has been turned off for this model {model_name}.") valid_tasks = "feature-extraction" if model_info.pipeline_tag not in valid_tasks: - raise ValueError(f"Model {model_name} is not a valid task, " - f"must be one of {valid_tasks}.") + raise ValueError(f"Model {model_name} is not a valid task, " f"must be one of {valid_tasks}.") except Exception as e: raise CredentialsValidateFailedError(f"{str(e)}") def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -156,7 +138,7 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -166,25 +148,26 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel try: url = f'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}' headers = { - 'Authorization': f'Bearer {credentials["huggingfacehub_api_token"]}', - 'Content-Type': 'application/json' + "Authorization": f'Bearer {credentials["huggingfacehub_api_token"]}', + "Content-Type": "application/json", } response = requests.get(url=url, headers=headers) if response.status_code != 200: - raise ValueError('User Name or Organization Name is invalid.') + raise ValueError("User Name or Organization Name is invalid.") - model_repository_name = '' + model_repository_name = "" for item in response.json().get("items", []): - if item.get("status", {}).get("url") == credentials['huggingfacehub_endpoint_url']: + if item.get("status", {}).get("url") == credentials["huggingfacehub_endpoint_url"]: model_repository_name = item.get("model", {}).get("repository") break if model_repository_name != model_name: raise ValueError( - f'Model Name {model_name} is invalid. Please check it on the inference endpoints console.') + f"Model Name {model_name} is invalid. Please check it on the inference endpoints console." + ) except Exception as e: raise ValueError(str(e)) diff --git a/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py b/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py index 9454466250..97d7e28dc6 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py @@ -6,6 +6,5 @@ logger = logging.getLogger(__name__) class HuggingfaceTeiProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py b/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py index 34013426de..c128c35f6d 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py @@ -47,29 +47,29 @@ class HuggingfaceTeiRerankModel(RerankModel): """ if len(docs) == 0: return RerankResult(model=model, docs=[]) - server_url = credentials['server_url'] + server_url = credentials["server_url"] - if server_url.endswith('/'): + if server_url.endswith("/"): server_url = server_url[:-1] try: results = TeiHelper.invoke_rerank(server_url, query, docs) rerank_documents = [] - for result in results: + for result in results: rerank_document = RerankDocument( - index=result['index'], - text=result['text'], - score=result['score'], + index=result["index"], + text=result["text"], + score=result["score"], ) - if score_threshold is None or result['score'] >= score_threshold: + if score_threshold is None or result["score"] >= score_threshold: rerank_documents.append(rerank_document) if top_n is not None and len(rerank_documents) >= top_n: break return RerankResult(model=model, docs=rerank_documents) except httpx.HTTPStatusError as e: - raise InvokeServerUnavailableError(str(e)) + raise InvokeServerUnavailableError(str(e)) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -80,21 +80,21 @@ class HuggingfaceTeiRerankModel(RerankModel): :return: """ try: - server_url = credentials['server_url'] + server_url = credentials["server_url"] extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) - if extra_args.model_type != 'reranker': - raise CredentialsValidateFailedError('Current model is not a rerank model') + if extra_args.model_type != "reranker": + raise CredentialsValidateFailedError("Current model is not a rerank model") - credentials['context_size'] = extra_args.max_input_length + credentials["context_size"] = extra_args.max_input_length self.invoke( model=model, credentials=credentials, - query='Whose kasumi', + query="Whose kasumi", docs=[ 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', - 'Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ', - 'and she leads a team named PopiParty.', + "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", + "and she leads a team named PopiParty.", ], score_threshold=0.8, ) @@ -129,7 +129,7 @@ class HuggingfaceTeiRerankModel(RerankModel): fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.RERANK, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)), }, parameter_rules=[], ) diff --git a/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py b/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py index 2aa785c89d..56c51e8888 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py @@ -31,16 +31,16 @@ class TeiHelper: with cache_lock: if model_name not in cache: cache[model_name] = { - 'expires': time() + 300, - 'value': TeiHelper._get_tei_extra_parameter(server_url), + "expires": time() + 300, + "value": TeiHelper._get_tei_extra_parameter(server_url), } - return cache[model_name]['value'] + return cache[model_name]["value"] @staticmethod def _clean_cache() -> None: try: with cache_lock: - expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()] + expired_keys = [model_uid for model_uid, model in cache.items() if model["expires"] < time()] for model_uid in expired_keys: del cache[model_uid] except RuntimeError as e: @@ -52,40 +52,38 @@ class TeiHelper: get tei model extra parameter like model_type, max_input_length, max_batch_requests """ - url = str(URL(server_url) / 'info') + url = str(URL(server_url) / "info") # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 session = Session() - session.mount('http://', HTTPAdapter(max_retries=3)) - session.mount('https://', HTTPAdapter(max_retries=3)) + session.mount("http://", HTTPAdapter(max_retries=3)) + session.mount("https://", HTTPAdapter(max_retries=3)) try: response = session.get(url, timeout=10) except (MissingSchema, ConnectionError, Timeout) as e: - raise RuntimeError(f'get tei model extra parameter failed, url: {url}, error: {e}') + raise RuntimeError(f"get tei model extra parameter failed, url: {url}, error: {e}") if response.status_code != 200: raise RuntimeError( - f'get tei model extra parameter failed, status code: {response.status_code}, response: {response.text}' + f"get tei model extra parameter failed, status code: {response.status_code}, response: {response.text}" ) response_json = response.json() - model_type = response_json.get('model_type', {}) + model_type = response_json.get("model_type", {}) if len(model_type.keys()) < 1: - raise RuntimeError('model_type is empty') + raise RuntimeError("model_type is empty") model_type = list(model_type.keys())[0] - if model_type not in ['embedding', 'reranker']: - raise RuntimeError(f'invalid model_type: {model_type}') - - max_input_length = response_json.get('max_input_length', 512) - max_client_batch_size = response_json.get('max_client_batch_size', 1) + if model_type not in ["embedding", "reranker"]: + raise RuntimeError(f"invalid model_type: {model_type}") + + max_input_length = response_json.get("max_input_length", 512) + max_client_batch_size = response_json.get("max_client_batch_size", 1) return TeiModelExtraParameter( - model_type=model_type, - max_input_length=max_input_length, - max_client_batch_size=max_client_batch_size + model_type=model_type, max_input_length=max_input_length, max_client_batch_size=max_client_batch_size ) - + @staticmethod def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: """ @@ -116,12 +114,12 @@ class TeiHelper: :param texts: texts to tokenize """ resp = httpx.post( - f'{server_url}/tokenize', - json={'inputs': texts}, + f"{server_url}/tokenize", + json={"inputs": texts}, ) resp.raise_for_status() return resp.json() - + @staticmethod def invoke_embeddings(server_url: str, texts: list[str]) -> dict: """ @@ -149,8 +147,8 @@ class TeiHelper: """ # Use OpenAI compatible API here, which has usage tracking resp = httpx.post( - f'{server_url}/v1/embeddings', - json={'input': texts}, + f"{server_url}/v1/embeddings", + json={"input": texts}, ) resp.raise_for_status() return resp.json() @@ -173,11 +171,11 @@ class TeiHelper: :param texts: texts to rerank :param candidates: candidates to rerank """ - params = {'query': query, 'texts': docs, 'return_text': True} + params = {"query": query, "texts": docs, "return_text": True} response = httpx.post( - server_url + '/rerank', + server_url + "/rerank", json=params, ) - response.raise_for_status() + response.raise_for_status() return response.json() diff --git a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py index 6897b87f6d..2d04abb277 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py @@ -40,12 +40,11 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - server_url = credentials['server_url'] + server_url = credentials["server_url"] - if server_url.endswith('/'): + if server_url.endswith("/"): server_url = server_url[:-1] - # get model properties context_size = self._get_context_size(model, credentials) max_chunks = self._get_max_chunks(model, credentials) @@ -58,7 +57,6 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts) for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)): - # Check if the number of tokens is larger than the context size num_tokens = len(tokenize_result) @@ -66,20 +64,22 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): # Find the best cutoff point pre_special_token_count = 0 for token in tokenize_result: - if token['special']: + if token["special"]: pre_special_token_count += 1 else: break - rest_special_token_count = len([token for token in tokenize_result if token['special']]) - pre_special_token_count + rest_special_token_count = ( + len([token for token in tokenize_result if token["special"]]) - pre_special_token_count + ) # Calculate the cutoff point, leave 20 extra space to avoid exceeding the limit token_cutoff = context_size - rest_special_token_count - 20 # Find the cutoff index cutpoint_token = tokenize_result[token_cutoff] - cutoff = cutpoint_token['start'] + cutoff = cutpoint_token["start"] - inputs.append(text[0: cutoff]) + inputs.append(text[0:cutoff]) else: inputs.append(text) indices += [i] @@ -92,12 +92,12 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): for i in _iter: iter_texts = inputs[i : i + max_chunks] results = TeiHelper.invoke_embeddings(server_url, iter_texts) - embeddings = results['data'] - embeddings = [embedding['embedding'] for embedding in embeddings] + embeddings = results["data"] + embeddings = [embedding["embedding"] for embedding in embeddings] batched_embeddings.extend(embeddings) - usage = results['usage'] - used_tokens += usage['total_tokens'] + usage = results["usage"] + used_tokens += usage["total_tokens"] except RuntimeError as e: raise InvokeServerUnavailableError(str(e)) @@ -117,9 +117,9 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): :return: """ num_tokens = 0 - server_url = credentials['server_url'] + server_url = credentials["server_url"] - if server_url.endswith('/'): + if server_url.endswith("/"): server_url = server_url[:-1] batch_tokens = TeiHelper.invoke_tokenize(server_url, texts) @@ -135,15 +135,15 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - server_url = credentials['server_url'] + server_url = credentials["server_url"] extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) print(extra_args) - if extra_args.model_type != 'embedding': - raise CredentialsValidateFailedError('Current model is not a embedding model') + if extra_args.model_type != "embedding": + raise CredentialsValidateFailedError("Current model is not a embedding model") - credentials['context_size'] = extra_args.max_input_length - credentials['max_chunks'] = extra_args.max_client_batch_size - self._invoke(model=model, credentials=credentials, texts=['ping']) + credentials["context_size"] = extra_args.max_input_length + credentials["max_chunks"] = extra_args.max_client_batch_size + self._invoke(model=model, credentials=credentials, texts=["ping"]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -195,8 +195,8 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ - ModelPropertyKey.MAX_CHUNKS: int(credentials.get('max_chunks', 1)), - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)), + ModelPropertyKey.MAX_CHUNKS: int(credentials.get("max_chunks", 1)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)), }, parameter_rules=[], ) diff --git a/api/core/model_runtime/model_providers/hunyuan/hunyuan.py b/api/core/model_runtime/model_providers/hunyuan/hunyuan.py index 5a298d33ac..e65772e7dd 100644 --- a/api/core/model_runtime/model_providers/hunyuan/hunyuan.py +++ b/api/core/model_runtime/model_providers/hunyuan/hunyuan.py @@ -6,8 +6,8 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) -class HunyuanProvider(ModelProvider): +class HunyuanProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +19,9 @@ class HunyuanProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `hunyuan-standard` model for validate, - model_instance.validate_credentials( - model='hunyuan-standard', - credentials=credentials - ) + model_instance.validate_credentials(model="hunyuan-standard", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/llm.py b/api/core/model_runtime/model_providers/hunyuan/llm/llm.py index 0bdf6ec005..c056ab7a08 100644 --- a/api/core/model_runtime/model_providers/hunyuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/hunyuan/llm/llm.py @@ -23,21 +23,27 @@ from core.model_runtime.model_providers.__base.large_language_model import Large logger = logging.getLogger(__name__) + class HunyuanLargeLanguageModel(LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: client = self._setup_hunyuan_client(credentials) request = models.ChatCompletionsRequest() messages_dict = self._convert_prompt_messages_to_dicts(prompt_messages) custom_parameters = { - 'Temperature': model_parameters.get('temperature', 0.0), - 'TopP': model_parameters.get('top_p', 1.0), - 'EnableEnhancement': model_parameters.get('enable_enhance', True) + "Temperature": model_parameters.get("temperature", 0.0), + "TopP": model_parameters.get("top_p", 1.0), + "EnableEnhancement": model_parameters.get("enable_enhance", True), } params = { @@ -47,16 +53,19 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): **custom_parameters, } # add Tools and ToolChoice - if (tools and len(tools) > 0): - params['ToolChoice'] = "auto" - params['Tools'] = [{ - "Type": "function", - "Function": { - "Name": tool.name, - "Description": tool.description, - "Parameters": json.dumps(tool.parameters) + if tools and len(tools) > 0: + params["ToolChoice"] = "auto" + params["Tools"] = [ + { + "Type": "function", + "Function": { + "Name": tool.name, + "Description": tool.description, + "Parameters": json.dumps(tool.parameters), + }, } - } for tool in tools] + for tool in tools + ] request.from_json_string(json.dumps(params)) response = client.ChatCompletions(request) @@ -76,22 +85,19 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): req = models.ChatCompletionsRequest() params = { "Model": model, - "Messages": [{ - "Role": "user", - "Content": "hello" - }], + "Messages": [{"Role": "user", "Content": "hello"}], "TopP": 1, "Temperature": 0, - "Stream": False + "Stream": False, } req.from_json_string(json.dumps(params)) client.ChatCompletions(req) except Exception as e: - raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") def _setup_hunyuan_client(self, credentials): - secret_id = credentials['secret_id'] - secret_key = credentials['secret_key'] + secret_id = credentials["secret_id"] + secret_key = credentials["secret_key"] cred = credential.Credential(secret_id, secret_key) httpProfile = HttpProfile() httpProfile.endpoint = "hunyuan.tencentcloudapi.com" @@ -106,92 +112,96 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): for message in prompt_messages: if isinstance(message, AssistantPromptMessage): tool_calls = message.tool_calls - if (tool_calls and len(tool_calls) > 0): + if tool_calls and len(tool_calls) > 0: dict_tool_calls = [ { "Id": tool_call.id, "Type": tool_call.type, "Function": { "Name": tool_call.function.name, - "Arguments": tool_call.function.arguments if (tool_call.function.arguments == "") else "{}" - } - } for tool_call in tool_calls] - - dict_list.append({ - "Role": message.role.value, - # fix set content = "" while tool_call request - # fix [hunyuan] None, [TencentCloudSDKException] code:InvalidParameter message:Messages Content and Contents not allowed empty at the same time. - "Content": " ", # message.content if (message.content is not None) else "", - "ToolCalls": dict_tool_calls - }) + "Arguments": tool_call.function.arguments + if (tool_call.function.arguments == "") + else "{}", + }, + } + for tool_call in tool_calls + ] + + dict_list.append( + { + "Role": message.role.value, + # fix set content = "" while tool_call request + # fix [hunyuan] None, [TencentCloudSDKException] code:InvalidParameter message:Messages Content and Contents not allowed empty at the same time. + "Content": " ", # message.content if (message.content is not None) else "", + "ToolCalls": dict_tool_calls, + } + ) else: - dict_list.append({ "Role": message.role.value, "Content": message.content }) + dict_list.append({"Role": message.role.value, "Content": message.content}) elif isinstance(message, ToolPromptMessage): - tool_execute_result = { "result": message.content } - content =json.dumps(tool_execute_result, ensure_ascii=False) - dict_list.append({ "Role": message.role.value, "Content": content, "ToolCallId": message.tool_call_id }) + tool_execute_result = {"result": message.content} + content = json.dumps(tool_execute_result, ensure_ascii=False) + dict_list.append({"Role": message.role.value, "Content": content, "ToolCallId": message.tool_call_id}) else: - dict_list.append({ "Role": message.role.value, "Content": message.content }) + dict_list.append({"Role": message.role.value, "Content": message.content}) return dict_list def _handle_stream_chat_response(self, model, credentials, prompt_messages, resp): - tool_call = None tool_calls = [] for index, event in enumerate(resp): logging.debug("_handle_stream_chat_response, event: %s", event) - data_str = event['data'] + data_str = event["data"] data = json.loads(data_str) - choices = data.get('Choices', []) + choices = data.get("Choices", []) if not choices: continue choice = choices[0] - delta = choice.get('Delta', {}) - message_content = delta.get('Content', '') - finish_reason = choice.get('FinishReason', '') + delta = choice.get("Delta", {}) + message_content = delta.get("Content", "") + finish_reason = choice.get("FinishReason", "") - usage = data.get('Usage', {}) - prompt_tokens = usage.get('PromptTokens', 0) - completion_tokens = usage.get('CompletionTokens', 0) + usage = data.get("Usage", {}) + prompt_tokens = usage.get("PromptTokens", 0) + completion_tokens = usage.get("CompletionTokens", 0) - response_tool_calls = delta.get('ToolCalls') - if (response_tool_calls is not None): + response_tool_calls = delta.get("ToolCalls") + if response_tool_calls is not None: new_tool_calls = self._extract_response_tool_calls(response_tool_calls) - if (len(new_tool_calls) > 0): + if len(new_tool_calls) > 0: new_tool_call = new_tool_calls[0] - if (tool_call is None): tool_call = new_tool_call - elif (tool_call.id != new_tool_call.id): + if tool_call is None: + tool_call = new_tool_call + elif tool_call.id != new_tool_call.id: tool_calls.append(tool_call) tool_call = new_tool_call else: tool_call.function.name += new_tool_call.function.name tool_call.function.arguments += new_tool_call.function.arguments - if (tool_call is not None and len(tool_call.function.name) > 0 and len(tool_call.function.arguments) > 0): + if tool_call is not None and len(tool_call.function.name) > 0 and len(tool_call.function.arguments) > 0: tool_calls.append(tool_call) tool_call = None - assistant_prompt_message = AssistantPromptMessage( - content=message_content, - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=message_content, tool_calls=[]) # rewrite content = "" while tool_call to avoid show content on web page - if (len(tool_calls) > 0): assistant_prompt_message.content = "" - + if len(tool_calls) > 0: + assistant_prompt_message.content = "" + # add tool_calls to assistant_prompt_message - if (finish_reason == 'tool_calls'): + if finish_reason == "tool_calls": assistant_prompt_message.tool_calls = tool_calls tool_call = None tool_calls = [] - if (len(finish_reason) > 0): + if len(finish_reason) > 0: usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) delta_chunk = LLMResultChunkDelta( index=index, - role=delta.get('Role', 'assistant'), + role=delta.get("Role", "assistant"), message=assistant_prompt_message, usage=usage, finish_reason=finish_reason, @@ -212,8 +222,9 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): ) def _handle_chat_response(self, credentials, model, prompt_messages, response): - usage = self._calc_response_usage(model, credentials, response.Usage.PromptTokens, - response.Usage.CompletionTokens) + usage = self._calc_response_usage( + model, credentials, response.Usage.PromptTokens, response.Usage.CompletionTokens + ) assistant_prompt_message = AssistantPromptMessage() assistant_prompt_message.content = response.Choices[0].Message.Content result = LLMResult( @@ -225,8 +236,13 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): return result - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: if len(prompt_messages) == 0: return 0 prompt = self._convert_messages_to_prompt(prompt_messages) @@ -241,10 +257,7 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() @@ -287,10 +300,8 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): return { InvokeError: [TencentCloudSDKException], } - - def _extract_response_tool_calls(self, - response_tool_calls: list[dict]) \ - -> list[AssistantPromptMessage.ToolCall]: + + def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -300,17 +311,14 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): tool_calls = [] if response_tool_calls: for response_tool_call in response_tool_calls: - response_function = response_tool_call.get('Function', {}) + response_function = response_tool_call.get("Function", {}) function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function.get('Name', ''), - arguments=response_function.get('Arguments', '') + name=response_function.get("Name", ""), arguments=response_function.get("Arguments", "") ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.get('Id', 0), - type='function', - function=function + id=response_tool_call.get("Id", 0), type="function", function=function ) tool_calls.append(tool_call) - return tool_calls \ No newline at end of file + return tool_calls diff --git a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py index 64d8dcf795..1396e59e18 100644 --- a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py @@ -19,14 +19,15 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE logger = logging.getLogger(__name__) + class HunyuanTextEmbeddingModel(TextEmbeddingModel): """ Model class for Hunyuan text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -37,9 +38,9 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): :return: embeddings result """ - if model != 'hunyuan-embedding': - raise ValueError('Invalid model name') - + if model != "hunyuan-embedding": + raise ValueError("Invalid model name") + client = self._setup_hunyuan_client(credentials) embeddings = [] @@ -47,9 +48,7 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): for input in texts: request = models.GetEmbeddingRequest() - params = { - "Input": input - } + params = {"Input": input} request.from_json_string(json.dumps(params)) response = client.GetEmbedding(request) usage = response.Usage.TotalTokens @@ -60,11 +59,7 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): result = TextEmbeddingResult( model=model, embeddings=embeddings, - usage=self._calc_response_usage( - model=model, - credentials=credentials, - tokens=token_usage - ) + usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage), ) return result @@ -79,22 +74,19 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): req = models.ChatCompletionsRequest() params = { "Model": model, - "Messages": [{ - "Role": "user", - "Content": "hello" - }], + "Messages": [{"Role": "user", "Content": "hello"}], "TopP": 1, "Temperature": 0, - "Stream": False + "Stream": False, } req.from_json_string(json.dumps(params)) client.ChatCompletions(req) except Exception as e: - raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") def _setup_hunyuan_client(self, credentials): - secret_id = credentials['secret_id'] - secret_key = credentials['secret_key'] + secret_id = credentials["secret_id"] + secret_key = credentials["secret_key"] cred = credential.Credential(secret_id, secret_key) httpProfile = HttpProfile() httpProfile.endpoint = "hunyuan.tencentcloudapi.com" @@ -102,7 +94,7 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): clientProfile.httpProfile = httpProfile client = hunyuan_client.HunyuanClient(cred, "", clientProfile) return client - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -114,10 +106,7 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -128,11 +117,11 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage - + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ @@ -146,7 +135,7 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): return { InvokeError: [TencentCloudSDKException], } - + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ Get number of tokens for given prompt messages @@ -170,4 +159,4 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): # response = client.GetTokenCount(request) # num_tokens += response.TokenCount - return num_tokens \ No newline at end of file + return num_tokens diff --git a/api/core/model_runtime/model_providers/jina/jina.py b/api/core/model_runtime/model_providers/jina/jina.py index cde4313495..33977b6a33 100644 --- a/api/core/model_runtime/model_providers/jina/jina.py +++ b/api/core/model_runtime/model_providers/jina/jina.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class JinaProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -21,12 +20,9 @@ class JinaProvider(ModelProvider): # Use `jina-embeddings-v2-base-en` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='jina-embeddings-v2-base-en', - credentials=credentials - ) + model_instance.validate_credentials(model="jina-embeddings-v2-base-en", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/jina/rerank/rerank.py b/api/core/model_runtime/model_providers/jina/rerank/rerank.py index de7e038b9f..d8394f7a4c 100644 --- a/api/core/model_runtime/model_providers/jina/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/jina/rerank/rerank.py @@ -22,9 +22,16 @@ class JinaRerankModel(RerankModel): Model class for Jina rerank model. """ - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -40,37 +47,32 @@ class JinaRerankModel(RerankModel): if len(docs) == 0: return RerankResult(model=model, docs=[]) - base_url = credentials.get('base_url', 'https://api.jina.ai/v1') - if base_url.endswith('/'): + base_url = credentials.get("base_url", "https://api.jina.ai/v1") + if base_url.endswith("/"): base_url = base_url[:-1] try: response = httpx.post( - base_url + '/rerank', - json={ - "model": model, - "query": query, - "documents": docs, - "top_n": top_n - }, - headers={"Authorization": f"Bearer {credentials.get('api_key')}"} + base_url + "/rerank", + json={"model": model, "query": query, "documents": docs, "top_n": top_n}, + headers={"Authorization": f"Bearer {credentials.get('api_key')}"}, ) - response.raise_for_status() + response.raise_for_status() results = response.json() rerank_documents = [] - for result in results['results']: + for result in results["results"]: rerank_document = RerankDocument( - index=result['index'], - text=result['document']['text'], - score=result['relevance_score'], + index=result["index"], + text=result["document"]["text"], + score=result["relevance_score"], ) - if score_threshold is None or result['relevance_score'] >= score_threshold: + if score_threshold is None or result["relevance_score"] >= score_threshold: rerank_documents.append(rerank_document) return RerankResult(model=model, docs=rerank_documents) except httpx.HTTPStatusError as e: - raise InvokeServerUnavailableError(str(e)) + raise InvokeServerUnavailableError(str(e)) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -81,7 +83,6 @@ class JinaRerankModel(RerankModel): :return: """ try: - self._invoke( model=model, credentials=credentials, @@ -92,7 +93,7 @@ class JinaRerankModel(RerankModel): "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -105,23 +106,21 @@ class JinaRerankModel(RerankModel): return { InvokeConnectionError: [httpx.ConnectError], InvokeServerUnavailableError: [httpx.RemoteProtocolError], - InvokeRateLimitError: [], - InvokeAuthorizationError: [httpx.HTTPStatusError], - InvokeBadRequestError: [httpx.RequestError] + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, label=I18nObject(en_US=model), model_type=ModelType.RERANK, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')) - } + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py index 50f8c73ed9..d80cbfa83d 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py @@ -14,19 +14,19 @@ class JinaTokenizer: with cls._lock: if cls._tokenizer is None: base_path = abspath(__file__) - gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer') + gpt2_tokenizer_path = join(dirname(base_path), "tokenizer") cls._tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path) return cls._tokenizer @classmethod def _get_num_tokens_by_jina_base(cls, text: str) -> int: """ - use jina tokenizer to get num tokens + use jina tokenizer to get num tokens """ tokenizer = cls._get_tokenizer() tokens = tokenizer.encode(text) return len(tokens) - + @classmethod def get_num_tokens(cls, text: str) -> int: - return cls._get_num_tokens_by_jina_base(text) \ No newline at end of file + return cls._get_num_tokens_by_jina_base(text) diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py index 23203491e6..7ed3e4d384 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py @@ -24,11 +24,12 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): """ Model class for Jina text embedding model. """ - api_base: str = 'https://api.jina.ai/v1' - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + api_base: str = "https://api.jina.ai/v1" + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -38,29 +39,23 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - api_key = credentials['api_key'] + api_key = credentials["api_key"] if not api_key: - raise CredentialsValidateFailedError('api_key is required') + raise CredentialsValidateFailedError("api_key is required") - base_url = credentials.get('base_url', self.api_base) - if base_url.endswith('/'): + base_url = credentials.get("base_url", self.api_base) + if base_url.endswith("/"): base_url = base_url[:-1] - url = base_url + '/embeddings' - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + url = base_url + "/embeddings" + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} def transform_jina_input_text(model, text): - if model == 'jina-clip-v1': + if model == "jina-clip-v1": return {"text": text} return text - data = { - 'model': model, - 'input': [transform_jina_input_text(model, text) for text in texts] - } + data = {"model": model, "input": [transform_jina_input_text(model, text) for text in texts]} try: response = post(url, headers=headers, data=dumps(data)) @@ -70,7 +65,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): if response.status_code != 200: try: resp = response.json() - msg = resp['detail'] + msg = resp["detail"] if response.status_code == 401: raise InvokeAuthorizationError(msg) elif response.status_code == 429: @@ -81,25 +76,20 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): raise InvokeBadRequestError(msg) except JSONDecodeError as e: raise InvokeServerUnavailableError( - f"Failed to convert response to json: {e} with text: {response.text}") + f"Failed to convert response to json: {e} with text: {response.text}" + ) try: resp = response.json() - embeddings = resp['data'] - usage = resp['usage'] + embeddings = resp["data"] + usage = resp["usage"] except Exception as e: - raise InvokeServerUnavailableError( - f"Failed to convert response to json: {e} with text: {response.text}") + raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") - usage = self._calc_response_usage( - model=model, credentials=credentials, tokens=usage['total_tokens']) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"]) result = TextEmbeddingResult( - model=model, - embeddings=[[ - float(data) for data in x['embedding'] - ] for x in embeddings], - usage=usage + model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage ) return result @@ -128,30 +118,18 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except Exception as e: - raise CredentialsValidateFailedError( - f'Credentials validation failed: {e}') + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError, - InvokeBadRequestError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError, InvokeBadRequestError], } def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: @@ -165,10 +143,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -179,24 +154,21 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, label=I18nObject(en_US=model), model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int( - credentials.get('context_size')) - } + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, ) return entity diff --git a/api/core/model_runtime/model_providers/leptonai/leptonai.py b/api/core/model_runtime/model_providers/leptonai/leptonai.py index b035c31ac5..34a55ff192 100644 --- a/api/core/model_runtime/model_providers/leptonai/leptonai.py +++ b/api/core/model_runtime/model_providers/leptonai/leptonai.py @@ -6,8 +6,8 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) -class LeptonAIProvider(ModelProvider): +class LeptonAIProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -18,12 +18,9 @@ class LeptonAIProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='llama2-7b', - credentials=credentials - ) + model_instance.validate_credentials(model="llama2-7b", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/leptonai/llm/llm.py b/api/core/model_runtime/model_providers/leptonai/llm/llm.py index 523309bac5..3d69417e45 100644 --- a/api/core/model_runtime/model_providers/leptonai/llm/llm.py +++ b/api/core/model_runtime/model_providers/leptonai/llm/llm.py @@ -8,18 +8,25 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class LeptonAILargeLanguageModel(OAIAPICompatLargeLanguageModel): MODEL_PREFIX_MAP = { - 'llama2-7b': 'llama2-7b', - 'gemma-7b': 'gemma-7b', - 'mistral-7b': 'mistral-7b', - 'mixtral-8x7b': 'mixtral-8x7b', - 'llama3-70b': 'llama3-70b', - 'llama2-13b': 'llama2-13b', - } - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + "llama2-7b": "llama2-7b", + "gemma-7b": "gemma-7b", + "mistral-7b": "mistral-7b", + "mixtral-8x7b": "mixtral-8x7b", + "llama3-70b": "llama3-70b", + "llama2-13b": "llama2-13b", + } + + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials, model) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -29,6 +36,5 @@ class LeptonAILargeLanguageModel(OAIAPICompatLargeLanguageModel): @classmethod def _add_custom_parameters(cls, credentials: dict, model: str) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = f'https://{cls.MODEL_PREFIX_MAP[model]}.lepton.run/api/v1' - \ No newline at end of file + credentials["mode"] = "chat" + credentials["endpoint_url"] = f"https://{cls.MODEL_PREFIX_MAP[model]}.lepton.run/api/v1" diff --git a/api/core/model_runtime/model_providers/localai/llm/llm.py b/api/core/model_runtime/model_providers/localai/llm/llm.py index 1009995c58..94c03efe7b 100644 --- a/api/core/model_runtime/model_providers/localai/llm/llm.py +++ b/api/core/model_runtime/model_providers/localai/llm/llm.py @@ -52,29 +52,48 @@ from core.model_runtime.utils import helper class LocalAILanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, - model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: + return self._generate( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: # tools is not supported yet return self._num_tokens_from_messages(prompt_messages, tools=tools) def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: """ - Calculate num tokens for baichuan model - LocalAI does not supports + Calculate num tokens for baichuan model + LocalAI does not supports """ def tokens(text: str): """ - We could not determine which tokenizer to use, cause the model is customized. - So we use gpt2 tokenizer to calculate the num tokens for convenience. + We could not determine which tokenizer to use, cause the model is customized. + So we use gpt2 tokenizer to calculate the num tokens for convenience. """ return self._get_num_tokens_by_gpt2(text) @@ -87,10 +106,10 @@ class LocalAILanguageModel(LargeLanguageModel): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -142,30 +161,30 @@ class LocalAILanguageModel(LargeLanguageModel): num_tokens = 0 for tool in tools: # calculate num tokens for function object - num_tokens += tokens('name') + num_tokens += tokens("name") num_tokens += tokens(tool.name) - num_tokens += tokens('description') + num_tokens += tokens("description") num_tokens += tokens(tool.description) parameters = tool.parameters - num_tokens += tokens('parameters') - num_tokens += tokens('type') + num_tokens += tokens("parameters") + num_tokens += tokens("type") num_tokens += tokens(parameters.get("type")) - if 'properties' in parameters: - num_tokens += tokens('properties') - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += tokens("properties") + for key, value in parameters.get("properties").items(): num_tokens += tokens(key) for field_key, field_value in value.items(): num_tokens += tokens(field_key) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += tokens(enum_field) else: num_tokens += tokens(field_key) num_tokens += tokens(str(field_value)) - if 'required' in parameters: - num_tokens += tokens('required') - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += tokens("required") + for required_field in parameters["required"]: num_tokens += 3 num_tokens += tokens(required_field) @@ -180,102 +199,104 @@ class LocalAILanguageModel(LargeLanguageModel): :return: """ try: - self._invoke(model=model, credentials=credentials, prompt_messages=[ - UserPromptMessage(content='ping') - ], model_parameters={ - 'max_tokens': 10, - }, stop=[], stream=False) + self._invoke( + model=model, + credentials=credentials, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={ + "max_tokens": 10, + }, + stop=[], + stream=False, + ) except Exception as ex: - raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}') + raise CredentialsValidateFailedError(f"Invalid credentials {str(ex)}") def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: completion_model = None - if credentials['completion_type'] == 'chat_completion': + if credentials["completion_type"] == "chat_completion": completion_model = LLMMode.CHAT.value - elif credentials['completion_type'] == 'completion': + elif credentials["completion_type"] == "completion": completion_model = LLMMode.COMPLETION.value else: raise ValueError(f"Unknown completion type {credentials['completion_type']}") rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ) + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, max=2048, default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), + ), ] - model_properties = { - ModelPropertyKey.MODE: completion_model, - } if completion_model else {} + model_properties = ( + { + ModelPropertyKey.MODE: completion_model, + } + if completion_model + else {} + ) - model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get('context_size', '2048')) + model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get("context_size", "2048")) entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, model_properties=model_properties, - parameter_rules=rules + parameter_rules=rules, ) return entity - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: kwargs = self._to_client_kwargs(credentials) # init model client client = OpenAI(**kwargs) model_name = model - completion_type = credentials['completion_type'] + completion_type = credentials["completion_type"] extra_model_kwargs = { "timeout": 60, } if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user if tools and len(tools) > 0: - extra_model_kwargs['functions'] = [ - helper.dump_model(tool) for tool in tools - ] + extra_model_kwargs["functions"] = [helper.dump_model(tool) for tool in tools] - if completion_type == 'chat_completion': + if completion_type == "chat_completion": result = client.chat.completions.create( messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], model=model_name, @@ -283,36 +304,32 @@ class LocalAILanguageModel(LargeLanguageModel): **model_parameters, **extra_model_kwargs, ) - elif completion_type == 'completion': + elif completion_type == "completion": result = client.completions.create( prompt=self._convert_prompt_message_to_completion_prompts(prompt_messages), model=model, stream=stream, **model_parameters, - **extra_model_kwargs + **extra_model_kwargs, ) else: raise ValueError(f"Unknown completion type {completion_type}") if stream: - if completion_type == 'completion': + if completion_type == "completion": return self._handle_completion_generate_stream_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) return self._handle_chat_generate_stream_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) - if completion_type == 'completion': + if completion_type == "completion": return self._handle_completion_generate_response( - model=model, credentials=credentials, response=result, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, prompt_messages=prompt_messages ) return self._handle_chat_generate_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) def _to_client_kwargs(self, credentials: dict) -> dict: @@ -322,13 +339,13 @@ class LocalAILanguageModel(LargeLanguageModel): :param credentials: credentials dict :return: client kwargs """ - if not credentials['server_url'].endswith('/'): - credentials['server_url'] += '/' + if not credentials["server_url"].endswith("/"): + credentials["server_url"] += "/" client_kwargs = { "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "api_key": "1", - "base_url": str(URL(credentials['server_url']) / 'v1'), + "base_url": str(URL(credentials["server_url"]) / "v1"), } return client_kwargs @@ -349,7 +366,7 @@ class LocalAILanguageModel(LargeLanguageModel): if message.tool_calls and len(message.tool_calls) > 0: message_dict["function_call"] = { "name": message.tool_calls[0].function.name, - "arguments": message.tool_calls[0].function.arguments + "arguments": message.tool_calls[0].function.arguments, } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) @@ -359,11 +376,7 @@ class LocalAILanguageModel(LargeLanguageModel): message = cast(ToolPromptMessage, message) message_dict = { "role": "user", - "content": [{ - "type": "tool_result", - "tool_use_id": message.tool_call_id, - "content": message.content - }] + "content": [{"type": "tool_result", "tool_use_id": message.tool_call_id, "content": message.content}], } else: raise ValueError(f"Unknown message type {type(message)}") @@ -374,27 +387,29 @@ class LocalAILanguageModel(LargeLanguageModel): """ Convert PromptMessage to completion prompts """ - prompts = '' + prompts = "" for message in messages: if isinstance(message, UserPromptMessage): message = cast(UserPromptMessage, message) - prompts += f'{message.content}\n' + prompts += f"{message.content}\n" elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) - prompts += f'{message.content}\n' + prompts += f"{message.content}\n" elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - prompts += f'{message.content}\n' + prompts += f"{message.content}\n" else: raise ValueError(f"Unknown message type {type(message)}") return prompts - def _handle_completion_generate_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Completion, - ) -> LLMResult: + def _handle_completion_generate_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Completion, + ) -> LLMResult: """ Handle llm chat response @@ -411,18 +426,16 @@ class LocalAILanguageModel(LargeLanguageModel): assistant_message = response.choices[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message, - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message, tool_calls=[]) prompt_tokens = self._get_num_tokens_by_gpt2( self._convert_prompt_message_to_completion_prompts(prompt_messages) ) completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) response = LLMResult( model=model, @@ -434,11 +447,14 @@ class LocalAILanguageModel(LargeLanguageModel): return response - def _handle_chat_generate_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: ChatCompletion, - tools: list[PromptMessageTool]) -> LLMResult: + def _handle_chat_generate_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: ChatCompletion, + tools: list[PromptMessageTool], + ) -> LLMResult: """ Handle llm chat response @@ -459,16 +475,14 @@ class LocalAILanguageModel(LargeLanguageModel): tool_calls = self._extract_response_tool_calls([function_calls] if function_calls else []) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) response = LLMResult( model=model, @@ -480,12 +494,15 @@ class LocalAILanguageModel(LargeLanguageModel): return response - def _handle_completion_generate_stream_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Stream[Completion], - tools: list[PromptMessageTool]) -> Generator: - full_response = '' + def _handle_completion_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Stream[Completion], + tools: list[PromptMessageTool], + ) -> Generator: + full_response = "" for chunk in response: if len(chunk.choices) == 0: @@ -494,17 +511,11 @@ class LocalAILanguageModel(LargeLanguageModel): delta = chunk.choices[0] # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=delta.text if delta.text else '', - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=delta.text if delta.text else "", tool_calls=[]) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage - temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=[] - ) + temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[]) prompt_tokens = self._get_num_tokens_by_gpt2( self._convert_prompt_message_to_completion_prompts(prompt_messages) @@ -512,8 +523,12 @@ class LocalAILanguageModel(LargeLanguageModel): completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) yield LLMResultChunk( model=model, @@ -523,7 +538,7 @@ class LocalAILanguageModel(LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage + usage=usage, ), ) else: @@ -539,12 +554,15 @@ class LocalAILanguageModel(LargeLanguageModel): full_response += delta.text - def _handle_chat_generate_stream_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Stream[ChatCompletionChunk], - tools: list[PromptMessageTool]) -> Generator: - full_response = '' + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Stream[ChatCompletionChunk], + tools: list[PromptMessageTool], + ) -> Generator: + full_response = "" for chunk in response: if len(chunk.choices) == 0: @@ -552,7 +570,7 @@ class LocalAILanguageModel(LargeLanguageModel): delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ""): continue # check if there is a tool call in the response @@ -564,22 +582,24 @@ class LocalAILanguageModel(LargeLanguageModel): # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=assistant_message_tool_calls + content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls ) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=assistant_message_tool_calls + content=full_response, tool_calls=assistant_message_tool_calls ) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) yield LLMResultChunk( model=model, @@ -589,7 +609,7 @@ class LocalAILanguageModel(LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage + usage=usage, ), ) else: @@ -605,9 +625,9 @@ class LocalAILanguageModel(LargeLanguageModel): full_response += delta.delta.content - def _extract_response_tool_calls(self, - response_function_calls: list[FunctionCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls( + self, response_function_calls: list[FunctionCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -618,15 +638,10 @@ class LocalAILanguageModel(LargeLanguageModel): if response_function_calls: for response_tool_call in response_function_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.name, - arguments=response_tool_call.arguments + name=response_tool_call.name, arguments=response_tool_call.arguments ) - tool_call = AssistantPromptMessage.ToolCall( - id=0, - type='function', - function=function - ) + tool_call = AssistantPromptMessage.ToolCall(id=0, type="function", function=function) tool_calls.append(tool_call) return tool_calls @@ -651,15 +666,9 @@ class LocalAILanguageModel(LargeLanguageModel): ConflictError, NotFoundError, UnprocessableEntityError, - PermissionDeniedError + PermissionDeniedError, ], - InvokeRateLimitError: [ - RateLimitError - ], - InvokeAuthorizationError: [ - AuthenticationError - ], - InvokeBadRequestError: [ - ValueError - ] + InvokeRateLimitError: [RateLimitError], + InvokeAuthorizationError: [AuthenticationError], + InvokeBadRequestError: [ValueError], } diff --git a/api/core/model_runtime/model_providers/localai/localai.py b/api/core/model_runtime/model_providers/localai/localai.py index 6d2278fd54..4ff898052b 100644 --- a/api/core/model_runtime/model_providers/localai/localai.py +++ b/api/core/model_runtime/model_providers/localai/localai.py @@ -6,6 +6,5 @@ logger = logging.getLogger(__name__) class LocalAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/localai/rerank/rerank.py b/api/core/model_runtime/model_providers/localai/rerank/rerank.py index c8ba9a6c7c..2b0f53bc19 100644 --- a/api/core/model_runtime/model_providers/localai/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/localai/rerank/rerank.py @@ -25,9 +25,16 @@ class LocalaiRerankModel(RerankModel): LocalAI rerank model API is compatible with Jina rerank model API. So just copy the JinaRerankModel class code here. """ - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -43,45 +50,37 @@ class LocalaiRerankModel(RerankModel): if len(docs) == 0: return RerankResult(model=model, docs=[]) - server_url = credentials['server_url'] + server_url = credentials["server_url"] model_name = model - - if not server_url: - raise CredentialsValidateFailedError('server_url is required') - if not model_name: - raise CredentialsValidateFailedError('model_name is required') - - url = server_url - headers = { - 'Authorization': f"Bearer {credentials.get('api_key')}", - 'Content-Type': 'application/json' - } - data = { - "model": model_name, - "query": query, - "documents": docs, - "top_n": top_n - } + if not server_url: + raise CredentialsValidateFailedError("server_url is required") + if not model_name: + raise CredentialsValidateFailedError("model_name is required") + + url = server_url + headers = {"Authorization": f"Bearer {credentials.get('api_key')}", "Content-Type": "application/json"} + + data = {"model": model_name, "query": query, "documents": docs, "top_n": top_n} try: - response = post(str(URL(url) / 'rerank'), headers=headers, data=dumps(data), timeout=10) - response.raise_for_status() + response = post(str(URL(url) / "rerank"), headers=headers, data=dumps(data), timeout=10) + response.raise_for_status() results = response.json() rerank_documents = [] - for result in results['results']: + for result in results["results"]: rerank_document = RerankDocument( - index=result['index'], - text=result['document']['text'], - score=result['relevance_score'], + index=result["index"], + text=result["document"]["text"], + score=result["relevance_score"], ) - if score_threshold is None or result['relevance_score'] >= score_threshold: + if score_threshold is None or result["relevance_score"] >= score_threshold: rerank_documents.append(rerank_document) return RerankResult(model=model, docs=rerank_documents) except httpx.HTTPStatusError as e: - raise InvokeServerUnavailableError(str(e)) + raise InvokeServerUnavailableError(str(e)) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -92,7 +91,6 @@ class LocalaiRerankModel(RerankModel): :return: """ try: - self._invoke( model=model, credentials=credentials, @@ -103,7 +101,7 @@ class LocalaiRerankModel(RerankModel): "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -116,21 +114,21 @@ class LocalaiRerankModel(RerankModel): return { InvokeConnectionError: [httpx.ConnectError], InvokeServerUnavailableError: [httpx.RemoteProtocolError], - InvokeRateLimitError: [], - InvokeAuthorizationError: [httpx.HTTPStatusError], - InvokeBadRequestError: [httpx.RequestError] + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], } - + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, label=I18nObject(en_US=model), model_type=ModelType.RERANK, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={} + model_properties={}, ) return entity diff --git a/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py index d7403aff4f..4b9d0f5bfe 100644 --- a/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py @@ -32,8 +32,8 @@ class LocalAISpeech2text(Speech2TextModel): :param user: unique user id :return: text for given audio file """ - - url = str(URL(credentials['server_url']) / "v1/audio/transcriptions") + + url = str(URL(credentials["server_url"]) / "v1/audio/transcriptions") data = {"model": model} files = {"file": file} @@ -42,7 +42,7 @@ class LocalAISpeech2text(Speech2TextModel): prepared_request = session.prepare_request(request) response = session.send(prepared_request) - if 'error' in response.json(): + if "error" in response.json(): raise InvokeServerUnavailableError("Empty response") return response.json()["text"] @@ -58,7 +58,7 @@ class LocalAISpeech2text(Speech2TextModel): try: audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self._invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -66,36 +66,24 @@ class LocalAISpeech2text(Speech2TextModel): @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError - ], + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.SPEECH2TEXT, model_properties={}, - parameter_rules=[] + parameter_rules=[], ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py index 954c9d10f2..7d258be81e 100644 --- a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py @@ -24,9 +24,10 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): """ Model class for Jina text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -37,39 +38,33 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): :return: embeddings result """ if len(texts) != 1: - raise InvokeBadRequestError('Only one text is supported') + raise InvokeBadRequestError("Only one text is supported") - server_url = credentials['server_url'] + server_url = credentials["server_url"] model_name = model if not server_url: - raise CredentialsValidateFailedError('server_url is required') + raise CredentialsValidateFailedError("server_url is required") if not model_name: - raise CredentialsValidateFailedError('model_name is required') - - url = server_url - headers = { - 'Authorization': 'Bearer 123', - 'Content-Type': 'application/json' - } + raise CredentialsValidateFailedError("model_name is required") - data = { - 'model': model_name, - 'input': texts[0] - } + url = server_url + headers = {"Authorization": "Bearer 123", "Content-Type": "application/json"} + + data = {"model": model_name, "input": texts[0]} try: - response = post(str(URL(url) / 'embeddings'), headers=headers, data=dumps(data), timeout=10) + response = post(str(URL(url) / "embeddings"), headers=headers, data=dumps(data), timeout=10) except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: try: resp = response.json() - code = resp['error']['code'] - msg = resp['error']['message'] + code = resp["error"]["code"] + msg = resp["error"]["message"] if code == 500: raise InvokeServerUnavailableError(msg) - + if response.status_code == 401: raise InvokeAuthorizationError(msg) elif response.status_code == 429: @@ -79,23 +74,21 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): else: raise InvokeError(msg) except JSONDecodeError as e: - raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") + raise InvokeServerUnavailableError( + f"Failed to convert response to json: {e} with text: {response.text}" + ) try: resp = response.json() - embeddings = resp['data'] - usage = resp['usage'] + embeddings = resp["data"] + usage = resp["usage"] except Exception as e: raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") - usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens']) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"]) result = TextEmbeddingResult( - model=model, - embeddings=[[ - float(data) for data in x['embedding'] - ] for x in embeddings], - usage=usage + model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage ) return result @@ -114,7 +107,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): # use GPT2Tokenizer to get num tokens num_tokens += self._get_num_tokens_by_gpt2(text) return num_tokens - + def _get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ Get customizable model schema @@ -130,10 +123,10 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): features=[], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', '512')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "512")), ModelPropertyKey.MAX_CHUNKS: 1, }, - parameter_rules=[] + parameter_rules=[], ) def validate_credentials(self, model: str, credentials: dict) -> None: @@ -145,32 +138,22 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeAuthorizationError: - raise CredentialsValidateFailedError('Invalid credentials') + raise CredentialsValidateFailedError("Invalid credentials") except InvokeConnectionError as e: - raise CredentialsValidateFailedError(f'Invalid credentials: {e}') + raise CredentialsValidateFailedError(f"Invalid credentials: {e}") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -182,10 +165,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -196,7 +176,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py index 6c41e0d2a5..96f99c8929 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py @@ -17,42 +17,48 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage class MinimaxChatCompletion: """ - Minimax Chat Completion API + Minimax Chat Completion API """ - def generate(self, model: str, api_key: str, group_id: str, - prompt_messages: list[MinimaxMessage], model_parameters: dict, - tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \ - -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: + + def generate( + self, + model: str, + api_key: str, + group_id: str, + prompt_messages: list[MinimaxMessage], + model_parameters: dict, + tools: list[dict[str, Any]], + stop: list[str] | None, + stream: bool, + user: str, + ) -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: """ - generate chat completion + generate chat completion """ if not api_key or not group_id: - raise InvalidAPIKeyError('Invalid API key or group ID') - - url = f'https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}' + raise InvalidAPIKeyError("Invalid API key or group ID") + + url = f"https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}" extra_kwargs = {} - if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int: - extra_kwargs['tokens_to_generate'] = model_parameters['max_tokens'] + if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int: + extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"] - if 'temperature' in model_parameters and type(model_parameters['temperature']) == float: - extra_kwargs['temperature'] = model_parameters['temperature'] + if "temperature" in model_parameters and type(model_parameters["temperature"]) == float: + extra_kwargs["temperature"] = model_parameters["temperature"] - if 'top_p' in model_parameters and type(model_parameters['top_p']) == float: - extra_kwargs['top_p'] = model_parameters['top_p'] + if "top_p" in model_parameters and type(model_parameters["top_p"]) == float: + extra_kwargs["top_p"] = model_parameters["top_p"] - prompt = '你是一个什么都懂的专家' + prompt = "你是一个什么都懂的专家" - role_meta = { - 'user_name': '我', - 'bot_name': '专家' - } + role_meta = {"user_name": "我", "bot_name": "专家"} # check if there is a system message if len(prompt_messages) == 0: - raise BadRequestError('At least one message is required') - + raise BadRequestError("At least one message is required") + if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value: if prompt_messages[0].content: prompt = prompt_messages[0].content @@ -60,40 +66,39 @@ class MinimaxChatCompletion: # check if there is a user message if len(prompt_messages) == 0: - raise BadRequestError('At least one user message is required') - - messages = [{ - 'sender_type': message.role, - 'text': message.content, - } for message in prompt_messages] + raise BadRequestError("At least one user message is required") - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + messages = [ + { + "sender_type": message.role, + "text": message.content, + } + for message in prompt_messages + ] + + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} body = { - 'model': model, - 'messages': messages, - 'prompt': prompt, - 'role_meta': role_meta, - 'stream': stream, - **extra_kwargs + "model": model, + "messages": messages, + "prompt": prompt, + "role_meta": role_meta, + "stream": stream, + **extra_kwargs, } try: - response = post( - url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) + response = post(url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) except Exception as e: raise InternalServerError(e) - + if response.status_code != 200: raise InternalServerError(response.text) - + if stream: return self._handle_stream_chat_generate_response(response) return self._handle_chat_generate_response(response) - + def _handle_error(self, code: int, msg: str): if code == 1000 or code == 1001 or code == 1013 or code == 1027: raise InternalServerError(msg) @@ -110,65 +115,52 @@ class MinimaxChatCompletion: def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage: """ - handle chat generate response + handle chat generate response """ response = response.json() - if 'base_resp' in response and response['base_resp']['status_code'] != 0: - code = response['base_resp']['status_code'] - msg = response['base_resp']['status_msg'] + if "base_resp" in response and response["base_resp"]["status_code"] != 0: + code = response["base_resp"]["status_code"] + msg = response["base_resp"]["status_msg"] self._handle_error(code, msg) - - message = MinimaxMessage( - content=response['reply'], - role=MinimaxMessage.Role.ASSISTANT.value - ) + + message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value) message.usage = { - 'prompt_tokens': 0, - 'completion_tokens': response['usage']['total_tokens'], - 'total_tokens': response['usage']['total_tokens'] + "prompt_tokens": 0, + "completion_tokens": response["usage"]["total_tokens"], + "total_tokens": response["usage"]["total_tokens"], } - message.stop_reason = response['choices'][0]['finish_reason'] + message.stop_reason = response["choices"][0]["finish_reason"] return message def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]: """ - handle stream chat generate response + handle stream chat generate response """ for line in response.iter_lines(): if not line: continue - line: str = line.decode('utf-8') - if line.startswith('data: '): + line: str = line.decode("utf-8") + if line.startswith("data: "): line = line[6:].strip() data = loads(line) - if 'base_resp' in data and data['base_resp']['status_code'] != 0: - code = data['base_resp']['status_code'] - msg = data['base_resp']['status_msg'] + if "base_resp" in data and data["base_resp"]["status_code"] != 0: + code = data["base_resp"]["status_code"] + msg = data["base_resp"]["status_msg"] self._handle_error(code, msg) - if data['reply']: - total_tokens = data['usage']['total_tokens'] - message = MinimaxMessage( - role=MinimaxMessage.Role.ASSISTANT.value, - content='' - ) - message.usage = { - 'prompt_tokens': 0, - 'completion_tokens': total_tokens, - 'total_tokens': total_tokens - } - message.stop_reason = data['choices'][0]['finish_reason'] + if data["reply"]: + total_tokens = data["usage"]["total_tokens"] + message = MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content="") + message.usage = {"prompt_tokens": 0, "completion_tokens": total_tokens, "total_tokens": total_tokens} + message.stop_reason = data["choices"][0]["finish_reason"] yield message return - choices = data.get('choices', []) + choices = data.get("choices", []) if len(choices) == 0: continue for choice in choices: - message = choice['delta'] - yield MinimaxMessage( - content=message, - role=MinimaxMessage.Role.ASSISTANT.value - ) \ No newline at end of file + message = choice["delta"] + yield MinimaxMessage(content=message, role=MinimaxMessage.Role.ASSISTANT.value) diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py index 55747057c9..0a2a67a56d 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py @@ -17,86 +17,83 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage class MinimaxChatCompletionPro: """ - Minimax Chat Completion Pro API, supports function calling - however, we do not have enough time and energy to implement it, but the parameters are reserved + Minimax Chat Completion Pro API, supports function calling + however, we do not have enough time and energy to implement it, but the parameters are reserved """ - def generate(self, model: str, api_key: str, group_id: str, - prompt_messages: list[MinimaxMessage], model_parameters: dict, - tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \ - -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: + + def generate( + self, + model: str, + api_key: str, + group_id: str, + prompt_messages: list[MinimaxMessage], + model_parameters: dict, + tools: list[dict[str, Any]], + stop: list[str] | None, + stream: bool, + user: str, + ) -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: """ - generate chat completion + generate chat completion """ if not api_key or not group_id: - raise InvalidAPIKeyError('Invalid API key or group ID') + raise InvalidAPIKeyError("Invalid API key or group ID") - url = f'https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}' + url = f"https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}" extra_kwargs = {} - if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int: - extra_kwargs['tokens_to_generate'] = model_parameters['max_tokens'] + if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int: + extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"] - if 'temperature' in model_parameters and type(model_parameters['temperature']) == float: - extra_kwargs['temperature'] = model_parameters['temperature'] + if "temperature" in model_parameters and type(model_parameters["temperature"]) == float: + extra_kwargs["temperature"] = model_parameters["temperature"] - if 'top_p' in model_parameters and type(model_parameters['top_p']) == float: - extra_kwargs['top_p'] = model_parameters['top_p'] + if "top_p" in model_parameters and type(model_parameters["top_p"]) == float: + extra_kwargs["top_p"] = model_parameters["top_p"] - if 'mask_sensitive_info' in model_parameters and type(model_parameters['mask_sensitive_info']) == bool: - extra_kwargs['mask_sensitive_info'] = model_parameters['mask_sensitive_info'] - - if model_parameters.get('plugin_web_search'): - extra_kwargs['plugins'] = [ - 'plugin_web_search' - ] + if "mask_sensitive_info" in model_parameters and type(model_parameters["mask_sensitive_info"]) == bool: + extra_kwargs["mask_sensitive_info"] = model_parameters["mask_sensitive_info"] - bot_setting = { - 'bot_name': '专家', - 'content': '你是一个什么都懂的专家' - } + if model_parameters.get("plugin_web_search"): + extra_kwargs["plugins"] = ["plugin_web_search"] - reply_constraints = { - 'sender_type': 'BOT', - 'sender_name': '专家' - } + bot_setting = {"bot_name": "专家", "content": "你是一个什么都懂的专家"} + + reply_constraints = {"sender_type": "BOT", "sender_name": "专家"} # check if there is a system message if len(prompt_messages) == 0: - raise BadRequestError('At least one message is required') + raise BadRequestError("At least one message is required") if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value: if prompt_messages[0].content: - bot_setting['content'] = prompt_messages[0].content + bot_setting["content"] = prompt_messages[0].content prompt_messages = prompt_messages[1:] # check if there is a user message if len(prompt_messages) == 0: - raise BadRequestError('At least one user message is required') + raise BadRequestError("At least one user message is required") messages = [message.to_dict() for message in prompt_messages] - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} body = { - 'model': model, - 'messages': messages, - 'bot_setting': [bot_setting], - 'reply_constraints': reply_constraints, - 'stream': stream, - **extra_kwargs + "model": model, + "messages": messages, + "bot_setting": [bot_setting], + "reply_constraints": reply_constraints, + "stream": stream, + **extra_kwargs, } if tools: - body['functions'] = tools - body['function_call'] = {'type': 'auto'} + body["functions"] = tools + body["function_call"] = {"type": "auto"} try: - response = post( - url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) + response = post(url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) except Exception as e: raise InternalServerError(e) @@ -123,78 +120,72 @@ class MinimaxChatCompletionPro: def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage: """ - handle chat generate response + handle chat generate response """ response = response.json() - if 'base_resp' in response and response['base_resp']['status_code'] != 0: - code = response['base_resp']['status_code'] - msg = response['base_resp']['status_msg'] + if "base_resp" in response and response["base_resp"]["status_code"] != 0: + code = response["base_resp"]["status_code"] + msg = response["base_resp"]["status_msg"] self._handle_error(code, msg) - message = MinimaxMessage( - content=response['reply'], - role=MinimaxMessage.Role.ASSISTANT.value - ) + message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value) message.usage = { - 'prompt_tokens': 0, - 'completion_tokens': response['usage']['total_tokens'], - 'total_tokens': response['usage']['total_tokens'] + "prompt_tokens": 0, + "completion_tokens": response["usage"]["total_tokens"], + "total_tokens": response["usage"]["total_tokens"], } - message.stop_reason = response['choices'][0]['finish_reason'] + message.stop_reason = response["choices"][0]["finish_reason"] return message def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]: """ - handle stream chat generate response + handle stream chat generate response """ for line in response.iter_lines(): if not line: continue - line: str = line.decode('utf-8') - if line.startswith('data: '): + line: str = line.decode("utf-8") + if line.startswith("data: "): line = line[6:].strip() data = loads(line) - if 'base_resp' in data and data['base_resp']['status_code'] != 0: - code = data['base_resp']['status_code'] - msg = data['base_resp']['status_msg'] + if "base_resp" in data and data["base_resp"]["status_code"] != 0: + code = data["base_resp"]["status_code"] + msg = data["base_resp"]["status_msg"] self._handle_error(code, msg) # final chunk - if data['reply'] or data.get('usage'): - total_tokens = data['usage']['total_tokens'] - minimax_message = MinimaxMessage( - role=MinimaxMessage.Role.ASSISTANT.value, - content='' - ) + if data["reply"] or data.get("usage"): + total_tokens = data["usage"]["total_tokens"] + minimax_message = MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content="") minimax_message.usage = { - 'prompt_tokens': 0, - 'completion_tokens': total_tokens, - 'total_tokens': total_tokens + "prompt_tokens": 0, + "completion_tokens": total_tokens, + "total_tokens": total_tokens, } - minimax_message.stop_reason = data['choices'][0]['finish_reason'] + minimax_message.stop_reason = data["choices"][0]["finish_reason"] - choices = data.get('choices', []) + choices = data.get("choices", []) if len(choices) > 0: for choice in choices: - message = choice['messages'][0] + message = choice["messages"][0] # append function_call message - if 'function_call' in message: - function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value) - function_call_message.function_call = message['function_call'] + if "function_call" in message: + function_call_message = MinimaxMessage(content="", role=MinimaxMessage.Role.ASSISTANT.value) + function_call_message.function_call = message["function_call"] yield function_call_message yield minimax_message return # partial chunk - choices = data.get('choices', []) + choices = data.get("choices", []) if len(choices) == 0: continue for choice in choices: - message = choice['messages'][0] + message = choice["messages"][0] # append text message - if 'text' in message: - minimax_message = MinimaxMessage(content=message['text'], role=MinimaxMessage.Role.ASSISTANT.value) + if "text" in message: + minimax_message = MinimaxMessage(content=message["text"], role=MinimaxMessage.Role.ASSISTANT.value) yield minimax_message diff --git a/api/core/model_runtime/model_providers/minimax/llm/errors.py b/api/core/model_runtime/model_providers/minimax/llm/errors.py index d9d279e6ca..309b5cf413 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/errors.py +++ b/api/core/model_runtime/model_providers/minimax/llm/errors.py @@ -1,17 +1,22 @@ class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass + class InsufficientAccountBalanceError(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/minimax/llm/llm.py b/api/core/model_runtime/model_providers/minimax/llm/llm.py index 1fab20ebbc..76ed704a75 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/llm.py +++ b/api/core/model_runtime/model_providers/minimax/llm/llm.py @@ -34,18 +34,25 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage class MinimaxLargeLanguageModel(LargeLanguageModel): model_apis = { - 'abab6.5s-chat': MinimaxChatCompletionPro, - 'abab6.5-chat': MinimaxChatCompletionPro, - 'abab6-chat': MinimaxChatCompletionPro, - 'abab5.5s-chat': MinimaxChatCompletionPro, - 'abab5.5-chat': MinimaxChatCompletionPro, - 'abab5-chat': MinimaxChatCompletion + "abab6.5s-chat": MinimaxChatCompletionPro, + "abab6.5-chat": MinimaxChatCompletionPro, + "abab6-chat": MinimaxChatCompletionPro, + "abab5.5s-chat": MinimaxChatCompletionPro, + "abab5.5-chat": MinimaxChatCompletionPro, + "abab5-chat": MinimaxChatCompletion, } - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def validate_credentials(self, model: str, credentials: dict) -> None: @@ -53,82 +60,97 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): Validate credentials for Baichuan model """ if model not in self.model_apis: - raise CredentialsValidateFailedError(f'Invalid model: {model}') + raise CredentialsValidateFailedError(f"Invalid model: {model}") - if not credentials.get('minimax_api_key'): - raise CredentialsValidateFailedError('Invalid API key') + if not credentials.get("minimax_api_key"): + raise CredentialsValidateFailedError("Invalid API key") + + if not credentials.get("minimax_group_id"): + raise CredentialsValidateFailedError("Invalid group ID") - if not credentials.get('minimax_group_id'): - raise CredentialsValidateFailedError('Invalid group ID') - # ping instance = MinimaxChatCompletionPro() try: instance.generate( - model=model, api_key=credentials['minimax_api_key'], group_id=credentials['minimax_group_id'], - prompt_messages=[ - MinimaxMessage(content='ping', role='USER') - ], + model=model, + api_key=credentials["minimax_api_key"], + group_id=credentials["minimax_group_id"], + prompt_messages=[MinimaxMessage(content="ping", role="USER")], model_parameters={}, - tools=[], stop=[], + tools=[], + stop=[], stream=False, - user='' + user="", ) except (InvalidAuthenticationError, InsufficientAccountBalanceError) as e: raise CredentialsValidateFailedError(f"Invalid API key: {e}") - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: return self._num_tokens_from_messages(prompt_messages, tools) def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: """ - Calculate num tokens for minimax model + Calculate num tokens for minimax model - not like ChatGLM, Minimax has a special prompt structure, we could not find a proper way - to caculate the num tokens, so we use str() to convert the prompt to string + not like ChatGLM, Minimax has a special prompt structure, we could not find a proper way + to calculate the num tokens, so we use str() to convert the prompt to string - Minimax does not provide their own tokenizer of adab5.5 and abab5 model - therefore, we use gpt2 tokenizer instead + Minimax does not provide their own tokenizer of adab5.5 and abab5 model + therefore, we use gpt2 tokenizer instead """ messages_dict = [self._convert_prompt_message_to_minimax_message(m).to_dict() for m in messages] return self._get_num_tokens_by_gpt2(str(messages_dict)) - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - use MinimaxChatCompletionPro as the type of client, anyway, MinimaxChatCompletion has the same interface + use MinimaxChatCompletionPro as the type of client, anyway, MinimaxChatCompletion has the same interface """ client: MinimaxChatCompletionPro = self.model_apis[model]() if tools: - tools = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] + tools = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} for tool in tools + ] response = client.generate( model=model, - api_key=credentials['minimax_api_key'], - group_id=credentials['minimax_group_id'], + api_key=credentials["minimax_api_key"], + group_id=credentials["minimax_group_id"], prompt_messages=[self._convert_prompt_message_to_minimax_message(message) for message in prompt_messages], model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, - user=user + user=user, ) if stream: - return self._handle_chat_generate_stream_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response) - return self._handle_chat_generate_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response) + return self._handle_chat_generate_stream_response( + model=model, prompt_messages=prompt_messages, credentials=credentials, response=response + ) + return self._handle_chat_generate_response( + model=model, prompt_messages=prompt_messages, credentials=credentials, response=response + ) def _convert_prompt_message_to_minimax_message(self, prompt_message: PromptMessage) -> MinimaxMessage: """ - convert PromptMessage to MinimaxMessage so that we can use MinimaxChatCompletionPro interface + convert PromptMessage to MinimaxMessage so that we can use MinimaxChatCompletionPro interface """ if isinstance(prompt_message, SystemPromptMessage): return MinimaxMessage(role=MinimaxMessage.Role.SYSTEM.value, content=prompt_message.content) @@ -136,26 +158,27 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): return MinimaxMessage(role=MinimaxMessage.Role.USER.value, content=prompt_message.content) elif isinstance(prompt_message, AssistantPromptMessage): if prompt_message.tool_calls: - message = MinimaxMessage( - role=MinimaxMessage.Role.ASSISTANT.value, - content='' - ) - message.function_call={ - 'name': prompt_message.tool_calls[0].function.name, - 'arguments': prompt_message.tool_calls[0].function.arguments + message = MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content="") + message.function_call = { + "name": prompt_message.tool_calls[0].function.name, + "arguments": prompt_message.tool_calls[0].function.arguments, } return message return MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content=prompt_message.content) elif isinstance(prompt_message, ToolPromptMessage): return MinimaxMessage(role=MinimaxMessage.Role.FUNCTION.value, content=prompt_message.content) else: - raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported') + raise NotImplementedError(f"Prompt message type {type(prompt_message)} is not supported") - def _handle_chat_generate_response(self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: MinimaxMessage) -> LLMResult: - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=response.usage['prompt_tokens'], - completion_tokens=response.usage['completion_tokens'] - ) + def _handle_chat_generate_response( + self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: MinimaxMessage + ) -> LLMResult: + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=response.usage["prompt_tokens"], + completion_tokens=response.usage["completion_tokens"], + ) return LLMResult( model=model, prompt_messages=prompt_messages, @@ -166,31 +189,33 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): usage=usage, ) - def _handle_chat_generate_stream_response(self, model: str, prompt_messages: list[PromptMessage], - credentials: dict, response: Generator[MinimaxMessage, None, None]) \ - -> Generator[LLMResultChunk, None, None]: + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Generator[MinimaxMessage, None, None], + ) -> Generator[LLMResultChunk, None, None]: for message in response: if message.usage: usage = self._calc_response_usage( - model=model, credentials=credentials, - prompt_tokens=message.usage['prompt_tokens'], - completion_tokens=message.usage['completion_tokens'] + model=model, + credentials=credentials, + prompt_tokens=message.usage["prompt_tokens"], + completion_tokens=message.usage["completion_tokens"], ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=message.content, tool_calls=[]), usage=usage, finish_reason=message.stop_reason if message.stop_reason else None, ), ) elif message.function_call: - if 'name' not in message.function_call or 'arguments' not in message.function_call: + if "name" not in message.function_call or "arguments" not in message.function_call: continue yield LLMResultChunk( @@ -199,15 +224,16 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=0, message=AssistantPromptMessage( - content='', - tool_calls=[AssistantPromptMessage.ToolCall( - id='', - type='function', - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=message.function_call['name'], - arguments=message.function_call['arguments'] + content="", + tool_calls=[ + AssistantPromptMessage.ToolCall( + id="", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=message.function_call["name"], arguments=message.function_call["arguments"] + ), ) - )] + ], ), ), ) @@ -217,10 +243,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=message.content, tool_calls=[]), finish_reason=message.stop_reason if message.stop_reason else None, ), ) @@ -236,22 +259,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, InsufficientAccountBalanceError, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } - diff --git a/api/core/model_runtime/model_providers/minimax/llm/types.py b/api/core/model_runtime/model_providers/minimax/llm/types.py index b33a7ca9ac..88ebe5e2e0 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/types.py +++ b/api/core/model_runtime/model_providers/minimax/llm/types.py @@ -4,32 +4,27 @@ from typing import Any class MinimaxMessage: class Role(Enum): - USER = 'USER' - ASSISTANT = 'BOT' - SYSTEM = 'SYSTEM' - FUNCTION = 'FUNCTION' + USER = "USER" + ASSISTANT = "BOT" + SYSTEM = "SYSTEM" + FUNCTION = "FUNCTION" role: str = Role.USER.value content: str usage: dict[str, int] = None - stop_reason: str = '' + stop_reason: str = "" function_call: dict[str, Any] = None def to_dict(self) -> dict[str, Any]: if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value: - return { - 'sender_type': 'BOT', - 'sender_name': '专家', - 'text': '', - 'function_call': self.function_call - } - + return {"sender_type": "BOT", "sender_name": "专家", "text": "", "function_call": self.function_call} + return { - 'sender_type': self.role, - 'sender_name': '我' if self.role == 'USER' else '专家', - 'text': self.content, + "sender_type": self.role, + "sender_name": "我" if self.role == "USER" else "专家", + "text": self.content, } - - def __init__(self, content: str, role: str = 'USER') -> None: + + def __init__(self, content: str, role: str = "USER") -> None: self.content = content - self.role = role \ No newline at end of file + self.role = role diff --git a/api/core/model_runtime/model_providers/minimax/minimax.py b/api/core/model_runtime/model_providers/minimax/minimax.py index 52f6c2f1d3..5a761903a1 100644 --- a/api/core/model_runtime/model_providers/minimax/minimax.py +++ b/api/core/model_runtime/model_providers/minimax/minimax.py @@ -6,6 +6,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) + class MinimaxProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -19,12 +20,9 @@ class MinimaxProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `abab5.5-chat` model for validate, - model_instance.validate_credentials( - model='abab5.5-chat', - credentials=credentials - ) + model_instance.validate_credentials(model="abab5.5-chat", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') - raise CredentialsValidateFailedError(f'{ex}') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise CredentialsValidateFailedError(f"{ex}") diff --git a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py index 85dc6ef51d..02a53708be 100644 --- a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py @@ -30,11 +30,12 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): """ Model class for Minimax text embedding model. """ - api_base: str = 'https://api.minimax.chat/v1/embeddings' - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + api_base: str = "https://api.minimax.chat/v1/embeddings" + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -44,54 +45,43 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - api_key = credentials['minimax_api_key'] - group_id = credentials['minimax_group_id'] - if model != 'embo-01': - raise ValueError('Invalid model name') + api_key = credentials["minimax_api_key"] + group_id = credentials["minimax_group_id"] + if model != "embo-01": + raise ValueError("Invalid model name") if not api_key: - raise CredentialsValidateFailedError('api_key is required') - url = f'{self.api_base}?GroupId={group_id}' - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + raise CredentialsValidateFailedError("api_key is required") + url = f"{self.api_base}?GroupId={group_id}" + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} - data = { - 'model': 'embo-01', - 'texts': texts, - 'type': 'db' - } + data = {"model": "embo-01", "texts": texts, "type": "db"} try: response = post(url, headers=headers, data=dumps(data)) except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: raise InvokeServerUnavailableError(response.text) - + try: resp = response.json() # check if there is an error - if resp['base_resp']['status_code'] != 0: - code = resp['base_resp']['status_code'] - msg = resp['base_resp']['status_msg'] + if resp["base_resp"]["status_code"] != 0: + code = resp["base_resp"]["status_code"] + msg = resp["base_resp"]["status_msg"] self._handle_error(code, msg) - embeddings = resp['vectors'] - total_tokens = resp['total_tokens'] + embeddings = resp["vectors"] + total_tokens = resp["total_tokens"] except InvalidAuthenticationError: - raise InvalidAPIKeyError('Invalid api key') + raise InvalidAPIKeyError("Invalid api key") except KeyError as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") usage = self._calc_response_usage(model=model, credentials=credentials, tokens=total_tokens) - result = TextEmbeddingResult( - model=model, - embeddings=embeddings, - usage=usage - ) + result = TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage) return result @@ -119,9 +109,9 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvalidAPIKeyError: - raise CredentialsValidateFailedError('Invalid api key') + raise CredentialsValidateFailedError("Invalid api key") def _handle_error(self, code: int, msg: str): if code == 1000 or code == 1001: @@ -148,25 +138,17 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, InsufficientAccountBalanceError, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -178,10 +160,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -192,7 +171,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/mistralai/llm/llm.py b/api/core/model_runtime/model_providers/mistralai/llm/llm.py index 01ed8010de..da60bd7661 100644 --- a/api/core/model_runtime/model_providers/mistralai/llm/llm.py +++ b/api/core/model_runtime/model_providers/mistralai/llm/llm.py @@ -7,14 +7,19 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: - + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) - + # mistral dose not support user/stop arguments stop = [] user = None @@ -27,5 +32,5 @@ class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel): @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.mistral.ai/v1' + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.mistral.ai/v1" diff --git a/api/core/model_runtime/model_providers/mistralai/mistralai.py b/api/core/model_runtime/model_providers/mistralai/mistralai.py index f1d825f6c6..7f9db8da1c 100644 --- a/api/core/model_runtime/model_providers/mistralai/mistralai.py +++ b/api/core/model_runtime/model_providers/mistralai/mistralai.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class MistralAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +18,9 @@ class MistralAIProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='open-mistral-7b', - credentials=credentials - ) + model_instance.validate_credentials(model="open-mistral-7b", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/moonshot/llm/llm.py b/api/core/model_runtime/model_providers/moonshot/llm/llm.py index c233596637..3ea46c2967 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/llm.py +++ b/api/core/model_runtime/model_providers/moonshot/llm/llm.py @@ -30,11 +30,17 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) self._add_function_call(model, credentials) user = user[:32] if user else None @@ -49,50 +55,50 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): model=model, label=I18nObject(en_US=model, zh_Hans=model), model_type=ModelType.LLM, - features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] - if credentials.get('function_calling_type') == 'tool_call' - else [], + features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] + if credentials.get("function_calling_type") == "tool_call" + else [], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 4096)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)), ModelPropertyKey.MODE: LLMMode.CHAT.value, }, parameter_rules=[ ParameterRule( - name='temperature', - use_template='temperature', - label=I18nObject(en_US='Temperature', zh_Hans='温度'), + name="temperature", + use_template="temperature", + label=I18nObject(en_US="Temperature", zh_Hans="温度"), type=ParameterType.FLOAT, ), ParameterRule( - name='max_tokens', - use_template='max_tokens', + name="max_tokens", + use_template="max_tokens", default=512, min=1, - max=int(credentials.get('max_tokens', 4096)), - label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'), + max=int(credentials.get("max_tokens", 4096)), + label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"), type=ParameterType.INT, ), ParameterRule( - name='top_p', - use_template='top_p', - label=I18nObject(en_US='Top P', zh_Hans='Top P'), + name="top_p", + use_template="top_p", + label=I18nObject(en_US="Top P", zh_Hans="Top P"), type=ParameterType.FLOAT, ), - ] + ], ) def _add_custom_parameters(self, credentials: dict) -> None: - credentials['mode'] = 'chat' - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - credentials['endpoint_url'] = 'https://api.moonshot.cn/v1' + credentials["mode"] = "chat" + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + credentials["endpoint_url"] = "https://api.moonshot.cn/v1" def _add_function_call(self, model: str, credentials: dict) -> None: model_schema = self.get_model_schema(model, credentials) - if model_schema and { - ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL - }.intersection(model_schema.features or []): - credentials['function_calling_type'] = 'tool_call' + if model_schema and {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}.intersection( + model_schema.features or [] + ): + credentials["function_calling_type"] = "tool_call" def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: Optional[dict] = None) -> dict: """ @@ -107,19 +113,13 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) message_dict = {"role": "user", "content": sub_messages} @@ -129,14 +129,16 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): if message.tool_calls: message_dict["tool_calls"] = [] for function_call in message.tool_calls: - message_dict["tool_calls"].append({ - "id": function_call.id, - "type": function_call.type, - "function": { - "name": function_call.function.name, - "arguments": function_call.function.arguments + message_dict["tool_calls"].append( + { + "id": function_call.id, + "type": function_call.type, + "function": { + "name": function_call.function.name, + "arguments": function_call.function.arguments, + }, } - }) + ) elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id} @@ -162,21 +164,26 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "", - arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else "" + name=response_tool_call["function"]["name"] + if response_tool_call.get("function", {}).get("name") + else "", + arguments=response_tool_call["function"]["arguments"] + if response_tool_call.get("function", {}).get("arguments") + else "", ) tool_call = AssistantPromptMessage.ToolCall( id=response_tool_call["id"] if response_tool_call.get("id") else "", type=response_tool_call["type"] if response_tool_call.get("type") else "", - function=function + function=function, ) tool_calls.append(tool_call) return tool_calls - def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -186,11 +193,12 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" chunk_index = 0 - def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \ - -> LLMResultChunk: + def create_final_llm_result_chunk( + index: int, message: AssistantPromptMessage, finish_reason: str + ) -> LLMResultChunk: # calculate num tokens prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) completion_tokens = self._num_tokens_from_string(model, full_assistant_content) @@ -201,12 +209,7 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): return LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=message, - finish_reason=finish_reason, - usage=usage - ) + delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage), ) tools_calls: list[AssistantPromptMessage.ToolCall] = [] @@ -220,9 +223,9 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None) if tool_call is None: tool_call = AssistantPromptMessage.ToolCall( - id='', - type='', - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="") + id="", + type="", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""), ) tools_calls.append(tool_call) @@ -244,9 +247,9 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"): if chunk: # ignore sse comments - if chunk.startswith(':'): + if chunk.startswith(":"): continue - decoded_chunk = chunk.strip().lstrip('data: ').lstrip() + decoded_chunk = chunk.strip().lstrip("data: ").lstrip() chunk_json = None try: chunk_json = json.loads(decoded_chunk) @@ -255,21 +258,21 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): yield create_final_llm_result_chunk( index=chunk_index + 1, message=AssistantPromptMessage(content=""), - finish_reason="Non-JSON encountered." + finish_reason="Non-JSON encountered.", ) break - if not chunk_json or len(chunk_json['choices']) == 0: + if not chunk_json or len(chunk_json["choices"]) == 0: continue - choice = chunk_json['choices'][0] - finish_reason = chunk_json['choices'][0].get('finish_reason') + choice = chunk_json["choices"][0] + finish_reason = chunk_json["choices"][0].get("finish_reason") chunk_index += 1 - if 'delta' in choice: - delta = choice['delta'] - delta_content = delta.get('content') + if "delta" in choice: + delta = choice["delta"] + delta_content = delta.get("content") - assistant_message_tool_calls = delta.get('tool_calls', None) + assistant_message_tool_calls = delta.get("tool_calls", None) # assistant_message_function_call = delta.delta.function_call # extract tool calls from response @@ -277,19 +280,18 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) increase_tool_call(tool_calls) - if delta_content is None or delta_content == '': + if delta_content is None or delta_content == "": continue # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta_content, - tool_calls=tool_calls if assistant_message_tool_calls else [] + content=delta_content, tool_calls=tool_calls if assistant_message_tool_calls else [] ) full_assistant_content += delta_content - elif 'text' in choice: - choice_text = choice.get('text', '') - if choice_text == '': + elif "text" in choice: + choice_text = choice.get("text", "") + if choice_text == "": continue # transform assistant message to prompt message @@ -305,26 +307,21 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): delta=LLMResultChunkDelta( index=chunk_index, message=assistant_prompt_message, - ) + ), ) chunk_index += 1 - + if tools_calls: yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=chunk_index, - message=AssistantPromptMessage( - tool_calls=tools_calls, - content="" - ), - ) + message=AssistantPromptMessage(tool_calls=tools_calls, content=""), + ), ) yield create_final_llm_result_chunk( - index=chunk_index, - message=AssistantPromptMessage(content=""), - finish_reason=finish_reason - ) \ No newline at end of file + index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason + ) diff --git a/api/core/model_runtime/model_providers/moonshot/moonshot.py b/api/core/model_runtime/model_providers/moonshot/moonshot.py index 5654ae1459..4995e235f5 100644 --- a/api/core/model_runtime/model_providers/moonshot/moonshot.py +++ b/api/core/model_runtime/model_providers/moonshot/moonshot.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class MoonshotProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +18,9 @@ class MoonshotProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='moonshot-v1-8k', - credentials=credentials - ) + model_instance.validate_credentials(model="moonshot-v1-8k", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/novita/llm/llm.py b/api/core/model_runtime/model_providers/novita/llm/llm.py index c7b223d1b7..23367ed1b4 100644 --- a/api/core/model_runtime/model_providers/novita/llm/llm.py +++ b/api/core/model_runtime/model_providers/novita/llm/llm.py @@ -8,19 +8,25 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class NovitaLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _update_endpoint_url(self, credentials: dict): - credentials['endpoint_url'] = "https://api.novita.ai/v3/openai" - credentials['extra_headers'] = { 'X-Novita-Source': 'dify.ai' } + credentials["endpoint_url"] = "https://api.novita.ai/v3/openai" + credentials["extra_headers"] = {"X-Novita-Source": "dify.ai"} return credentials - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super()._invoke(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user) + def validate_credentials(self, model: str, credentials: dict) -> None: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) self._add_custom_parameters(credentials, model) @@ -28,21 +34,36 @@ class NovitaLargeLanguageModel(OAIAPICompatLargeLanguageModel): @classmethod def _add_custom_parameters(cls, credentials: dict, model: str) -> None: - credentials['mode'] = 'chat' + credentials["mode"] = "chat" - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) - return super()._generate(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user) + return super()._generate( + model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user + ) def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super().get_customizable_model_schema(model, cred_with_endpoint) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super().get_num_tokens(model, cred_with_endpoint, prompt_messages, tools) diff --git a/api/core/model_runtime/model_providers/novita/novita.py b/api/core/model_runtime/model_providers/novita/novita.py index f1b7224605..76a75b01e2 100644 --- a/api/core/model_runtime/model_providers/novita/novita.py +++ b/api/core/model_runtime/model_providers/novita/novita.py @@ -20,12 +20,9 @@ class NovitaProvider(ModelProvider): # Use `meta-llama/llama-3-8b-instruct` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='meta-llama/llama-3-8b-instruct', - credentials=credentials - ) + model_instance.validate_credentials(model="meta-llama/llama-3-8b-instruct", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/nvidia/llm/llm.py b/api/core/model_runtime/model_providers/nvidia/llm/llm.py index bc42eaca65..4d3747dc84 100644 --- a/api/core/model_runtime/model_providers/nvidia/llm/llm.py +++ b/api/core/model_runtime/model_providers/nvidia/llm/llm.py @@ -21,31 +21,36 @@ from core.model_runtime.utils import helper class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): MODEL_SUFFIX_MAP = { - 'fuyu-8b': 'vlm/adept/fuyu-8b', - 'mistralai/mistral-large': '', - 'mistralai/mixtral-8x7b-instruct-v0.1': '', - 'mistralai/mixtral-8x22b-instruct-v0.1': '', - 'google/gemma-7b': '', - 'google/codegemma-7b': '', - 'snowflake/arctic':'', - 'meta/llama2-70b': '', - 'meta/llama3-8b-instruct': '', - 'meta/llama3-70b-instruct': '', - 'meta/llama-3.1-8b-instruct': '', - 'meta/llama-3.1-70b-instruct': '', - 'meta/llama-3.1-405b-instruct': '', - 'google/recurrentgemma-2b': '', - 'nvidia/nemotron-4-340b-instruct': '', - 'microsoft/phi-3-medium-128k-instruct':'', - 'microsoft/phi-3-mini-128k-instruct':'' + "fuyu-8b": "vlm/adept/fuyu-8b", + "mistralai/mistral-large": "", + "mistralai/mixtral-8x7b-instruct-v0.1": "", + "mistralai/mixtral-8x22b-instruct-v0.1": "", + "google/gemma-7b": "", + "google/codegemma-7b": "", + "snowflake/arctic": "", + "meta/llama2-70b": "", + "meta/llama3-8b-instruct": "", + "meta/llama3-70b-instruct": "", + "meta/llama-3.1-8b-instruct": "", + "meta/llama-3.1-70b-instruct": "", + "meta/llama-3.1-405b-instruct": "", + "google/recurrentgemma-2b": "", + "nvidia/nemotron-4-340b-instruct": "", + "microsoft/phi-3-medium-128k-instruct": "", + "microsoft/phi-3-mini-128k-instruct": "", } - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: - + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials, model) prompt_messages = self._transform_prompt_messages(prompt_messages) stop = [] @@ -60,16 +65,14 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): for i, p in enumerate(prompt_messages): if isinstance(p, UserPromptMessage) and isinstance(p.content, list): content = p.content - content_text = '' + content_text = "" for prompt_content in content: if prompt_content.type == PromptMessageContentType.TEXT: content_text += prompt_content.data else: content_text += f' ' - prompt_message = UserPromptMessage( - content=content_text - ) + prompt_message = UserPromptMessage(content=content_text) prompt_messages[i] = prompt_message return prompt_messages @@ -78,15 +81,15 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): self._validate_credentials(model, credentials) def _add_custom_parameters(self, credentials: dict, model: str) -> None: - credentials['mode'] = 'chat' - - if self.MODEL_SUFFIX_MAP[model]: - credentials['server_url'] = f'https://ai.api.nvidia.com/v1/{self.MODEL_SUFFIX_MAP[model]}' - credentials.pop('endpoint_url') - else: - credentials['endpoint_url'] = 'https://integrate.api.nvidia.com/v1' + credentials["mode"] = "chat" - credentials['stream_mode_delimiter'] = '\n' + if self.MODEL_SUFFIX_MAP[model]: + credentials["server_url"] = f"https://ai.api.nvidia.com/v1/{self.MODEL_SUFFIX_MAP[model]}" + credentials.pop("endpoint_url") + else: + credentials["endpoint_url"] = "https://integrate.api.nvidia.com/v1" + + credentials["stream_mode_delimiter"] = "\n" def _validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -97,72 +100,67 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): :return: """ try: - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - endpoint_url = credentials.get('endpoint_url') - if endpoint_url and not endpoint_url.endswith('/'): - endpoint_url += '/' - server_url = credentials.get('server_url') + endpoint_url = credentials.get("endpoint_url") + if endpoint_url and not endpoint_url.endswith("/"): + endpoint_url += "/" + server_url = credentials.get("server_url") # prepare the payload for a simple ping to the model - data = { - 'model': model, - 'max_tokens': 5 - } + data = {"model": model, "max_tokens": 5} - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) if completion_type is LLMMode.CHAT: - data['messages'] = [ - { - "role": "user", - "content": "ping" - }, + data["messages"] = [ + {"role": "user", "content": "ping"}, ] - if 'endpoint_url' in credentials: - endpoint_url = str(URL(endpoint_url) / 'chat' / 'completions') - elif 'server_url' in credentials: + if "endpoint_url" in credentials: + endpoint_url = str(URL(endpoint_url) / "chat" / "completions") + elif "server_url" in credentials: endpoint_url = server_url elif completion_type is LLMMode.COMPLETION: - data['prompt'] = 'ping' - if 'endpoint_url' in credentials: - endpoint_url = str(URL(endpoint_url) / 'completions') - elif 'server_url' in credentials: + data["prompt"] = "ping" + if "endpoint_url" in credentials: + endpoint_url = str(URL(endpoint_url) / "completions") + elif "server_url" in credentials: endpoint_url = server_url else: raise ValueError("Unsupported completion type for model configuration.") # send a post request to validate the credentials - response = requests.post( - endpoint_url, - headers=headers, - json=data, - timeout=(10, 300) - ) + response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300)) if response.status_code != 200: raise CredentialsValidateFailedError( - f'Credentials validation failed with status code {response.status_code}') + f"Credentials validation failed with status code {response.status_code}" + ) try: json_result = response.json() except json.JSONDecodeError as e: - raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') + raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error") except CredentialsValidateFailedError: raise except Exception as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}") - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, \ - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -176,57 +174,51 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): :return: full response or stream response chunk generator result """ headers = { - 'Content-Type': 'application/json', - 'Accept-Charset': 'utf-8', + "Content-Type": "application/json", + "Accept-Charset": "utf-8", } - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: - headers['Authorization'] = f'Bearer {api_key}' + headers["Authorization"] = f"Bearer {api_key}" if stream: - headers['Accept'] = 'text/event-stream' + headers["Accept"] = "text/event-stream" - endpoint_url = credentials.get('endpoint_url') - if endpoint_url and not endpoint_url.endswith('/'): - endpoint_url += '/' - server_url = credentials.get('server_url') + endpoint_url = credentials.get("endpoint_url") + if endpoint_url and not endpoint_url.endswith("/"): + endpoint_url += "/" + server_url = credentials.get("server_url") - data = { - "model": model, - "stream": stream, - **model_parameters - } + data = {"model": model, "stream": stream, **model_parameters} - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) if completion_type is LLMMode.CHAT: - if 'endpoint_url' in credentials: - endpoint_url = str(URL(endpoint_url) / 'chat' / 'completions') - elif 'server_url' in credentials: + if "endpoint_url" in credentials: + endpoint_url = str(URL(endpoint_url) / "chat" / "completions") + elif "server_url" in credentials: endpoint_url = server_url - data['messages'] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages] + data["messages"] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages] elif completion_type is LLMMode.COMPLETION: - data['prompt'] = 'ping' - if 'endpoint_url' in credentials: - endpoint_url = str(URL(endpoint_url) / 'completions') - elif 'server_url' in credentials: + data["prompt"] = "ping" + if "endpoint_url" in credentials: + endpoint_url = str(URL(endpoint_url) / "completions") + elif "server_url" in credentials: endpoint_url = server_url else: raise ValueError("Unsupported completion type for model configuration.") - # annotate tools with names, descriptions, etc. - function_calling_type = credentials.get('function_calling_type', 'no_call') + function_calling_type = credentials.get("function_calling_type", "no_call") formatted_tools = [] if tools: - if function_calling_type == 'function_call': - data['functions'] = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] - elif function_calling_type == 'tool_call': + if function_calling_type == "function_call": + data["functions"] = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} + for tool in tools + ] + elif function_calling_type == "tool_call": data["tool_choice"] = "auto" for tool in tools: @@ -240,16 +232,10 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): if user: data["user"] = user - response = requests.post( - endpoint_url, - headers=headers, - json=data, - timeout=(10, 300), - stream=stream - ) + response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream) - if response.encoding is None or response.encoding == 'ISO-8859-1': - response.encoding = 'utf-8' + if response.encoding is None or response.encoding == "ISO-8859-1": + response.encoding = "utf-8" if not response.ok: raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}") diff --git a/api/core/model_runtime/model_providers/nvidia/nvidia.py b/api/core/model_runtime/model_providers/nvidia/nvidia.py index e83f8badb5..058fa00346 100644 --- a/api/core/model_runtime/model_providers/nvidia/nvidia.py +++ b/api/core/model_runtime/model_providers/nvidia/nvidia.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class MistralAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +18,9 @@ class MistralAIProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='mistralai/mixtral-8x7b-instruct-v0.1', - credentials=credentials - ) + model_instance.validate_credentials(model="mistralai/mixtral-8x7b-instruct-v0.1", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py b/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py index 9d33f55bc2..fabebc67ab 100644 --- a/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py @@ -22,11 +22,18 @@ class NvidiaRerankModel(RerankModel): """ def _sigmoid(self, logit: float) -> float: - return 1/(1+exp(-logit)) + return 1 / (1 + exp(-logit)) - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -54,16 +61,15 @@ class NvidiaRerankModel(RerankModel): "query": {"text": query}, "passages": [{"text": doc} for doc in docs], } - session = requests.Session() response = session.post(invoke_url, headers=headers, json=payload) response.raise_for_status() results = response.json() rerank_documents = [] - for result in results['rankings']: - index = result['index'] - logit = result['logit'] + for result in results["rankings"]: + index = result["index"] + logit = result["logit"] rerank_document = RerankDocument( index=index, text=docs[index], @@ -71,7 +77,10 @@ class NvidiaRerankModel(RerankModel): ) rerank_documents.append(rerank_document) - + if rerank_documents: + rerank_documents = sorted(rerank_documents, key=lambda x: x.score, reverse=True) + if top_n: + rerank_documents = rerank_documents[:top_n] return RerankResult(model=model, docs=rerank_documents) except requests.HTTPError as e: raise InvokeServerUnavailableError(str(e)) @@ -108,5 +117,5 @@ class NvidiaRerankModel(RerankModel): InvokeServerUnavailableError: [requests.HTTPError], InvokeRateLimitError: [], InvokeAuthorizationError: [requests.HTTPError], - InvokeBadRequestError: [requests.RequestException] + InvokeBadRequestError: [requests.RequestException], } diff --git a/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py index a2adef400d..00cec265d5 100644 --- a/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py @@ -22,12 +22,13 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel): """ Model class for Nvidia text embedding model. """ - api_base: str = 'https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings' - models: list[str] = ['NV-Embed-QA'] - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + api_base: str = "https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings" + models: list[str] = ["NV-Embed-QA"] + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -37,32 +38,25 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - api_key = credentials['api_key'] + api_key = credentials["api_key"] if model not in self.models: - raise InvokeBadRequestError('Invalid model name') + raise InvokeBadRequestError("Invalid model name") if not api_key: - raise CredentialsValidateFailedError('api_key is required') + raise CredentialsValidateFailedError("api_key is required") url = self.api_base - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} - data = { - 'model': model, - 'input': texts[0], - 'input_type': 'query' - } + data = {"model": model, "input": texts[0], "input_type": "query"} try: response = post(url, headers=headers, data=dumps(data)) except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: try: resp = response.json() - msg = resp['detail'] + msg = resp["detail"] if response.status_code == 401: raise InvokeAuthorizationError(msg) elif response.status_code == 429: @@ -72,23 +66,21 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel): else: raise InvokeError(msg) except JSONDecodeError as e: - raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") + raise InvokeServerUnavailableError( + f"Failed to convert response to json: {e} with text: {response.text}" + ) try: resp = response.json() - embeddings = resp['data'] - usage = resp['usage'] + embeddings = resp["data"] + usage = resp["usage"] except Exception as e: raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") - usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens']) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"]) result = TextEmbeddingResult( - model=model, - embeddings=[[ - float(data) for data in x['embedding'] - ] for x in embeddings], - usage=usage + model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage ) return result @@ -117,30 +109,20 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeAuthorizationError: - raise CredentialsValidateFailedError('Invalid api key') + raise CredentialsValidateFailedError("Invalid api key") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -152,10 +134,7 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -166,7 +145,7 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py b/api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py index f7b849fbe2..6ff380bdd9 100644 --- a/api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py +++ b/api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py @@ -9,4 +9,5 @@ class NVIDIANIMProvider(OAIAPICompatLargeLanguageModel): """ Model class for NVIDIA NIM large language model. """ + pass diff --git a/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.py b/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.py index 25ab3e8e20..ad890ada22 100644 --- a/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.py +++ b/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.py @@ -6,6 +6,5 @@ logger = logging.getLogger(__name__) class NVIDIANIMProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/oci/__init__.py b/api/core/model_runtime/model_providers/oci/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/oci/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/oci/_assets/icon_l_en.svg new file mode 100644 index 0000000000..0981dfcff2 --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/_assets/icon_l_en.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/oci/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/oci/_assets/icon_s_en.svg new file mode 100644 index 0000000000..0981dfcff2 --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/_assets/icon_s_en.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/oci/llm/cohere.command-r-16k.yaml b/api/core/model_runtime/model_providers/oci/llm/cohere.command-r-16k.yaml new file mode 100644 index 0000000000..eb60cbcd90 --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/llm/cohere.command-r-16k.yaml @@ -0,0 +1,52 @@ +model: cohere.command-r-16k +label: + en_US: cohere.command-r-16k v1.2 +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + default: 1 + max: 1.0 + - name: topP + use_template: top_p + default: 0.75 + min: 0 + max: 1 + - name: topK + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + default: 0 + min: 0 + max: 500 + - name: presencePenalty + use_template: presence_penalty + min: 0 + max: 1 + default: 0 + - name: frequencyPenalty + use_template: frequency_penalty + min: 0 + max: 1 + default: 0 + - name: maxTokens + use_template: max_tokens + default: 600 + max: 4000 +pricing: + input: '0.004' + output: '0.004' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/oci/llm/cohere.command-r-plus.yaml b/api/core/model_runtime/model_providers/oci/llm/cohere.command-r-plus.yaml new file mode 100644 index 0000000000..df31b0d0df --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/llm/cohere.command-r-plus.yaml @@ -0,0 +1,52 @@ +model: cohere.command-r-plus +label: + en_US: cohere.command-r-plus v1.2 +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + default: 1 + max: 1.0 + - name: topP + use_template: top_p + default: 0.75 + min: 0 + max: 1 + - name: topK + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + default: 0 + min: 0 + max: 500 + - name: presencePenalty + use_template: presence_penalty + min: 0 + max: 1 + default: 0 + - name: frequencyPenalty + use_template: frequency_penalty + min: 0 + max: 1 + default: 0 + - name: maxTokens + use_template: max_tokens + default: 600 + max: 4000 +pricing: + input: '0.0219' + output: '0.0219' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/oci/llm/llm.py b/api/core/model_runtime/model_providers/oci/llm/llm.py new file mode 100644 index 0000000000..ad5197a154 --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/llm/llm.py @@ -0,0 +1,470 @@ +import base64 +import copy +import json +import logging +from collections.abc import Generator +from typing import Optional, Union + +import oci +from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageContentType, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + +logger = logging.getLogger(__name__) + +request_template = { + "compartmentId": "", + "servingMode": {"modelId": "cohere.command-r-plus", "servingType": "ON_DEMAND"}, + "chatRequest": { + "apiFormat": "COHERE", + # "preambleOverride": "You are a helpful assistant.", + # "message": "Hello!", + # "chatHistory": [], + "maxTokens": 600, + "isStream": False, + "frequencyPenalty": 0, + "presencePenalty": 0, + "temperature": 1, + "topP": 0.75, + }, +} +oci_config_template = { + "user": "", + "fingerprint": "", + "tenancy": "", + "region": "", + "compartment_id": "", + "key_content": "", +} + + +class OCILargeLanguageModel(LargeLanguageModel): + # https://docs.oracle.com/en-us/iaas/Content/generative-ai/pretrained-models.htm + _supported_models = { + "meta.llama-3-70b-instruct": { + "system": True, + "multimodal": False, + "tool_call": False, + "stream_tool_call": False, + }, + "cohere.command-r-16k": { + "system": True, + "multimodal": False, + "tool_call": True, + "stream_tool_call": False, + }, + "cohere.command-r-plus": { + "system": True, + "multimodal": False, + "tool_call": True, + "stream_tool_call": False, + }, + } + + def _is_tool_call_supported(self, model_id: str, stream: bool = False) -> bool: + feature = self._supported_models.get(model_id) + if not feature: + return False + return feature["stream_tool_call"] if stream else feature["tool_call"] + + def _is_multimodal_supported(self, model_id: str) -> bool: + feature = self._supported_models.get(model_id) + if not feature: + return False + return feature["multimodal"] + + def _is_system_prompt_supported(self, model_id: str) -> bool: + feature = self._supported_models.get(model_id) + if not feature: + return False + return feature["system"] + + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param tools: tools for tool calling + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + # print("model"+"*"*20) + # print(model) + # print("credentials"+"*"*20) + # print(credentials) + # print("model_parameters"+"*"*20) + # print(model_parameters) + # print("prompt_messages"+"*"*200) + # print(prompt_messages) + # print("tools"+"*"*20) + # print(tools) + + # invoke model + return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return:md = genai.GenerativeModel(model) + """ + prompt = self._convert_messages_to_prompt(prompt_messages) + + return self._get_num_tokens_by_gpt2(prompt) + + def get_num_characters( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return:md = genai.GenerativeModel(model) + """ + prompt = self._convert_messages_to_prompt(prompt_messages) + + return len(prompt) + + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: + """ + :param messages: List of PromptMessage to combine. + :return: Combined string with necessary human_prompt and ai_prompt tags. + """ + messages = messages.copy() # don't mutate the original list + + text = "".join(self._convert_one_message_to_text(message) for message in messages) + + return text.rstrip() + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + # Setup basic variables + # Auth Config + try: + ping_message = SystemPromptMessage(content="ping") + self._generate(model, credentials, [ping_message], {"maxTokens": 5}) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: credentials kwargs + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + # config_kwargs = model_parameters.copy() + # config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None) + # if stop: + # config_kwargs["stop_sequences"] = stop + + # initialize client + # ref: https://docs.oracle.com/en-us/iaas/api/#/en/generative-ai-inference/20231130/ChatResult/Chat + oci_config = copy.deepcopy(oci_config_template) + if "oci_config_content" in credentials: + oci_config_content = base64.b64decode(credentials.get("oci_config_content")).decode("utf-8") + config_items = oci_config_content.split("/") + if len(config_items) != 5: + raise CredentialsValidateFailedError( + "oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))" + ) + oci_config["user"] = config_items[0] + oci_config["fingerprint"] = config_items[1] + oci_config["tenancy"] = config_items[2] + oci_config["region"] = config_items[3] + oci_config["compartment_id"] = config_items[4] + else: + raise CredentialsValidateFailedError("need to set oci_config_content in credentials ") + if "oci_key_content" in credentials: + oci_key_content = base64.b64decode(credentials.get("oci_key_content")).decode("utf-8") + oci_config["key_content"] = oci_key_content.encode(encoding="utf-8") + else: + raise CredentialsValidateFailedError("need to set oci_config_content in credentials ") + + # oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile')) + compartment_id = oci_config["compartment_id"] + client = oci.generative_ai_inference.GenerativeAiInferenceClient(config=oci_config) + # call embedding model + request_args = copy.deepcopy(request_template) + request_args["compartmentId"] = compartment_id + request_args["servingMode"]["modelId"] = model + + chat_history = [] + system_prompts = [] + # if "meta.llama" in model: + # request_args["chatRequest"]["apiFormat"] = "GENERIC" + request_args["chatRequest"]["maxTokens"] = model_parameters.pop("maxTokens", 600) + request_args["chatRequest"].update(model_parameters) + frequency_penalty = model_parameters.get("frequencyPenalty", 0) + presence_penalty = model_parameters.get("presencePenalty", 0) + if frequency_penalty > 0 and presence_penalty > 0: + raise InvokeBadRequestError("Cannot set both frequency penalty and presence penalty") + + # for msg in prompt_messages: # makes message roles strictly alternating + # content = self._format_message_to_glm_content(msg) + # if history and history[-1]["role"] == content["role"]: + # history[-1]["parts"].extend(content["parts"]) + # else: + # history.append(content) + + # temporary not implement the tool call function + valid_value = self._is_tool_call_supported(model, stream) + if tools is not None and len(tools) > 0: + if not valid_value: + raise InvokeBadRequestError("Does not support function calling") + if model.startswith("cohere"): + # print("run cohere " * 10) + for message in prompt_messages[:-1]: + text = "" + if isinstance(message.content, str): + text = message.content + if isinstance(message, UserPromptMessage): + chat_history.append({"role": "USER", "message": text}) + else: + chat_history.append({"role": "CHATBOT", "message": text}) + if isinstance(message, SystemPromptMessage): + if isinstance(message.content, str): + system_prompts.append(message.content) + args = { + "apiFormat": "COHERE", + "preambleOverride": " ".join(system_prompts), + "message": prompt_messages[-1].content, + "chatHistory": chat_history, + } + request_args["chatRequest"].update(args) + elif model.startswith("meta"): + # print("run meta " * 10) + meta_messages = [] + for message in prompt_messages: + text = message.content + meta_messages.append({"role": message.role.name, "content": [{"type": "TEXT", "text": text}]}) + args = {"apiFormat": "GENERIC", "messages": meta_messages, "numGenerations": 1, "topK": -1} + request_args["chatRequest"].update(args) + + if stream: + request_args["chatRequest"]["isStream"] = True + # print("final request" + "|" * 20) + # print(request_args) + response = client.chat(request_args) + # print(vars(response)) + + if stream: + return self._handle_generate_stream_response(model, credentials, response, prompt_messages) + + return self._handle_generate_response(model, credentials, response, prompt_messages) + + def _handle_generate_response( + self, model: str, credentials: dict, response: BaseChatResponse, prompt_messages: list[PromptMessage] + ) -> LLMResult: + """ + Handle llm response + + :param model: model name + :param credentials: credentials + :param response: response + :param prompt_messages: prompt messages + :return: llm response + """ + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage(content=response.data.chat_response.text) + + # calculate num tokens + prompt_tokens = self.get_num_characters(model, credentials, prompt_messages) + completion_tokens = self.get_num_characters(model, credentials, [assistant_prompt_message]) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + # transform response + result = LLMResult( + model=model, + prompt_messages=prompt_messages, + message=assistant_prompt_message, + usage=usage, + ) + + return result + + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: BaseChatResponse, prompt_messages: list[PromptMessage] + ) -> Generator: + """ + Handle llm stream response + + :param model: model name + :param credentials: credentials + :param response: response + :param prompt_messages: prompt messages + :return: llm response chunk generator result + """ + index = -1 + events = response.data.events() + for stream in events: + chunk = json.loads(stream.data) + # print(chunk) + # chunk: {'apiFormat': 'COHERE', 'text': 'Hello'} + + # for chunk in response: + # for part in chunk.parts: + # if part.function_call: + # assistant_prompt_message.tool_calls = [ + # AssistantPromptMessage.ToolCall( + # id=part.function_call.name, + # type='function', + # function=AssistantPromptMessage.ToolCall.ToolCallFunction( + # name=part.function_call.name, + # arguments=json.dumps(dict(part.function_call.args.items())) + # ) + # ) + # ] + + if "finishReason" not in chunk: + assistant_prompt_message = AssistantPromptMessage(content="") + if model.startswith("cohere"): + if chunk["text"]: + assistant_prompt_message.content += chunk["text"] + elif model.startswith("meta"): + assistant_prompt_message.content += chunk["message"]["content"][0]["text"] + index += 1 + # transform assistant message to prompt message + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), + ) + else: + # calculate num tokens + prompt_tokens = self.get_num_characters(model, credentials, prompt_messages) + completion_tokens = self.get_num_characters(model, credentials, [assistant_prompt_message]) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=assistant_prompt_message, + finish_reason=str(chunk["finishReason"]), + usage=usage, + ), + ) + + def _convert_one_message_to_text(self, message: PromptMessage) -> str: + """ + Convert a single message to a string. + + :param message: PromptMessage to convert. + :return: String representation of the message. + """ + human_prompt = "\n\nuser:" + ai_prompt = "\n\nmodel:" + + content = message.content + if isinstance(content, list): + content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE) + + if isinstance(message, UserPromptMessage): + message_text = f"{human_prompt} {content}" + elif isinstance(message, AssistantPromptMessage): + message_text = f"{ai_prompt} {content}" + elif isinstance(message, SystemPromptMessage): + message_text = f"{human_prompt} {content}" + elif isinstance(message, ToolPromptMessage): + message_text = f"{human_prompt} {content}" + else: + raise ValueError(f"Got unknown type {message}") + + return message_text + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [], + InvokeServerUnavailableError: [], + InvokeRateLimitError: [], + InvokeAuthorizationError: [], + InvokeBadRequestError: [], + } diff --git a/api/core/model_runtime/model_providers/oci/llm/meta.llama-3-70b-instruct.yaml b/api/core/model_runtime/model_providers/oci/llm/meta.llama-3-70b-instruct.yaml new file mode 100644 index 0000000000..dd5be107c0 --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/llm/meta.llama-3-70b-instruct.yaml @@ -0,0 +1,51 @@ +model: meta.llama-3-70b-instruct +label: + zh_Hans: meta.llama-3-70b-instruct + en_US: meta.llama-3-70b-instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + default: 1 + max: 2.0 + - name: topP + use_template: top_p + default: 0.75 + min: 0 + max: 1 + - name: topK + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + default: 0 + min: 0 + max: 500 + - name: presencePenalty + use_template: presence_penalty + min: -2 + max: 2 + default: 0 + - name: frequencyPenalty + use_template: frequency_penalty + min: -2 + max: 2 + default: 0 + - name: maxTokens + use_template: max_tokens + default: 600 + max: 8000 +pricing: + input: '0.015' + output: '0.015' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/oci/oci.py b/api/core/model_runtime/model_providers/oci/oci.py new file mode 100644 index 0000000000..e182d2d043 --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/oci.py @@ -0,0 +1,28 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class OCIGENAIProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.LLM) + + # Use `cohere.command-r-plus` model for validate, + model_instance.validate_credentials(model="cohere.command-r-plus", credentials=credentials) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise ex diff --git a/api/core/model_runtime/model_providers/oci/oci.yaml b/api/core/model_runtime/model_providers/oci/oci.yaml new file mode 100644 index 0000000000..f2f23e18f1 --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/oci.yaml @@ -0,0 +1,42 @@ +provider: oci +label: + en_US: OCIGenerativeAI +description: + en_US: Models provided by OCI, such as Cohere Command R and Cohere Command R+. + zh_Hans: OCI 提供的模型,例如 Cohere Command R 和 Cohere Command R+。 +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.svg +background: "#FFFFFF" +help: + title: + en_US: Get your API Key from OCI + zh_Hans: 从 OCI 获取 API Key + url: + en_US: https://docs.cloud.oracle.com/Content/API/Concepts/sdkconfig.htm +supported_model_types: + - llm + - text-embedding + #- rerank +configurate_methods: + - predefined-model + #- customizable-model +provider_credential_schema: + credential_form_schemas: + - variable: oci_config_content + label: + en_US: oci api key config file's content + type: text-input + required: true + placeholder: + zh_Hans: 在此输入您的 oci api key config 文件的内容(base64.b64encode("user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid".encode('utf-8')) ) + en_US: Enter your oci api key config file's content(base64.b64encode("user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid".encode('utf-8')) ) + - variable: oci_key_content + label: + en_US: oci api key file's content + type: text-input + required: true + placeholder: + zh_Hans: 在此输入您的 oci api key 文件的内容(base64.b64encode("pem file content".encode('utf-8'))) + en_US: Enter your oci api key file's content(base64.b64encode("pem file content".encode('utf-8'))) diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/__init__.py b/api/core/model_runtime/model_providers/oci/text_embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/_position.yaml b/api/core/model_runtime/model_providers/oci/text_embedding/_position.yaml new file mode 100644 index 0000000000..149f1e3797 --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/text_embedding/_position.yaml @@ -0,0 +1,5 @@ +- cohere.embed-english-light-v2.0 +- cohere.embed-english-light-v3.0 +- cohere.embed-english-v3.0 +- cohere.embed-multilingual-light-v3.0 +- cohere.embed-multilingual-v3.0 diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-light-v2.0.yaml b/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-light-v2.0.yaml new file mode 100644 index 0000000000..259d5b45b7 --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-light-v2.0.yaml @@ -0,0 +1,9 @@ +model: cohere.embed-english-light-v2.0 +model_type: text-embedding +model_properties: + context_size: 1024 + max_chunks: 48 +pricing: + input: '0.001' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-light-v3.0.yaml b/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-light-v3.0.yaml new file mode 100644 index 0000000000..065e7474c0 --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-light-v3.0.yaml @@ -0,0 +1,9 @@ +model: cohere.embed-english-light-v3.0 +model_type: text-embedding +model_properties: + context_size: 384 + max_chunks: 48 +pricing: + input: '0.001' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-v3.0.yaml b/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-v3.0.yaml new file mode 100644 index 0000000000..3e2deea16a --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-v3.0.yaml @@ -0,0 +1,9 @@ +model: cohere.embed-english-v3.0 +model_type: text-embedding +model_properties: + context_size: 1024 + max_chunks: 48 +pricing: + input: '0.001' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-multilingual-light-v3.0.yaml b/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-multilingual-light-v3.0.yaml new file mode 100644 index 0000000000..0d2b892c64 --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-multilingual-light-v3.0.yaml @@ -0,0 +1,9 @@ +model: cohere.embed-multilingual-light-v3.0 +model_type: text-embedding +model_properties: + context_size: 384 + max_chunks: 48 +pricing: + input: '0.001' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-multilingual-v3.0.yaml b/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-multilingual-v3.0.yaml new file mode 100644 index 0000000000..9ebe260b32 --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-multilingual-v3.0.yaml @@ -0,0 +1,9 @@ +model: cohere.embed-multilingual-v3.0 +model_type: text-embedding +model_properties: + context_size: 1024 + max_chunks: 48 +pricing: + input: '0.001' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py new file mode 100644 index 0000000000..df77db47d9 --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py @@ -0,0 +1,216 @@ +import base64 +import copy +import time +from typing import Optional + +import numpy as np +import oci + +from core.model_runtime.entities.model_entities import PriceType +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel + +request_template = { + "compartmentId": "", + "servingMode": {"modelId": "cohere.embed-english-light-v3.0", "servingType": "ON_DEMAND"}, + "truncate": "NONE", + "inputs": [""], +} +oci_config_template = { + "user": "", + "fingerprint": "", + "tenancy": "", + "region": "", + "compartment_id": "", + "key_content": "", +} + + +class OCITextEmbeddingModel(TextEmbeddingModel): + """ + Model class for Cohere text embedding model. + """ + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ + # get model properties + context_size = self._get_context_size(model, credentials) + max_chunks = self._get_max_chunks(model, credentials) + + inputs = [] + indices = [] + used_tokens = 0 + + for i, text in enumerate(texts): + # Here token count is only an approximation based on the GPT2 tokenizer + num_tokens = self._get_num_tokens_by_gpt2(text) + + if num_tokens >= context_size: + cutoff = int(len(text) * (np.floor(context_size / num_tokens))) + # if num tokens is larger than context length, only use the start + inputs.append(text[0:cutoff]) + else: + inputs.append(text) + indices += [i] + + batched_embeddings = [] + _iter = range(0, len(inputs), max_chunks) + + for i in _iter: + # call embedding model + embeddings_batch, embedding_used_tokens = self._embedding_invoke( + model=model, credentials=credentials, texts=inputs[i : i + max_chunks] + ) + + used_tokens += embedding_used_tokens + batched_embeddings += embeddings_batch + + # calc usage + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) + + return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model) + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + return sum(self._get_num_tokens_by_gpt2(text) for text in texts) + + def get_num_characters(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + characters = 0 + for text in texts: + characters += len(text) + return characters + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + # call embedding model + self._embedding_invoke(model=model, credentials=credentials, texts=["ping"]) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> tuple[list[list[float]], int]: + """ + Invoke embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: embeddings and used tokens + """ + + # oci + # initialize client + oci_config = copy.deepcopy(oci_config_template) + if "oci_config_content" in credentials: + oci_config_content = base64.b64decode(credentials.get("oci_config_content")).decode("utf-8") + config_items = oci_config_content.split("/") + if len(config_items) != 5: + raise CredentialsValidateFailedError( + "oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))" + ) + oci_config["user"] = config_items[0] + oci_config["fingerprint"] = config_items[1] + oci_config["tenancy"] = config_items[2] + oci_config["region"] = config_items[3] + oci_config["compartment_id"] = config_items[4] + else: + raise CredentialsValidateFailedError("need to set oci_config_content in credentials ") + if "oci_key_content" in credentials: + oci_key_content = base64.b64decode(credentials.get("oci_key_content")).decode("utf-8") + oci_config["key_content"] = oci_key_content.encode(encoding="utf-8") + else: + raise CredentialsValidateFailedError("need to set oci_config_content in credentials ") + # oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile')) + compartment_id = oci_config["compartment_id"] + client = oci.generative_ai_inference.GenerativeAiInferenceClient(config=oci_config) + # call embedding model + request_args = copy.deepcopy(request_template) + request_args["compartmentId"] = compartment_id + request_args["servingMode"]["modelId"] = model + request_args["inputs"] = texts + response = client.embed_text(request_args) + return response.data.embeddings, self.get_num_characters(model=model, credentials=credentials, texts=texts) + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at, + ) + + return usage + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], + } diff --git a/api/core/model_runtime/model_providers/ollama/llm/llm.py b/api/core/model_runtime/model_providers/ollama/llm/llm.py index 42a588e3dd..160eea0148 100644 --- a/api/core/model_runtime/model_providers/ollama/llm/llm.py +++ b/api/core/model_runtime/model_providers/ollama/llm/llm.py @@ -121,9 +121,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): text = "" for message_content in first_prompt_message.content: if message_content.type == PromptMessageContentType.TEXT: - message_content = cast( - TextPromptMessageContent, message_content - ) + message_content = cast(TextPromptMessageContent, message_content) text = message_content.data break return self._get_num_tokens_by_gpt2(text) @@ -145,13 +143,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel): stream=False, ) except InvokeError as ex: - raise CredentialsValidateFailedError( - f"An error occurred during credentials validation: {ex.description}" - ) + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {ex.description}") except Exception as ex: - raise CredentialsValidateFailedError( - f"An error occurred during credentials validation: {str(ex)}" - ) + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}") def _generate( self, @@ -201,9 +195,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): if completion_type is LLMMode.CHAT: endpoint_url = urljoin(endpoint_url, "api/chat") - data["messages"] = [ - self._convert_prompt_message_to_dict(m) for m in prompt_messages - ] + data["messages"] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] else: endpoint_url = urljoin(endpoint_url, "api/generate") first_prompt_message = prompt_messages[0] @@ -216,14 +208,10 @@ class OllamaLargeLanguageModel(LargeLanguageModel): images = [] for message_content in first_prompt_message.content: if message_content.type == PromptMessageContentType.TEXT: - message_content = cast( - TextPromptMessageContent, message_content - ) + message_content = cast(TextPromptMessageContent, message_content) text = message_content.data elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast( - ImagePromptMessageContent, message_content - ) + message_content = cast(ImagePromptMessageContent, message_content) image_data = re.sub( r"^data:image\/[a-zA-Z]+;base64,", "", @@ -235,24 +223,16 @@ class OllamaLargeLanguageModel(LargeLanguageModel): data["images"] = images # send a post request to validate the credentials - response = requests.post( - endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream - ) + response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream) response.encoding = "utf-8" if response.status_code != 200: - raise InvokeError( - f"API request failed with status code {response.status_code}: {response.text}" - ) + raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}") if stream: - return self._handle_generate_stream_response( - model, credentials, completion_type, response, prompt_messages - ) + return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages) - return self._handle_generate_response( - model, credentials, completion_type, response, prompt_messages - ) + return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages) def _handle_generate_response( self, @@ -292,9 +272,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): completion_tokens = self._get_num_tokens_by_gpt2(assistant_message.content) # transform usage - usage = self._calc_response_usage( - model, credentials, prompt_tokens, completion_tokens - ) + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) # transform response result = LLMResult( @@ -335,9 +313,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): completion_tokens = self._get_num_tokens_by_gpt2(full_text) # transform usage - usage = self._calc_response_usage( - model, credentials, prompt_tokens, completion_tokens - ) + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) return LLMResultChunk( model=model, @@ -394,15 +370,11 @@ class OllamaLargeLanguageModel(LargeLanguageModel): completion_tokens = chunk_json["eval_count"] else: # calculate num tokens - prompt_tokens = self._get_num_tokens_by_gpt2( - prompt_messages[0].content - ) + prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content) completion_tokens = self._get_num_tokens_by_gpt2(full_text) # transform usage - usage = self._calc_response_usage( - model, credentials, prompt_tokens, completion_tokens - ) + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) yield LLMResultChunk( model=chunk_json["model"], @@ -439,17 +411,11 @@ class OllamaLargeLanguageModel(LargeLanguageModel): images = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: - message_content = cast( - TextPromptMessageContent, message_content - ) + message_content = cast(TextPromptMessageContent, message_content) text = message_content.data elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast( - ImagePromptMessageContent, message_content - ) - image_data = re.sub( - r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data - ) + message_content = cast(ImagePromptMessageContent, message_content) + image_data = re.sub(r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data) images.append(image_data) message_dict = {"role": "user", "content": text, "images": images} @@ -479,9 +445,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): return num_tokens - def get_customizable_model_schema( - self, model: str, credentials: dict - ) -> AIModelEntity: + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ Get customizable model schema. @@ -502,9 +466,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ ModelPropertyKey.MODE: credentials.get("mode"), - ModelPropertyKey.CONTEXT_SIZE: int( - credentials.get("context_size", 4096) - ), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)), }, parameter_rules=[ ParameterRule( @@ -568,9 +530,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): en_US="Maximum number of tokens to predict when generating text. " "(Default: 128, -1 = infinite generation, -2 = fill context)" ), - default=( - 512 if int(credentials.get("max_tokens", 4096)) >= 768 else 128 - ), + default=(512 if int(credentials.get("max_tokens", 4096)) >= 768 else 128), min=-2, max=int(credentials.get("max_tokens", 4096)), ), @@ -612,22 +572,23 @@ class OllamaLargeLanguageModel(LargeLanguageModel): label=I18nObject(en_US="Size of context window"), type=ParameterType.INT, help=I18nObject( - en_US="Sets the size of the context window used to generate the next token. " - "(Default: 2048)" + en_US="Sets the size of the context window used to generate the next token. " "(Default: 2048)" ), default=2048, min=1, ), ParameterRule( - name='num_gpu', + name="num_gpu", label=I18nObject(en_US="GPU Layers"), type=ParameterType.INT, - help=I18nObject(en_US="The number of layers to offload to the GPU(s). " - "On macOS it defaults to 1 to enable metal support, 0 to disable." - "As long as a model fits into one gpu it stays in one. " - "It does not set the number of GPU(s). "), + help=I18nObject( + en_US="The number of layers to offload to the GPU(s). " + "On macOS it defaults to 1 to enable metal support, 0 to disable." + "As long as a model fits into one gpu it stays in one. " + "It does not set the number of GPU(s). " + ), min=-1, - default=1 + default=1, ), ParameterRule( name="num_thread", @@ -688,8 +649,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): label=I18nObject(en_US="Format"), type=ParameterType.STRING, help=I18nObject( - en_US="the format to return a response in." - " Currently the only accepted value is json." + en_US="the format to return a response in." " Currently the only accepted value is json." ), options=["json"], ), diff --git a/api/core/model_runtime/model_providers/ollama/ollama.py b/api/core/model_runtime/model_providers/ollama/ollama.py index f8a17b98a0..115280193a 100644 --- a/api/core/model_runtime/model_providers/ollama/ollama.py +++ b/api/core/model_runtime/model_providers/ollama/ollama.py @@ -6,7 +6,6 @@ logger = logging.getLogger(__name__) class OpenAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials diff --git a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py index 9e26d35afc..2cfb79b241 100644 --- a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py @@ -37,9 +37,9 @@ class OllamaEmbeddingModel(TextEmbeddingModel): Model class for an Ollama text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -51,15 +51,13 @@ class OllamaEmbeddingModel(TextEmbeddingModel): """ # Prepare headers and payload for the request - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - endpoint_url = credentials.get('base_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("base_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'api/embed') + endpoint_url = urljoin(endpoint_url, "api/embed") # get model properties context_size = self._get_context_size(model, credentials) @@ -74,45 +72,29 @@ class OllamaEmbeddingModel(TextEmbeddingModel): if num_tokens >= context_size: cutoff = int(np.floor(len(text) * (context_size / num_tokens))) # if num tokens is larger than context length, only use the start - inputs.append(text[0: cutoff]) + inputs.append(text[0:cutoff]) else: inputs.append(text) # Prepare the payload for the request - payload = { - 'input': inputs, - 'model': model, - } + payload = {"input": inputs, "model": model, "options": {"use_mmap": True}} - # Make the request to the OpenAI API - response = requests.post( - endpoint_url, - headers=headers, - data=json.dumps(payload), - timeout=(10, 300) - ) + # Make the request to the Ollama API + response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) response.raise_for_status() # Raise an exception for HTTP errors response_data = response.json() # Extract embeddings and used tokens from the response - embeddings = response_data['embeddings'] + embeddings = response_data["embeddings"] embedding_used_tokens = self.get_num_tokens(model, credentials, inputs) used_tokens += embedding_used_tokens # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) - return TextEmbeddingResult( - embeddings=embeddings, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -134,19 +116,15 @@ class OllamaEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke( - model=model, - credentials=credentials, - texts=['ping'] - ) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeError as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}') + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {ex.description}") except Exception as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}") def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, @@ -154,15 +132,15 @@ class OllamaEmbeddingModel(TextEmbeddingModel): model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity @@ -178,10 +156,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -192,7 +167,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -220,10 +195,10 @@ class OllamaEmbeddingModel(TextEmbeddingModel): ], InvokeServerUnavailableError: [ requests.exceptions.ConnectionError, # Engine Overloaded - requests.exceptions.HTTPError # Server Error + requests.exceptions.HTTPError, # Server Error ], InvokeConnectionError: [ requests.exceptions.ConnectTimeout, # Timeout - requests.exceptions.ReadTimeout # Timeout - ] + requests.exceptions.ReadTimeout, # Timeout + ], } diff --git a/api/core/model_runtime/model_providers/openai/_common.py b/api/core/model_runtime/model_providers/openai/_common.py index 467a51daf2..2181bb4f08 100644 --- a/api/core/model_runtime/model_providers/openai/_common.py +++ b/api/core/model_runtime/model_providers/openai/_common.py @@ -22,7 +22,7 @@ class _CommonOpenAI: :return: """ credentials_kwargs = { - "api_key": credentials['openai_api_key'], + "api_key": credentials["openai_api_key"], "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "max_retries": 1, } @@ -31,8 +31,8 @@ class _CommonOpenAI: openai_api_base = credentials["openai_api_base"].rstrip("/") credentials_kwargs["base_url"] = openai_api_base + "/v1" - if 'openai_organization' in credentials: - credentials_kwargs['organization'] = credentials['openai_organization'] + if "openai_organization" in credentials: + credentials_kwargs["organization"] = credentials["openai_organization"] return credentials_kwargs diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 06135c9584..5950b77a96 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -39,16 +39,23 @@ if you are not sure about the structure. """ + class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): """ Model class for OpenAI large language model. """ - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -64,8 +71,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): """ # handle fine tune remote models base_model = model - if model.startswith('ft:'): - base_model = model.split(':')[1] + if model.startswith("ft:"): + base_model = model.split(":")[1] # get model mode model_mode = self.get_model_mode(base_model, credentials) @@ -80,7 +87,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): tools=tools, stop=stop, stream=stream, - user=user + user=user, ) else: # text completion model @@ -91,26 +98,34 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) - def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: list[Callback] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper for invoking large language model """ # handle fine tune remote models base_model = model - if model.startswith('ft:'): - base_model = model.split(':')[1] + if model.startswith("ft:"): + base_model = model.split(":")[1] # get model mode model_mode = self.get_model_mode(base_model, credentials) # transform response format - if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']: + if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: stop = stop or [] if model_mode == LLMMode.CHAT: # chat model @@ -123,7 +138,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): stop=stop, stream=stream, user=user, - response_format=model_parameters['response_format'] + response_format=model_parameters["response_format"], ) else: self._transform_completion_json_prompts( @@ -135,9 +150,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): stop=stop, stream=stream, user=user, - response_format=model_parameters['response_format'] + response_format=model_parameters["response_format"], ) - model_parameters.pop('response_format') + model_parameters.pop("response_format") return self._invoke( model=model, @@ -147,14 +162,21 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): tools=tools, stop=stop, stream=stream, - user=user + user=user, ) - def _transform_chat_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ - -> None: + def _transform_chat_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ Transform json prompts """ @@ -167,25 +189,35 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): # override the system message prompt_messages[0] = SystemPromptMessage( - content=OPENAI_BLOCK_MODE_PROMPT - .replace("{{instructions}}", prompt_messages[0].content) - .replace("{{block}}", response_format) + content=OPENAI_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace( + "{{block}}", response_format + ) ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n")) else: # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=OPENAI_BLOCK_MODE_PROMPT - .replace("{{instructions}}", f"Please output a valid {response_format} object.") - .replace("{{block}}", response_format) - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=OPENAI_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", f"Please output a valid {response_format} object." + ).replace("{{block}}", response_format) + ), + ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) - - def _transform_completion_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ - -> None: + + def _transform_completion_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ Transform json prompts """ @@ -202,25 +234,30 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): break if user_message: - if prompt_messages[i].content[-11:] == 'Assistant: ': + if prompt_messages[i].content[-11:] == "Assistant: ": # now we are in the chat app, remove the last assistant message prompt_messages[i].content = prompt_messages[i].content[:-11] prompt_messages[i] = UserPromptMessage( - content=OPENAI_BLOCK_MODE_PROMPT - .replace("{{instructions}}", user_message.content) - .replace("{{block}}", response_format) + content=OPENAI_BLOCK_MODE_PROMPT.replace("{{instructions}}", user_message.content).replace( + "{{block}}", response_format + ) ) prompt_messages[i].content += f"Assistant:\n```{response_format}\n" else: prompt_messages[i] = UserPromptMessage( - content=OPENAI_BLOCK_MODE_PROMPT - .replace("{{instructions}}", user_message.content) - .replace("{{block}}", response_format) + content=OPENAI_BLOCK_MODE_PROMPT.replace("{{instructions}}", user_message.content).replace( + "{{block}}", response_format + ) ) prompt_messages[i].content += f"\n```{response_format}\n" - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -231,8 +268,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): :return: """ # handle fine tune remote models - if model.startswith('ft:'): - base_model = model.split(':')[1] + if model.startswith("ft:"): + base_model = model.split(":")[1] else: base_model = model @@ -262,14 +299,14 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): # handle fine tune remote models base_model = model # fine-tuned model name likes ft:gpt-3.5-turbo-0613:personal::xxxxx - if model.startswith('ft:'): - base_model = model.split(':')[1] + if model.startswith("ft:"): + base_model = model.split(":")[1] # check if model exists remote_models = self.remote_models(credentials) remote_model_map = {model.model: model for model in remote_models} if model not in remote_model_map: - raise CredentialsValidateFailedError(f'Fine-tuned model {model} not found') + raise CredentialsValidateFailedError(f"Fine-tuned model {model} not found") # get model mode model_mode = self.get_model_mode(base_model, credentials) @@ -277,7 +314,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if model_mode == LLMMode.CHAT: # chat model client.chat.completions.create( - messages=[{"role": "user", "content": 'ping'}], + messages=[{"role": "user", "content": "ping"}], model=model, temperature=0, max_tokens=20, @@ -286,7 +323,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): else: # text completion model client.completions.create( - prompt='ping', + prompt="ping", model=model, temperature=0, max_tokens=20, @@ -313,11 +350,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): # get all remote models remote_models = client.models.list() - fine_tune_models = [model for model in remote_models if model.id.startswith('ft:')] + fine_tune_models = [model for model in remote_models if model.id.startswith("ft:")] ai_model_entities = [] for model in fine_tune_models: - base_model = model.id.split(':')[1] + base_model = model.id.split(":")[1] base_model_schema = None for predefined_model_name, predefined_model in predefined_models_map.items(): @@ -329,30 +366,29 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): ai_model_entity = AIModelEntity( model=model.id, - label=I18nObject( - zh_Hans=model.id, - en_US=model.id - ), + label=I18nObject(zh_Hans=model.id, en_US=model.id), model_type=ModelType.LLM, features=base_model_schema.features, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties=base_model_schema.model_properties, parameter_rules=base_model_schema.parameter_rules, - pricing=PriceConfig( - input=0.003, - output=0.006, - unit=0.001, - currency='USD' - ) + pricing=PriceConfig(input=0.003, output=0.006, unit=0.001, currency="USD"), ) ai_model_entities.append(ai_model_entity) return ai_model_entities - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -374,23 +410,17 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): extra_model_kwargs = {} if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user if stream: - extra_model_kwargs['stream_options'] = { - "include_usage": True - } - + extra_model_kwargs["stream_options"] = {"include_usage": True} + # text completion model response = client.completions.create( - prompt=prompt_messages[0].content, - model=model, - stream=stream, - **model_parameters, - **extra_model_kwargs + prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs ) if stream: @@ -398,8 +428,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: Completion, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: Completion, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm completion response @@ -412,9 +443,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): assistant_text = response.choices[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_text) # calculate num tokens if response.usage: @@ -440,8 +469,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: Stream[Completion], prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm completion stream response @@ -451,7 +481,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator result """ - full_text = '' + full_text = "" prompt_tokens = 0 completion_tokens = 0 @@ -460,8 +490,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage(content=''), - ) + message=AssistantPromptMessage(content=""), + ), ) for chunk in response: @@ -474,14 +504,12 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): delta = chunk.choices[0] - if delta.finish_reason is None and (delta.text is None or delta.text == ''): + if delta.finish_reason is None and (delta.text is None or delta.text == ""): continue # transform assistant message to prompt message - text = delta.text if delta.text else '' - assistant_prompt_message = AssistantPromptMessage( - content=text - ) + text = delta.text if delta.text else "" + assistant_prompt_message = AssistantPromptMessage(content=text) full_text += text @@ -494,7 +522,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - ) + ), ) else: yield LLMResultChunk( @@ -504,7 +532,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, - ) + ), ) if not prompt_tokens: @@ -520,10 +548,17 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): yield final_chunk - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm chat model @@ -552,7 +587,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): try: schema = json.loads(json_schema) except: - raise ValueError(f"not currect json_schema format: {json_schema}") + raise ValueError(f"not correct json_schema format: {json_schema}") model_parameters.pop("json_schema") model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema} else: @@ -562,22 +597,18 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if tools: # extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] - extra_model_kwargs['functions'] = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] + extra_model_kwargs["functions"] = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} for tool in tools + ] if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user if stream: - extra_model_kwargs['stream_options'] = { - 'include_usage': True - } + extra_model_kwargs["stream_options"] = {"include_usage": True} # clear illegal prompt messages prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) @@ -596,9 +627,14 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) - def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> LLMResult: + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: ChatCompletion, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> LLMResult: """ Handle llm chat response @@ -619,10 +655,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): tool_calls = [function_call] if function_call else [] # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) # calculate num tokens if response.usage: @@ -648,9 +681,14 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return response - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: Stream[ChatCompletionChunk], - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Stream[ChatCompletionChunk], + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> Generator: """ Handle llm chat stream response @@ -660,7 +698,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): :param tools: tools for tool calling :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None prompt_tokens = 0 completion_tokens = 0 @@ -670,8 +708,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage(content=''), - ) + message=AssistantPromptMessage(content=""), + ), ) for chunk in response: @@ -685,8 +723,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): delta = chunk.choices[0] has_finish_reason = delta.finish_reason is not None - if not has_finish_reason and (delta.delta.content is None or delta.delta.content == '') and \ - delta.delta.function_call is None: + if ( + not has_finish_reason + and (delta.delta.content is None or delta.delta.content == "") + and delta.delta.function_call is None + ): continue # assistant_message_tool_calls = delta.delta.tool_calls @@ -708,7 +749,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): # start of stream function call delta_assistant_message_function_call_storage = assistant_message_function_call if delta_assistant_message_function_call_storage.arguments is None: - delta_assistant_message_function_call_storage.arguments = '' + delta_assistant_message_function_call_storage.arguments = "" if not has_finish_reason: continue @@ -720,11 +761,10 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=tool_calls + content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls ) - full_assistant_content += delta.delta.content if delta.delta.content else '' + full_assistant_content += delta.delta.content if delta.delta.content else "" if has_finish_reason: final_chunk = LLMResultChunk( @@ -735,7 +775,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - ) + ), ) else: yield LLMResultChunk( @@ -745,7 +785,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, - ) + ), ) if not prompt_tokens: @@ -753,8 +793,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if not completion_tokens: full_assistant_prompt_message = AssistantPromptMessage( - content=full_assistant_content, - tool_calls=final_tool_calls + content=full_assistant_content, tool_calls=final_tool_calls ) completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message]) @@ -764,9 +803,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): yield final_chunk - def _extract_response_tool_calls(self, - response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls( + self, response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -777,21 +816,19 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function + id=response_tool_call.id, type=response_tool_call.type, function=function ) tool_calls.append(tool_call) return tool_calls - def _extract_response_function_call(self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \ - -> AssistantPromptMessage.ToolCall: + def _extract_response_function_call( + self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall + ) -> AssistantPromptMessage.ToolCall: """ Extract function call from response @@ -801,14 +838,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): tool_call = None if response_function_call: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function_call.name, - arguments=response_function_call.arguments + name=response_function_call.name, arguments=response_function_call.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call.name, - type="function", - function=function + id=response_function_call.name, type="function", function=function ) return tool_call @@ -821,7 +855,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): :param prompt_messages: prompt messages :return: cleaned prompt messages """ - checklist = ['gpt-4-turbo', 'gpt-4-turbo-2024-04-09'] + checklist = ["gpt-4-turbo", "gpt-4-turbo-2024-04-09"] if model in checklist: # count how many user messages are there @@ -830,11 +864,16 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): for prompt_message in prompt_messages: if isinstance(prompt_message, UserPromptMessage): if isinstance(prompt_message.content, list): - prompt_message.content = '\n'.join([ - item.data if item.type == PromptMessageContentType.TEXT else - '[IMAGE]' if item.type == PromptMessageContentType.IMAGE else '' - for item in prompt_message.content - ]) + prompt_message.content = "\n".join( + [ + item.data + if item.type == PromptMessageContentType.TEXT + else "[IMAGE]" + if item.type == PromptMessageContentType.IMAGE + else "" + for item in prompt_message.content + ] + ) return prompt_messages @@ -851,19 +890,13 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) @@ -889,11 +922,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): # "content": message.content, # "tool_call_id": message.tool_call_id # } - message_dict = { - "role": "function", - "content": message.content, - "name": message.tool_call_id - } + message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id} else: raise ValueError(f"Got unknown type {message}") @@ -902,8 +931,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return message_dict - def _num_tokens_from_string(self, model: str, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -924,13 +952,14 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return num_tokens - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" - if model.startswith('ft:'): - model = model.split(':')[1] + if model.startswith("ft:"): + model = model.split(":")[1] # Currently, we can use gpt4o to calculate chatgpt-4o-latest's token. if model == "chatgpt-4o-latest": @@ -969,10 +998,10 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -1011,37 +1040,37 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): """ num_tokens = 0 for tool in tools: - num_tokens += len(encoding.encode('type')) - num_tokens += len(encoding.encode('function')) + num_tokens += len(encoding.encode("type")) + num_tokens += len(encoding.encode("function")) # calculate num tokens for function object - num_tokens += len(encoding.encode('name')) + num_tokens += len(encoding.encode("name")) num_tokens += len(encoding.encode(tool.name)) - num_tokens += len(encoding.encode('description')) + num_tokens += len(encoding.encode("description")) num_tokens += len(encoding.encode(tool.description)) parameters = tool.parameters - num_tokens += len(encoding.encode('parameters')) - if 'title' in parameters: - num_tokens += len(encoding.encode('title')) + num_tokens += len(encoding.encode("parameters")) + if "title" in parameters: + num_tokens += len(encoding.encode("title")) num_tokens += len(encoding.encode(parameters.get("title"))) - num_tokens += len(encoding.encode('type')) + num_tokens += len(encoding.encode("type")) num_tokens += len(encoding.encode(parameters.get("type"))) - if 'properties' in parameters: - num_tokens += len(encoding.encode('properties')) - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += len(encoding.encode("properties")) + for key, value in parameters.get("properties").items(): num_tokens += len(encoding.encode(key)) for field_key, field_value in value.items(): num_tokens += len(encoding.encode(field_key)) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += len(encoding.encode(enum_field)) else: num_tokens += len(encoding.encode(field_key)) num_tokens += len(encoding.encode(str(field_value))) - if 'required' in parameters: - num_tokens += len(encoding.encode('required')) - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += len(encoding.encode("required")) + for required_field in parameters["required"]: num_tokens += 3 num_tokens += len(encoding.encode(required_field)) @@ -1049,26 +1078,26 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - OpenAI supports fine-tuning of their models. This method returns the schema of the base model - but renamed to the fine-tuned model name. + OpenAI supports fine-tuning of their models. This method returns the schema of the base model + but renamed to the fine-tuned model name. - :param model: model name - :param credentials: credentials + :param model: model name + :param credentials: credentials - :return: model schema + :return: model schema """ - if not model.startswith('ft:'): + if not model.startswith("ft:"): base_model = model else: # get base_model - base_model = model.split(':')[1] + base_model = model.split(":")[1] # get model schema models = self.predefined_models() model_map = {model.model: model for model in models} if base_model not in model_map: - raise ValueError(f'Base model {base_model} not found') - + raise ValueError(f"Base model {base_model} not found") + base_model_schema = model_map[base_model] base_model_schema_features = base_model_schema.features or [] @@ -1077,16 +1106,13 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): entity = AIModelEntity( model=model, - label=I18nObject( - zh_Hans=model, - en_US=model - ), + label=I18nObject(zh_Hans=model, en_US=model), model_type=ModelType.LLM, features=list(base_model_schema_features), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties=dict(base_model_schema_model_properties.items()), parameter_rules=list(base_model_schema_parameters_rules), - pricing=base_model_schema.pricing + pricing=base_model_schema.pricing, ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/openai/moderation/moderation.py b/api/core/model_runtime/model_providers/openai/moderation/moderation.py index b1d0e57ad2..619044d808 100644 --- a/api/core/model_runtime/model_providers/openai/moderation/moderation.py +++ b/api/core/model_runtime/model_providers/openai/moderation/moderation.py @@ -14,9 +14,7 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel): Model class for OpenAI text moderation model. """ - def _invoke(self, model: str, credentials: dict, - text: str, user: Optional[str] = None) \ - -> bool: + def _invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool: """ Invoke moderation model @@ -34,10 +32,10 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel): # chars per chunk length = self._get_max_characters_per_chunk(model, credentials) - text_chunks = [text[i:i + length] for i in range(0, len(text), length)] + text_chunks = [text[i : i + length] for i in range(0, len(text), length)] max_text_chunks = self._get_max_chunks(model, credentials) - chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)] + chunks = [text_chunks[i : i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)] for text_chunk in chunks: moderation_result = self._moderation_invoke(model=model, client=client, texts=text_chunk) @@ -65,7 +63,7 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel): self._moderation_invoke( model=model, client=client, - texts=['ping'], + texts=["ping"], ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) diff --git a/api/core/model_runtime/model_providers/openai/openai.py b/api/core/model_runtime/model_providers/openai/openai.py index 66efd4797f..175d7db73c 100644 --- a/api/core/model_runtime/model_providers/openai/openai.py +++ b/api/core/model_runtime/model_providers/openai/openai.py @@ -9,7 +9,6 @@ logger = logging.getLogger(__name__) class OpenAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: Mapping) -> None: """ Validate provider credentials @@ -22,12 +21,9 @@ class OpenAIProvider(ModelProvider): # Use `gpt-3.5-turbo` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='gpt-3.5-turbo', - credentials=credentials - ) + model_instance.validate_credentials(model="gpt-3.5-turbo", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py index efbdd054f9..18f97e45f3 100644 --- a/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py @@ -12,9 +12,7 @@ class OpenAISpeech2TextModel(_CommonOpenAI, Speech2TextModel): Model class for OpenAI Speech to text model. """ - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model @@ -37,7 +35,7 @@ class OpenAISpeech2TextModel(_CommonOpenAI, Speech2TextModel): try: audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self._speech2text_invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) diff --git a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py index e23a2edf87..535d8388bc 100644 --- a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py @@ -18,9 +18,9 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): Model class for OpenAI text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -37,9 +37,9 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): extra_model_kwargs = {} if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user - extra_model_kwargs['encoding_format'] = 'base64' + extra_model_kwargs["encoding_format"] = "base64" # get model properties context_size = self._get_context_size(model, credentials) @@ -56,11 +56,9 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): enc = tiktoken.get_encoding("cl100k_base") for i, text in enumerate(texts): - token = enc.encode( - text - ) + token = enc.encode(text) for j in range(0, len(token), context_size): - tokens += [token[j: j + context_size]] + tokens += [token[j : j + context_size]] indices += [i] batched_embeddings = [] @@ -69,10 +67,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): for i in _iter: # call embedding model embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - client=client, - texts=tokens[i: i + max_chunks], - extra_model_kwargs=extra_model_kwargs + model=model, client=client, texts=tokens[i : i + max_chunks], extra_model_kwargs=extra_model_kwargs ) used_tokens += embedding_used_tokens @@ -88,10 +83,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): _result = results[i] if len(_result) == 0: embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - client=client, - texts="", - extra_model_kwargs=extra_model_kwargs + model=model, client=client, texts="", extra_model_kwargs=extra_model_kwargs ) used_tokens += embedding_used_tokens @@ -101,17 +93,9 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): embeddings[i] = (average / np.linalg.norm(average)).tolist() # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) - return TextEmbeddingResult( - embeddings=embeddings, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -152,17 +136,13 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): client = OpenAI(**credentials_kwargs) # call embedding model - self._embedding_invoke( - model=model, - client=client, - texts=['ping'], - extra_model_kwargs={} - ) + self._embedding_invoke(model=model, client=client, texts=["ping"], extra_model_kwargs={}) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _embedding_invoke(self, model: str, client: OpenAI, texts: Union[list[str], str], - extra_model_kwargs: dict) -> tuple[list[list[float]], int]: + def _embedding_invoke( + self, model: str, client: OpenAI, texts: Union[list[str], str], extra_model_kwargs: dict + ) -> tuple[list[list[float]], int]: """ Invoke embedding model @@ -179,10 +159,12 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): **extra_model_kwargs, ) - if 'encoding_format' in extra_model_kwargs and extra_model_kwargs['encoding_format'] == 'base64': + if "encoding_format" in extra_model_kwargs and extra_model_kwargs["encoding_format"] == "base64": # decode base64 embedding - return ([list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data], - response.usage.total_tokens) + return ( + [list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data], + response.usage.total_tokens, + ) return [data.embedding for data in response.data], response.usage.total_tokens @@ -197,10 +179,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -211,7 +190,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/openai/tts/tts.py b/api/core/model_runtime/model_providers/openai/tts/tts.py index afa5d4b88a..bfb443698c 100644 --- a/api/core/model_runtime/model_providers/openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/openai/tts/tts.py @@ -14,8 +14,9 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): Model class for OpenAI Speech to text model. """ - def _invoke(self, model: str, tenant_id: str, credentials: dict, - content_text: str, voice: str, user: Optional[str] = None) -> any: + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ) -> any: """ _invoke text2speech model @@ -28,14 +29,12 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): :return: text translated to audio file """ - if not voice or voice not in [d['value'] for d in - self.get_tts_model_voices(model=model, credentials=credentials)]: + if not voice or voice not in [ + d["value"] for d in self.get_tts_model_voices(model=model, credentials=credentials) + ]: voice = self._get_model_default_voice(model, credentials) # if streaming: - return self._tts_invoke_streaming(model=model, - credentials=credentials, - content_text=content_text, - voice=voice) + return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice) def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: """ @@ -50,14 +49,13 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): self._tts_invoke_streaming( model=model, credentials=credentials, - content_text='Hello Dify!', + content_text="Hello Dify!", voice=self._get_model_default_voice(model, credentials), ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, - voice: str) -> any: + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: """ _tts_invoke_streaming text2speech model @@ -71,31 +69,38 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): # doc: https://platform.openai.com/docs/guides/text-to-speech credentials_kwargs = self._to_credential_kwargs(credentials) client = OpenAI(**credentials_kwargs) - model_support_voice = [x.get("value") for x in - self.get_tts_model_voices(model=model, credentials=credentials)] + model_support_voice = [ + x.get("value") for x in self.get_tts_model_voices(model=model, credentials=credentials) + ] if not voice or voice not in model_support_voice: voice = self._get_model_default_voice(model, credentials) word_limit = self._get_model_word_limit(model, credentials) if len(content_text) > word_limit: sentences = self._split_text_into_sentences(content_text, max_length=word_limit) executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences))) - futures = [executor.submit(client.audio.speech.with_streaming_response.create, model=model, - response_format="mp3", - input=sentences[i], voice=voice) for i in range(len(sentences))] + futures = [ + executor.submit( + client.audio.speech.with_streaming_response.create, + model=model, + response_format="mp3", + input=sentences[i], + voice=voice, + ) + for i in range(len(sentences)) + ] for index, future in enumerate(futures): yield from future.result().__enter__().iter_bytes(1024) else: - response = client.audio.speech.with_streaming_response.create(model=model, voice=voice, - response_format="mp3", - input=content_text.strip()) + response = client.audio.speech.with_streaming_response.create( + model=model, voice=voice, response_format="mp3", input=content_text.strip() + ) yield from response.__enter__().iter_bytes(1024) except Exception as ex: raise InvokeBadRequestError(str(ex)) - def _process_sentence(self, sentence: str, model: str, - voice, credentials: dict): + def _process_sentence(self, sentence: str, model: str, voice, credentials: dict): """ _tts_invoke openai text2speech model api diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py index 51950ca377..1234e44f80 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py @@ -1,4 +1,3 @@ - import requests from core.model_runtime.errors.invoke import ( @@ -11,7 +10,7 @@ from core.model_runtime.errors.invoke import ( ) -class _CommonOAI_API_Compat: +class _CommonOaiApiCompat: @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ @@ -35,10 +34,10 @@ class _CommonOAI_API_Compat: ], InvokeServerUnavailableError: [ requests.exceptions.ConnectionError, # Engine Overloaded - requests.exceptions.HTTPError # Server Error + requests.exceptions.HTTPError, # Server Error ], InvokeConnectionError: [ requests.exceptions.ConnectTimeout, # Timeout - requests.exceptions.ReadTimeout # Timeout - ] - } \ No newline at end of file + requests.exceptions.ReadTimeout, # Timeout + ], + } diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index 2b729d4293..24317b488c 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -35,22 +35,28 @@ from core.model_runtime.entities.model_entities import ( from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat +from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat from core.model_runtime.utils import helper logger = logging.getLogger(__name__) -class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): +class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): """ Model class for OpenAI large language model. """ - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -77,8 +83,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): user=user, ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -99,93 +110,85 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): :return: """ try: - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - endpoint_url = credentials['endpoint_url'] - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials["endpoint_url"] + if not endpoint_url.endswith("/"): + endpoint_url += "/" # prepare the payload for a simple ping to the model - data = { - 'model': model, - 'max_tokens': 5 - } + data = {"model": model, "max_tokens": 5} - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) if completion_type is LLMMode.CHAT: - data['messages'] = [ - { - "role": "user", - "content": "ping" - }, + data["messages"] = [ + {"role": "user", "content": "ping"}, ] - endpoint_url = urljoin(endpoint_url, 'chat/completions') + endpoint_url = urljoin(endpoint_url, "chat/completions") elif completion_type is LLMMode.COMPLETION: - data['prompt'] = 'ping' - endpoint_url = urljoin(endpoint_url, 'completions') + data["prompt"] = "ping" + endpoint_url = urljoin(endpoint_url, "completions") else: raise ValueError("Unsupported completion type for model configuration.") # send a post request to validate the credentials - response = requests.post( - endpoint_url, - headers=headers, - json=data, - timeout=(10, 300) - ) + response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300)) if response.status_code != 200: raise CredentialsValidateFailedError( - f'Credentials validation failed with status code {response.status_code}') + f"Credentials validation failed with status code {response.status_code}" + ) try: json_result = response.json() except json.JSONDecodeError as e: - raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') + raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error") - if (completion_type is LLMMode.CHAT and json_result['object'] == ''): - json_result['object'] = 'chat.completion' - elif (completion_type is LLMMode.COMPLETION and json_result['object'] == ''): - json_result['object'] = 'text_completion' + if completion_type is LLMMode.CHAT and json_result.get("object", "") == "": + json_result["object"] = "chat.completion" + elif completion_type is LLMMode.COMPLETION and json_result.get("object", "") == "": + json_result["object"] = "text_completion" - if (completion_type is LLMMode.CHAT - and ('object' not in json_result or json_result['object'] != 'chat.completion')): + if completion_type is LLMMode.CHAT and ( + "object" not in json_result or json_result["object"] != "chat.completion" + ): raise CredentialsValidateFailedError( - 'Credentials validation failed: invalid response object, must be \'chat.completion\'') - elif (completion_type is LLMMode.COMPLETION - and ('object' not in json_result or json_result['object'] != 'text_completion')): + "Credentials validation failed: invalid response object, must be 'chat.completion'" + ) + elif completion_type is LLMMode.COMPLETION and ( + "object" not in json_result or json_result["object"] != "text_completion" + ): raise CredentialsValidateFailedError( - 'Credentials validation failed: invalid response object, must be \'text_completion\'') + "Credentials validation failed: invalid response object, must be 'text_completion'" + ) except CredentialsValidateFailedError: raise except Exception as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}") def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ features = [] - function_calling_type = credentials.get('function_calling_type', 'no_call') - if function_calling_type in ['function_call']: + function_calling_type = credentials.get("function_calling_type", "no_call") + if function_calling_type in ["function_call"]: features.append(ModelFeature.TOOL_CALL) - elif function_calling_type in ['tool_call']: + elif function_calling_type in ["tool_call"]: features.append(ModelFeature.MULTI_TOOL_CALL) - stream_function_calling = credentials.get('stream_function_calling', 'supported') - if stream_function_calling == 'supported': + stream_function_calling = credentials.get("stream_function_calling", "supported") + if stream_function_calling == "supported": features.append(ModelFeature.STREAM_TOOL_CALL) - vision_support = credentials.get('vision_support', 'not_support') - if vision_support == 'support': + vision_support = credentials.get("vision_support", "not_support") + if vision_support == "support": features.append(ModelFeature.VISION) entity = AIModelEntity( @@ -195,43 +198,43 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, features=features, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")), - ModelPropertyKey.MODE: credentials.get('mode'), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "4096")), + ModelPropertyKey.MODE: credentials.get("mode"), }, parameter_rules=[ ParameterRule( name=DefaultParameterName.TEMPERATURE.value, label=I18nObject(en_US="Temperature"), type=ParameterType.FLOAT, - default=float(credentials.get('temperature', 0.7)), + default=float(credentials.get("temperature", 0.7)), min=0, max=2, - precision=2 + precision=2, ), ParameterRule( name=DefaultParameterName.TOP_P.value, label=I18nObject(en_US="Top P"), type=ParameterType.FLOAT, - default=float(credentials.get('top_p', 1)), + default=float(credentials.get("top_p", 1)), min=0, max=1, - precision=2 + precision=2, ), ParameterRule( name=DefaultParameterName.FREQUENCY_PENALTY.value, label=I18nObject(en_US="Frequency Penalty"), type=ParameterType.FLOAT, - default=float(credentials.get('frequency_penalty', 0)), + default=float(credentials.get("frequency_penalty", 0)), min=-2, - max=2 + max=2, ), ParameterRule( name=DefaultParameterName.PRESENCE_PENALTY.value, label=I18nObject(en_US="Presence Penalty"), type=ParameterType.FLOAT, - default=float(credentials.get('presence_penalty', 0)), + default=float(credentials.get("presence_penalty", 0)), min=-2, - max=2 + max=2, ), ParameterRule( name=DefaultParameterName.MAX_TOKENS.value, @@ -239,20 +242,20 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): type=ParameterType.INT, default=512, min=1, - max=int(credentials.get('max_tokens_to_sample', 4096)), - ) + max=int(credentials.get("max_tokens_to_sample", 4096)), + ), ], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - output=Decimal(credentials.get('output_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") + input=Decimal(credentials.get("input_price", 0)), + output=Decimal(credentials.get("output_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), ), ) - if credentials['mode'] == 'chat': + if credentials["mode"] == "chat": entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value - elif credentials['mode'] == 'completion': + elif credentials["mode"] == "completion": entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value else: raise ValueError(f"Unknown completion type {credentials['completion_type']}") @@ -260,10 +263,17 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): return entity # validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard. - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, \ - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -277,52 +287,47 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): :return: full response or stream response chunk generator result """ headers = { - 'Content-Type': 'application/json', - 'Accept-Charset': 'utf-8', + "Content-Type": "application/json", + "Accept-Charset": "utf-8", } - extra_headers = credentials.get('extra_headers') + extra_headers = credentials.get("extra_headers") if extra_headers is not None: headers = { - **headers, - **extra_headers, + **headers, + **extra_headers, } - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" endpoint_url = credentials["endpoint_url"] - if not endpoint_url.endswith('/'): - endpoint_url += '/' + if not endpoint_url.endswith("/"): + endpoint_url += "/" - data = { - "model": model, - "stream": stream, - **model_parameters - } + data = {"model": model, "stream": stream, **model_parameters} - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) if completion_type is LLMMode.CHAT: - endpoint_url = urljoin(endpoint_url, 'chat/completions') - data['messages'] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages] + endpoint_url = urljoin(endpoint_url, "chat/completions") + data["messages"] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages] elif completion_type is LLMMode.COMPLETION: - endpoint_url = urljoin(endpoint_url, 'completions') - data['prompt'] = prompt_messages[0].content + endpoint_url = urljoin(endpoint_url, "completions") + data["prompt"] = prompt_messages[0].content else: raise ValueError("Unsupported completion type for model configuration.") # annotate tools with names, descriptions, etc. - function_calling_type = credentials.get('function_calling_type', 'no_call') + function_calling_type = credentials.get("function_calling_type", "no_call") formatted_tools = [] if tools: - if function_calling_type == 'function_call': - data['functions'] = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] - elif function_calling_type == 'tool_call': + if function_calling_type == "function_call": + data["functions"] = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} + for tool in tools + ] + elif function_calling_type == "tool_call": data["tool_choice"] = "auto" for tool in tools: @@ -336,16 +341,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): if user: data["user"] = user - response = requests.post( - endpoint_url, - headers=headers, - json=data, - timeout=(10, 300), - stream=stream - ) + response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream) - if response.encoding is None or response.encoding == 'ISO-8859-1': - response.encoding = 'utf-8' + if response.encoding is None or response.encoding == "ISO-8859-1": + response.encoding = "utf-8" if response.status_code != 200: raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}") @@ -355,8 +354,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -366,11 +366,12 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" chunk_index = 0 - def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \ - -> LLMResultChunk: + def create_final_llm_result_chunk( + index: int, message: AssistantPromptMessage, finish_reason: str + ) -> LLMResultChunk: # calculate num tokens prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) completion_tokens = self._num_tokens_from_string(model, full_assistant_content) @@ -381,16 +382,12 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): return LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=message, - finish_reason=finish_reason, - usage=usage - ) + delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage), ) # delimiter for stream response, need unicode_escape import codecs + delimiter = credentials.get("stream_mode_delimiter", "\n\n") delimiter = codecs.decode(delimiter, "unicode_escape") @@ -406,10 +403,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): tool_call = AssistantPromptMessage.ToolCall( id=tool_call_id, type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name="", - arguments="" - ) + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""), ) tools_calls.append(tool_call) @@ -434,10 +428,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): chunk = chunk.strip() if chunk: # ignore sse comments - if chunk.startswith(':'): + if chunk.startswith(":"): continue - decoded_chunk = chunk.strip().lstrip('data: ').lstrip() - if decoded_chunk == '[DONE]': # Some provider returns "data: [DONE]" + decoded_chunk = chunk.strip().lstrip("data: ").lstrip() + if decoded_chunk == "[DONE]": # Some provider returns "data: [DONE]" continue try: @@ -447,30 +441,31 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): yield create_final_llm_result_chunk( index=chunk_index + 1, message=AssistantPromptMessage(content=""), - finish_reason="Non-JSON encountered." + finish_reason="Non-JSON encountered.", ) break - if not chunk_json or len(chunk_json['choices']) == 0: + if not chunk_json or len(chunk_json["choices"]) == 0: continue - choice = chunk_json['choices'][0] - finish_reason = chunk_json['choices'][0].get('finish_reason') + choice = chunk_json["choices"][0] + finish_reason = chunk_json["choices"][0].get("finish_reason") chunk_index += 1 - if 'delta' in choice: - delta = choice['delta'] - delta_content = delta.get('content') + if "delta" in choice: + delta = choice["delta"] + delta_content = delta.get("content") assistant_message_tool_calls = None - if 'tool_calls' in delta and credentials.get('function_calling_type', 'no_call') == 'tool_call': - assistant_message_tool_calls = delta.get('tool_calls', None) - elif 'function_call' in delta and credentials.get('function_calling_type', 'no_call') == 'function_call': - assistant_message_tool_calls = [{ - 'id': 'tool_call_id', - 'type': 'function', - 'function': delta.get('function_call', {}) - }] + if "tool_calls" in delta and credentials.get("function_calling_type", "no_call") == "tool_call": + assistant_message_tool_calls = delta.get("tool_calls", None) + elif ( + "function_call" in delta + and credentials.get("function_calling_type", "no_call") == "function_call" + ): + assistant_message_tool_calls = [ + {"id": "tool_call_id", "type": "function", "function": delta.get("function_call", {})} + ] # assistant_message_function_call = delta.delta.function_call @@ -479,7 +474,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) increase_tool_call(tool_calls) - if delta_content is None or delta_content == '': + if delta_content is None or delta_content == "": continue # transform assistant message to prompt message @@ -490,9 +485,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): # reset tool calls tool_calls = [] full_assistant_content += delta_content - elif 'text' in choice: - choice_text = choice.get('text', '') - if choice_text == '': + elif "text" in choice: + choice_text = choice.get("text", "") + if choice_text == "": continue # transform assistant message to prompt message @@ -507,7 +502,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): delta=LLMResultChunkDelta( index=chunk_index, message=assistant_prompt_message, - ) + ), ) chunk_index += 1 @@ -518,47 +513,42 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=chunk_index, - message=AssistantPromptMessage( - tool_calls=tools_calls, - content="" - ), - ) + message=AssistantPromptMessage(tool_calls=tools_calls, content=""), + ), ) yield create_final_llm_result_chunk( - index=chunk_index, - message=AssistantPromptMessage(content=""), - finish_reason=finish_reason + index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason ) - def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response, - prompt_messages: list[PromptMessage]) -> LLMResult: - + def _handle_generate_response( + self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] + ) -> LLMResult: response_json = response.json() - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) - output = response_json['choices'][0] + output = response_json["choices"][0] - response_content = '' + response_content = "" tool_calls = None - function_calling_type = credentials.get('function_calling_type', 'no_call') + function_calling_type = credentials.get("function_calling_type", "no_call") if completion_type is LLMMode.CHAT: - response_content = output.get('message', {})['content'] - if function_calling_type == 'tool_call': - tool_calls = output.get('message', {}).get('tool_calls') - elif function_calling_type == 'function_call': - tool_calls = output.get('message', {}).get('function_call') + response_content = output.get("message", {})["content"] + if function_calling_type == "tool_call": + tool_calls = output.get("message", {}).get("tool_calls") + elif function_calling_type == "function_call": + tool_calls = output.get("message", {}).get("function_call") elif completion_type is LLMMode.COMPLETION: - response_content = output['text'] + response_content = output["text"] assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[]) if tool_calls: - if function_calling_type == 'tool_call': + if function_calling_type == "tool_call": assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls) - elif function_calling_type == 'function_call': + elif function_calling_type == "function_call": assistant_message.tool_calls = [self._extract_response_function_call(tool_calls)] usage = response_json.get("usage") @@ -597,19 +587,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) @@ -618,11 +602,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): message = cast(AssistantPromptMessage, message) message_dict = {"role": "assistant", "content": message.content} if message.tool_calls: - function_calling_type = credentials.get('function_calling_type', 'no_call') - if function_calling_type == 'tool_call': - message_dict["tool_calls"] = [tool_call.dict() for tool_call in - message.tool_calls] - elif function_calling_type == 'function_call': + function_calling_type = credentials.get("function_calling_type", "no_call") + if function_calling_type == "tool_call": + message_dict["tool_calls"] = [tool_call.dict() for tool_call in message.tool_calls] + elif function_calling_type == "function_call": function_call = message.tool_calls[0] message_dict["function_call"] = { "name": function_call.function.name, @@ -633,19 +616,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): message_dict = {"role": "system", "content": message.content} elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) - function_calling_type = credentials.get('function_calling_type', 'no_call') - if function_calling_type == 'tool_call': - message_dict = { - "role": "tool", - "content": message.content, - "tool_call_id": message.tool_call_id - } - elif function_calling_type == 'function_call': - message_dict = { - "role": "function", - "content": message.content, - "name": message.tool_call_id - } + function_calling_type = credentials.get("function_calling_type", "no_call") + if function_calling_type == "tool_call": + message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id} + elif function_calling_type == "function_call": + message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id} else: raise ValueError(f"Got unknown type {message}") @@ -654,8 +629,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): return message_dict - def _num_tokens_from_string(self, model: str, text: Union[str, list[PromptMessageContent]], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string( + self, model: str, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """ Approximate num tokens for model with gpt2 tokenizer. @@ -667,7 +643,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): if isinstance(text, str): full_text = text else: - full_text = '' + full_text = "" for message_content in text: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) @@ -680,8 +656,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): return num_tokens - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, credentials: dict = None) -> int: + def _num_tokens_from_messages( + self, + model: str, + messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + credentials: dict = None, + ) -> int: """ Approximate num tokens with GPT2 tokenizer. """ @@ -700,10 +681,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -741,46 +722,44 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): """ num_tokens = 0 for tool in tools: - num_tokens += self._get_num_tokens_by_gpt2('type') - num_tokens += self._get_num_tokens_by_gpt2('function') - num_tokens += self._get_num_tokens_by_gpt2('function') + num_tokens += self._get_num_tokens_by_gpt2("type") + num_tokens += self._get_num_tokens_by_gpt2("function") + num_tokens += self._get_num_tokens_by_gpt2("function") # calculate num tokens for function object - num_tokens += self._get_num_tokens_by_gpt2('name') + num_tokens += self._get_num_tokens_by_gpt2("name") num_tokens += self._get_num_tokens_by_gpt2(tool.name) - num_tokens += self._get_num_tokens_by_gpt2('description') + num_tokens += self._get_num_tokens_by_gpt2("description") num_tokens += self._get_num_tokens_by_gpt2(tool.description) parameters = tool.parameters - num_tokens += self._get_num_tokens_by_gpt2('parameters') - if 'title' in parameters: - num_tokens += self._get_num_tokens_by_gpt2('title') + num_tokens += self._get_num_tokens_by_gpt2("parameters") + if "title" in parameters: + num_tokens += self._get_num_tokens_by_gpt2("title") num_tokens += self._get_num_tokens_by_gpt2(parameters.get("title")) - num_tokens += self._get_num_tokens_by_gpt2('type') + num_tokens += self._get_num_tokens_by_gpt2("type") num_tokens += self._get_num_tokens_by_gpt2(parameters.get("type")) - if 'properties' in parameters: - num_tokens += self._get_num_tokens_by_gpt2('properties') - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += self._get_num_tokens_by_gpt2("properties") + for key, value in parameters.get("properties").items(): num_tokens += self._get_num_tokens_by_gpt2(key) for field_key, field_value in value.items(): num_tokens += self._get_num_tokens_by_gpt2(field_key) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += self._get_num_tokens_by_gpt2(enum_field) else: num_tokens += self._get_num_tokens_by_gpt2(field_key) num_tokens += self._get_num_tokens_by_gpt2(str(field_value)) - if 'required' in parameters: - num_tokens += self._get_num_tokens_by_gpt2('required') - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += self._get_num_tokens_by_gpt2("required") + for required_field in parameters["required"]: num_tokens += 3 num_tokens += self._get_num_tokens_by_gpt2(required_field) return num_tokens - def _extract_response_tool_calls(self, - response_tool_calls: list[dict]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -792,20 +771,17 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( name=response_tool_call.get("function", {}).get("name", ""), - arguments=response_tool_call.get("function", {}).get("arguments", "") + arguments=response_tool_call.get("function", {}).get("arguments", ""), ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.get("id", ""), - type=response_tool_call.get("type", ""), - function=function + id=response_tool_call.get("id", ""), type=response_tool_call.get("type", ""), function=function ) tool_calls.append(tool_call) return tool_calls - def _extract_response_function_call(self, response_function_call) \ - -> AssistantPromptMessage.ToolCall: + def _extract_response_function_call(self, response_function_call) -> AssistantPromptMessage.ToolCall: """ Extract function call from response @@ -815,14 +791,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): tool_call = None if response_function_call: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function_call.get('name', ''), - arguments=response_function_call.get('arguments', '') + name=response_function_call.get("name", ""), arguments=response_function_call.get("arguments", "") ) tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call.get('id', ''), - type="function", - function=function + id=response_function_call.get("id", ""), type="function", function=function ) return tool_call diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py index 3445ebbaf7..ca6f185287 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py @@ -6,6 +6,5 @@ logger = logging.getLogger(__name__) class OAICompatProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py index 00702ba936..405096578c 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py @@ -6,17 +6,15 @@ import requests from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat +from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat -class OAICompatSpeech2TextModel(_CommonOAI_API_Compat, Speech2TextModel): +class OAICompatSpeech2TextModel(_CommonOaiApiCompat, Speech2TextModel): """ Model class for OpenAI Compatible Speech to text model. """ - def _invoke( - self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None - ) -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py index 363054b084..e83cfdf873 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py @@ -19,17 +19,17 @@ from core.model_runtime.entities.model_entities import ( from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat +from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat -class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): +class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel): """ Model class for an OpenAI API-compatible text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -39,27 +39,25 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - - # Prepare headers and payload for the request - headers = { - 'Content-Type': 'application/json' - } - api_key = credentials.get('api_key') + # Prepare headers and payload for the request + headers = {"Content-Type": "application/json"} + + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - endpoint_url = credentials.get('endpoint_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("endpoint_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'embeddings') + endpoint_url = urljoin(endpoint_url, "embeddings") extra_model_kwargs = {} if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user - extra_model_kwargs['encoding_format'] = 'float' + extra_model_kwargs["encoding_format"] = "float" # get model properties context_size = self._get_context_size(model, credentials) @@ -70,7 +68,6 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): used_tokens = 0 for i, text in enumerate(texts): - # Here token count is only an approximation based on the GPT2 tokenizer # TODO: Optimize for better token estimation and chunking num_tokens = self._get_num_tokens_by_gpt2(text) @@ -78,7 +75,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): if num_tokens >= context_size: cutoff = int(np.floor(len(text) * (context_size / num_tokens))) # if num tokens is larger than context length, only use the start - inputs.append(text[0: cutoff]) + inputs.append(text[0:cutoff]) else: inputs.append(text) indices += [i] @@ -88,42 +85,25 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): for i in _iter: # Prepare the payload for the request - payload = { - 'input': inputs[i: i + max_chunks], - 'model': model, - **extra_model_kwargs - } + payload = {"input": inputs[i : i + max_chunks], "model": model, **extra_model_kwargs} # Make the request to the OpenAI API - response = requests.post( - endpoint_url, - headers=headers, - data=json.dumps(payload), - timeout=(10, 300) - ) + response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) response.raise_for_status() # Raise an exception for HTTP errors response_data = response.json() # Extract embeddings and used tokens from the response - embeddings_batch = [data['embedding'] for data in response_data['data']] - embedding_used_tokens = response_data['usage']['total_tokens'] + embeddings_batch = [data["embedding"] for data in response_data["data"]] + embedding_used_tokens = response_data["usage"]["total_tokens"] used_tokens += embedding_used_tokens batched_embeddings += embeddings_batch # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) - - return TextEmbeddingResult( - embeddings=batched_embeddings, - usage=usage, - model=model - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) + + return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -145,45 +125,35 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): :return: """ try: - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - endpoint_url = credentials.get('endpoint_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("endpoint_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'embeddings') + endpoint_url = urljoin(endpoint_url, "embeddings") - payload = { - 'input': 'ping', - 'model': model - } + payload = {"input": "ping", "model": model} - response = requests.post( - url=endpoint_url, - headers=headers, - data=json.dumps(payload), - timeout=(10, 300) - ) + response = requests.post(url=endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) if response.status_code != 200: raise CredentialsValidateFailedError( - f'Credentials validation failed with status code {response.status_code}') + f"Credentials validation failed with status code {response.status_code}" + ) try: json_result = response.json() except json.JSONDecodeError as e: - raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') + raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error") - if 'model' not in json_result: - raise CredentialsValidateFailedError( - 'Credentials validation failed: invalid response') + if "model" not in json_result: + raise CredentialsValidateFailedError("Credentials validation failed: invalid response") except CredentialsValidateFailedError: raise except Exception as ex: @@ -191,7 +161,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, @@ -199,20 +169,19 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -224,10 +193,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -238,7 +204,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/openllm/llm/llm.py b/api/core/model_runtime/model_providers/openllm/llm/llm.py index 8ea5819bde..b560afca39 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/llm.py +++ b/api/core/model_runtime/model_providers/openllm/llm/llm.py @@ -38,88 +38,115 @@ from core.model_runtime.model_providers.openllm.llm.openllm_generate_errors impo class OpenLLMLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate credentials for Baichuan model """ - if not credentials.get('server_url'): - raise CredentialsValidateFailedError('Invalid server URL') + if not credentials.get("server_url"): + raise CredentialsValidateFailedError("Invalid server URL") # ping instance = OpenLLMGenerate() try: instance.generate( - server_url=credentials['server_url'], - model_name=model, - prompt_messages=[ - OpenLLMGenerateMessage(content='ping\nAnswer: ', role='user') - ], + server_url=credentials["server_url"], + model_name=model, + prompt_messages=[OpenLLMGenerateMessage(content="ping\nAnswer: ", role="user")], model_parameters={ - 'max_tokens': 64, - 'temperature': 0.8, - 'top_p': 0.9, - 'top_k': 15, + "max_tokens": 64, + "temperature": 0.8, + "top_p": 0.9, + "top_k": 15, }, stream=False, - user='', + user="", stop=[], ) except InvalidAuthenticationError as e: raise CredentialsValidateFailedError(f"Invalid API key: {e}") - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: return self._num_tokens_from_messages(prompt_messages, tools) def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: """ - Calculate num tokens for OpenLLM model - it's a generate model, so we just join them by spe + Calculate num tokens for OpenLLM model + it's a generate model, so we just join them by spe """ - messages = ','.join([message.content for message in messages]) + messages = ",".join([message.content for message in messages]) return self._get_num_tokens_by_gpt2(messages) - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: client = OpenLLMGenerate() response = client.generate( model_name=model, - server_url=credentials['server_url'], + server_url=credentials["server_url"], prompt_messages=[self._convert_prompt_message_to_openllm_message(message) for message in prompt_messages], model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) if stream: - return self._handle_chat_generate_stream_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response) - return self._handle_chat_generate_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response) + return self._handle_chat_generate_stream_response( + model=model, prompt_messages=prompt_messages, credentials=credentials, response=response + ) + return self._handle_chat_generate_response( + model=model, prompt_messages=prompt_messages, credentials=credentials, response=response + ) def _convert_prompt_message_to_openllm_message(self, prompt_message: PromptMessage) -> OpenLLMGenerateMessage: """ - convert PromptMessage to OpenLLMGenerateMessage so that we can use OpenLLMGenerateMessage interface + convert PromptMessage to OpenLLMGenerateMessage so that we can use OpenLLMGenerateMessage interface """ if isinstance(prompt_message, UserPromptMessage): return OpenLLMGenerateMessage(role=OpenLLMGenerateMessage.Role.USER.value, content=prompt_message.content) elif isinstance(prompt_message, AssistantPromptMessage): - return OpenLLMGenerateMessage(role=OpenLLMGenerateMessage.Role.ASSISTANT.value, content=prompt_message.content) + return OpenLLMGenerateMessage( + role=OpenLLMGenerateMessage.Role.ASSISTANT.value, content=prompt_message.content + ) else: - raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported') + raise NotImplementedError(f"Prompt message type {type(prompt_message)} is not supported") - def _handle_chat_generate_response(self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: OpenLLMGenerateMessage) -> LLMResult: - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=response.usage['prompt_tokens'], - completion_tokens=response.usage['completion_tokens'] - ) + def _handle_chat_generate_response( + self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: OpenLLMGenerateMessage + ) -> LLMResult: + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=response.usage["prompt_tokens"], + completion_tokens=response.usage["completion_tokens"], + ) return LLMResult( model=model, prompt_messages=prompt_messages, @@ -130,25 +157,27 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): usage=usage, ) - def _handle_chat_generate_stream_response(self, model: str, prompt_messages: list[PromptMessage], - credentials: dict, response: Generator[OpenLLMGenerateMessage, None, None]) \ - -> Generator[LLMResultChunk, None, None]: + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Generator[OpenLLMGenerateMessage, None, None], + ) -> Generator[LLMResultChunk, None, None]: for message in response: if message.usage: usage = self._calc_response_usage( - model=model, credentials=credentials, - prompt_tokens=message.usage['prompt_tokens'], - completion_tokens=message.usage['completion_tokens'] + model=model, + credentials=credentials, + prompt_tokens=message.usage["prompt_tokens"], + completion_tokens=message.usage["completion_tokens"], ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=message.content, tool_calls=[]), usage=usage, finish_reason=message.stop_reason if message.stop_reason else None, ), @@ -159,73 +188,55 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=message.content, tool_calls=[]), finish_reason=message.stop_reason if message.stop_reason else None, ), ) - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ) + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='top_k', + name="top_k", type=ParameterType.INT, - use_template='top_k', + use_template="top_k", min=1, default=1, - label=I18nObject( - zh_Hans='Top K', - en_US='Top K' - ) + label=I18nObject(zh_Hans="Top K", en_US="Top K"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), + ), ] entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, - model_properties={ + model_properties={ ModelPropertyKey.MODE: LLMMode.COMPLETION.value, }, - parameter_rules=rules + parameter_rules=rules, ) return entity @@ -241,22 +252,13 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, InsufficientAccountBalanceError, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } - diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py index 1c3f084207..e754479ec0 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py +++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py @@ -15,32 +15,38 @@ from core.model_runtime.model_providers.openllm.llm.openllm_generate_errors impo class OpenLLMGenerateMessage: class Role(Enum): - USER = 'user' - ASSISTANT = 'assistant' + USER = "user" + ASSISTANT = "assistant" role: str = Role.USER.value content: str usage: dict[str, int] = None - stop_reason: str = '' + stop_reason: str = "" def to_dict(self) -> dict[str, Any]: return { - 'role': self.role, - 'content': self.content, + "role": self.role, + "content": self.content, } - - def __init__(self, content: str, role: str = 'user') -> None: + + def __init__(self, content: str, role: str = "user") -> None: self.content = content self.role = role class OpenLLMGenerate: def generate( - self, server_url: str, model_name: str, stream: bool, model_parameters: dict[str, Any], - stop: list[str], prompt_messages: list[OpenLLMGenerateMessage], user: str, + self, + server_url: str, + model_name: str, + stream: bool, + model_parameters: dict[str, Any], + stop: list[str], + prompt_messages: list[OpenLLMGenerateMessage], + user: str, ) -> Union[Generator[OpenLLMGenerateMessage, None, None], OpenLLMGenerateMessage]: if not server_url: - raise InvalidAuthenticationError('Invalid server URL') + raise InvalidAuthenticationError("Invalid server URL") default_llm_config = { "max_new_tokens": 128, @@ -72,40 +78,37 @@ class OpenLLMGenerate: "frequency_penalty": 0, "use_beam_search": False, "ignore_eos": False, - "skip_special_tokens": True + "skip_special_tokens": True, } - if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int: - default_llm_config['max_new_tokens'] = model_parameters['max_tokens'] + if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int: + default_llm_config["max_new_tokens"] = model_parameters["max_tokens"] - if 'temperature' in model_parameters and type(model_parameters['temperature']) == float: - default_llm_config['temperature'] = model_parameters['temperature'] + if "temperature" in model_parameters and type(model_parameters["temperature"]) == float: + default_llm_config["temperature"] = model_parameters["temperature"] - if 'top_p' in model_parameters and type(model_parameters['top_p']) == float: - default_llm_config['top_p'] = model_parameters['top_p'] + if "top_p" in model_parameters and type(model_parameters["top_p"]) == float: + default_llm_config["top_p"] = model_parameters["top_p"] - if 'top_k' in model_parameters and type(model_parameters['top_k']) == int: - default_llm_config['top_k'] = model_parameters['top_k'] + if "top_k" in model_parameters and type(model_parameters["top_k"]) == int: + default_llm_config["top_k"] = model_parameters["top_k"] - if 'use_cache' in model_parameters and type(model_parameters['use_cache']) == bool: - default_llm_config['use_cache'] = model_parameters['use_cache'] + if "use_cache" in model_parameters and type(model_parameters["use_cache"]) == bool: + default_llm_config["use_cache"] = model_parameters["use_cache"] - headers = { - 'Content-Type': 'application/json', - 'accept': 'application/json' - } + headers = {"Content-Type": "application/json", "accept": "application/json"} if stream: - url = f'{server_url}/v1/generate_stream' + url = f"{server_url}/v1/generate_stream" timeout = 10 else: - url = f'{server_url}/v1/generate' + url = f"{server_url}/v1/generate" timeout = 120 data = { - 'stop': stop if stop else [], - 'prompt': '\n'.join([message.content for message in prompt_messages]), - 'llm_config': default_llm_config, + "stop": stop if stop else [], + "prompt": "\n".join([message.content for message in prompt_messages]), + "llm_config": default_llm_config, } try: @@ -113,10 +116,10 @@ class OpenLLMGenerate: except (ConnectionError, InvalidSchema, MissingSchema) as e: # cloud not connect to the server raise InvalidAuthenticationError(f"Invalid server URL: {e}") - + if not response.ok: resp = response.json() - msg = resp['msg'] + msg = resp["msg"] if response.status_code == 400: raise BadRequestError(msg) elif response.status_code == 404: @@ -125,69 +128,71 @@ class OpenLLMGenerate: raise InternalServerError(msg) else: raise InternalServerError(msg) - + if stream: return self._handle_chat_stream_generate_response(response) return self._handle_chat_generate_response(response) - + def _handle_chat_generate_response(self, response: Response) -> OpenLLMGenerateMessage: try: data = response.json() except Exception as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") - message = data['outputs'][0] - text = message['text'] - token_ids = message['token_ids'] - prompt_token_ids = data['prompt_token_ids'] - stop_reason = message['finish_reason'] + message = data["outputs"][0] + text = message["text"] + token_ids = message["token_ids"] + prompt_token_ids = data["prompt_token_ids"] + stop_reason = message["finish_reason"] message = OpenLLMGenerateMessage(content=text, role=OpenLLMGenerateMessage.Role.ASSISTANT.value) message.stop_reason = stop_reason message.usage = { - 'prompt_tokens': len(prompt_token_ids), - 'completion_tokens': len(token_ids), - 'total_tokens': len(prompt_token_ids) + len(token_ids), + "prompt_tokens": len(prompt_token_ids), + "completion_tokens": len(token_ids), + "total_tokens": len(prompt_token_ids) + len(token_ids), } return message - def _handle_chat_stream_generate_response(self, response: Response) -> Generator[OpenLLMGenerateMessage, None, None]: + def _handle_chat_stream_generate_response( + self, response: Response + ) -> Generator[OpenLLMGenerateMessage, None, None]: completion_usage = 0 for line in response.iter_lines(): if not line: continue - line: str = line.decode('utf-8') - if line.startswith('data: '): + line: str = line.decode("utf-8") + if line.startswith("data: "): line = line[6:].strip() - if line == '[DONE]': + if line == "[DONE]": return try: data = loads(line) except Exception as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {line}") - - output = data['outputs'] + + output = data["outputs"] for choice in output: - text = choice['text'] - token_ids = choice['token_ids'] + text = choice["text"] + token_ids = choice["token_ids"] completion_usage += len(token_ids) message = OpenLLMGenerateMessage(content=text, role=OpenLLMGenerateMessage.Role.ASSISTANT.value) - if choice.get('finish_reason'): - finish_reason = choice['finish_reason'] - prompt_token_ids = data['prompt_token_ids'] + if choice.get("finish_reason"): + finish_reason = choice["finish_reason"] + prompt_token_ids = data["prompt_token_ids"] message.stop_reason = finish_reason message.usage = { - 'prompt_tokens': len(prompt_token_ids), - 'completion_tokens': completion_usage, - 'total_tokens': completion_usage + len(prompt_token_ids), + "prompt_tokens": len(prompt_token_ids), + "completion_tokens": completion_usage, + "total_tokens": completion_usage + len(prompt_token_ids), } - - yield message \ No newline at end of file + + yield message diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py index d9d279e6ca..309b5cf413 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py +++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py @@ -1,17 +1,22 @@ class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass + class InsufficientAccountBalanceError(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py index 4dbd0678e7..00e583cc79 100644 --- a/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py @@ -23,9 +23,10 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): """ Model class for OpenLLM text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -35,16 +36,13 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - server_url = credentials['server_url'] + server_url = credentials["server_url"] if not server_url: - raise CredentialsValidateFailedError('server_url is required') - - headers = { - 'Content-Type': 'application/json', - 'accept': 'application/json' - } + raise CredentialsValidateFailedError("server_url is required") - url = f'{server_url}/v1/embeddings' + headers = {"Content-Type": "application/json", "accept": "application/json"} + + url = f"{server_url}/v1/embeddings" data = texts try: @@ -54,7 +52,7 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): raise InvokeAuthorizationError(f"Invalid server URL: {e}") except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: if response.status_code == 400: raise InvokeBadRequestError(response.text) @@ -62,21 +60,17 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): raise InvokeAuthorizationError(response.text) elif response.status_code == 500: raise InvokeServerUnavailableError(response.text) - + try: resp = response.json()[0] - embeddings = resp['embeddings'] - total_tokens = resp['num_tokens'] + embeddings = resp["embeddings"] + total_tokens = resp["num_tokens"] except KeyError as e: raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") usage = self._calc_response_usage(model=model, credentials=credentials, tokens=total_tokens) - result = TextEmbeddingResult( - model=model, - embeddings=embeddings, - usage=usage - ) + result = TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage) return result @@ -104,9 +98,9 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeAuthorizationError: - raise CredentialsValidateFailedError('Invalid server_url') + raise CredentialsValidateFailedError("Invalid server_url") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: @@ -119,23 +113,13 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -147,10 +131,7 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -161,7 +142,7 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llm.py b/api/core/model_runtime/model_providers/openrouter/llm/llm.py index e78ac4caf1..71b5745f7d 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llm.py +++ b/api/core/model_runtime/model_providers/openrouter/llm/llm.py @@ -8,18 +8,23 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _update_credential(self, model: str, credentials: dict): - credentials['endpoint_url'] = "https://openrouter.ai/api/v1" - credentials['mode'] = self.get_model_mode(model).value - credentials['function_calling_type'] = 'tool_call' + credentials["endpoint_url"] = "https://openrouter.ai/api/v1" + credentials["mode"] = self.get_model_mode(model).value + credentials["function_calling_type"] = "tool_call" return - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._update_credential(model, credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -29,9 +34,17 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel): return super().validate_credentials(model, credentials) - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._update_credential(model, credentials) return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -41,8 +54,13 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel): return super().get_customizable_model_schema(model, credentials) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: self._update_credential(model, credentials) return super().get_num_tokens(model, credentials, prompt_messages, tools) diff --git a/api/core/model_runtime/model_providers/openrouter/openrouter.py b/api/core/model_runtime/model_providers/openrouter/openrouter.py index 613f71deb1..2e59ab5059 100644 --- a/api/core/model_runtime/model_providers/openrouter/openrouter.py +++ b/api/core/model_runtime/model_providers/openrouter/openrouter.py @@ -8,17 +8,13 @@ logger = logging.getLogger(__name__) class OpenRouterProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='openai/gpt-3.5-turbo', - credentials=credentials - ) + model_instance.validate_credentials(model="openai/gpt-3.5-turbo", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') - raise ex \ No newline at end of file + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise ex diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/llm.py b/api/core/model_runtime/model_providers/perfxcloud/llm/llm.py index c9116bf685..89cac665aa 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/llm.py +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/llm.py @@ -13,11 +13,17 @@ from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguag class PerfXCloudLargeLanguageModel(OpenAILargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -27,8 +33,7 @@ class PerfXCloudLargeLanguageModel(OpenAILargeLanguageModel): super().validate_credentials(model, credentials) # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_string(self, model: str, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -46,8 +51,9 @@ class PerfXCloudLargeLanguageModel(OpenAILargeLanguageModel): return num_tokens # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/ @@ -67,10 +73,10 @@ class PerfXCloudLargeLanguageModel(OpenAILargeLanguageModel): # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -101,10 +107,10 @@ class PerfXCloudLargeLanguageModel(OpenAILargeLanguageModel): @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['openai_api_key']=credentials['api_key'] - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - credentials['openai_api_base']='https://cloud.perfxlab.cn' + credentials["mode"] = "chat" + credentials["openai_api_key"] = credentials["api_key"] + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + credentials["openai_api_base"] = "https://cloud.perfxlab.cn" else: - parsed_url = urlparse(credentials['endpoint_url']) - credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}" + parsed_url = urlparse(credentials["endpoint_url"]) + credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}" diff --git a/api/core/model_runtime/model_providers/perfxcloud/perfxcloud.py b/api/core/model_runtime/model_providers/perfxcloud/perfxcloud.py index 0854ef5185..450d22fb75 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/perfxcloud.py +++ b/api/core/model_runtime/model_providers/perfxcloud/perfxcloud.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class PerfXCloudProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -21,12 +20,9 @@ class PerfXCloudProvider(ModelProvider): # Use `Qwen2_72B_Chat_GPTQ_Int4` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='Qwen2-72B-Instruct-GPTQ-Int4', - credentials=credentials - ) + model_instance.validate_credentials(model="Qwen2-72B-Instruct-GPTQ-Int4", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py index 11d57e3749..b62a2d2aaf 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py @@ -19,17 +19,17 @@ from core.model_runtime.entities.model_entities import ( from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat +from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat -class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): +class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel): """ Model class for an OpenAI API-compatible text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -39,30 +39,28 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - - # Prepare headers and payload for the request - headers = { - 'Content-Type': 'application/json' - } - api_key = credentials.get('api_key') + # Prepare headers and payload for the request + headers = {"Content-Type": "application/json"} + + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - endpoint_url='https://cloud.perfxlab.cn/v1/' + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + endpoint_url = "https://cloud.perfxlab.cn/v1/" else: - endpoint_url = credentials.get('endpoint_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("endpoint_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'embeddings') + endpoint_url = urljoin(endpoint_url, "embeddings") extra_model_kwargs = {} if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user - extra_model_kwargs['encoding_format'] = 'float' + extra_model_kwargs["encoding_format"] = "float" # get model properties context_size = self._get_context_size(model, credentials) @@ -73,7 +71,6 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): used_tokens = 0 for i, text in enumerate(texts): - # Here token count is only an approximation based on the GPT2 tokenizer # TODO: Optimize for better token estimation and chunking num_tokens = self._get_num_tokens_by_gpt2(text) @@ -81,7 +78,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): if num_tokens >= context_size: cutoff = int(np.floor(len(text) * (context_size / num_tokens))) # if num tokens is larger than context length, only use the start - inputs.append(text[0: cutoff]) + inputs.append(text[0:cutoff]) else: inputs.append(text) indices += [i] @@ -91,42 +88,25 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): for i in _iter: # Prepare the payload for the request - payload = { - 'input': inputs[i: i + max_chunks], - 'model': model, - **extra_model_kwargs - } + payload = {"input": inputs[i : i + max_chunks], "model": model, **extra_model_kwargs} # Make the request to the OpenAI API - response = requests.post( - endpoint_url, - headers=headers, - data=json.dumps(payload), - timeout=(10, 300) - ) + response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) response.raise_for_status() # Raise an exception for HTTP errors response_data = response.json() # Extract embeddings and used tokens from the response - embeddings_batch = [data['embedding'] for data in response_data['data']] - embedding_used_tokens = response_data['usage']['total_tokens'] + embeddings_batch = [data["embedding"] for data in response_data["data"]] + embedding_used_tokens = response_data["usage"]["total_tokens"] used_tokens += embedding_used_tokens batched_embeddings += embeddings_batch # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) - - return TextEmbeddingResult( - embeddings=batched_embeddings, - usage=usage, - model=model - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) + + return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -148,48 +128,38 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): :return: """ try: - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - endpoint_url='https://cloud.perfxlab.cn/v1/' + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + endpoint_url = "https://cloud.perfxlab.cn/v1/" else: - endpoint_url = credentials.get('endpoint_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("endpoint_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'embeddings') + endpoint_url = urljoin(endpoint_url, "embeddings") - payload = { - 'input': 'ping', - 'model': model - } + payload = {"input": "ping", "model": model} - response = requests.post( - url=endpoint_url, - headers=headers, - data=json.dumps(payload), - timeout=(10, 300) - ) + response = requests.post(url=endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) if response.status_code != 200: raise CredentialsValidateFailedError( - f'Credentials validation failed with status code {response.status_code}') + f"Credentials validation failed with status code {response.status_code}" + ) try: json_result = response.json() except json.JSONDecodeError as e: - raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') + raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error") - if 'model' not in json_result: - raise CredentialsValidateFailedError( - 'Credentials validation failed: invalid response') + if "model" not in json_result: + raise CredentialsValidateFailedError("Credentials validation failed: invalid response") except CredentialsValidateFailedError: raise except Exception as ex: @@ -197,7 +167,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, @@ -205,20 +175,19 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -230,10 +199,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -244,7 +210,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/replicate/_common.py b/api/core/model_runtime/model_providers/replicate/_common.py index 29d8427d8e..915f6e0eef 100644 --- a/api/core/model_runtime/model_providers/replicate/_common.py +++ b/api/core/model_runtime/model_providers/replicate/_common.py @@ -4,12 +4,6 @@ from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError class _CommonReplicate: - @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: - return { - InvokeBadRequestError: [ - ReplicateError, - ModelError - ] - } + return {InvokeBadRequestError: [ReplicateError, ModelError]} diff --git a/api/core/model_runtime/model_providers/replicate/llm/llm.py b/api/core/model_runtime/model_providers/replicate/llm/llm.py index 31b81a829e..87c8bc4a91 100644 --- a/api/core/model_runtime/model_providers/replicate/llm/llm.py +++ b/api/core/model_runtime/model_providers/replicate/llm/llm.py @@ -28,16 +28,22 @@ from core.model_runtime.model_providers.replicate._common import _CommonReplicat class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + model_version = "" + if "model_version" in credentials: + model_version = credentials["model_version"] - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: - - model_version = '' - if 'model_version' in credentials: - model_version = credentials['model_version'] - - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) + client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30) model_info = client.models.get(model) if model_version: @@ -48,39 +54,43 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): inputs = {**model_parameters} if prompt_messages[0].role == PromptMessageRole.SYSTEM: - if 'system_prompt' in model_info_version.openapi_schema['components']['schemas']['Input']['properties']: - inputs['system_prompt'] = prompt_messages[0].content - inputs['prompt'] = prompt_messages[1].content + if "system_prompt" in model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"]: + inputs["system_prompt"] = prompt_messages[0].content + inputs["prompt"] = prompt_messages[1].content else: - inputs['prompt'] = prompt_messages[0].content + inputs["prompt"] = prompt_messages[0].content - prediction = client.predictions.create( - version=model_info_version, input=inputs - ) + prediction = client.predictions.create(version=model_info_version, input=inputs) if stream: return self._handle_generate_stream_response(model, credentials, prediction, stop, prompt_messages) return self._handle_generate_response(model, credentials, prediction, stop, prompt_messages) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: prompt = self._convert_messages_to_prompt(prompt_messages) return self._get_num_tokens_by_gpt2(prompt) def validate_credentials(self, model: str, credentials: dict) -> None: - if 'replicate_api_token' not in credentials: - raise CredentialsValidateFailedError('Replicate Access Token must be provided.') + if "replicate_api_token" not in credentials: + raise CredentialsValidateFailedError("Replicate Access Token must be provided.") - model_version = '' - if 'model_version' in credentials: - model_version = credentials['model_version'] + model_version = "" + if "model_version" in credentials: + model_version = credentials["model_version"] if model.count("/") != 1: - raise CredentialsValidateFailedError('Replicate Model Name must be provided, ' - 'format: {user_name}/{model_name}') + raise CredentialsValidateFailedError( + "Replicate Model Name must be provided, " "format: {user_name}/{model_name}" + ) try: - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) + client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30) model_info = client.models.get(model) if model_version: @@ -91,45 +101,44 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): self._check_text_generation_model(model_info_version, model, model_version, model_info.description) except ReplicateError as e: raise CredentialsValidateFailedError( - f"Model {model}:{model_version} not exists, cause: {e.__class__.__name__}:{str(e)}") + f"Model {model}:{model_version} not exists, cause: {e.__class__.__name__}:{str(e)}" + ) except Exception as e: raise CredentialsValidateFailedError(str(e)) @staticmethod def _check_text_generation_model(model_info_version, model_name, version, description): - if 'language model' in description.lower(): + if "language model" in description.lower(): return - if 'temperature' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \ - or 'top_p' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \ - or 'top_k' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties']: + if ( + "temperature" not in model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"] + or "top_p" not in model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"] + or "top_k" not in model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"] + ): raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Text Generation model.") def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - model_type = LLMMode.CHAT if model.endswith('-chat') else LLMMode.COMPLETION + model_type = LLMMode.CHAT if model.endswith("-chat") else LLMMode.COMPLETION entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, - model_properties={ - ModelPropertyKey.MODE: model_type.value - }, - parameter_rules=self._get_customizable_model_parameter_rules(model, credentials) + model_properties={ModelPropertyKey.MODE: model_type.value}, + parameter_rules=self._get_customizable_model_parameter_rules(model, credentials), ) return entity @classmethod def _get_customizable_model_parameter_rules(cls, model: str, credentials: dict) -> list[ParameterRule]: - model_version = '' - if 'model_version' in credentials: - model_version = credentials['model_version'] + model_version = "" + if "model_version" in credentials: + model_version = credentials["model_version"] - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) + client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30) model_info = client.models.get(model) if model_version: @@ -140,15 +149,13 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): parameter_rules = [] input_properties = sorted( - model_info_version.openapi_schema["components"]["schemas"]["Input"][ - "properties" - ].items(), + model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"].items(), key=lambda item: item[1].get("x-order", 0), ) for key, value in input_properties: - if key not in ['system_prompt', 'prompt'] and 'stop' not in key: - value_type = value.get('type') + if key not in ["system_prompt", "prompt"] and "stop" not in key: + value_type = value.get("type") if not value_type: continue @@ -157,28 +164,28 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): rule = ParameterRule( name=key, - label={ - 'en_US': value['title'] - }, + label={"en_US": value["title"]}, type=param_type, help={ - 'en_US': value.get('description'), + "en_US": value.get("description"), }, required=False, - default=value.get('default'), - min=value.get('minimum'), - max=value.get('maximum') + default=value.get("default"), + min=value.get("minimum"), + max=value.get("maximum"), ) parameter_rules.append(rule) return parameter_rules - def _handle_generate_stream_response(self, - model: str, - credentials: dict, - prediction: Prediction, - stop: list[str], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + prediction: Prediction, + stop: list[str], + prompt_messages: list[PromptMessage], + ) -> Generator: index = -1 current_completion: str = "" stop_condition_reached = False @@ -189,7 +196,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): for output in prediction.output_iterator(): current_completion += output - if not is_prediction_output_finished and prediction.status == 'succeeded': + if not is_prediction_output_finished and prediction.status == "succeeded": prediction_output_length = len(prediction.output) - 1 is_prediction_output_finished = True @@ -207,18 +214,13 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): index += 1 - assistant_prompt_message = AssistantPromptMessage( - content=output if output else '' - ) + assistant_prompt_message = AssistantPromptMessage(content=output if output else "") if index < prediction_output_length: yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) else: prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -229,15 +231,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - usage=usage - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message, usage=usage), ) - def _handle_generate_response(self, model: str, credentials: dict, prediction: Prediction, stop: list[str], - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, + model: str, + credentials: dict, + prediction: Prediction, + stop: list[str], + prompt_messages: list[PromptMessage], + ) -> LLMResult: current_completion: str = "" stop_condition_reached = False for output in prediction.output_iterator(): @@ -255,9 +259,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): if stop_condition_reached: break - assistant_prompt_message = AssistantPromptMessage( - content=current_completion - ) + assistant_prompt_message = AssistantPromptMessage(content=current_completion) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) @@ -275,21 +277,13 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): @classmethod def _get_parameter_type(cls, param_type: str) -> str: - type_mapping = { - 'integer': 'int', - 'number': 'float', - 'boolean': 'boolean', - 'string': 'string' - } + type_mapping = {"integer": "int", "number": "float", "boolean": "boolean", "string": "string"} return type_mapping.get(param_type) def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) return text.rstrip() diff --git a/api/core/model_runtime/model_providers/replicate/replicate.py b/api/core/model_runtime/model_providers/replicate/replicate.py index 3a5c9b84a0..ca137579c9 100644 --- a/api/core/model_runtime/model_providers/replicate/replicate.py +++ b/api/core/model_runtime/model_providers/replicate/replicate.py @@ -6,6 +6,5 @@ logger = logging.getLogger(__name__) class ReplicateProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py index 0e4cdbf5bc..f6b7754d74 100644 --- a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py @@ -13,32 +13,27 @@ from core.model_runtime.model_providers.replicate._common import _CommonReplicat class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): - def _invoke(self, model: str, credentials: dict, texts: list[str], - user: Optional[str] = None) -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: + client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30) - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) - - if 'model_version' in credentials: - model_version = credentials['model_version'] + if "model_version" in credentials: + model_version = credentials["model_version"] else: model_info = client.models.get(model) model_version = model_info.latest_version.id - replicate_model_version = f'{model}:{model_version}' + replicate_model_version = f"{model}:{model_version}" text_input_key = self._get_text_input_key(model, model_version, client) - embeddings = self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key, - texts) + embeddings = self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key, texts) tokens = self.get_num_tokens(model, credentials, texts) usage = self._calc_response_usage(model, credentials, tokens) - return TextEmbeddingResult( - model=model, - embeddings=embeddings, - usage=usage - ) + return TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: num_tokens = 0 @@ -47,39 +42,35 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): return num_tokens def validate_credentials(self, model: str, credentials: dict) -> None: - if 'replicate_api_token' not in credentials: - raise CredentialsValidateFailedError('Replicate Access Token must be provided.') + if "replicate_api_token" not in credentials: + raise CredentialsValidateFailedError("Replicate Access Token must be provided.") try: - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) + client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30) - if 'model_version' in credentials: - model_version = credentials['model_version'] + if "model_version" in credentials: + model_version = credentials["model_version"] else: model_info = client.models.get(model) model_version = model_info.latest_version.id - replicate_model_version = f'{model}:{model_version}' + replicate_model_version = f"{model}:{model_version}" text_input_key = self._get_text_input_key(model, model_version, client) - self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key, - ['Hello worlds!']) + self._generate_embeddings_by_text_input_key( + client, replicate_model_version, text_input_key, ["Hello worlds!"] + ) except Exception as e: raise CredentialsValidateFailedError(str(e)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, - model_properties={ - 'context_size': 4096, - 'max_chunks': 1 - } + model_properties={"context_size": 4096, "max_chunks": 1}, ) return entity @@ -90,49 +81,45 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): # sort through the openapi schema to get the name of text, texts or inputs input_properties = sorted( - model_info_version.openapi_schema["components"]["schemas"]["Input"][ - "properties" - ].items(), + model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"].items(), key=lambda item: item[1].get("x-order", 0), ) for input_property in input_properties: - if input_property[0] in ('text', 'texts', 'inputs'): + if input_property[0] in ("text", "texts", "inputs"): text_input_key = input_property[0] return text_input_key - return '' + return "" @staticmethod - def _generate_embeddings_by_text_input_key(client: ReplicateClient, replicate_model_version: str, - text_input_key: str, texts: list[str]) -> list[list[float]]: - - if text_input_key in ('text', 'inputs'): + def _generate_embeddings_by_text_input_key( + client: ReplicateClient, replicate_model_version: str, text_input_key: str, texts: list[str] + ) -> list[list[float]]: + if text_input_key in ("text", "inputs"): embeddings = [] for text in texts: - result = client.run(replicate_model_version, input={ - text_input_key: text - }) - embeddings.append(result[0].get('embedding')) + result = client.run(replicate_model_version, input={text_input_key: text}) + embeddings.append(result[0].get("embedding")) return [list(map(float, e)) for e in embeddings] - elif 'texts' == text_input_key: - result = client.run(replicate_model_version, input={ - 'texts': json.dumps(texts), - "batch_size": 4, - "convert_to_numpy": False, - "normalize_embeddings": True - }) + elif "texts" == text_input_key: + result = client.run( + replicate_model_version, + input={ + "texts": json.dumps(texts), + "batch_size": 4, + "convert_to_numpy": False, + "normalize_embeddings": True, + }, + ) return result else: - raise ValueError(f'embeddings input key is invalid: {text_input_key}') + raise ValueError(f"embeddings input key is invalid: {text_input_key}") def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -143,7 +130,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py index f8e7757a96..2edd13d56d 100644 --- a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py +++ b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py @@ -1,17 +1,36 @@ import json import logging -from collections.abc import Generator -from typing import Any, Optional, Union +import re +from collections.abc import Generator, Iterator +from typing import Any, Optional, Union, cast +# from openai.types.chat import ChatCompletion, ChatCompletionChunk import boto3 +from sagemaker import Predictor, serializers +from sagemaker.session import Session -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, + ImagePromptMessageContent, PromptMessage, + PromptMessageContent, + PromptMessageContentType, PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + I18nObject, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, ) -from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType from core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, @@ -26,17 +45,160 @@ from core.model_runtime.model_providers.__base.large_language_model import Large logger = logging.getLogger(__name__) +def inference(predictor, messages: list[dict[str, Any]], params: dict[str, Any], stop: list, stream=False): + """ + params: + predictor : Sagemaker Predictor + messages (List[Dict[str,Any]]): message list。 + messages = [ + {"role": "system", "content":"please answer in Chinese"}, + {"role": "user", "content": "who are you? what are you doing?"}, + ] + params (Dict[str,Any]): model parameters for LLM。 + stream (bool): False by default。 + + response: + result of inference if stream is False + Iterator of Chunks if stream is True + """ + payload = { + "model": params.get("model_name"), + "stop": stop, + "messages": messages, + "stream": stream, + "max_tokens": params.get("max_new_tokens", params.get("max_tokens", 2048)), + "temperature": params.get("temperature", 0.1), + "top_p": params.get("top_p", 0.9), + } + + if not stream: + response = predictor.predict(payload) + return response + else: + response_stream = predictor.predict_stream(payload) + return response_stream + + class SageMakerLargeLanguageModel(LargeLanguageModel): """ Model class for Cohere large language model. """ - sagemaker_client: Any = None - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + sagemaker_client: Any = None + sagemaker_sess: Any = None + predictor: Any = None + + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: bytes, + ) -> LLMResult: + """ + handle normal chat generate response + """ + resp_obj = json.loads(resp.decode("utf-8")) + resp_str = resp_obj.get("choices")[0].get("message").get("content") + + if len(resp_str) == 0: + raise InvokeServerUnavailableError("Empty response") + + assistant_prompt_message = AssistantPromptMessage(content=resp_str, tool_calls=[]) + + prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) + completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) + + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) + + response = LLMResult( + model=model, + prompt_messages=prompt_messages, + system_fingerprint=None, + usage=usage, + message=assistant_prompt_message, + ) + + return response + + def _handle_chat_stream_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Iterator[bytes], + ) -> Generator: + """ + handle stream chat generate response + """ + full_response = "" + buffer = "" + for chunk_bytes in resp: + buffer += chunk_bytes.decode("utf-8") + last_idx = 0 + for match in re.finditer(r"^data:\s*(.+?)(\n\n)", buffer): + try: + data = json.loads(match.group(1).strip()) + last_idx = match.span()[1] + + if "content" in data["choices"][0]["delta"]: + chunk_content = data["choices"][0]["delta"]["content"] + assistant_prompt_message = AssistantPromptMessage(content=chunk_content, tool_calls=[]) + + if data["choices"][0]["finish_reason"] is not None: + temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[]) + prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) + completion_tokens = self._num_tokens_from_messages( + messages=[temp_assistant_prompt_message], tools=[] + ) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + system_fingerprint=None, + delta=LLMResultChunkDelta( + index=0, + message=assistant_prompt_message, + finish_reason=data["choices"][0]["finish_reason"], + usage=usage, + ), + ) + else: + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + system_fingerprint=None, + delta=LLMResultChunkDelta(index=0, message=assistant_prompt_message), + ) + + full_response += chunk_content + except (json.JSONDecodeError, KeyError, IndexError) as e: + logger.info("json parse exception, content: {}".format(match.group(1).strip())) + pass + + buffer = buffer[last_idx:] + + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -50,58 +212,153 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): :param user: unique user id :return: full response or stream response chunk generator result """ - # get model mode - model_mode = self.get_model_mode(model, credentials) - if not self.sagemaker_client: - access_key = credentials.get('access_key') - secret_key = credentials.get('secret_key') - aws_region = credentials.get('aws_region') + access_key = credentials.get("access_key") + secret_key = credentials.get("secret_key") + aws_region = credentials.get("aws_region") if aws_region: if access_key and secret_key: - self.sagemaker_client = boto3.client("sagemaker-runtime", + self.sagemaker_client = boto3.client( + "sagemaker-runtime", aws_access_key_id=access_key, aws_secret_access_key=secret_key, - region_name=aws_region) + region_name=aws_region, + ) else: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) else: self.sagemaker_client = boto3.client("sagemaker-runtime") + sagemaker_session = Session(sagemaker_runtime_client=self.sagemaker_client) + self.predictor = Predictor( + endpoint_name=credentials.get("sagemaker_endpoint"), + sagemaker_session=sagemaker_session, + serializer=serializers.JSONSerializer(), + ) - sagemaker_endpoint = credentials.get('sagemaker_endpoint') - response_model = self.sagemaker_client.invoke_endpoint( - EndpointName=sagemaker_endpoint, - Body=json.dumps( - { - "inputs": prompt_messages[0].content, - "parameters": { "stop" : stop}, - "history" : [] - } - ), - ContentType="application/json", - ) - - assistant_text = response_model['Body'].read().decode('utf8') - - # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text + messages: list[dict[str, Any]] = [{"role": p.role.value, "content": p.content} for p in prompt_messages] + response = inference( + predictor=self.predictor, messages=messages, params=model_parameters, stop=stop, stream=stream ) - usage = self._calc_response_usage(model, credentials, 0, 0) + if stream: + if tools and len(tools) > 0: + raise InvokeBadRequestError(f"{model}'s tool calls does not support stream mode") - response = LLMResult( - model=model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage + return self._handle_chat_stream_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=response + ) + return self._handle_chat_generate_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=response ) - return response + def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: + """ + Convert PromptMessage to dict for OpenAI Compatibility API + """ + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": "user", "content": message.content} + else: + sub_messages = [] + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(PromptMessageContent, message_content) + sub_message_dict = {"type": "text", "text": message_content.data} + sub_messages.append(sub_message_dict) + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast(ImagePromptMessageContent, message_content) + sub_message_dict = { + "type": "image_url", + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, + } + sub_messages.append(sub_message_dict) + message_dict = {"role": "user", "content": sub_messages} + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {"role": "assistant", "content": message.content} + if message.tool_calls and len(message.tool_calls) > 0: + message_dict["function_call"] = { + "name": message.tool_calls[0].function.name, + "arguments": message.tool_calls[0].function.arguments, + } + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, ToolPromptMessage): + message = cast(ToolPromptMessage, message) + message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content} + else: + raise ValueError(f"Unknown message type {type(message)}") - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + return message_dict + + def _num_tokens_from_messages( + self, messages: list[PromptMessage], tools: list[PromptMessageTool], is_completion_model: bool = False + ) -> int: + def tokens(text: str): + return self._get_num_tokens_by_gpt2(text) + + if is_completion_model: + return sum(tokens(str(message.content)) for message in messages) + + tokens_per_message = 3 + tokens_per_name = 1 + + num_tokens = 0 + messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages] + for message in messages_dict: + num_tokens += tokens_per_message + for key, value in message.items(): + if isinstance(value, list): + text = "" + for item in value: + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] + + value = text + + if key == "tool_calls": + for tool_call in value: + for t_key, t_value in tool_call.items(): + num_tokens += tokens(t_key) + if t_key == "function": + for f_key, f_value in t_value.items(): + num_tokens += tokens(f_key) + num_tokens += tokens(f_value) + else: + num_tokens += tokens(t_key) + num_tokens += tokens(t_value) + if key == "function_call": + for t_key, t_value in value.items(): + num_tokens += tokens(t_key) + if t_key == "function": + for f_key, f_value in t_value.items(): + num_tokens += tokens(f_key) + num_tokens += tokens(f_value) + else: + num_tokens += tokens(t_key) + num_tokens += tokens(t_value) + else: + num_tokens += tokens(str(value)) + + if key == "name": + num_tokens += tokens_per_name + num_tokens += 3 + + if tools: + num_tokens += self._num_tokens_for_tools(tools) + + return num_tokens + + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -112,10 +369,8 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): :return: """ # get model mode - model_mode = self.get_model_mode(model) - try: - return 0 + return self._num_tokens_from_messages(prompt_messages, tools) except Exception as e: raise self._transform_invoke_error(e) @@ -129,7 +384,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): """ try: # get model mode - model_mode = self.get_model_mode(model) + pass except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -144,95 +399,63 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ), + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, - max=credentials.get('context_length', 2048), + max=credentials.get("context_length", 2048), default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), + ), ] - completion_type = LLMMode.value_of(credentials["mode"]) - - if completion_type == LLMMode.CHAT: - print(f"completion_type : {LLMMode.CHAT.value}") - - if completion_type == LLMMode.COMPLETION: - print(f"completion_type : {LLMMode.COMPLETION.value}") + completion_type = LLMMode.value_of(credentials["mode"]).value features = [] - support_function_call = credentials.get('support_function_call', False) + support_function_call = credentials.get("support_function_call", False) if support_function_call: features.append(ModelFeature.TOOL_CALL) - support_vision = credentials.get('support_vision', False) + support_vision = credentials.get("support_vision", False) if support_vision: features.append(ModelFeature.VISION) - context_length = credentials.get('context_length', 2048) + context_length = credentials.get("context_length", 2048) entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, features=features, - model_properties={ - ModelPropertyKey.MODE: completion_type, - ModelPropertyKey.CONTEXT_SIZE: context_length - }, - parameter_rules=rules + model_properties={ModelPropertyKey.MODE: completion_type, ModelPropertyKey.CONTEXT_SIZE: context_length}, + parameter_rules=rules, ) return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py index 0b06f54ef1..7e7614055c 100644 --- a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py @@ -20,34 +20,36 @@ from core.model_runtime.model_providers.__base.rerank_model import RerankModel logger = logging.getLogger(__name__) + class SageMakerRerankModel(RerankModel): """ - Model class for Cohere rerank model. + Model class for SageMaker rerank model. """ + sagemaker_client: Any = None - def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str): - inputs = [query_input]*len(docs) + def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str): + inputs = [query_input] * len(docs) response_model = self.sagemaker_client.invoke_endpoint( EndpointName=rerank_endpoint, - Body=json.dumps( - { - "inputs": inputs, - "docs": docs - } - ), + Body=json.dumps({"inputs": inputs, "docs": docs}), ContentType="application/json", ) - json_str = response_model['Body'].read().decode('utf8') + json_str = response_model["Body"].read().decode("utf8") json_obj = json.loads(json_str) - scores = json_obj['scores'] + scores = json_obj["scores"] return scores if isinstance(scores, list) else [scores] - - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -63,22 +65,21 @@ class SageMakerRerankModel(RerankModel): line = 0 try: if len(docs) == 0: - return RerankResult( - model=model, - docs=docs - ) + return RerankResult(model=model, docs=docs) line = 1 if not self.sagemaker_client: - access_key = credentials.get('aws_access_key_id') - secret_key = credentials.get('aws_secret_access_key') - aws_region = credentials.get('aws_region') + access_key = credentials.get("aws_access_key_id") + secret_key = credentials.get("aws_secret_access_key") + aws_region = credentials.get("aws_region") if aws_region: if access_key and secret_key: - self.sagemaker_client = boto3.client("sagemaker-runtime", + self.sagemaker_client = boto3.client( + "sagemaker-runtime", aws_access_key_id=access_key, aws_secret_access_key=secret_key, - region_name=aws_region) + region_name=aws_region, + ) else: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) else: @@ -86,22 +87,20 @@ class SageMakerRerankModel(RerankModel): line = 2 - sagemaker_endpoint = credentials.get('sagemaker_endpoint') + sagemaker_endpoint = credentials.get("sagemaker_endpoint") candidate_docs = [] scores = self._sagemaker_rerank(query, docs, sagemaker_endpoint) for idx in range(len(scores)): - candidate_docs.append({"content" : docs[idx], "score": scores[idx]}) + candidate_docs.append({"content": docs[idx], "score": scores[idx]}) - sorted(candidate_docs, key=lambda x: x['score'], reverse=True) + sorted(candidate_docs, key=lambda x: x["score"], reverse=True) line = 3 rerank_documents = [] for idx, result in enumerate(candidate_docs): rerank_document = RerankDocument( - index=idx, - text=result.get('content'), - score=result.get('score', -100.0) + index=idx, text=result.get("content"), score=result.get("score", -100.0) ) if score_threshold is not None: @@ -110,13 +109,10 @@ class SageMakerRerankModel(RerankModel): else: rerank_documents.append(rerank_document) - return RerankResult( - model=model, - docs=rerank_documents - ) + return RerankResult(model=model, docs=rerank_documents) except Exception as e: - logger.exception(f'Exception {e}, line : {line}') + logger.exception(f"Exception {e}, line : {line}") def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -137,7 +133,7 @@ class SageMakerRerankModel(RerankModel): "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -153,38 +149,24 @@ class SageMakerRerankModel(RerankModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.RERANK, - model_properties={ }, - parameter_rules=[] + model_properties={}, + parameter_rules=[], ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/sagemaker.py b/api/core/model_runtime/model_providers/sagemaker/sagemaker.py index 02d05f406c..042155b152 100644 --- a/api/core/model_runtime/model_providers/sagemaker/sagemaker.py +++ b/api/core/model_runtime/model_providers/sagemaker/sagemaker.py @@ -1,4 +1,6 @@ import logging +import uuid +from typing import IO, Any from core.model_runtime.model_providers.__base.model_provider import ModelProvider @@ -15,3 +17,25 @@ class SageMakerProvider(ModelProvider): :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. """ pass + + +def buffer_to_s3(s3_client: Any, file: IO[bytes], bucket: str, s3_prefix: str) -> str: + """ + return s3_uri of this file + """ + s3_key = f"{s3_prefix}{uuid.uuid4()}.mp3" + s3_client.put_object(Body=file.read(), Bucket=bucket, Key=s3_key, ContentType="audio/mp3") + return s3_key + + +def generate_presigned_url(s3_client: Any, file: IO[bytes], bucket_name: str, s3_prefix: str, expiration=600) -> str: + object_key = buffer_to_s3(s3_client, file, bucket_name, s3_prefix) + try: + response = s3_client.generate_presigned_url( + "get_object", Params={"Bucket": bucket_name, "Key": object_key}, ExpiresIn=expiration + ) + except Exception as e: + print(f"Error generating presigned URL: {e}") + return None + + return response diff --git a/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml b/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml index 290cb0edab..87cd50f50c 100644 --- a/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml +++ b/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml @@ -21,6 +21,8 @@ supported_model_types: - llm - text-embedding - rerank + - speech2text + - tts configurate_methods: - customizable-model model_credential_schema: @@ -45,14 +47,10 @@ model_credential_schema: zh_Hans: 选择对话类型 en_US: Select completion mode options: - - value: completion - label: - en_US: Completion - zh_Hans: 补全 - value: chat label: en_US: Chat - zh_Hans: 对话 + zh_Hans: Chat - variable: sagemaker_endpoint label: en_US: sagemaker endpoint @@ -61,6 +59,76 @@ model_credential_schema: placeholder: zh_Hans: 请输出你的Sagemaker推理端点 en_US: Enter your Sagemaker Inference endpoint + - variable: audio_s3_cache_bucket + show_on: + - variable: __model_type + value: speech2text + label: + zh_Hans: 音频缓存桶(s3 bucket) + en_US: audio cache bucket(s3 bucket) + type: text-input + required: true + placeholder: + zh_Hans: sagemaker-us-east-1-******207838 + en_US: sagemaker-us-east-1-*******7838 + - variable: audio_model_type + show_on: + - variable: __model_type + value: tts + label: + en_US: Audio model type + type: select + required: true + placeholder: + zh_Hans: 语音模型类型 + en_US: Audio model type + options: + - value: PresetVoice + label: + en_US: preset voice + zh_Hans: 内置音色 + - value: CloneVoice + label: + en_US: clone voice + zh_Hans: 克隆音色 + - value: CloneVoice_CrossLingual + label: + en_US: crosslingual clone voice + zh_Hans: 跨语种克隆音色 + - value: InstructVoice + label: + en_US: Instruct voice + zh_Hans: 文字指令音色 + - variable: prompt_audio + show_on: + - variable: __model_type + value: tts + label: + en_US: Mock Audio Source + type: text-input + required: false + placeholder: + zh_Hans: 被模仿的音色音频 + en_US: source audio to be mocked + - variable: prompt_text + show_on: + - variable: __model_type + value: tts + label: + en_US: Prompt Audio Text + type: text-input + required: false + placeholder: + zh_Hans: 模仿音色的对应文本 + en_US: text for the mocked source audio + - variable: instruct_text + show_on: + - variable: __model_type + value: tts + label: + en_US: instruct text for speaker + type: text-input + required: false - variable: aws_access_key_id required: false label: diff --git a/api/core/model_runtime/model_providers/sagemaker/speech2text/__init__.py b/api/core/model_runtime/model_providers/sagemaker/speech2text/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py b/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py new file mode 100644 index 0000000000..6aa8c9995f --- /dev/null +++ b/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py @@ -0,0 +1,125 @@ +import json +import logging +from typing import IO, Any, Optional + +import boto3 + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from core.model_runtime.model_providers.sagemaker.sagemaker import generate_presigned_url + +logger = logging.getLogger(__name__) + + +class SageMakerSpeech2TextModel(Speech2TextModel): + """ + Model class for Xinference speech to text model. + """ + + sagemaker_client: Any = None + s3_client: Any = None + + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: + """ + Invoke speech2text model + + :param model: model name + :param credentials: model credentials + :param file: audio file + :param user: unique user id + :return: text for given audio file + """ + asr_text = None + + try: + if not self.sagemaker_client: + access_key = credentials.get("aws_access_key_id") + secret_key = credentials.get("aws_secret_access_key") + aws_region = credentials.get("aws_region") + if aws_region: + if access_key and secret_key: + self.sagemaker_client = boto3.client( + "sagemaker-runtime", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=aws_region, + ) + self.s3_client = boto3.client( + "s3", aws_access_key_id=access_key, aws_secret_access_key=secret_key, region_name=aws_region + ) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + self.s3_client = boto3.client("s3", region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime") + self.s3_client = boto3.client("s3") + + s3_prefix = "dify/speech2text/" + sagemaker_endpoint = credentials.get("sagemaker_endpoint") + bucket = credentials.get("audio_s3_cache_bucket") + + s3_presign_url = generate_presigned_url(self.s3_client, file, bucket, s3_prefix) + payload = {"audio_s3_presign_uri": s3_presign_url} + + response_model = self.sagemaker_client.invoke_endpoint( + EndpointName=sagemaker_endpoint, Body=json.dumps(payload), ContentType="application/json" + ) + json_str = response_model["Body"].read().decode("utf8") + json_obj = json.loads(json_str) + asr_text = json_obj["text"] + except Exception as e: + logger.exception(f"Exception {e}, line : {line}") + + return asr_text + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + pass + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.SPEECH2TEXT, + model_properties={}, + parameter_rules=[], + ) + + return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py index 4b2858b1a2..d55144f8a7 100644 --- a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py @@ -10,21 +10,22 @@ from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel BATCH_SIZE = 20 -CONTEXT_SIZE=8192 +CONTEXT_SIZE = 8192 logger = logging.getLogger(__name__) + def batch_generator(generator, batch_size): while True: batch = list(itertools.islice(generator, batch_size)) @@ -32,33 +33,28 @@ def batch_generator(generator, batch_size): break yield batch + class SageMakerEmbeddingModel(TextEmbeddingModel): """ Model class for Cohere text embedding model. """ + sagemaker_client: Any = None - def _sagemaker_embedding(self, sm_client, endpoint_name, content_list:list[str]): + def _sagemaker_embedding(self, sm_client, endpoint_name, content_list: list[str]): response_model = sm_client.invoke_endpoint( EndpointName=endpoint_name, - Body=json.dumps( - { - "inputs": content_list, - "parameters": {}, - "is_query" : False, - "instruction" : '' - } - ), + Body=json.dumps({"inputs": content_list, "parameters": {}, "is_query": False, "instruction": ""}), ContentType="application/json", ) - json_str = response_model['Body'].read().decode('utf8') + json_str = response_model["Body"].read().decode("utf8") json_obj = json.loads(json_str) - embeddings = json_obj['embeddings'] + embeddings = json_obj["embeddings"] return embeddings - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -72,25 +68,27 @@ class SageMakerEmbeddingModel(TextEmbeddingModel): try: line = 1 if not self.sagemaker_client: - access_key = credentials.get('aws_access_key_id') - secret_key = credentials.get('aws_secret_access_key') - aws_region = credentials.get('aws_region') + access_key = credentials.get("aws_access_key_id") + secret_key = credentials.get("aws_secret_access_key") + aws_region = credentials.get("aws_region") if aws_region: if access_key and secret_key: - self.sagemaker_client = boto3.client("sagemaker-runtime", + self.sagemaker_client = boto3.client( + "sagemaker-runtime", aws_access_key_id=access_key, aws_secret_access_key=secret_key, - region_name=aws_region) + region_name=aws_region, + ) else: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) else: self.sagemaker_client = boto3.client("sagemaker-runtime") line = 2 - sagemaker_endpoint = credentials.get('sagemaker_endpoint') + sagemaker_endpoint = credentials.get("sagemaker_endpoint") line = 3 - truncated_texts = [ item[:CONTEXT_SIZE] for item in texts ] + truncated_texts = [item[:CONTEXT_SIZE] for item in texts] batches = batch_generator((text for text in truncated_texts), batch_size=BATCH_SIZE) all_embeddings = [] @@ -105,18 +103,14 @@ class SageMakerEmbeddingModel(TextEmbeddingModel): usage = self._calc_response_usage( model=model, credentials=credentials, - tokens=0 # It's not SAAS API, usage is meaningless + tokens=0, # It's not SAAS API, usage is meaningless ) line = 6 - return TextEmbeddingResult( - embeddings=all_embeddings, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=all_embeddings, usage=usage, model=model) except Exception as e: - logger.exception(f'Exception {e}, line : {line}') + logger.exception(f"Exception {e}, line : {line}") def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -153,10 +147,7 @@ class SageMakerEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -167,7 +158,7 @@ class SageMakerEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -175,40 +166,28 @@ class SageMakerEmbeddingModel(TextEmbeddingModel): @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ - + entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ ModelPropertyKey.CONTEXT_SIZE: CONTEXT_SIZE, ModelPropertyKey.MAX_CHUNKS: BATCH_SIZE, }, - parameter_rules=[] + parameter_rules=[], ) return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/tts/__init__.py b/api/core/model_runtime/model_providers/sagemaker/tts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/sagemaker/tts/tts.py b/api/core/model_runtime/model_providers/sagemaker/tts/tts.py new file mode 100644 index 0000000000..3dd5f8f64c --- /dev/null +++ b/api/core/model_runtime/model_providers/sagemaker/tts/tts.py @@ -0,0 +1,275 @@ +import concurrent.futures +import copy +import json +import logging +from enum import Enum +from typing import Any, Optional + +import boto3 +import requests + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.model_providers.__base.tts_model import TTSModel + +logger = logging.getLogger(__name__) + + +class TTSModelType(Enum): + PresetVoice = "PresetVoice" + CloneVoice = "CloneVoice" + CloneVoice_CrossLingual = "CloneVoice_CrossLingual" + InstructVoice = "InstructVoice" + + +class SageMakerText2SpeechModel(TTSModel): + sagemaker_client: Any = None + s3_client: Any = None + comprehend_client: Any = None + + def __init__(self): + # preset voices, need support custom voice + self.model_voices = { + "__default": { + "all": [ + {"name": "Default", "value": "default"}, + ] + }, + "CosyVoice": { + "zh-Hans": [ + {"name": "中文男", "value": "中文男"}, + {"name": "中文女", "value": "中文女"}, + {"name": "粤语女", "value": "粤语女"}, + ], + "zh-Hant": [ + {"name": "中文男", "value": "中文男"}, + {"name": "中文女", "value": "中文女"}, + {"name": "粤语女", "value": "粤语女"}, + ], + "en-US": [ + {"name": "英文男", "value": "英文男"}, + {"name": "英文女", "value": "英文女"}, + ], + "ja-JP": [ + {"name": "日语男", "value": "日语男"}, + ], + "ko-KR": [ + {"name": "韩语女", "value": "韩语女"}, + ], + }, + } + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + pass + + def _detect_lang_code(self, content: str, map_dict: dict = None): + map_dict = {"zh": "<|zh|>", "en": "<|en|>", "ja": "<|jp|>", "zh-TW": "<|yue|>", "ko": "<|ko|>"} + + response = self.comprehend_client.detect_dominant_language(Text=content) + language_code = response["Languages"][0]["LanguageCode"] + + return map_dict.get(language_code, "<|zh|>") + + def _build_tts_payload( + self, + model_type: str, + content_text: str, + model_role: str, + prompt_text: str, + prompt_audio: str, + instruct_text: str, + ): + if model_type == TTSModelType.PresetVoice.value and model_role: + return {"tts_text": content_text, "role": model_role} + if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio: + return {"tts_text": content_text, "prompt_text": prompt_text, "prompt_audio": prompt_audio} + if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio: + lang_tag = self._detect_lang_code(content_text) + return {"tts_text": f"{content_text}", "prompt_audio": prompt_audio, "lang_tag": lang_tag} + if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role: + return {"tts_text": content_text, "role": model_role, "instruct_text": instruct_text} + + raise RuntimeError(f"Invalid params for {model_type}") + + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ): + """ + _invoke text2speech model + + :param model: model name + :param tenant_id: user tenant id + :param credentials: model credentials + :param voice: model timbre + :param content_text: text content to be translated + :param user: unique user id + :return: text translated to audio file + """ + if not self.sagemaker_client: + access_key = credentials.get("aws_access_key_id") + secret_key = credentials.get("aws_secret_access_key") + aws_region = credentials.get("aws_region") + if aws_region: + if access_key and secret_key: + self.sagemaker_client = boto3.client( + "sagemaker-runtime", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=aws_region, + ) + self.s3_client = boto3.client( + "s3", aws_access_key_id=access_key, aws_secret_access_key=secret_key, region_name=aws_region + ) + self.comprehend_client = boto3.client( + "comprehend", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=aws_region, + ) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + self.s3_client = boto3.client("s3", region_name=aws_region) + self.comprehend_client = boto3.client("comprehend", region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime") + self.s3_client = boto3.client("s3") + self.comprehend_client = boto3.client("comprehend") + + model_type = credentials.get("audio_model_type", "PresetVoice") + prompt_text = credentials.get("prompt_text") + prompt_audio = credentials.get("prompt_audio") + instruct_text = credentials.get("instruct_text") + sagemaker_endpoint = credentials.get("sagemaker_endpoint") + payload = self._build_tts_payload(model_type, content_text, voice, prompt_text, prompt_audio, instruct_text) + + return self._tts_invoke_streaming(model_type, payload, sagemaker_endpoint) + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.TTS, + model_properties={}, + parameter_rules=[], + ) + + return entity + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], + } + + def _get_model_default_voice(self, model: str, credentials: dict) -> any: + return "" + + def _get_model_word_limit(self, model: str, credentials: dict) -> int: + return 15 + + def _get_model_audio_type(self, model: str, credentials: dict) -> str: + return "mp3" + + def _get_model_workers_limit(self, model: str, credentials: dict) -> int: + return 5 + + def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: + audio_model_name = "CosyVoice" + for key, voices in self.model_voices.items(): + if key in audio_model_name: + if language and language in voices: + return voices[language] + elif "all" in voices: + return voices["all"] + + return self.model_voices["__default"]["all"] + + def _invoke_sagemaker(self, payload: dict, endpoint: str): + response_model = self.sagemaker_client.invoke_endpoint( + EndpointName=endpoint, + Body=json.dumps(payload), + ContentType="application/json", + ) + json_str = response_model["Body"].read().decode("utf8") + json_obj = json.loads(json_str) + return json_obj + + def _tts_invoke_streaming(self, model_type: str, payload: dict, sagemaker_endpoint: str) -> any: + """ + _tts_invoke_streaming text2speech model + + :param model: model name + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: model timbre + :return: text translated to audio file + """ + try: + lang_tag = "" + if model_type == TTSModelType.CloneVoice_CrossLingual.value: + lang_tag = payload.pop("lang_tag") + + word_limit = self._get_model_word_limit(model="", credentials={}) + content_text = payload.get("tts_text") + if len(content_text) > word_limit: + split_sentences = self._split_text_into_sentences(content_text, max_length=word_limit) + sentences = [f"{lang_tag}{s}" for s in split_sentences if len(s)] + len_sent = len(sentences) + executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(4, len_sent)) + payloads = [copy.deepcopy(payload) for i in range(len_sent)] + for idx in range(len_sent): + payloads[idx]["tts_text"] = sentences[idx] + + futures = [ + executor.submit( + self._invoke_sagemaker, + payload=payload, + endpoint=sagemaker_endpoint, + ) + for payload in payloads + ] + + for index, future in enumerate(futures): + resp = future.result() + audio_bytes = requests.get(resp.get("s3_presign_url")).content + for i in range(0, len(audio_bytes), 1024): + yield audio_bytes[i : i + 1024] + else: + resp = self._invoke_sagemaker(payload, sagemaker_endpoint) + audio_bytes = requests.get(resp.get("s3_presign_url")).content + + for i in range(0, len(audio_bytes), 1024): + yield audio_bytes[i : i + 1024] + except Exception as ex: + raise InvokeBadRequestError(str(ex)) diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/llm.py b/api/core/model_runtime/model_providers/siliconflow/llm/llm.py index a9ce7b98c3..c1868b6ad0 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/llm.py +++ b/api/core/model_runtime/model_providers/siliconflow/llm/llm.py @@ -7,11 +7,17 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class SiliconflowLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -21,5 +27,5 @@ class SiliconflowLargeLanguageModel(OAIAPICompatLargeLanguageModel): @classmethod def _add_custom_parameters(cls, credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.siliconflow.cn/v1' + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.siliconflow.cn/v1" diff --git a/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py b/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py index 6835915816..6f652e9d52 100644 --- a/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py @@ -16,39 +16,39 @@ from core.model_runtime.model_providers.__base.rerank_model import RerankModel class SiliconflowRerankModel(RerankModel): - - def _invoke(self, model: str, credentials: dict, query: str, docs: list[str], - score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: if len(docs) == 0: return RerankResult(model=model, docs=[]) - base_url = credentials.get('base_url', 'https://api.siliconflow.cn/v1') - if base_url.endswith('/'): + base_url = credentials.get("base_url", "https://api.siliconflow.cn/v1") + if base_url.endswith("/"): base_url = base_url[:-1] try: response = httpx.post( - base_url + '/rerank', - json={ - "model": model, - "query": query, - "documents": docs, - "top_n": top_n, - "return_documents": True - }, - headers={"Authorization": f"Bearer {credentials.get('api_key')}"} + base_url + "/rerank", + json={"model": model, "query": query, "documents": docs, "top_n": top_n, "return_documents": True}, + headers={"Authorization": f"Bearer {credentials.get('api_key')}"}, ) response.raise_for_status() results = response.json() rerank_documents = [] - for result in results['results']: + for result in results["results"]: rerank_document = RerankDocument( - index=result['index'], - text=result['document']['text'], - score=result['relevance_score'], + index=result["index"], + text=result["document"]["text"], + score=result["relevance_score"], ) - if score_threshold is None or result['relevance_score'] >= score_threshold: + if score_threshold is None or result["relevance_score"] >= score_threshold: rerank_documents.append(rerank_document) return RerankResult(model=model, docs=rerank_documents) @@ -57,7 +57,6 @@ class SiliconflowRerankModel(RerankModel): def validate_credentials(self, model: str, credentials: dict) -> None: try: - self._invoke( model=model, credentials=credentials, @@ -68,7 +67,7 @@ class SiliconflowRerankModel(RerankModel): "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -83,5 +82,5 @@ class SiliconflowRerankModel(RerankModel): InvokeServerUnavailableError: [httpx.RemoteProtocolError], InvokeRateLimitError: [], InvokeAuthorizationError: [httpx.HTTPStatusError], - InvokeBadRequestError: [httpx.RequestError] - } \ No newline at end of file + InvokeBadRequestError: [httpx.RequestError], + } diff --git a/api/core/model_runtime/model_providers/siliconflow/siliconflow.py b/api/core/model_runtime/model_providers/siliconflow/siliconflow.py index dd0eea362a..e121ab8c7e 100644 --- a/api/core/model_runtime/model_providers/siliconflow/siliconflow.py +++ b/api/core/model_runtime/model_providers/siliconflow/siliconflow.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class SiliconflowProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +18,9 @@ class SiliconflowProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='deepseek-ai/DeepSeek-V2-Chat', - credentials=credentials - ) + model_instance.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py b/api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py index 6ad3cab587..8d1932863e 100644 --- a/api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py @@ -8,9 +8,7 @@ class SiliconflowSpeech2TextModel(OAICompatSpeech2TextModel): Model class for Siliconflow Speech to text model. """ - def _invoke( - self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None - ) -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model diff --git a/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py index c58765cecb..6cdf4933b4 100644 --- a/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py @@ -10,20 +10,21 @@ class SiliconflowTextEmbeddingModel(OAICompatEmbeddingModel): """ Model class for Siliconflow text embedding model. """ + def validate_credentials(self, model: str, credentials: dict) -> None: self._add_custom_parameters(credentials) super().validate_credentials(model, credentials) - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, texts, user) - + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: self._add_custom_parameters(credentials) return super().get_num_tokens(model, credentials, texts) - + @classmethod def _add_custom_parameters(cls, credentials: dict) -> None: - credentials['endpoint_url'] = 'https://api.siliconflow.cn/v1' \ No newline at end of file + credentials["endpoint_url"] = "https://api.siliconflow.cn/v1" diff --git a/api/core/model_runtime/model_providers/spark/llm/_client.py b/api/core/model_runtime/model_providers/spark/llm/_client.py index 10da265701..25223e8340 100644 --- a/api/core/model_runtime/model_providers/spark/llm/_client.py +++ b/api/core/model_runtime/model_providers/spark/llm/_client.py @@ -15,51 +15,35 @@ import websocket class SparkLLMClient: def __init__(self, model: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None): - domain = 'spark-api.xf-yun.com' - endpoint = 'chat' + domain = "spark-api.xf-yun.com" + endpoint = "chat" if api_domain: domain = api_domain - if model == 'spark-v3': - endpoint = 'multimodal' model_api_configs = { - 'spark-1.5': { - 'version': 'v1.1', - 'chat_domain': 'general' - }, - 'spark-2': { - 'version': 'v2.1', - 'chat_domain': 'generalv2' - }, - 'spark-3': { - 'version': 'v3.1', - 'chat_domain': 'generalv3' - }, - 'spark-3.5': { - 'version': 'v3.5', - 'chat_domain': 'generalv3.5' - }, - 'spark-4': { - 'version': 'v4.0', - 'chat_domain': '4.0Ultra' - } + "spark-lite": {"version": "v1.1", "chat_domain": "general"}, + "spark-pro": {"version": "v3.1", "chat_domain": "generalv3"}, + "spark-pro-128k": {"version": "pro-128k", "chat_domain": "pro-128k"}, + "spark-max": {"version": "v3.5", "chat_domain": "generalv3.5"}, + "spark-4.0-ultra": {"version": "v4.0", "chat_domain": "4.0Ultra"}, } - api_version = model_api_configs[model]['version'] + api_version = model_api_configs[model]["version"] + + self.chat_domain = model_api_configs[model]["chat_domain"] + + if model == "spark-pro-128k": + self.api_base = f"wss://{domain}/{endpoint}/{api_version}" + else: + self.api_base = f"wss://{domain}/{api_version}/{endpoint}" - self.chat_domain = model_api_configs[model]['chat_domain'] - self.api_base = f"wss://{domain}/{api_version}/{endpoint}" self.app_id = app_id self.ws_url = self.create_url( - urlparse(self.api_base).netloc, - urlparse(self.api_base).path, - self.api_base, - api_key, - api_secret + urlparse(self.api_base).netloc, urlparse(self.api_base).path, self.api_base, api_key, api_secret ) self.queue = queue.Queue() - self.blocking_message = '' + self.blocking_message = "" def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str: # generate timestamp by RFC1123 @@ -71,33 +55,29 @@ class SparkLLMClient: signature_origin += "GET " + path + " HTTP/1.1" # encrypt using hmac-sha256 - signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'), - digestmod=hashlib.sha256).digest() + signature_sha = hmac.new( + api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256 + ).digest() - signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8') + signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8") authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' - authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') + authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8") - v = { - "authorization": authorization, - "date": date, - "host": host - } + v = {"authorization": authorization, "date": date, "host": host} # generate url - url = api_base + '?' + urlencode(v) + url = api_base + "?" + urlencode(v) return url - def run(self, messages: list, user_id: str, - model_kwargs: Optional[dict] = None, streaming: bool = False): + def run(self, messages: list, user_id: str, model_kwargs: Optional[dict] = None, streaming: bool = False): websocket.enableTrace(False) ws = websocket.WebSocketApp( self.ws_url, on_message=self.on_message, on_error=self.on_error, on_close=self.on_close, - on_open=self.on_open + on_open=self.on_open, ) ws.messages = messages ws.user_id = user_id @@ -106,86 +86,71 @@ class SparkLLMClient: ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) def on_error(self, ws, error): - self.queue.put({ - 'status_code': error.status_code, - 'error': error.resp_body.decode('utf-8') - }) + self.queue.put({"status_code": error.status_code, "error": error.resp_body.decode("utf-8")}) ws.close() def on_close(self, ws, close_status_code, close_reason): - self.queue.put({'done': True}) + self.queue.put({"done": True}) def on_open(self, ws): - self.blocking_message = '' - data = json.dumps(self.gen_params( - messages=ws.messages, - user_id=ws.user_id, - model_kwargs=ws.model_kwargs - )) + self.blocking_message = "" + data = json.dumps(self.gen_params(messages=ws.messages, user_id=ws.user_id, model_kwargs=ws.model_kwargs)) ws.send(data) def on_message(self, ws, message): data = json.loads(message) - code = data['header']['code'] + code = data["header"]["code"] if code != 0: - self.queue.put({ - 'status_code': 400, - 'error': f"Code: {code}, Error: {data['header']['message']}" - }) + self.queue.put({"status_code": 400, "error": f"Code: {code}, Error: {data['header']['message']}"}) ws.close() else: choices = data["payload"]["choices"] status = choices["status"] content = choices["text"][0]["content"] if ws.streaming: - self.queue.put({'data': content}) + self.queue.put({"data": content}) else: self.blocking_message += content if status == 2: if not ws.streaming: - self.queue.put({'data': self.blocking_message}) + self.queue.put({"data": self.blocking_message}) ws.close() - def gen_params(self, messages: list, user_id: str, - model_kwargs: Optional[dict] = None) -> dict: + def gen_params(self, messages: list, user_id: str, model_kwargs: Optional[dict] = None) -> dict: data = { "header": { "app_id": self.app_id, # resolve this error message => $.header.uid' length must be less or equal than 32 - "uid": user_id[:32] if user_id else None + "uid": user_id[:32] if user_id else None, }, - "parameter": { - "chat": { - "domain": self.chat_domain - } - }, - "payload": { - "message": { - "text": messages - } - } + "parameter": {"chat": {"domain": self.chat_domain}}, + "payload": {"message": {"text": messages}}, } if model_kwargs: - data['parameter']['chat'].update(model_kwargs) + data["parameter"]["chat"].update(model_kwargs) return data def subscribe(self): while True: content = self.queue.get() - if 'error' in content: - if content['status_code'] == 401: - raise SparkError('[Spark] The credentials you provided are incorrect. ' - 'Please double-check and fill them in again.') - elif content['status_code'] == 403: - raise SparkError("[Spark] Sorry, the credentials you provided are access denied. " - "Please try again after obtaining the necessary permissions.") + if "error" in content: + if content["status_code"] == 401: + raise SparkError( + "[Spark] The credentials you provided are incorrect. " + "Please double-check and fill them in again." + ) + elif content["status_code"] == 403: + raise SparkError( + "[Spark] Sorry, the credentials you provided are access denied. " + "Please try again after obtaining the necessary permissions." + ) else: raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}") - if 'data' not in content: + if "data" not in content: break yield content diff --git a/api/core/model_runtime/model_providers/spark/llm/_position.yaml b/api/core/model_runtime/model_providers/spark/llm/_position.yaml index e49ee97db7..458397f2aa 100644 --- a/api/core/model_runtime/model_providers/spark/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/spark/llm/_position.yaml @@ -1,3 +1,8 @@ +- spark-4.0-ultra +- spark-max +- spark-pro-128k +- spark-pro +- spark-lite - spark-4 - spark-3.5 - spark-3 diff --git a/api/core/model_runtime/model_providers/spark/llm/llm.py b/api/core/model_runtime/model_providers/spark/llm/llm.py index 65beae517c..0c42acf5aa 100644 --- a/api/core/model_runtime/model_providers/spark/llm/llm.py +++ b/api/core/model_runtime/model_providers/spark/llm/llm.py @@ -25,12 +25,17 @@ from ._client import SparkLLMClient class SparkLargeLanguageModel(LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -47,8 +52,13 @@ class SparkLargeLanguageModel(LargeLanguageModel): # invoke model return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -80,15 +90,21 @@ class SparkLargeLanguageModel(LargeLanguageModel): model_parameters={ "temperature": 0.5, }, - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -103,7 +119,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): """ extra_model_kwargs = {} if stop: - extra_model_kwargs['stop_sequences'] = stop + extra_model_kwargs["stop_sequences"] = stop # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) @@ -113,21 +129,33 @@ class SparkLargeLanguageModel(LargeLanguageModel): **credentials_kwargs, ) - thread = threading.Thread(target=client.run, args=( - [{ 'role': prompt_message.role.value, 'content': prompt_message.content } for prompt_message in prompt_messages], - user, - model_parameters, - stream - )) + thread = threading.Thread( + target=client.run, + args=( + [ + {"role": prompt_message.role.value, "content": prompt_message.content} + for prompt_message in prompt_messages + ], + user, + model_parameters, + stream, + ), + ) thread.start() if stream: return self._handle_generate_stream_response(thread, model, credentials, client, prompt_messages) return self._handle_generate_response(thread, model, credentials, client, prompt_messages) - - def _handle_generate_response(self, thread: threading.Thread, model: str, credentials: dict, client: SparkLLMClient, - prompt_messages: list[PromptMessage]) -> LLMResult: + + def _handle_generate_response( + self, + thread: threading.Thread, + model: str, + credentials: dict, + client: SparkLLMClient, + prompt_messages: list[PromptMessage], + ) -> LLMResult: """ Handle llm response @@ -140,7 +168,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): for content in client.subscribe(): if isinstance(content, dict): - delta = content['data'] + delta = content["data"] else: delta = content @@ -148,9 +176,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): thread.join() # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=completion - ) + assistant_prompt_message = AssistantPromptMessage(content=completion) # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -168,9 +194,15 @@ class SparkLargeLanguageModel(LargeLanguageModel): ) return result - - def _handle_generate_stream_response(self, thread: threading.Thread, model: str, credentials: dict, client: SparkLLMClient, - prompt_messages: list[PromptMessage]) -> Generator: + + def _handle_generate_stream_response( + self, + thread: threading.Thread, + model: str, + credentials: dict, + client: SparkLLMClient, + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -183,12 +215,12 @@ class SparkLargeLanguageModel(LargeLanguageModel): """ for index, content in enumerate(client.subscribe()): if isinstance(content, dict): - delta = content['data'] + delta = content["data"] else: delta = content assistant_prompt_message = AssistantPromptMessage( - content=delta if delta else '', + content=delta if delta else "", ) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -199,11 +231,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - usage=usage - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message, usage=usage), ) thread.join() @@ -216,9 +244,9 @@ class SparkLargeLanguageModel(LargeLanguageModel): :return: """ credentials_kwargs = { - "app_id": credentials['app_id'], - "api_secret": credentials['api_secret'], - "api_key": credentials['api_key'], + "app_id": credentials["app_id"], + "api_secret": credentials["api_secret"], + "api_key": credentials["api_key"], } return credentials_kwargs @@ -244,7 +272,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): raise ValueError(f"Got unknown type {message}") return message_text - + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ Format a list of messages into a full prompt for the Anthropic model @@ -254,10 +282,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() @@ -277,5 +302,5 @@ class SparkLargeLanguageModel(LargeLanguageModel): InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], - InvokeBadRequestError: [] + InvokeBadRequestError: [], } diff --git a/api/core/model_runtime/model_providers/spark/llm/spark-1.5.yaml b/api/core/model_runtime/model_providers/spark/llm/spark-1.5.yaml index 41b8765fe6..fcd65c24e0 100644 --- a/api/core/model_runtime/model_providers/spark/llm/spark-1.5.yaml +++ b/api/core/model_runtime/model_providers/spark/llm/spark-1.5.yaml @@ -1,4 +1,5 @@ model: spark-1.5 +deprecated: true label: en_US: Spark V1.5 model_type: llm diff --git a/api/core/model_runtime/model_providers/spark/llm/spark-3.5.yaml b/api/core/model_runtime/model_providers/spark/llm/spark-3.5.yaml index 6d24932ea8..86617a53d0 100644 --- a/api/core/model_runtime/model_providers/spark/llm/spark-3.5.yaml +++ b/api/core/model_runtime/model_providers/spark/llm/spark-3.5.yaml @@ -1,4 +1,5 @@ model: spark-3.5 +deprecated: true label: en_US: Spark V3.5 model_type: llm diff --git a/api/core/model_runtime/model_providers/spark/llm/spark-3.yaml b/api/core/model_runtime/model_providers/spark/llm/spark-3.yaml index 2ef9e10f45..9f296c684d 100644 --- a/api/core/model_runtime/model_providers/spark/llm/spark-3.yaml +++ b/api/core/model_runtime/model_providers/spark/llm/spark-3.yaml @@ -1,4 +1,5 @@ model: spark-3 +deprecated: true label: en_US: Spark V3.0 model_type: llm diff --git a/api/core/model_runtime/model_providers/spark/llm/spark-4.0-ultra.yaml b/api/core/model_runtime/model_providers/spark/llm/spark-4.0-ultra.yaml new file mode 100644 index 0000000000..bbf85764f1 --- /dev/null +++ b/api/core/model_runtime/model_providers/spark/llm/spark-4.0-ultra.yaml @@ -0,0 +1,42 @@ +model: spark-4.0-ultra +label: + en_US: Spark 4.0 Ultra +model_type: llm +model_properties: + mode: chat +parameter_rules: + - name: temperature + use_template: temperature + default: 0.5 + help: + zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。 + en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question. + - name: max_tokens + use_template: max_tokens + default: 4096 + min: 1 + max: 8192 + help: + zh_Hans: 模型回答的tokens的最大长度。 + en_US: Maximum length of tokens for the model response. + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + default: 4 + min: 1 + max: 6 + help: + zh_Hans: 从 k 个候选中随机选择一个(非等概率)。 + en_US: Randomly select one from k candidates (non-equal probability). + required: false + - name: show_ref_label + label: + zh_Hans: 联网检索 + en_US: web search + type: boolean + default: false + help: + zh_Hans: 该参数仅4.0 Ultra版本支持,当设置为true时,如果输入内容触发联网检索插件,会先返回检索信源列表,然后再返回星火回复结果,否则仅返回星火回复结果 + en_US: The parameter is only supported in the 4.0 Ultra version. When set to true, if the input triggers the online search plugin, it will first return a list of search sources and then return the Spark response. Otherwise, it will only return the Spark response. diff --git a/api/core/model_runtime/model_providers/spark/llm/spark-4.yaml b/api/core/model_runtime/model_providers/spark/llm/spark-4.yaml index 4b0bf27029..4b5529e81c 100644 --- a/api/core/model_runtime/model_providers/spark/llm/spark-4.yaml +++ b/api/core/model_runtime/model_providers/spark/llm/spark-4.yaml @@ -1,4 +1,5 @@ model: spark-4 +deprecated: true label: en_US: Spark V4.0 model_type: llm diff --git a/api/core/model_runtime/model_providers/spark/llm/spark-lite.yaml b/api/core/model_runtime/model_providers/spark/llm/spark-lite.yaml new file mode 100644 index 0000000000..1f6141a816 --- /dev/null +++ b/api/core/model_runtime/model_providers/spark/llm/spark-lite.yaml @@ -0,0 +1,33 @@ +model: spark-lite +label: + en_US: Spark Lite +model_type: llm +model_properties: + mode: chat +parameter_rules: + - name: temperature + use_template: temperature + default: 0.5 + help: + zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。 + en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question. + - name: max_tokens + use_template: max_tokens + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 模型回答的tokens的最大长度。 + en_US: Maximum length of tokens for the model response. + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + default: 4 + min: 1 + max: 6 + help: + zh_Hans: 从 k 个候选中随机选择一个(非等概率)。 + en_US: Randomly select one from k candidates (non-equal probability). + required: false diff --git a/api/core/model_runtime/model_providers/spark/llm/spark-max.yaml b/api/core/model_runtime/model_providers/spark/llm/spark-max.yaml new file mode 100644 index 0000000000..71eb2b86d3 --- /dev/null +++ b/api/core/model_runtime/model_providers/spark/llm/spark-max.yaml @@ -0,0 +1,33 @@ +model: spark-max +label: + en_US: Spark Max +model_type: llm +model_properties: + mode: chat +parameter_rules: + - name: temperature + use_template: temperature + default: 0.5 + help: + zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。 + en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question. + - name: max_tokens + use_template: max_tokens + default: 4096 + min: 1 + max: 8192 + help: + zh_Hans: 模型回答的tokens的最大长度。 + en_US: Maximum length of tokens for the model response. + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + default: 4 + min: 1 + max: 6 + help: + zh_Hans: 从 k 个候选中随机选择一个(非等概率)。 + en_US: Randomly select one from k candidates (non-equal probability). + required: false diff --git a/api/core/model_runtime/model_providers/spark/llm/spark-pro-128k.yaml b/api/core/model_runtime/model_providers/spark/llm/spark-pro-128k.yaml new file mode 100644 index 0000000000..da1fead6da --- /dev/null +++ b/api/core/model_runtime/model_providers/spark/llm/spark-pro-128k.yaml @@ -0,0 +1,33 @@ +model: spark-pro-128k +label: + en_US: Spark Pro-128K +model_type: llm +model_properties: + mode: chat +parameter_rules: + - name: temperature + use_template: temperature + default: 0.5 + help: + zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。 + en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question. + - name: max_tokens + use_template: max_tokens + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 模型回答的tokens的最大长度。 + en_US: Maximum length of tokens for the model response. + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + default: 4 + min: 1 + max: 6 + help: + zh_Hans: 从 k 个候选中随机选择一个(非等概率)。 + en_US: Randomly select one from k candidates (non-equal probability). + required: false diff --git a/api/core/model_runtime/model_providers/spark/llm/spark-pro.yaml b/api/core/model_runtime/model_providers/spark/llm/spark-pro.yaml new file mode 100644 index 0000000000..9ee479f15b --- /dev/null +++ b/api/core/model_runtime/model_providers/spark/llm/spark-pro.yaml @@ -0,0 +1,33 @@ +model: spark-pro +label: + en_US: Spark Pro +model_type: llm +model_properties: + mode: chat +parameter_rules: + - name: temperature + use_template: temperature + default: 0.5 + help: + zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。 + en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question. + - name: max_tokens + use_template: max_tokens + default: 4096 + min: 1 + max: 8192 + help: + zh_Hans: 模型回答的tokens的最大长度。 + en_US: Maximum length of tokens for the model response. + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + default: 4 + min: 1 + max: 6 + help: + zh_Hans: 从 k 个候选中随机选择一个(非等概率)。 + en_US: Randomly select one from k candidates (non-equal probability). + required: false diff --git a/api/core/model_runtime/model_providers/stepfun/llm/llm.py b/api/core/model_runtime/model_providers/stepfun/llm/llm.py index 6f6ffc8faa..dab666e4d0 100644 --- a/api/core/model_runtime/model_providers/stepfun/llm/llm.py +++ b/api/core/model_runtime/model_providers/stepfun/llm/llm.py @@ -30,11 +30,17 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) self._add_function_call(model, credentials) user = user[:32] if user else None @@ -49,51 +55,51 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): model=model, label=I18nObject(en_US=model, zh_Hans=model), model_type=ModelType.LLM, - features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] - if credentials.get('function_calling_type') == 'tool_call' - else [], + features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] + if credentials.get("function_calling_type") == "tool_call" + else [], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 8000)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000)), ModelPropertyKey.MODE: LLMMode.CHAT.value, }, parameter_rules=[ ParameterRule( - name='temperature', - use_template='temperature', - label=I18nObject(en_US='Temperature', zh_Hans='温度'), + name="temperature", + use_template="temperature", + label=I18nObject(en_US="Temperature", zh_Hans="温度"), type=ParameterType.FLOAT, ), ParameterRule( - name='max_tokens', - use_template='max_tokens', + name="max_tokens", + use_template="max_tokens", default=512, min=1, - max=int(credentials.get('max_tokens', 1024)), - label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'), + max=int(credentials.get("max_tokens", 1024)), + label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"), type=ParameterType.INT, ), ParameterRule( - name='top_p', - use_template='top_p', - label=I18nObject(en_US='Top P', zh_Hans='Top P'), + name="top_p", + use_template="top_p", + label=I18nObject(en_US="Top P", zh_Hans="Top P"), type=ParameterType.FLOAT, ), - ] + ], ) def _add_custom_parameters(self, credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.stepfun.com/v1' + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.stepfun.com/v1" def _add_function_call(self, model: str, credentials: dict) -> None: model_schema = self.get_model_schema(model, credentials) - if model_schema and { - ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL - }.intersection(model_schema.features or []): - credentials['function_calling_type'] = 'tool_call' + if model_schema and {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}.intersection( + model_schema.features or [] + ): + credentials["function_calling_type"] = "tool_call" - def _convert_prompt_message_to_dict(self, message: PromptMessage,credentials: Optional[dict] = None) -> dict: + def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: Optional[dict] = None) -> dict: """ Convert PromptMessage to dict for OpenAI API format """ @@ -106,10 +112,7 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) @@ -117,7 +120,7 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): "type": "image_url", "image_url": { "url": message_content.data, - } + }, } sub_messages.append(sub_message_dict) message_dict = {"role": "user", "content": sub_messages} @@ -127,14 +130,16 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): if message.tool_calls: message_dict["tool_calls"] = [] for function_call in message.tool_calls: - message_dict["tool_calls"].append({ - "id": function_call.id, - "type": function_call.type, - "function": { - "name": function_call.function.name, - "arguments": function_call.function.arguments + message_dict["tool_calls"].append( + { + "id": function_call.id, + "type": function_call.type, + "function": { + "name": function_call.function.name, + "arguments": function_call.function.arguments, + }, } - }) + ) elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id} @@ -160,21 +165,26 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "", - arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else "" + name=response_tool_call["function"]["name"] + if response_tool_call.get("function", {}).get("name") + else "", + arguments=response_tool_call["function"]["arguments"] + if response_tool_call.get("function", {}).get("arguments") + else "", ) tool_call = AssistantPromptMessage.ToolCall( id=response_tool_call["id"] if response_tool_call.get("id") else "", type=response_tool_call["type"] if response_tool_call.get("type") else "", - function=function + function=function, ) tool_calls.append(tool_call) return tool_calls - def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -184,11 +194,12 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" chunk_index = 0 - def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \ - -> LLMResultChunk: + def create_final_llm_result_chunk( + index: int, message: AssistantPromptMessage, finish_reason: str + ) -> LLMResultChunk: # calculate num tokens prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) completion_tokens = self._num_tokens_from_string(model, full_assistant_content) @@ -199,12 +210,7 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): return LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=message, - finish_reason=finish_reason, - usage=usage - ) + delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage), ) tools_calls: list[AssistantPromptMessage.ToolCall] = [] @@ -218,9 +224,9 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None) if tool_call is None: tool_call = AssistantPromptMessage.ToolCall( - id='', - type='', - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="") + id="", + type="", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""), ) tools_calls.append(tool_call) @@ -242,9 +248,9 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"): if chunk: # ignore sse comments - if chunk.startswith(':'): + if chunk.startswith(":"): continue - decoded_chunk = chunk.strip().lstrip('data: ').lstrip() + decoded_chunk = chunk.strip().lstrip("data: ").lstrip() chunk_json = None try: chunk_json = json.loads(decoded_chunk) @@ -253,21 +259,21 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): yield create_final_llm_result_chunk( index=chunk_index + 1, message=AssistantPromptMessage(content=""), - finish_reason="Non-JSON encountered." + finish_reason="Non-JSON encountered.", ) break - if not chunk_json or len(chunk_json['choices']) == 0: + if not chunk_json or len(chunk_json["choices"]) == 0: continue - choice = chunk_json['choices'][0] - finish_reason = chunk_json['choices'][0].get('finish_reason') + choice = chunk_json["choices"][0] + finish_reason = chunk_json["choices"][0].get("finish_reason") chunk_index += 1 - if 'delta' in choice: - delta = choice['delta'] - delta_content = delta.get('content') + if "delta" in choice: + delta = choice["delta"] + delta_content = delta.get("content") - assistant_message_tool_calls = delta.get('tool_calls', None) + assistant_message_tool_calls = delta.get("tool_calls", None) # assistant_message_function_call = delta.delta.function_call # extract tool calls from response @@ -275,19 +281,18 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) increase_tool_call(tool_calls) - if delta_content is None or delta_content == '': + if delta_content is None or delta_content == "": continue # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta_content, - tool_calls=tool_calls if assistant_message_tool_calls else [] + content=delta_content, tool_calls=tool_calls if assistant_message_tool_calls else [] ) full_assistant_content += delta_content - elif 'text' in choice: - choice_text = choice.get('text', '') - if choice_text == '': + elif "text" in choice: + choice_text = choice.get("text", "") + if choice_text == "": continue # transform assistant message to prompt message @@ -303,26 +308,21 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): delta=LLMResultChunkDelta( index=chunk_index, message=assistant_prompt_message, - ) + ), ) chunk_index += 1 - + if tools_calls: yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=chunk_index, - message=AssistantPromptMessage( - tool_calls=tools_calls, - content="" - ), - ) + message=AssistantPromptMessage(tool_calls=tools_calls, content=""), + ), ) yield create_final_llm_result_chunk( - index=chunk_index, - message=AssistantPromptMessage(content=""), - finish_reason=finish_reason - ) \ No newline at end of file + index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason + ) diff --git a/api/core/model_runtime/model_providers/stepfun/stepfun.py b/api/core/model_runtime/model_providers/stepfun/stepfun.py index 50b17392b5..e1c41a9153 100644 --- a/api/core/model_runtime/model_providers/stepfun/stepfun.py +++ b/api/core/model_runtime/model_providers/stepfun/stepfun.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class StepfunProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +18,9 @@ class StepfunProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='step-1-8k', - credentials=credentials - ) + model_instance.validate_credentials(model="step-1-8k", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py b/api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py index c3e3b7c258..9fd4a45f45 100644 --- a/api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py +++ b/api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py @@ -67,10 +67,10 @@ class FlashRecognitionRequest: class FlashRecognizer: """ - reponse: + response: request_id string - status Integer - message String + status Integer + message String audio_duration Integer flash_result Result Array @@ -81,16 +81,16 @@ class FlashRecognizer: Sentence: text String - start_time Integer - end_time Integer - speaker_id Integer + start_time Integer + end_time Integer + speaker_id Integer word_list Word Array Word: - word String - start_time Integer - end_time Integer - stable_flag: Integer + word String + start_time Integer + end_time Integer + stable_flag: Integer """ def __init__(self, appid, credential): @@ -100,13 +100,13 @@ class FlashRecognizer: def _format_sign_string(self, param): signstr = "POSTasr.cloud.tencent.com/asr/flash/v1/" for t in param: - if 'appid' in t: + if "appid" in t: signstr += str(t[1]) break signstr += "?" for x in param: tmp = x - if 'appid' in x: + if "appid" in x: continue for t in tmp: signstr += str(t) @@ -121,10 +121,9 @@ class FlashRecognizer: return header def _sign(self, signstr, secret_key): - hmacstr = hmac.new(secret_key.encode('utf-8'), - signstr.encode('utf-8'), hashlib.sha1).digest() + hmacstr = hmac.new(secret_key.encode("utf-8"), signstr.encode("utf-8"), hashlib.sha1).digest() s = base64.b64encode(hmacstr) - s = s.decode('utf-8') + s = s.decode("utf-8") return s def _build_req_with_signature(self, secret_key, params, header): @@ -132,20 +131,28 @@ class FlashRecognizer: signstr = self._format_sign_string(query) signature = self._sign(signstr, secret_key) header["Authorization"] = signature - requrl = "https://" - requrl += signstr[4::] - return requrl + req_url = "https://" + req_url += signstr[4::] + return req_url def _create_query_arr(self, req): return { - 'appid': self.appid, 'secretid': self.credential.secret_id, 'timestamp': str(int(time.time())), - 'engine_type': req.engine_type, 'voice_format': req.voice_format, - 'speaker_diarization': req.speaker_diarization, 'hotword_id': req.hotword_id, - 'customization_id': req.customization_id, 'filter_dirty': req.filter_dirty, - 'filter_modal': req.filter_modal, 'filter_punc': req.filter_punc, - 'convert_num_mode': req.convert_num_mode, 'word_info': req.word_info, - 'first_channel_only': req.first_channel_only, 'reinforce_hotword': req.reinforce_hotword, - 'sentence_max_length': req.sentence_max_length + "appid": self.appid, + "secretid": self.credential.secret_id, + "timestamp": str(int(time.time())), + "engine_type": req.engine_type, + "voice_format": req.voice_format, + "speaker_diarization": req.speaker_diarization, + "hotword_id": req.hotword_id, + "customization_id": req.customization_id, + "filter_dirty": req.filter_dirty, + "filter_modal": req.filter_modal, + "filter_punc": req.filter_punc, + "convert_num_mode": req.convert_num_mode, + "word_info": req.word_info, + "first_channel_only": req.first_channel_only, + "reinforce_hotword": req.reinforce_hotword, + "sentence_max_length": req.sentence_max_length, } def recognize(self, req, data): diff --git a/api/core/model_runtime/model_providers/tencent/speech2text/speech2text.py b/api/core/model_runtime/model_providers/tencent/speech2text/speech2text.py index 00ec5aa9c8..5b427663ca 100644 --- a/api/core/model_runtime/model_providers/tencent/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/tencent/speech2text/speech2text.py @@ -18,9 +18,7 @@ from core.model_runtime.model_providers.tencent.speech2text.flash_recognizer imp class TencentSpeech2TextModel(Speech2TextModel): - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model @@ -43,7 +41,7 @@ class TencentSpeech2TextModel(Speech2TextModel): try: audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self._speech2text_invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -83,10 +81,6 @@ class TencentSpeech2TextModel(Speech2TextModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - requests.exceptions.ConnectionError - ], - InvokeAuthorizationError: [ - CredentialsValidateFailedError - ] + InvokeConnectionError: [requests.exceptions.ConnectionError], + InvokeAuthorizationError: [CredentialsValidateFailedError], } diff --git a/api/core/model_runtime/model_providers/tencent/tencent.py b/api/core/model_runtime/model_providers/tencent/tencent.py index dd9f90bb47..79c6f577b8 100644 --- a/api/core/model_runtime/model_providers/tencent/tencent.py +++ b/api/core/model_runtime/model_providers/tencent/tencent.py @@ -18,12 +18,9 @@ class TencentProvider(ModelProvider): """ try: model_instance = self.get_model_instance(ModelType.SPEECH2TEXT) - model_instance.validate_credentials( - model='tencent', - credentials=credentials - ) + model_instance.validate_credentials(model="tencent", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/togetherai/llm/llm.py b/api/core/model_runtime/model_providers/togetherai/llm/llm.py index bb802d4071..b96d43979e 100644 --- a/api/core/model_runtime/model_providers/togetherai/llm/llm.py +++ b/api/core/model_runtime/model_providers/togetherai/llm/llm.py @@ -22,16 +22,21 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _update_endpoint_url(self, credentials: dict): - credentials['endpoint_url'] = "https://api.together.xyz/v1" + credentials["endpoint_url"] = "https://api.together.xyz/v1" return credentials - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super()._invoke(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user) @@ -41,12 +46,22 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel): return super().validate_credentials(model, cred_with_endpoint) - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) - return super()._generate(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user) + return super()._generate( + model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user + ) def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) @@ -61,45 +76,45 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel): fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, features=features, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(cred_with_endpoint.get('context_size', "4096")), - ModelPropertyKey.MODE: cred_with_endpoint.get('mode'), + ModelPropertyKey.CONTEXT_SIZE: int(cred_with_endpoint.get("context_size", "4096")), + ModelPropertyKey.MODE: cred_with_endpoint.get("mode"), }, parameter_rules=[ ParameterRule( name=DefaultParameterName.TEMPERATURE.value, label=I18nObject(en_US="Temperature"), type=ParameterType.FLOAT, - default=float(cred_with_endpoint.get('temperature', 0.7)), + default=float(cred_with_endpoint.get("temperature", 0.7)), min=0, max=2, - precision=2 + precision=2, ), ParameterRule( name=DefaultParameterName.TOP_P.value, label=I18nObject(en_US="Top P"), type=ParameterType.FLOAT, - default=float(cred_with_endpoint.get('top_p', 1)), + default=float(cred_with_endpoint.get("top_p", 1)), min=0, max=1, - precision=2 + precision=2, ), ParameterRule( name=TOP_K, label=I18nObject(en_US="Top K"), type=ParameterType.INT, - default=int(cred_with_endpoint.get('top_k', 50)), + default=int(cred_with_endpoint.get("top_k", 50)), min=-2147483647, max=2147483647, - precision=0 + precision=0, ), ParameterRule( name=REPETITION_PENALTY, label=I18nObject(en_US="Repetition Penalty"), type=ParameterType.FLOAT, - default=float(cred_with_endpoint.get('repetition_penalty', 1)), + default=float(cred_with_endpoint.get("repetition_penalty", 1)), min=-3.4, max=3.4, - precision=1 + precision=1, ), ParameterRule( name=DefaultParameterName.MAX_TOKENS.value, @@ -107,46 +122,49 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel): type=ParameterType.INT, default=512, min=1, - max=int(cred_with_endpoint.get('max_tokens_to_sample', 4096)), + max=int(cred_with_endpoint.get("max_tokens_to_sample", 4096)), ), ParameterRule( name=DefaultParameterName.FREQUENCY_PENALTY.value, label=I18nObject(en_US="Frequency Penalty"), type=ParameterType.FLOAT, - default=float(credentials.get('frequency_penalty', 0)), + default=float(credentials.get("frequency_penalty", 0)), min=-2, - max=2 + max=2, ), ParameterRule( name=DefaultParameterName.PRESENCE_PENALTY.value, label=I18nObject(en_US="Presence Penalty"), type=ParameterType.FLOAT, - default=float(credentials.get('presence_penalty', 0)), + default=float(credentials.get("presence_penalty", 0)), min=-2, - max=2 + max=2, ), ], pricing=PriceConfig( - input=Decimal(cred_with_endpoint.get('input_price', 0)), - output=Decimal(cred_with_endpoint.get('output_price', 0)), - unit=Decimal(cred_with_endpoint.get('unit', 0)), - currency=cred_with_endpoint.get('currency', "USD") + input=Decimal(cred_with_endpoint.get("input_price", 0)), + output=Decimal(cred_with_endpoint.get("output_price", 0)), + unit=Decimal(cred_with_endpoint.get("unit", 0)), + currency=cred_with_endpoint.get("currency", "USD"), ), ) - if cred_with_endpoint['mode'] == 'chat': + if cred_with_endpoint["mode"] == "chat": entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value - elif cred_with_endpoint['mode'] == 'completion': + elif cred_with_endpoint["mode"] == "completion": entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value else: raise ValueError(f"Unknown completion type {cred_with_endpoint['completion_type']}") return entity - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super().get_num_tokens(model, cred_with_endpoint, prompt_messages, tools) - - diff --git a/api/core/model_runtime/model_providers/togetherai/togetherai.py b/api/core/model_runtime/model_providers/togetherai/togetherai.py index ffce4794e7..aa4100a7c9 100644 --- a/api/core/model_runtime/model_providers/togetherai/togetherai.py +++ b/api/core/model_runtime/model_providers/togetherai/togetherai.py @@ -6,6 +6,5 @@ logger = logging.getLogger(__name__) class TogetherAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/tongyi/_common.py b/api/core/model_runtime/model_providers/tongyi/_common.py index fab18b41fd..8a50c7aa05 100644 --- a/api/core/model_runtime/model_providers/tongyi/_common.py +++ b/api/core/model_runtime/model_providers/tongyi/_common.py @@ -21,7 +21,7 @@ class _CommonTongyi: @staticmethod def _to_credential_kwargs(credentials: dict) -> dict: credentials_kwargs = { - "dashscope_api_key": credentials['dashscope_api_key'], + "dashscope_api_key": credentials["dashscope_api_key"], } return credentials_kwargs @@ -51,5 +51,5 @@ class _CommonTongyi: InvalidParameter, UnsupportedModel, UnsupportedHTTPMethod, - ] + ], } diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index 4e1bb0a5a4..72c319d395 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -17,7 +17,6 @@ from dashscope.common.error import ( UnsupportedModel, ) -from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -46,11 +45,17 @@ from core.model_runtime.model_providers.__base.large_language_model import Large class TongyiLargeLanguageModel(LargeLanguageModel): tokenizers = {} - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -64,90 +69,16 @@ class TongyiLargeLanguageModel(LargeLanguageModel): :param user: unique user id :return: full response or stream response chunk generator result """ - # invoke model + # invoke model without code wrapper return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def _code_block_mode_wrapper(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, callbacks: list[Callback] = None) \ - -> LLMResult | Generator: - """ - Wrapper for code block mode - """ - block_prompts = """You should always follow the instructions and output a valid {{block}} object. -The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure -if you are not sure about the structure. - - -{{instructions}} - -You should also complete the text started with ``` but not tell ``` directly. -""" - - code_block = model_parameters.get("response_format", "") - if not code_block: - return self._invoke( - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user - ) - - model_parameters.pop("response_format") - stop = stop or [] - stop.extend(["\n```", "```\n"]) - block_prompts = block_prompts.replace("{{block}}", code_block) - - # check if there is a system message - if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): - # override the system message - prompt_messages[0] = SystemPromptMessage( - content=block_prompts - .replace("{{instructions}}", prompt_messages[0].content) - ) - else: - # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=block_prompts - .replace("{{instructions}}", f"Please output a valid {code_block} with markdown codeblocks.") - )) - - if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage): - # add ```JSON\n to the last message - prompt_messages[-1].content += f"\n```{code_block}\n" - else: - # append a user message - prompt_messages.append(UserPromptMessage( - content=f"```{code_block}\n" - )) - - response = self._invoke( - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user - ) - - if isinstance(response, Generator): - return self._code_block_mode_stream_processor_with_backtick( - model=model, - prompt_messages=prompt_messages, - input_generator=response - ) - - return response - - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -157,10 +88,10 @@ You should also complete the text started with ``` but not tell ``` directly. :param tools: tools for tool calling :return: """ - if model in ['qwen-turbo-chat', 'qwen-plus-chat']: - model = model.replace('-chat', '') - if model == 'farui-plus': - model = 'qwen-farui-plus' + if model in ["qwen-turbo-chat", "qwen-plus-chat"]: + model = model.replace("-chat", "") + if model == "farui-plus": + model = "qwen-farui-plus" if model in self.tokenizers: tokenizer = self.tokenizers[model] @@ -191,16 +122,22 @@ You should also complete the text started with ``` but not tell ``` directly. model_parameters={ "temperature": 0.5, }, - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -219,18 +156,18 @@ You should also complete the text started with ``` but not tell ``` directly. mode = self.get_model_mode(model, credentials) - if model in ['qwen-turbo-chat', 'qwen-plus-chat']: - model = model.replace('-chat', '') + if model in ["qwen-turbo-chat", "qwen-plus-chat"]: + model = model.replace("-chat", "") extra_model_kwargs = {} if tools: - extra_model_kwargs['tools'] = self._convert_tools(tools) + extra_model_kwargs["tools"] = self._convert_tools(tools) if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop params = { - 'model': model, + "model": model, **model_parameters, **credentials_kwargs, **extra_model_kwargs, @@ -238,23 +175,22 @@ You should also complete the text started with ``` but not tell ``` directly. model_schema = self.get_model_schema(model, credentials) if ModelFeature.VISION in (model_schema.features or []): - params['messages'] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages, rich_content=True) + params["messages"] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages, rich_content=True) response = MultiModalConversation.call(**params, stream=stream) else: # nothing different between chat model and completion model in tongyi - params['messages'] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages) - response = Generation.call(**params, - result_format='message', - stream=stream) + params["messages"] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages) + response = Generation.call(**params, result_format="message", stream=stream) if stream: return self._handle_generate_stream_response(model, credentials, response, prompt_messages) return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: GenerationResponse, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: GenerationResponse, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -265,9 +201,7 @@ You should also complete the text started with ``` but not tell ``` directly. :return: llm response """ if response.status_code != 200 and response.status_code != HTTPStatus.OK: - raise ServiceUnavailableError( - response.message - ) + raise ServiceUnavailableError(response.message) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( content=response.output.choices[0].message.content, @@ -286,9 +220,13 @@ You should also complete the text started with ``` but not tell ``` directly. return result - def _handle_generate_stream_response(self, model: str, credentials: dict, - responses: Generator[GenerationResponse, None, None], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + responses: Generator[GenerationResponse, None, None], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -298,7 +236,7 @@ You should also complete the text started with ``` but not tell ``` directly. :param prompt_messages: prompt messages :return: llm response chunk generator result """ - full_text = '' + full_text = "" tool_calls = [] for index, response in enumerate(responses): if response.status_code != 200 and response.status_code != HTTPStatus.OK: @@ -309,22 +247,22 @@ You should also complete the text started with ``` but not tell ``` directly. resp_finish_reason = response.output.choices[0].finish_reason - if resp_finish_reason is not None and resp_finish_reason != 'null': + if resp_finish_reason is not None and resp_finish_reason != "null": resp_content = response.output.choices[0].message.content assistant_prompt_message = AssistantPromptMessage( - content='', + content="", ) - if 'tool_calls' in response.output.choices[0].message: - tool_calls = response.output.choices[0].message['tool_calls'] + if "tool_calls" in response.output.choices[0].message: + tool_calls = response.output.choices[0].message["tool_calls"] elif resp_content: # special for qwen-vl if isinstance(resp_content, list): - resp_content = resp_content[0]['text'] + resp_content = resp_content[0]["text"] # transform assistant message to prompt message - assistant_prompt_message.content = resp_content.replace(full_text, '', 1) + assistant_prompt_message.content = resp_content.replace(full_text, "", 1) full_text = resp_content @@ -332,12 +270,11 @@ You should also complete the text started with ``` but not tell ``` directly. message_tool_calls = [] for tool_call_obj in tool_calls: message_tool_call = AssistantPromptMessage.ToolCall( - id=tool_call_obj['function']['name'], - type='function', + id=tool_call_obj["function"]["name"], + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=tool_call_obj['function']['name'], - arguments=tool_call_obj['function']['arguments'] - ) + name=tool_call_obj["function"]["name"], arguments=tool_call_obj["function"]["arguments"] + ), ) message_tool_calls.append(message_tool_call) @@ -351,26 +288,23 @@ You should also complete the text started with ``` but not tell ``` directly. model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - finish_reason=resp_finish_reason, - usage=usage - ) + index=index, message=assistant_prompt_message, finish_reason=resp_finish_reason, usage=usage + ), ) else: resp_content = response.output.choices[0].message.content if not resp_content: - if 'tool_calls' in response.output.choices[0].message: - tool_calls = response.output.choices[0].message['tool_calls'] + if "tool_calls" in response.output.choices[0].message: + tool_calls = response.output.choices[0].message["tool_calls"] continue # special for qwen-vl if isinstance(resp_content, list): - resp_content = resp_content[0]['text'] + resp_content = resp_content[0]["text"] # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=resp_content.replace(full_text, '', 1), + content=resp_content.replace(full_text, "", 1), ) full_text = resp_content @@ -378,10 +312,7 @@ You should also complete the text started with ``` but not tell ``` directly. yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) def _to_credential_kwargs(self, credentials: dict) -> dict: @@ -392,7 +323,7 @@ You should also complete the text started with ``` but not tell ``` directly. :return: """ credentials_kwargs = { - "api_key": credentials['dashscope_api_key'], + "api_key": credentials["dashscope_api_key"], } return credentials_kwargs @@ -437,16 +368,14 @@ You should also complete the text started with ``` but not tell ``` directly. """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() - def _convert_prompt_messages_to_tongyi_messages(self, prompt_messages: list[PromptMessage], - rich_content: bool = False) -> list[dict]: + def _convert_prompt_messages_to_tongyi_messages( + self, prompt_messages: list[PromptMessage], rich_content: bool = False + ) -> list[dict]: """ Convert prompt messages to tongyi messages @@ -456,24 +385,28 @@ You should also complete the text started with ``` but not tell ``` directly. tongyi_messages = [] for prompt_message in prompt_messages: if isinstance(prompt_message, SystemPromptMessage): - tongyi_messages.append({ - 'role': 'system', - 'content': prompt_message.content if not rich_content else [{"text": prompt_message.content}], - }) + tongyi_messages.append( + { + "role": "system", + "content": prompt_message.content if not rich_content else [{"text": prompt_message.content}], + } + ) elif isinstance(prompt_message, UserPromptMessage): if isinstance(prompt_message.content, str): - tongyi_messages.append({ - 'role': 'user', - 'content': prompt_message.content if not rich_content else [{"text": prompt_message.content}], - }) + tongyi_messages.append( + { + "role": "user", + "content": prompt_message.content + if not rich_content + else [{"text": prompt_message.content}], + } + ) else: sub_messages = [] for message_content in prompt_message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "text": message_content.data - } + sub_message_dict = {"text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) @@ -483,35 +416,25 @@ You should also complete the text started with ``` but not tell ``` directly. # convert image base64 data to file in /tmp image_url = self._save_base64_image_to_file(message_content.data) - sub_message_dict = { - "image": image_url - } + sub_message_dict = {"image": image_url} sub_messages.append(sub_message_dict) # resort sub_messages to ensure text is always at last - sub_messages = sorted(sub_messages, key=lambda x: 'text' in x) + sub_messages = sorted(sub_messages, key=lambda x: "text" in x) - tongyi_messages.append({ - 'role': 'user', - 'content': sub_messages - }) + tongyi_messages.append({"role": "user", "content": sub_messages}) elif isinstance(prompt_message, AssistantPromptMessage): content = prompt_message.content if not content: - content = ' ' - message = { - 'role': 'assistant', - 'content': content if not rich_content else [{"text": content}] - } + content = " " + message = {"role": "assistant", "content": content if not rich_content else [{"text": content}]} if prompt_message.tool_calls: - message['tool_calls'] = [tool_call.model_dump() for tool_call in prompt_message.tool_calls] + message["tool_calls"] = [tool_call.model_dump() for tool_call in prompt_message.tool_calls] tongyi_messages.append(message) elif isinstance(prompt_message, ToolPromptMessage): - tongyi_messages.append({ - "role": "tool", - "content": prompt_message.content, - "name": prompt_message.tool_call_id - }) + tongyi_messages.append( + {"role": "tool", "content": prompt_message.content, "name": prompt_message.tool_call_id} + ) else: raise ValueError(f"Got unknown type {prompt_message}") @@ -526,7 +449,7 @@ You should also complete the text started with ``` but not tell ``` directly. :return: image file path """ # get mime type and encoded string - mime_type, encoded_string = base64_image.split(',')[0].split(';')[0].split(':')[1], base64_image.split(',')[1] + mime_type, encoded_string = base64_image.split(",")[0].split(";")[0].split(":")[1], base64_image.split(",")[1] # save image to file temp_dir = tempfile.gettempdir() @@ -544,19 +467,18 @@ You should also complete the text started with ``` but not tell ``` directly. """ tool_definitions = [] for tool in tools: - properties = tool.parameters['properties'] - required_properties = tool.parameters['required'] + properties = tool.parameters["properties"] + required_properties = tool.parameters["required"] properties_definitions = {} for p_key, p_val in properties.items(): - desc = p_val['description'] - if 'enum' in p_val: - desc += (f"; Only accepts one of the following predefined options: " - f"[{', '.join(p_val['enum'])}]") + desc = p_val["description"] + if "enum" in p_val: + desc += f"; Only accepts one of the following predefined options: " f"[{', '.join(p_val['enum'])}]" properties_definitions[p_key] = { - 'description': desc, - 'type': p_val['type'], + "description": desc, + "type": p_val["type"], } tool_definition = { @@ -565,8 +487,8 @@ You should also complete the text started with ``` but not tell ``` directly. "name": tool.name, "description": tool.description, "parameters": properties_definitions, - "required": required_properties - } + "required": required_properties, + }, } tool_definitions.append(tool_definition) @@ -598,5 +520,5 @@ You should also complete the text started with ``` but not tell ``` directly. InvalidParameter, UnsupportedModel, UnsupportedHTTPMethod, - ] + ], } diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-long.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-long.yaml index b2cf3dd486..33b3435eb6 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-long.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-long.yaml @@ -24,7 +24,7 @@ parameter_rules: type: int default: 2000 min: 1 - max: 2000 + max: 6000 help: zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py index 97dcb72f7c..5783d2e383 100644 --- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py @@ -46,7 +46,6 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel): used_tokens = 0 for i, text in enumerate(texts): - # Here token count is only an approximation based on the GPT2 tokenizer num_tokens = self._get_num_tokens_by_gpt2(text) @@ -71,12 +70,8 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel): batched_embeddings += embeddings_batch # calc usage - usage = self._calc_response_usage( - model=model, credentials=credentials, tokens=used_tokens - ) - return TextEmbeddingResult( - embeddings=batched_embeddings, usage=usage, model=model - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) + return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -108,16 +103,12 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel): credentials_kwargs = self._to_credential_kwargs(credentials) # call embedding model - self.embed_documents( - credentials_kwargs=credentials_kwargs, model=model, texts=["ping"] - ) + self.embed_documents(credentials_kwargs=credentials_kwargs, model=model, texts=["ping"]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @staticmethod - def embed_documents( - credentials_kwargs: dict, model: str, texts: list[str] - ) -> tuple[list[list[float]], int]: + def embed_documents(credentials_kwargs: dict, model: str, texts: list[str]) -> tuple[list[list[float]], int]: """Call out to Tongyi's embedding endpoint. Args: @@ -145,7 +136,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel): raise ValueError("Embedding data is missing in the response.") else: raise ValueError("Response output is missing or does not contain embeddings.") - + if response.usage and "total_tokens" in response.usage: embedding_used_tokens += response.usage["total_tokens"] else: @@ -153,9 +144,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel): return [list(map(float, e)) for e in embeddings], embedding_used_tokens - def _calc_response_usage( - self, model: str, credentials: dict, tokens: int - ) -> EmbeddingUsage: + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage diff --git a/api/core/model_runtime/model_providers/tongyi/tongyi.py b/api/core/model_runtime/model_providers/tongyi/tongyi.py index d5e25e6ecf..a084512de9 100644 --- a/api/core/model_runtime/model_providers/tongyi/tongyi.py +++ b/api/core/model_runtime/model_providers/tongyi/tongyi.py @@ -20,12 +20,9 @@ class TongyiProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `qwen-turbo` model for validate, - model_instance.validate_credentials( - model='qwen-turbo', - credentials=credentials - ) + model_instance.validate_credentials(model="qwen-turbo", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/tongyi/tts/tts.py b/api/core/model_runtime/model_providers/tongyi/tts/tts.py index 664b02cd92..48a38897a8 100644 --- a/api/core/model_runtime/model_providers/tongyi/tts/tts.py +++ b/api/core/model_runtime/model_providers/tongyi/tts/tts.py @@ -18,8 +18,9 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): Model class for Tongyi Speech to text model. """ - def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, - user: Optional[str] = None) -> any: + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ) -> any: """ _invoke text2speech model @@ -31,14 +32,12 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): :param user: unique user id :return: text translated to audio file """ - if not voice or voice not in [d['value'] for d in - self.get_tts_model_voices(model=model, credentials=credentials)]: + if not voice or voice not in [ + d["value"] for d in self.get_tts_model_voices(model=model, credentials=credentials) + ]: voice = self._get_model_default_voice(model, credentials) - return self._tts_invoke_streaming(model=model, - credentials=credentials, - content_text=content_text, - voice=voice) + return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice) def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: """ @@ -53,14 +52,13 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): self._tts_invoke_streaming( model=model, credentials=credentials, - content_text='Hello Dify!', + content_text="Hello Dify!", voice=self._get_model_default_voice(model, credentials), ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, - voice: str) -> any: + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: """ _tts_invoke_streaming text2speech model @@ -82,15 +80,21 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): else: sentences = list(self._split_text_into_sentences(org_text=content, max_length=wl)) for sentence in sentences: - SpeechSynthesizer.call(model=v, sample_rate=16000, - api_key=api_key, - text=sentence.strip(), - callback=cb, - format=at, word_timestamp_enabled=True, - phoneme_timestamp_enabled=True) + SpeechSynthesizer.call( + model=v, + sample_rate=16000, + api_key=api_key, + text=sentence.strip(), + callback=cb, + format=at, + word_timestamp_enabled=True, + phoneme_timestamp_enabled=True, + ) - threading.Thread(target=invoke_remote, args=( - content_text, voice, credentials.get('dashscope_api_key'), callback, audio_type, word_limit)).start() + threading.Thread( + target=invoke_remote, + args=(content_text, voice, credentials.get("dashscope_api_key"), callback, audio_type, word_limit), + ).start() while True: audio = audio_queue.get() @@ -112,16 +116,18 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): :param audio_type: audio file type :return: text translated to audio file """ - response = dashscope.audio.tts.SpeechSynthesizer.call(model=voice, sample_rate=48000, - api_key=credentials.get('dashscope_api_key'), - text=sentence.strip(), - format=audio_type) + response = dashscope.audio.tts.SpeechSynthesizer.call( + model=voice, + sample_rate=48000, + api_key=credentials.get("dashscope_api_key"), + text=sentence.strip(), + format=audio_type, + ) if isinstance(response.get_audio_data(), bytes): return response.get_audio_data() class Callback(ResultCallback): - def __init__(self, queue: Queue): self._queue = queue diff --git a/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py b/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py index 95272a41c2..cf7e3f14be 100644 --- a/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py +++ b/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py @@ -33,198 +33,223 @@ from core.model_runtime.model_providers.__base.large_language_model import Large class TritonInferenceAILargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - invoke LLM + invoke LLM - see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` + see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` """ return self._generate( - model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, - tools=tools, stop=stop, stream=stream, user=user, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, ) def validate_credentials(self, model: str, credentials: dict) -> None: """ - validate credentials + validate credentials """ - if 'server_url' not in credentials: - raise CredentialsValidateFailedError('server_url is required in credentials') - + if "server_url" not in credentials: + raise CredentialsValidateFailedError("server_url is required in credentials") + try: - self._invoke(model=model, credentials=credentials, prompt_messages=[ - UserPromptMessage(content='ping') - ], model_parameters={}, stream=False) + self._invoke( + model=model, + credentials=credentials, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={}, + stream=False, + ) except InvokeError as ex: - raise CredentialsValidateFailedError(f'An error occurred during connection: {str(ex)}') + raise CredentialsValidateFailedError(f"An error occurred during connection: {str(ex)}") - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: """ - get number of tokens + get number of tokens - cause TritonInference LLM is a customized model, we could net detect which tokenizer to use - so we just take the GPT2 tokenizer as default + cause TritonInference LLM is a customized model, we could net detect which tokenizer to use + so we just take the GPT2 tokenizer as default """ return self._get_num_tokens_by_gpt2(self._convert_prompt_message_to_text(prompt_messages)) - + def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str: """ - convert prompt message to text + convert prompt message to text """ - text = '' + text = "" for item in message: if isinstance(item, UserPromptMessage): - text += f'User: {item.content}' + text += f"User: {item.content}" elif isinstance(item, SystemPromptMessage): - text += f'System: {item.content}' + text += f"System: {item.content}" elif isinstance(item, AssistantPromptMessage): - text += f'Assistant: {item.content}' + text += f"Assistant: {item.content}" else: - raise NotImplementedError(f'PromptMessage type {type(item)} is not supported') + raise NotImplementedError(f"PromptMessage type {type(item)} is not supported") return text def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ), + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, - max=int(credentials.get('context_length', 2048)), - default=min(512, int(credentials.get('context_length', 2048))), - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) - ) + max=int(credentials.get("context_length", 2048)), + default=min(512, int(credentials.get("context_length", 2048))), + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), + ), ] completion_type = None - if 'completion_type' in credentials: - if credentials['completion_type'] == 'chat': + if "completion_type" in credentials: + if credentials["completion_type"] == "chat": completion_type = LLMMode.CHAT.value - elif credentials['completion_type'] == 'completion': + elif credentials["completion_type"] == "completion": completion_type = LLMMode.COMPLETION.value else: raise ValueError(f'completion_type {credentials["completion_type"]} is not supported') - + entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), parameter_rules=rules, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, model_properties={ ModelPropertyKey.MODE: completion_type, - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_length', 2048)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_length", 2048)), }, ) return entity - - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - generate text from LLM + generate text from LLM """ - if 'server_url' not in credentials: - raise CredentialsValidateFailedError('server_url is required in credentials') - - if 'stream' in credentials and not bool(credentials['stream']) and stream: - raise ValueError(f'stream is not supported by model {model}') + if "server_url" not in credentials: + raise CredentialsValidateFailedError("server_url is required in credentials") + + if "stream" in credentials and not bool(credentials["stream"]) and stream: + raise ValueError(f"stream is not supported by model {model}") try: parameters = {} - if 'temperature' in model_parameters: - parameters['temperature'] = model_parameters['temperature'] - if 'top_p' in model_parameters: - parameters['top_p'] = model_parameters['top_p'] - if 'top_k' in model_parameters: - parameters['top_k'] = model_parameters['top_k'] - if 'presence_penalty' in model_parameters: - parameters['presence_penalty'] = model_parameters['presence_penalty'] - if 'frequency_penalty' in model_parameters: - parameters['frequency_penalty'] = model_parameters['frequency_penalty'] + if "temperature" in model_parameters: + parameters["temperature"] = model_parameters["temperature"] + if "top_p" in model_parameters: + parameters["top_p"] = model_parameters["top_p"] + if "top_k" in model_parameters: + parameters["top_k"] = model_parameters["top_k"] + if "presence_penalty" in model_parameters: + parameters["presence_penalty"] = model_parameters["presence_penalty"] + if "frequency_penalty" in model_parameters: + parameters["frequency_penalty"] = model_parameters["frequency_penalty"] - response = post(str(URL(credentials['server_url']) / 'v2' / 'models' / model / 'generate'), json={ - 'text_input': self._convert_prompt_message_to_text(prompt_messages), - 'max_tokens': model_parameters.get('max_tokens', 512), - 'parameters': { - 'stream': False, - **parameters + response = post( + str(URL(credentials["server_url"]) / "v2" / "models" / model / "generate"), + json={ + "text_input": self._convert_prompt_message_to_text(prompt_messages), + "max_tokens": model_parameters.get("max_tokens", 512), + "parameters": {"stream": False, **parameters}, }, - }, timeout=(10, 120)) + timeout=(10, 120), + ) response.raise_for_status() if response.status_code != 200: - raise InvokeBadRequestError(f'Invoke failed with status code {response.status_code}, {response.text}') - + raise InvokeBadRequestError(f"Invoke failed with status code {response.status_code}, {response.text}") + if stream: - return self._handle_chat_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages, - tools=tools, resp=response) - return self._handle_chat_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages, - tools=tools, resp=response) + return self._handle_chat_stream_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=response + ) + return self._handle_chat_generate_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=response + ) except Exception as ex: - raise InvokeConnectionError(f'An error occurred during connection: {str(ex)}') - - def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Response) -> LLMResult: + raise InvokeConnectionError(f"An error occurred during connection: {str(ex)}") + + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Response, + ) -> LLMResult: """ - handle normal chat generate response + handle normal chat generate response """ - text = resp.json()['text_output'] + text = resp.json()["text_output"] usage = LLMUsage.empty_usage() usage.prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) usage.completion_tokens = self._get_num_tokens_by_gpt2(text) return LLMResult( - model=model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=text - ), - usage=usage + model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), usage=usage ) - def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Response) -> Generator: + def _handle_chat_stream_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Response, + ) -> Generator: """ - handle normal chat generate response + handle normal chat generate response """ - text = resp.json()['text_output'] + text = resp.json()["text_output"] usage = LLMUsage.empty_usage() usage.prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -233,13 +258,7 @@ class TritonInferenceAILargeLanguageModel(LargeLanguageModel): yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage( - content=text - ), - usage=usage - ) + delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=text), usage=usage), ) @property @@ -253,15 +272,9 @@ class TritonInferenceAILargeLanguageModel(LargeLanguageModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - ], - InvokeRateLimitError: [ - ], - InvokeAuthorizationError: [ - ], - InvokeBadRequestError: [ - ValueError - ] - } \ No newline at end of file + InvokeConnectionError: [], + InvokeServerUnavailableError: [], + InvokeRateLimitError: [], + InvokeAuthorizationError: [], + InvokeBadRequestError: [ValueError], + } diff --git a/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py b/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py index 06846825ab..d85f7c82e7 100644 --- a/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py +++ b/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py @@ -4,6 +4,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) + class XinferenceAIProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/upstage/_common.py b/api/core/model_runtime/model_providers/upstage/_common.py index 13b73181e9..47ebaccd84 100644 --- a/api/core/model_runtime/model_providers/upstage/_common.py +++ b/api/core/model_runtime/model_providers/upstage/_common.py @@ -1,4 +1,3 @@ - from collections.abc import Mapping import openai @@ -20,13 +19,13 @@ class _CommonUpstage: Transform credentials to kwargs for model instance :param credentials: - :return: + :return: """ credentials_kwargs = { - "api_key": credentials['upstage_api_key'], + "api_key": credentials["upstage_api_key"], "base_url": "https://api.upstage.ai/v1/solar", "timeout": Timeout(315.0, read=300.0, write=20.0, connect=10.0), - "max_retries": 1 + "max_retries": 1, } return credentials_kwargs @@ -53,5 +52,3 @@ class _CommonUpstage: openai.APIError, ], } - - diff --git a/api/core/model_runtime/model_providers/upstage/llm/llm.py b/api/core/model_runtime/model_providers/upstage/llm/llm.py index d1ed4619d6..1014b53f39 100644 --- a/api/core/model_runtime/model_providers/upstage/llm/llm.py +++ b/api/core/model_runtime/model_providers/upstage/llm/llm.py @@ -36,15 +36,23 @@ if you are not sure about the structure. """ + class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): """ - Model class for Upstage large language model. + Model class for Upstage large language model. """ - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -67,15 +75,25 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): tools=tools, stop=stop, stream=stream, - user=user + user=user, ) - def _code_block_mode_wrapper(self, - model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> Union[LLMResult, Generator]: + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper for invoking large language model """ - if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']: + if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: stop = stop or [] self._transform_chat_json_prompts( model=model, @@ -86,9 +104,9 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): stop=stop, stream=stream, user=user, - response_format=model_parameters['response_format'] + response_format=model_parameters["response_format"], ) - model_parameters.pop('response_format') + model_parameters.pop("response_format") return self._invoke( model=model, @@ -98,15 +116,23 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): tools=tools, stop=stop, stream=stream, - user=user + user=user, ) - def _transform_chat_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') -> None: + def _transform_chat_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ - Transform json prompts + Transform json prompts """ if stop is None: stop = [] @@ -117,20 +143,29 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): prompt_messages[0] = SystemPromptMessage( - content=UPSTAGE_BLOCK_MODE_PROMPT - .replace("{{instructions}}", prompt_messages[0].content) - .replace("{{block}}", response_format) + content=UPSTAGE_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace( + "{{block}}", response_format + ) ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n")) else: - prompt_messages.insert(0, SystemPromptMessage( - content=UPSTAGE_BLOCK_MODE_PROMPT - .replace("{{instructions}}", f"Please output a valid {response_format} object.") - .replace("{{block}}", response_format) - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=UPSTAGE_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", f"Please output a valid {response_format} object." + ).replace("{{block}}", response_format) + ), + ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -155,30 +190,31 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): client = OpenAI(**credentials_kwargs) client.chat.completions.create( - messages=[{"role": "user", "content": "ping"}], - model=model, - temperature=0, - max_tokens=10, - stream=False + messages=[{"role": "user", "content": "ping"}], model=model, temperature=0, max_tokens=10, stream=False ) except Exception as e: raise CredentialsValidateFailedError(str(e)) - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: credentials_kwargs = self._to_credential_kwargs(credentials) client = OpenAI(**credentials_kwargs) extra_model_kwargs = {} if tools: - extra_model_kwargs["functions"] = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] + extra_model_kwargs["functions"] = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} for tool in tools + ] if stop: extra_model_kwargs["stop"] = stop @@ -198,10 +234,15 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): if stream: return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) - - def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> LLMResult: + + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: ChatCompletion, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> LLMResult: """ Handle llm chat response @@ -222,10 +263,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): tool_calls = [function_call] if function_call else [] # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) # calculate num tokens if response.usage: @@ -251,9 +289,14 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): return response - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: Stream[ChatCompletionChunk], - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Stream[ChatCompletionChunk], + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> Generator: """ Handle llm chat stream response @@ -263,7 +306,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): :param tools: tools for tool calling :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" delta_assistant_message_function_call_storage: Optional[ChoiceDeltaFunctionCall] = None prompt_tokens = 0 completion_tokens = 0 @@ -273,8 +316,8 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage(content=''), - ) + message=AssistantPromptMessage(content=""), + ), ) for chunk in response: @@ -288,8 +331,11 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): delta = chunk.choices[0] has_finish_reason = delta.finish_reason is not None - if not has_finish_reason and (delta.delta.content is None or delta.delta.content == '') and \ - delta.delta.function_call is None: + if ( + not has_finish_reason + and (delta.delta.content is None or delta.delta.content == "") + and delta.delta.function_call is None + ): continue # assistant_message_tool_calls = delta.delta.tool_calls @@ -311,7 +357,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): # start of stream function call delta_assistant_message_function_call_storage = assistant_message_function_call if delta_assistant_message_function_call_storage.arguments is None: - delta_assistant_message_function_call_storage.arguments = '' + delta_assistant_message_function_call_storage.arguments = "" if not has_finish_reason: continue @@ -323,11 +369,10 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=tool_calls + content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls ) - full_assistant_content += delta.delta.content if delta.delta.content else '' + full_assistant_content += delta.delta.content if delta.delta.content else "" if has_finish_reason: final_chunk = LLMResultChunk( @@ -338,7 +383,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - ) + ), ) else: yield LLMResultChunk( @@ -348,7 +393,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, - ) + ), ) if not prompt_tokens: @@ -356,8 +401,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): if not completion_tokens: full_assistant_prompt_message = AssistantPromptMessage( - content=full_assistant_content, - tool_calls=final_tool_calls + content=full_assistant_content, tool_calls=final_tool_calls ) completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message]) @@ -367,9 +411,9 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): yield final_chunk - def _extract_response_tool_calls(self, - response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls( + self, response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -380,21 +424,19 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function + id=response_tool_call.id, type=response_tool_call.type, function=function ) tool_calls.append(tool_call) return tool_calls - def _extract_response_function_call(self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \ - -> AssistantPromptMessage.ToolCall: + def _extract_response_function_call( + self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall + ) -> AssistantPromptMessage.ToolCall: """ Extract function call from response @@ -404,14 +446,11 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): tool_call = None if response_function_call: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function_call.name, - arguments=response_function_call.arguments + name=response_function_call.name, arguments=response_function_call.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call.name, - type="function", - function=function + id=response_function_call.name, type="function", function=function ) return tool_call @@ -429,19 +468,13 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) @@ -467,11 +500,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): # "content": message.content, # "tool_call_id": message.tool_call_id # } - message_dict = { - "role": "function", - "content": message.content, - "name": message.tool_call_id - } + message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id} else: raise ValueError(f"Got unknown type {message}") @@ -483,16 +512,17 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): def _get_tokenizer(self) -> Tokenizer: return Tokenizer.from_pretrained("upstage/solar-1-mini-tokenizer") - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """ Calculate num tokens for solar with Huggingface Solar tokenizer. - Solar tokenizer is opened in huggingface https://huggingface.co/upstage/solar-1-mini-tokenizer + Solar tokenizer is opened in huggingface https://huggingface.co/upstage/solar-1-mini-tokenizer """ tokenizer = self._get_tokenizer() - tokens_per_message = 5 # <|im_start|>{role}\n{message}<|im_end|> - tokens_prefix = 1 # <|startoftext|> - tokens_suffix = 3 # <|im_start|>assistant\n + tokens_per_message = 5 # <|im_start|>{role}\n{message}<|im_end|> + tokens_prefix = 1 # <|startoftext|> + tokens_suffix = 3 # <|im_start|>assistant\n num_tokens = 0 num_tokens += tokens_prefix @@ -502,10 +532,10 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text if key == "tool_calls": @@ -538,37 +568,37 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): """ num_tokens = 0 for tool in tools: - num_tokens += len(tokenizer.encode('type')) - num_tokens += len(tokenizer.encode('function')) + num_tokens += len(tokenizer.encode("type")) + num_tokens += len(tokenizer.encode("function")) # calculate num tokens for function object - num_tokens += len(tokenizer.encode('name')) + num_tokens += len(tokenizer.encode("name")) num_tokens += len(tokenizer.encode(tool.name)) - num_tokens += len(tokenizer.encode('description')) + num_tokens += len(tokenizer.encode("description")) num_tokens += len(tokenizer.encode(tool.description)) parameters = tool.parameters - num_tokens += len(tokenizer.encode('parameters')) - if 'title' in parameters: - num_tokens += len(tokenizer.encode('title')) + num_tokens += len(tokenizer.encode("parameters")) + if "title" in parameters: + num_tokens += len(tokenizer.encode("title")) num_tokens += len(tokenizer.encode(parameters.get("title"))) - num_tokens += len(tokenizer.encode('type')) + num_tokens += len(tokenizer.encode("type")) num_tokens += len(tokenizer.encode(parameters.get("type"))) - if 'properties' in parameters: - num_tokens += len(tokenizer.encode('properties')) - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += len(tokenizer.encode("properties")) + for key, value in parameters.get("properties").items(): num_tokens += len(tokenizer.encode(key)) for field_key, field_value in value.items(): num_tokens += len(tokenizer.encode(field_key)) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += len(tokenizer.encode(enum_field)) else: num_tokens += len(tokenizer.encode(field_key)) num_tokens += len(tokenizer.encode(str(field_value))) - if 'required' in parameters: - num_tokens += len(tokenizer.encode('required')) - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += len(tokenizer.encode("required")) + for required_field in parameters["required"]: num_tokens += 3 num_tokens += len(tokenizer.encode(required_field)) diff --git a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py index 05ae8665d6..edd4a36d98 100644 --- a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py @@ -18,6 +18,7 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): """ Model class for Upstage text embedding model. """ + def _get_tokenizer(self) -> Tokenizer: return Tokenizer.from_pretrained("upstage/solar-1-mini-tokenizer") @@ -53,9 +54,9 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): for i, text in enumerate(texts): token = tokenizer.encode(text, add_special_tokens=False).tokens for j in range(0, len(token), context_size): - tokens += [token[j:j+context_size]] + tokens += [token[j : j + context_size]] indices += [i] - + batched_embeddings = [] _iter = range(0, len(tokens), max_chunks) @@ -63,20 +64,20 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): embeddings_batch, embedding_used_tokens = self._embedding_invoke( model=model, client=client, - texts=tokens[i:i+max_chunks], + texts=tokens[i : i + max_chunks], extra_model_kwargs=extra_model_kwargs, ) used_tokens += embedding_used_tokens batched_embeddings += embeddings_batch - + results: list[list[list[float]]] = [[] for _ in range(len(texts))] num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))] for i in range(len(indices)): results[indices[i]].append(batched_embeddings[i]) num_tokens_in_batch[indices[i]].append(len(tokens[i])) - + for i in range(len(texts)): _result = results[i] if len(_result) == 0: @@ -91,15 +92,11 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): else: average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) embeddings[i] = (average / np.linalg.norm(average)).tolist() - - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model) - + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: tokenizer = self._get_tokenizer() """ @@ -122,7 +119,7 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): total_num_tokens += len(tokenized_text) return total_num_tokens - + def validate_credentials(self, model: str, credentials: Mapping) -> None: """ Validate model credentials @@ -137,16 +134,13 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): client = OpenAI(**credentials_kwargs) # call embedding model - self._embedding_invoke( - model=model, - client=client, - texts=['ping'], - extra_model_kwargs={} - ) + self._embedding_invoke(model=model, client=client, texts=["ping"], extra_model_kwargs={}) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - - def _embedding_invoke(self, model: str, client: OpenAI, texts: Union[list[str], str], extra_model_kwargs: dict) -> tuple[list[list[float]], int]: + + def _embedding_invoke( + self, model: str, client: OpenAI, texts: Union[list[str], str], extra_model_kwargs: dict + ) -> tuple[list[list[float]], int]: """ Invoke embedding model :param model: model name @@ -155,17 +149,19 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): :param extra_model_kwargs: extra model kwargs :return: embeddings and used tokens """ - response = client.embeddings.create( - model=model, - input=texts, - **extra_model_kwargs - ) + response = client.embeddings.create(model=model, input=texts, **extra_model_kwargs) + + if "encoding_format" in extra_model_kwargs and extra_model_kwargs["encoding_format"] == "base64": + return ( + [ + list(np.frombuffer(base64.b64decode(embedding.embedding), dtype=np.float32)) + for embedding in response.data + ], + response.usage.total_tokens, + ) - if 'encoding_format' in extra_model_kwargs and extra_model_kwargs['encoding_format'] == 'base64': - return ([list(np.frombuffer(base64.b64decode(embedding.embedding), dtype=np.float32)) for embedding in response.data], response.usage.total_tokens) - return [data.embedding for data in response.data], response.usage.total_tokens - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -176,10 +172,7 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): :return: usage """ input_price_info = self.get_price( - model=model, - credentials=credentials, - tokens=tokens, - price_type=PriceType.INPUT + model=model, credentials=credentials, tokens=tokens, price_type=PriceType.INPUT ) usage = EmbeddingUsage( @@ -189,7 +182,7 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/upstage/upstage.py b/api/core/model_runtime/model_providers/upstage/upstage.py index 56c91c0061..e45d4aae19 100644 --- a/api/core/model_runtime/model_providers/upstage/upstage.py +++ b/api/core/model_runtime/model_providers/upstage/upstage.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class UpstageProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,14 +18,10 @@ class UpstageProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model="solar-1-mini-chat", - credentials=credentials - ) + model_instance.validate_credentials(model="solar-1-mini-chat", credentials=credentials) except CredentialsValidateFailedError as e: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise e except Exception as e: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise e - diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py index 1a7368a2cf..09a7f53f28 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py @@ -49,12 +49,17 @@ logger = logging.getLogger(__name__) class VertexAiLargeLanguageModel(LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -74,8 +79,16 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): # invoke Gemini model return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate_anthropic( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke Anthropic large language model @@ -92,7 +105,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) project_id = credentials["vertex_project_id"] SCOPES = ["https://www.googleapis.com/auth/cloud-platform"] - token = '' + token = "" # get access token from service account credential if service_account_info: @@ -102,40 +115,32 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): token = credentials.token # Vertex AI Anthropic Claude3 Opus model available in us-east5 region, Sonnet and Haiku available in us-central1 region - if 'opus' or 'claude-3-5-sonnet' in model: - location = 'us-east5' + if "opus" in model or "claude-3-5-sonnet" in model: + location = "us-east5" else: - location = 'us-central1' - + location = "us-central1" + # use access token to authenticate if token: - client = AnthropicVertex( - region=location, - project_id=project_id, - access_token=token - ) + client = AnthropicVertex(region=location, project_id=project_id, access_token=token) # When access token is empty, try to use the Google Cloud VM's built-in service account or the GOOGLE_APPLICATION_CREDENTIALS environment variable else: client = AnthropicVertex( - region=location, + region=location, project_id=project_id, ) extra_model_kwargs = {} if stop: - extra_model_kwargs['stop_sequences'] = stop + extra_model_kwargs["stop_sequences"] = stop system, prompt_message_dicts = self._convert_claude_prompt_messages(prompt_messages) if system: - extra_model_kwargs['system'] = system + extra_model_kwargs["system"] = system response = client.messages.create( - model=model, - messages=prompt_message_dicts, - stream=stream, - **model_parameters, - **extra_model_kwargs + model=model, messages=prompt_message_dicts, stream=stream, **model_parameters, **extra_model_kwargs ) if stream: @@ -143,8 +148,9 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): return self._handle_claude_response(model, credentials, response, prompt_messages) - def _handle_claude_response(self, model: str, credentials: dict, response: Message, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_claude_response( + self, model: str, credentials: dict, response: Message, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm chat response @@ -156,9 +162,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): """ # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=response.content[0].text - ) + assistant_prompt_message = AssistantPromptMessage(content=response.content[0].text) # calculate num tokens if response.usage: @@ -175,16 +179,18 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): # transform response response = LLMResult( - model=response.model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage + model=response.model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage ) return response - def _handle_claude_stream_response(self, model: str, credentials: dict, response: Stream[MessageStreamEvent], - prompt_messages: list[PromptMessage], ) -> Generator: + def _handle_claude_stream_response( + self, + model: str, + credentials: dict, + response: Stream[MessageStreamEvent], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm chat stream response @@ -196,7 +202,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): """ try: - full_assistant_content = '' + full_assistant_content = "" return_model = None input_tokens = 0 output_tokens = 0 @@ -217,18 +223,16 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index + 1, - message=AssistantPromptMessage( - content='' - ), + message=AssistantPromptMessage(content=""), finish_reason=finish_reason, - usage=usage - ) + usage=usage, + ), ) elif isinstance(chunk, ContentBlockDeltaEvent): - chunk_text = chunk.delta.text if chunk.delta.text else '' + chunk_text = chunk.delta.text if chunk.delta.text else "" full_assistant_content += chunk_text assistant_prompt_message = AssistantPromptMessage( - content=chunk_text if chunk_text else '', + content=chunk_text if chunk_text else "", ) index = chunk.index yield LLMResultChunk( @@ -237,12 +241,14 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=index, message=assistant_prompt_message, - ) + ), ) except Exception as ex: raise InvokeError(str(ex)) - def _calc_claude_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage: + def _calc_claude_response_usage( + self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int + ) -> LLMUsage: """ Calculate response usage @@ -262,10 +268,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): # get completion price info completion_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.OUTPUT, - tokens=completion_tokens + model=model, credentials=credentials, price_type=PriceType.OUTPUT, tokens=completion_tokens ) # transform usage @@ -281,7 +284,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): total_tokens=prompt_tokens + completion_tokens, total_price=prompt_price_info.total_amount + completion_price_info.total_amount, currency=prompt_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -295,13 +298,13 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): first_loop = True for message in prompt_messages: if isinstance(message, SystemPromptMessage): - message.content=message.content.strip() + message.content = message.content.strip() if first_loop: - system=message.content - first_loop=False + system = message.content + first_loop = False else: - system+="\n" - system+=message.content + system += "\n" + system += message.content prompt_message_dicts = [] for message in prompt_messages: @@ -323,10 +326,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) @@ -336,7 +336,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): image_content = requests.get(message_content.data).content with Image.open(io.BytesIO(image_content)) as img: mime_type = f"image/{img.format.lower()}" - base64_data = base64.b64encode(image_content).decode('utf-8') + base64_data = base64.b64encode(image_content).decode("utf-8") except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") else: @@ -345,16 +345,14 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): base64_data = data_split[1] if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: - raise ValueError(f"Unsupported image type {mime_type}, " - f"only support image/jpeg, image/png, image/gif, and image/webp") + raise ValueError( + f"Unsupported image type {mime_type}, " + f"only support image/jpeg, image/png, image/gif, and image/webp" + ) sub_message_dict = { "type": "image", - "source": { - "type": "base64", - "media_type": mime_type, - "data": base64_data - } + "source": {"type": "base64", "media_type": mime_type, "data": base64_data}, } sub_messages.append(sub_message_dict) @@ -370,8 +368,13 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): return message_dict - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -384,7 +387,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): prompt = self._convert_messages_to_prompt(prompt_messages) return self._get_num_tokens_by_gpt2(prompt) - + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ Format a list of messages into a full prompt for the Google model @@ -394,13 +397,10 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) return text.rstrip() - + def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool: """ Convert tool messages to glm tools @@ -416,14 +416,16 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): type=glm.Type.OBJECT, properties={ key: { - 'type_': value.get('type', 'string').upper(), - 'description': value.get('description', ''), - 'enum': value.get('enum', []) - } for key, value in tool.parameters.get('properties', {}).items() + "type_": value.get("type", "string").upper(), + "description": value.get("description", ""), + "enum": value.get("enum", []), + } + for key, value in tool.parameters.get("properties", {}).items() }, - required=tool.parameters.get('required', []) + required=tool.parameters.get("required", []), ), - ) for tool in tools + ) + for tool in tools ] ) @@ -435,20 +437,25 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): :param credentials: model credentials :return: """ - + try: ping_message = SystemPromptMessage(content="ping") self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5}) - + except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None - ) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -462,7 +469,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ config_kwargs = model_parameters.copy() - config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None) + config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None) if stop: config_kwargs["stop_sequences"] = stop @@ -494,26 +501,21 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): else: history.append(content) - safety_settings={ + safety_settings = { HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, } - google_model = glm.GenerativeModel( - model_name=model, - system_instruction=system_instruction - ) + google_model = glm.GenerativeModel(model_name=model, system_instruction=system_instruction) response = google_model.generate_content( contents=history, - generation_config=glm.GenerationConfig( - **config_kwargs - ), + generation_config=glm.GenerationConfig(**config_kwargs), stream=stream, safety_settings=safety_settings, - tools=self._convert_tools_to_glm_tool(tools) if tools else None + tools=self._convert_tools_to_glm_tool(tools) if tools else None, ) if stream: @@ -521,8 +523,9 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: glm.GenerationResponse, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -533,9 +536,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): :return: llm response """ # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=response.candidates[0].content.parts[0].text - ) + assistant_prompt_message = AssistantPromptMessage(content=response.candidates[0].content.parts[0].text) # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -554,8 +555,9 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: glm.GenerationResponse, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -568,9 +570,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): index = -1 for chunk in response: for part in chunk.candidates[0].content.parts: - assistant_prompt_message = AssistantPromptMessage( - content='' - ) + assistant_prompt_message = AssistantPromptMessage(content="") if part.text: assistant_prompt_message.content += part.text @@ -579,35 +579,31 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): assistant_prompt_message.tool_calls = [ AssistantPromptMessage.ToolCall( id=part.function_call.name, - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=part.function_call.name, - arguments=json.dumps(dict(part.function_call.args.items())) - ) + arguments=json.dumps(dict(part.function_call.args.items())), + ), ) ] index += 1 - - if not hasattr(chunk, 'finish_reason') or not chunk.finish_reason: + + if not hasattr(chunk, "finish_reason") or not chunk.finish_reason: # transform assistant message to prompt message yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) else: - # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) # transform usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, @@ -615,8 +611,8 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): index=index, message=assistant_prompt_message, finish_reason=chunk.candidates[0].finish_reason, - usage=usage - ) + usage=usage, + ), ) def _convert_one_message_to_text(self, message: PromptMessage) -> str: @@ -631,9 +627,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): content = message.content if isinstance(content, list): - content = "".join( - c.data for c in content if c.type != PromptMessageContentType.IMAGE - ) + content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE) if isinstance(message, UserPromptMessage): message_text = f"{human_prompt} {content}" @@ -658,7 +652,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): if isinstance(message, UserPromptMessage): glm_content = glm.Content(role="user", parts=[]) - if (isinstance(message.content, str)): + if isinstance(message.content, str): glm_content = glm.Content(role="user", parts=[glm.Part.from_text(message.content)]) else: parts = [] @@ -666,8 +660,8 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): if c.type == PromptMessageContentType.TEXT: parts.append(glm.Part.from_text(c.data)) else: - metadata, data = c.data.split(',', 1) - mime_type = metadata.split(';', 1)[0].split(':')[1] + metadata, data = c.data.split(",", 1) + mime_type = metadata.split(";", 1)[0].split(":")[1] parts.append(glm.Part.from_data(mime_type=mime_type, data=data)) glm_content = glm.Content(role="user", parts=parts) return glm_content @@ -675,52 +669,58 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): if message.content: glm_content = glm.Content(role="model", parts=[glm.Part.from_text(message.content)]) if message.tool_calls: - glm_content = glm.Content(role="model", parts=[glm.Part.from_function_response(glm.FunctionCall( - name=message.tool_calls[0].function.name, - args=json.loads(message.tool_calls[0].function.arguments), - ))]) + glm_content = glm.Content( + role="model", + parts=[ + glm.Part.from_function_response( + glm.FunctionCall( + name=message.tool_calls[0].function.name, + args=json.loads(message.tool_calls[0].function.arguments), + ) + ) + ], + ) return glm_content elif isinstance(message, ToolPromptMessage): - glm_content = glm.Content(role="function", parts=[glm.Part(function_response=glm.FunctionResponse( - name=message.name, - response={ - "response": message.content - } - ))]) + glm_content = glm.Content( + role="function", + parts=[ + glm.Part( + function_response=glm.FunctionResponse( + name=message.name, response={"response": message.content} + ) + ) + ], + ) return glm_content else: raise ValueError(f"Got unknown type {message}") - + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ Map model invoke error to unified error - The key is the ermd = gml.GenerativeModel(model)ror type thrown to the caller - The value is the md = gml.GenerativeModel(model)error type thrown by the model, + The key is the ermd = gml.GenerativeModel(model) error type thrown to the caller + The value is the md = gml.GenerativeModel(model) error type thrown by the model, which needs to be converted into a unified error type for the caller. - :return: Invoke emd = gml.GenerativeModel(model)rror mapping + :return: Invoke emd = gml.GenerativeModel(model) error mapping """ return { - InvokeConnectionError: [ - exceptions.RetryError - ], + InvokeConnectionError: [exceptions.RetryError], InvokeServerUnavailableError: [ exceptions.ServiceUnavailable, exceptions.InternalServerError, exceptions.BadGateway, exceptions.GatewayTimeout, - exceptions.DeadlineExceeded - ], - InvokeRateLimitError: [ - exceptions.ResourceExhausted, - exceptions.TooManyRequests + exceptions.DeadlineExceeded, ], + InvokeRateLimitError: [exceptions.ResourceExhausted, exceptions.TooManyRequests], InvokeAuthorizationError: [ exceptions.Unauthenticated, exceptions.PermissionDenied, exceptions.Unauthenticated, - exceptions.Forbidden + exceptions.Forbidden, ], InvokeBadRequestError: [ exceptions.BadRequest, @@ -736,5 +736,5 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): exceptions.PreconditionFailed, exceptions.RequestRangeNotSatisfiable, exceptions.Cancelled, - ] + ], } diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py index 2404ba5894..519373a7f3 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py @@ -29,9 +29,9 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): Model class for Vertex AI text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -51,23 +51,12 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): client = VertexTextEmbeddingModel.from_pretrained(model) - embeddings_batch, embedding_used_tokens = self._embedding_invoke( - client=client, - texts=texts - ) + embeddings_batch, embedding_used_tokens = self._embedding_invoke(client=client, texts=texts) # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=embedding_used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=embedding_used_tokens) - return TextEmbeddingResult( - embeddings=embeddings_batch, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=embeddings_batch, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -115,15 +104,11 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): client = VertexTextEmbeddingModel.from_pretrained(model) # call embedding model - self._embedding_invoke( - model=model, - client=client, - texts=['ping'] - ) + self._embedding_invoke(model=model, client=client, texts=["ping"]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _embedding_invoke(self, client: VertexTextEmbeddingModel, texts: list[str]) -> [list[float], int]: # type: ignore + def _embedding_invoke(self, client: VertexTextEmbeddingModel, texts: list[str]) -> [list[float], int]: # type: ignore """ Invoke embedding model @@ -154,10 +139,7 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -168,14 +150,14 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage - + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, @@ -183,15 +165,15 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity diff --git a/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py index 3cbfb088d1..466a86fd36 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py +++ b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py @@ -20,12 +20,9 @@ class VertexAiProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `gemini-1.0-pro-002` model for validate, - model_instance.validate_credentials( - model='gemini-1.0-pro-002', - credentials=credentials - ) + model_instance.validate_credentials(model="gemini-1.0-pro-002", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/volcengine_maas/client.py b/api/core/model_runtime/model_providers/volcengine_maas/client.py index 5100494e58..d6f1356651 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/client.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/client.py @@ -32,6 +32,9 @@ from core.model_runtime.entities.message_entities import ( UserPromptMessage, ) +DEFAULT_V2_ENDPOINT = "maas-api.ml-platform-cn-beijing.volces.com" +DEFAULT_V3_ENDPOINT = "https://ark.cn-beijing.volces.com/api/v3" + class ArkClientV3: endpoint_id: Optional[str] = None @@ -43,33 +46,49 @@ class ArkClientV3: @staticmethod def is_legacy(credentials: dict) -> bool: + # match default v2 endpoint if ArkClientV3.is_compatible_with_legacy(credentials): return False - sdk_version = credentials.get("sdk_version", "v2") - return sdk_version != "v3" + # match default v3 endpoint + if credentials.get("api_endpoint_host") == DEFAULT_V3_ENDPOINT: + return False + # only v3 support api_key + if credentials.get("auth_method") == "api_key": + return False + # these cases are considered as sdk v2 + # - modified default v2 endpoint + # - modified default v3 endpoint and auth without api_key + return True @staticmethod def is_compatible_with_legacy(credentials: dict) -> bool: - sdk_version = credentials.get("sdk_version") endpoint = credentials.get("api_endpoint_host") - return sdk_version is None and endpoint == "maas-api.ml-platform-cn-beijing.volces.com" + return endpoint == DEFAULT_V2_ENDPOINT @classmethod def from_credentials(cls, credentials): """Initialize the client using the credentials provided.""" args = { - "base_url": credentials['api_endpoint_host'], - "region": credentials['volc_region'], - "ak": credentials['volc_access_key_id'], - "sk": credentials['volc_secret_access_key'], + "base_url": credentials["api_endpoint_host"], + "region": credentials["volc_region"], } - if cls.is_compatible_with_legacy(credentials): - args["base_url"] = "https://ark.cn-beijing.volces.com/api/v3" + if credentials.get("auth_method") == "api_key": + args = { + **args, + "api_key": credentials["volc_api_key"], + } + else: + args = { + **args, + "ak": credentials["volc_access_key_id"], + "sk": credentials["volc_secret_access_key"], + } - client = ArkClientV3( - **args - ) - client.endpoint_id = credentials['endpoint_id'] + if cls.is_compatible_with_legacy(credentials): + args = {**args, "base_url": DEFAULT_V3_ENDPOINT} + + client = ArkClientV3(**args) + client.endpoint_id = credentials["endpoint_id"] return client @staticmethod @@ -83,54 +102,48 @@ class ArkClientV3: content = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: - content.append(ChatCompletionContentPartTextParam( - text=message_content.text, - type='text', - )) + content.append( + ChatCompletionContentPartTextParam( + text=message_content.text, + type="text", + ) + ) elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast( - ImagePromptMessageContent, message_content) - image_data = re.sub( - r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data) - content.append(ChatCompletionContentPartImageParam( - image_url=ImageURL( - url=image_data, - detail=message_content.detail.value, - ), - type='image_url', - )) - message_dict = ChatCompletionUserMessageParam( - role='user', - content=content - ) + message_content = cast(ImagePromptMessageContent, message_content) + image_data = re.sub(r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data) + content.append( + ChatCompletionContentPartImageParam( + image_url=ImageURL( + url=image_data, + detail=message_content.detail.value, + ), + type="image_url", + ) + ) + message_dict = ChatCompletionUserMessageParam(role="user", content=content) elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) message_dict = ChatCompletionAssistantMessageParam( content=message.content, - role='assistant', - tool_calls=None if not message.tool_calls else [ + role="assistant", + tool_calls=None + if not message.tool_calls + else [ ChatCompletionMessageToolCallParam( id=call.id, - function=Function( - name=call.function.name, - arguments=call.function.arguments - ), - type='function' - ) for call in message.tool_calls - ] + function=Function(name=call.function.name, arguments=call.function.arguments), + type="function", + ) + for call in message.tool_calls + ], ) elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - message_dict = ChatCompletionSystemMessageParam( - content=message.content, - role='system' - ) + message_dict = ChatCompletionSystemMessageParam(content=message.content, role="system") elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) message_dict = ChatCompletionToolMessageParam( - content=message.content, - role='tool', - tool_call_id=message.tool_call_id + content=message.content, role="tool", tool_call_id=message.tool_call_id ) else: raise ValueError(f"Got unknown PromptMessage type {message}") @@ -140,23 +153,25 @@ class ArkClientV3: @staticmethod def _convert_tool_prompt(message: PromptMessageTool) -> ChatCompletionToolParam: return ChatCompletionToolParam( - type='function', + type="function", function=FunctionDefinition( name=message.name, description=message.description, parameters=message.parameters, - ) + ), ) - def chat(self, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, - frequency_penalty: Optional[float] = None, - max_tokens: Optional[int] = None, - presence_penalty: Optional[float] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None, - ) -> ChatCompletion: + def chat( + self, + messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + frequency_penalty: Optional[float] = None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + ) -> ChatCompletion: """Block chat""" return self.ark.chat.completions.create( model=self.endpoint_id, @@ -170,15 +185,17 @@ class ArkClientV3: temperature=temperature, ) - def stream_chat(self, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, - frequency_penalty: Optional[float] = None, - max_tokens: Optional[int] = None, - presence_penalty: Optional[float] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None, - ) -> Generator[ChatCompletionChunk]: + def stream_chat( + self, + messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + frequency_penalty: Optional[float] = None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + ) -> Generator[ChatCompletionChunk]: """Stream chat""" chunks = self.ark.chat.completions.create( stream=True, diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py index 1978c11680..266f1216f8 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py @@ -13,7 +13,7 @@ from core.model_runtime.entities.message_entities import ( UserPromptMessage, ) from core.model_runtime.model_providers.volcengine_maas.legacy.errors import wrap_error -from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasException, MaasService +from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasError, MaasService class MaaSClient(MaasService): @@ -25,12 +25,12 @@ class MaaSClient(MaasService): self.endpoint_id = endpoint_id @classmethod - def from_credential(cls, credentials: dict) -> 'MaaSClient': - host = credentials['api_endpoint_host'] - region = credentials['volc_region'] - ak = credentials['volc_access_key_id'] - sk = credentials['volc_secret_access_key'] - endpoint_id = credentials['endpoint_id'] + def from_credential(cls, credentials: dict) -> "MaaSClient": + host = credentials["api_endpoint_host"] + region = credentials["volc_region"] + ak = credentials["volc_access_key_id"] + sk = credentials["volc_secret_access_key"] + endpoint_id = credentials["endpoint_id"] client = cls(host, region) client.set_endpoint_id(endpoint_id) @@ -40,8 +40,8 @@ class MaaSClient(MaasService): def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict: req = { - 'parameters': params, - 'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages], + "parameters": params, + "messages": [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages], **extra_model_kwargs, } if not stream: @@ -55,9 +55,7 @@ class MaaSClient(MaasService): ) def embeddings(self, texts: list[str]) -> dict: - req = { - 'input': texts - } + req = {"input": texts} return super().embeddings(self.endpoint_id, req) @staticmethod @@ -65,49 +63,40 @@ class MaaSClient(MaasService): if isinstance(message, UserPromptMessage): message = cast(UserPromptMessage, message) if isinstance(message.content, str): - message_dict = {"role": ChatRole.USER, - "content": message.content} + message_dict = {"role": ChatRole.USER, "content": message.content} else: content = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: - raise ValueError( - 'Content object type only support image_url') + raise ValueError("Content object type only support image_url") elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast( - ImagePromptMessageContent, message_content) - image_data = re.sub( - r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data) - content.append({ - 'type': 'image_url', - 'image_url': { - 'url': '', - 'image_bytes': image_data, - 'detail': message_content.detail, + message_content = cast(ImagePromptMessageContent, message_content) + image_data = re.sub(r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data) + content.append( + { + "type": "image_url", + "image_url": { + "url": "", + "image_bytes": image_data, + "detail": message_content.detail, + }, } - }) + ) - message_dict = {'role': ChatRole.USER, 'content': content} + message_dict = {"role": ChatRole.USER, "content": content} elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) - message_dict = {'role': ChatRole.ASSISTANT, - 'content': message.content} + message_dict = {"role": ChatRole.ASSISTANT, "content": message.content} if message.tool_calls: - message_dict['tool_calls'] = [ - { - 'name': call.function.name, - 'arguments': call.function.arguments - } for call in message.tool_calls + message_dict["tool_calls"] = [ + {"name": call.function.name, "arguments": call.function.arguments} for call in message.tool_calls ] elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - message_dict = {'role': ChatRole.SYSTEM, - 'content': message.content} + message_dict = {"role": ChatRole.SYSTEM, "content": message.content} elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) - message_dict = {'role': ChatRole.FUNCTION, - 'content': message.content, - 'name': message.tool_call_id} + message_dict = {"role": ChatRole.FUNCTION, "content": message.content, "name": message.tool_call_id} else: raise ValueError(f"Got unknown PromptMessage type {message}") @@ -117,7 +106,7 @@ class MaaSClient(MaasService): def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator: try: resp = fn() - except MaasException as e: + except MaasError as e: raise wrap_error(e) return resp @@ -130,5 +119,5 @@ class MaaSClient(MaasService): "name": tool.name, "description": tool.description, "parameters": tool.parameters, - } + }, } diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py index 21ffaf1258..91dbe21a61 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py @@ -1,144 +1,144 @@ -from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasException +from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasError -class ClientSDKRequestError(MaasException): +class ClientSDKRequestError(MaasError): pass -class SignatureDoesNotMatch(MaasException): +class SignatureDoesNotMatchError(MaasError): pass -class RequestTimeout(MaasException): +class RequestTimeoutError(MaasError): pass -class ServiceConnectionTimeout(MaasException): +class ServiceConnectionTimeoutError(MaasError): pass -class MissingAuthenticationHeader(MaasException): +class MissingAuthenticationHeaderError(MaasError): pass -class AuthenticationHeaderIsInvalid(MaasException): +class AuthenticationHeaderIsInvalidError(MaasError): pass -class InternalServiceError(MaasException): +class InternalServiceError(MaasError): pass -class MissingParameter(MaasException): +class MissingParameterError(MaasError): pass -class InvalidParameter(MaasException): +class InvalidParameterError(MaasError): pass -class AuthenticationExpire(MaasException): +class AuthenticationExpireError(MaasError): pass -class EndpointIsInvalid(MaasException): +class EndpointIsInvalidError(MaasError): pass -class EndpointIsNotEnable(MaasException): +class EndpointIsNotEnableError(MaasError): pass -class ModelNotSupportStreamMode(MaasException): +class ModelNotSupportStreamModeError(MaasError): pass -class ReqTextExistRisk(MaasException): +class ReqTextExistRiskError(MaasError): pass -class RespTextExistRisk(MaasException): +class RespTextExistRiskError(MaasError): pass -class EndpointRateLimitExceeded(MaasException): +class EndpointRateLimitExceededError(MaasError): pass -class ServiceConnectionRefused(MaasException): +class ServiceConnectionRefusedError(MaasError): pass -class ServiceConnectionClosed(MaasException): +class ServiceConnectionClosedError(MaasError): pass -class UnauthorizedUserForEndpoint(MaasException): +class UnauthorizedUserForEndpointError(MaasError): pass -class InvalidEndpointWithNoURL(MaasException): +class InvalidEndpointWithNoURLError(MaasError): pass -class EndpointAccountRpmRateLimitExceeded(MaasException): +class EndpointAccountRpmRateLimitExceededError(MaasError): pass -class EndpointAccountTpmRateLimitExceeded(MaasException): +class EndpointAccountTpmRateLimitExceededError(MaasError): pass -class ServiceResourceWaitQueueFull(MaasException): +class ServiceResourceWaitQueueFullError(MaasError): pass -class EndpointIsPending(MaasException): +class EndpointIsPendingError(MaasError): pass -class ServiceNotOpen(MaasException): +class ServiceNotOpenError(MaasError): pass AuthErrors = { - 'SignatureDoesNotMatch': SignatureDoesNotMatch, - 'MissingAuthenticationHeader': MissingAuthenticationHeader, - 'AuthenticationHeaderIsInvalid': AuthenticationHeaderIsInvalid, - 'AuthenticationExpire': AuthenticationExpire, - 'UnauthorizedUserForEndpoint': UnauthorizedUserForEndpoint, + "SignatureDoesNotMatch": SignatureDoesNotMatchError, + "MissingAuthenticationHeader": MissingAuthenticationHeaderError, + "AuthenticationHeaderIsInvalid": AuthenticationHeaderIsInvalidError, + "AuthenticationExpire": AuthenticationExpireError, + "UnauthorizedUserForEndpoint": UnauthorizedUserForEndpointError, } BadRequestErrors = { - 'MissingParameter': MissingParameter, - 'InvalidParameter': InvalidParameter, - 'EndpointIsInvalid': EndpointIsInvalid, - 'EndpointIsNotEnable': EndpointIsNotEnable, - 'ModelNotSupportStreamMode': ModelNotSupportStreamMode, - 'ReqTextExistRisk': ReqTextExistRisk, - 'RespTextExistRisk': RespTextExistRisk, - 'InvalidEndpointWithNoURL': InvalidEndpointWithNoURL, - 'ServiceNotOpen': ServiceNotOpen, + "MissingParameter": MissingParameterError, + "InvalidParameter": InvalidParameterError, + "EndpointIsInvalid": EndpointIsInvalidError, + "EndpointIsNotEnable": EndpointIsNotEnableError, + "ModelNotSupportStreamMode": ModelNotSupportStreamModeError, + "ReqTextExistRisk": ReqTextExistRiskError, + "RespTextExistRisk": RespTextExistRiskError, + "InvalidEndpointWithNoURL": InvalidEndpointWithNoURLError, + "ServiceNotOpen": ServiceNotOpenError, } RateLimitErrors = { - 'EndpointRateLimitExceeded': EndpointRateLimitExceeded, - 'EndpointAccountRpmRateLimitExceeded': EndpointAccountRpmRateLimitExceeded, - 'EndpointAccountTpmRateLimitExceeded': EndpointAccountTpmRateLimitExceeded, + "EndpointRateLimitExceeded": EndpointRateLimitExceededError, + "EndpointAccountRpmRateLimitExceeded": EndpointAccountRpmRateLimitExceededError, + "EndpointAccountTpmRateLimitExceeded": EndpointAccountTpmRateLimitExceededError, } ServerUnavailableErrors = { - 'InternalServiceError': InternalServiceError, - 'EndpointIsPending': EndpointIsPending, - 'ServiceResourceWaitQueueFull': ServiceResourceWaitQueueFull, + "InternalServiceError": InternalServiceError, + "EndpointIsPending": EndpointIsPendingError, + "ServiceResourceWaitQueueFull": ServiceResourceWaitQueueFullError, } ConnectionErrors = { - 'ClientSDKRequestError': ClientSDKRequestError, - 'RequestTimeout': RequestTimeout, - 'ServiceConnectionTimeout': ServiceConnectionTimeout, - 'ServiceConnectionRefused': ServiceConnectionRefused, - 'ServiceConnectionClosed': ServiceConnectionClosed, + "ClientSDKRequestError": ClientSDKRequestError, + "RequestTimeout": RequestTimeoutError, + "ServiceConnectionTimeout": ServiceConnectionTimeoutError, + "ServiceConnectionRefused": ServiceConnectionRefusedError, + "ServiceConnectionClosed": ServiceConnectionClosedError, } ErrorCodeMap = { @@ -150,7 +150,7 @@ ErrorCodeMap = { } -def wrap_error(e: MaasException) -> Exception: +def wrap_error(e: MaasError) -> Exception: if ErrorCodeMap.get(e.code): return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id) return e diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py index 64f342f16e..8b3eb157be 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py @@ -1,4 +1,4 @@ from .common import ChatRole -from .maas import MaasException, MaasService +from .maas import MaasError, MaasService -__all__ = ['MaasService', 'ChatRole', 'MaasException'] +__all__ = ["MaasService", "ChatRole", "MaasError"] diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py index 053432a089..7435720252 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py @@ -8,12 +8,12 @@ from .util import Util class MetaData: def __init__(self): - self.algorithm = '' - self.credential_scope = '' - self.signed_headers = '' - self.date = '' - self.region = '' - self.service = '' + self.algorithm = "" + self.credential_scope = "" + self.signed_headers = "" + self.date = "" + self.region = "" + self.service = "" def set_date(self, date): self.date = date @@ -36,23 +36,23 @@ class MetaData: class SignResult: def __init__(self): - self.xdate = '' - self.xCredential = '' - self.xAlgorithm = '' - self.xSignedHeaders = '' - self.xSignedQueries = '' - self.xSignature = '' - self.xContextSha256 = '' - self.xSecurityToken = '' + self.xdate = "" + self.xCredential = "" + self.xAlgorithm = "" + self.xSignedHeaders = "" + self.xSignedQueries = "" + self.xSignature = "" + self.xContextSha256 = "" + self.xSecurityToken = "" - self.authorization = '' + self.authorization = "" def __str__(self): - return '\n'.join(['{}:{}'.format(*item) for item in self.__dict__.items()]) + return "\n".join(["{}:{}".format(*item) for item in self.__dict__.items()]) class Credentials: - def __init__(self, ak, sk, service, region, session_token=''): + def __init__(self, ak, sk, service, region, session_token=""): self.ak = ak self.sk = sk self.service = service @@ -72,73 +72,88 @@ class Credentials: class Signer: @staticmethod def sign(request, credentials): - if request.path == '': - request.path = '/' - if request.method != 'GET' and not ('Content-Type' in request.headers): - request.headers['Content-Type'] = 'application/x-www-form-urlencoded; charset=utf-8' + if request.path == "": + request.path = "/" + if request.method != "GET" and "Content-Type" not in request.headers: + request.headers["Content-Type"] = "application/x-www-form-urlencoded; charset=utf-8" format_date = Signer.get_current_format_date() - request.headers['X-Date'] = format_date - if credentials.session_token != '': - request.headers['X-Security-Token'] = credentials.session_token + request.headers["X-Date"] = format_date + if credentials.session_token != "": + request.headers["X-Security-Token"] = credentials.session_token md = MetaData() - md.set_algorithm('HMAC-SHA256') + md.set_algorithm("HMAC-SHA256") md.set_service(credentials.service) md.set_region(credentials.region) md.set_date(format_date[:8]) hashed_canon_req = Signer.hashed_canonical_request_v4(request, md) - md.set_credential_scope('/'.join([md.date, md.region, md.service, 'request'])) + md.set_credential_scope("/".join([md.date, md.region, md.service, "request"])) - signing_str = '\n'.join([md.algorithm, format_date, md.credential_scope, hashed_canon_req]) + signing_str = "\n".join([md.algorithm, format_date, md.credential_scope, hashed_canon_req]) signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service) sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str)) - request.headers['Authorization'] = Signer.build_auth_header_v4(sign, md, credentials) + request.headers["Authorization"] = Signer.build_auth_header_v4(sign, md, credentials) return @staticmethod def hashed_canonical_request_v4(request, meta): body_hash = Util.sha256(request.body) - request.headers['X-Content-Sha256'] = body_hash + request.headers["X-Content-Sha256"] = body_hash signed_headers = {} for key in request.headers: - if key in ['Content-Type', 'Content-Md5', 'Host'] or key.startswith('X-'): + if key in ["Content-Type", "Content-Md5", "Host"] or key.startswith("X-"): signed_headers[key.lower()] = request.headers[key] - if 'host' in signed_headers: - v = signed_headers['host'] - if v.find(':') != -1: - split = v.split(':') + if "host" in signed_headers: + v = signed_headers["host"] + if v.find(":") != -1: + split = v.split(":") port = split[1] - if str(port) == '80' or str(port) == '443': - signed_headers['host'] = split[0] + if str(port) == "80" or str(port) == "443": + signed_headers["host"] = split[0] - signed_str = '' + signed_str = "" for key in sorted(signed_headers.keys()): - signed_str += key + ':' + signed_headers[key] + '\n' + signed_str += key + ":" + signed_headers[key] + "\n" - meta.set_signed_headers(';'.join(sorted(signed_headers.keys()))) + meta.set_signed_headers(";".join(sorted(signed_headers.keys()))) - canonical_request = '\n'.join( - [request.method, Util.norm_uri(request.path), Util.norm_query(request.query), signed_str, - meta.signed_headers, body_hash]) + canonical_request = "\n".join( + [ + request.method, + Util.norm_uri(request.path), + Util.norm_query(request.query), + signed_str, + meta.signed_headers, + body_hash, + ] + ) return Util.sha256(canonical_request) @staticmethod def get_signing_secret_key_v4(sk, date, region, service): - date = Util.hmac_sha256(bytes(sk, encoding='utf-8'), date) + date = Util.hmac_sha256(bytes(sk, encoding="utf-8"), date) region = Util.hmac_sha256(date, region) service = Util.hmac_sha256(region, service) - return Util.hmac_sha256(service, 'request') + return Util.hmac_sha256(service, "request") @staticmethod def build_auth_header_v4(signature, meta, credentials): - credential = credentials.ak + '/' + meta.credential_scope - return meta.algorithm + ' Credential=' + credential + ', SignedHeaders=' + meta.signed_headers + ', Signature=' + signature + credential = credentials.ak + "/" + meta.credential_scope + return ( + meta.algorithm + + " Credential=" + + credential + + ", SignedHeaders=" + + meta.signed_headers + + ", Signature=" + + signature + ) @staticmethod def get_current_format_date(): - return datetime.datetime.now(tz=pytz.timezone('UTC')).strftime("%Y%m%dT%H%M%SZ") + return datetime.datetime.now(tz=pytz.timezone("UTC")).strftime("%Y%m%dT%H%M%SZ") diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py index 7271ae63fd..33c41f3eb3 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py @@ -6,7 +6,7 @@ import requests from .auth import Signer -VERSION = 'v1.0.137' +VERSION = "v1.0.137" class Service: @@ -31,7 +31,7 @@ class Service: self.service_info.scheme = scheme def get(self, api, params, doseq=0): - if not (api in self.api_info): + if api not in self.api_info: raise Exception("no such api") api_info = self.api_info[api] @@ -40,52 +40,61 @@ class Service: Signer.sign(r, self.service_info.credentials) url = r.build(doseq) - resp = self.session.get(url, headers=r.headers, - timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout)) + resp = self.session.get( + url, headers=r.headers, timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout) + ) if resp.status_code == 200: return resp.text else: raise Exception(resp.text) def post(self, api, params, form): - if not (api in self.api_info): + if api not in self.api_info: raise Exception("no such api") api_info = self.api_info[api] r = self.prepare_request(api_info, params) - r.headers['Content-Type'] = 'application/x-www-form-urlencoded' + r.headers["Content-Type"] = "application/x-www-form-urlencoded" r.form = self.merge(api_info.form, form) r.body = urlencode(r.form, True) Signer.sign(r, self.service_info.credentials) url = r.build() - resp = self.session.post(url, headers=r.headers, data=r.form, - timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout)) + resp = self.session.post( + url, + headers=r.headers, + data=r.form, + timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout), + ) if resp.status_code == 200: return resp.text else: raise Exception(resp.text) def json(self, api, params, body): - if not (api in self.api_info): + if api not in self.api_info: raise Exception("no such api") api_info = self.api_info[api] r = self.prepare_request(api_info, params) - r.headers['Content-Type'] = 'application/json' + r.headers["Content-Type"] = "application/json" r.body = body Signer.sign(r, self.service_info.credentials) url = r.build() - resp = self.session.post(url, headers=r.headers, data=r.body, - timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout)) + resp = self.session.post( + url, + headers=r.headers, + data=r.body, + timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout), + ) if resp.status_code == 200: return json.dumps(resp.json()) else: raise Exception(resp.text.encode("utf-8")) def put(self, url, file_path, headers): - with open(file_path, 'rb') as f: + with open(file_path, "rb") as f: resp = self.session.put(url, headers=headers, data=f) if resp.status_code == 200: return True, resp.text.encode("utf-8") @@ -105,7 +114,7 @@ class Service: params[key] = str(params[key]) elif type(params[key]) == list: if not doseq: - params[key] = ','.join(params[key]) + params[key] = ",".join(params[key]) connection_timeout = self.service_info.connection_timeout socket_timeout = self.service_info.socket_timeout @@ -117,8 +126,8 @@ class Service: r.set_socket_timeout(socket_timeout) headers = self.merge(api_info.header, self.service_info.header) - headers['Host'] = self.service_info.host - headers['User-Agent'] = 'volc-sdk-python/' + VERSION + headers["Host"] = self.service_info.host + headers["User-Agent"] = "volc-sdk-python/" + VERSION r.set_headers(headers) query = self.merge(api_info.query, params) @@ -143,13 +152,13 @@ class Service: class Request: def __init__(self): - self.schema = '' - self.method = '' - self.host = '' - self.path = '' + self.schema = "" + self.method = "" + self.host = "" + self.path = "" self.headers = OrderedDict() self.query = OrderedDict() - self.body = '' + self.body = "" self.form = {} self.connection_timeout = 0 self.socket_timeout = 0 @@ -182,11 +191,11 @@ class Request: self.socket_timeout = socket_timeout def build(self, doseq=0): - return self.schema + '://' + self.host + self.path + '?' + urlencode(self.query, doseq) + return self.schema + "://" + self.host + self.path + "?" + urlencode(self.query, doseq) class ServiceInfo: - def __init__(self, host, header, credentials, connection_timeout, socket_timeout, scheme='http'): + def __init__(self, host, header, credentials, connection_timeout, socket_timeout, scheme="http"): self.host = host self.header = header self.credentials = credentials @@ -204,4 +213,4 @@ class ApiInfo: self.header = header def __str__(self): - return 'method: ' + self.method + ', path: ' + self.path + return "method: " + self.method + ", path: " + self.path diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py index 7eb5fdfa91..44f9959965 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py @@ -7,28 +7,28 @@ from urllib.parse import quote class Util: @staticmethod def norm_uri(path): - return quote(path).replace('%2F', '/').replace('+', '%20') + return quote(path).replace("%2F", "/").replace("+", "%20") @staticmethod def norm_query(params): - query = '' + query = "" for key in sorted(params.keys()): if type(params[key]) == list: for k in params[key]: - query = query + quote(key, safe='-_.~') + '=' + quote(k, safe='-_.~') + '&' + query = query + quote(key, safe="-_.~") + "=" + quote(k, safe="-_.~") + "&" else: - query = query + quote(key, safe='-_.~') + '=' + quote(params[key], safe='-_.~') + '&' + query = query + quote(key, safe="-_.~") + "=" + quote(params[key], safe="-_.~") + "&" query = query[:-1] - return query.replace('+', '%20') + return query.replace("+", "%20") @staticmethod def hmac_sha256(key, content): - return hmac.new(key, bytes(content, encoding='utf-8'), hashlib.sha256).digest() + return hmac.new(key, bytes(content, encoding="utf-8"), hashlib.sha256).digest() @staticmethod def sha256(content): if isinstance(content, str) is True: - return hashlib.sha256(content.encode('utf-8')).hexdigest() + return hashlib.sha256(content.encode("utf-8")).hexdigest() else: return hashlib.sha256(content).hexdigest() @@ -36,8 +36,8 @@ class Util: def to_hex(content): lst = [] for ch in content: - hv = hex(ch).replace('0x', '') + hv = hex(ch).replace("0x", "") if len(hv) == 1: - hv = '0' + hv + hv = "0" + hv lst.append(hv) return reduce(lambda x, y: x + y, lst) diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py index 8b14d026d9..3825fd6574 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py @@ -43,9 +43,7 @@ def json_to_object(json_str, req_id=None): def gen_req_id(): - return datetime.now().strftime("%Y%m%d%H%M%S") + format( - random.randint(0, 2 ** 64 - 1), "020X" - ) + return datetime.now().strftime("%Y%m%d%H%M%S") + format(random.randint(0, 2**64 - 1), "020X") class SSEDecoder: @@ -53,13 +51,13 @@ class SSEDecoder: self.source = source def _read(self): - data = b'' + data = b"" for chunk in self.source: for line in chunk.splitlines(True): data += line - if data.endswith((b'\r\r', b'\n\n', b'\r\n\r\n')): + if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): yield data - data = b'' + data = b"" if data: yield data @@ -67,13 +65,13 @@ class SSEDecoder: for chunk in self._read(): for line in chunk.splitlines(): # skip comment - if line.startswith(b':'): + if line.startswith(b":"): continue - if b':' in line: - field, value = line.split(b':', 1) + if b":" in line: + field, value = line.split(b":", 1) else: - field, value = line, b'' + field, value = line, b"" - if field == b'data' and len(value) > 0: + if field == b"data" and len(value) > 0: yield value diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py index 3cbe9d9f09..a3836685f1 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py @@ -9,9 +9,7 @@ from .common import SSEDecoder, dict_to_object, gen_req_id, json_to_object class MaasService(Service): def __init__(self, host, region, connection_timeout=60, socket_timeout=60): - service_info = self.get_service_info( - host, region, connection_timeout, socket_timeout - ) + service_info = self.get_service_info(host, region, connection_timeout, socket_timeout) self._apikey = None api_info = self.get_api_info() super().__init__(service_info, api_info) @@ -35,9 +33,7 @@ class MaasService(Service): def get_api_info(): api_info = { "chat": ApiInfo("POST", "/api/v2/endpoint/{endpoint_id}/chat", {}, {}, {}), - "embeddings": ApiInfo( - "POST", "/api/v2/endpoint/{endpoint_id}/embeddings", {}, {}, {} - ), + "embeddings": ApiInfo("POST", "/api/v2/endpoint/{endpoint_id}/embeddings", {}, {}, {}), } return api_info @@ -52,9 +48,7 @@ class MaasService(Service): try: req["stream"] = True - res = self._call( - endpoint_id, "chat", req_id, {}, json.dumps(req).encode("utf-8"), apikey, stream=True - ) + res = self._call(endpoint_id, "chat", req_id, {}, json.dumps(req).encode("utf-8"), apikey, stream=True) decoder = SSEDecoder(res) @@ -64,13 +58,12 @@ class MaasService(Service): return try: - res = json_to_object( - str(data, encoding="utf-8"), req_id=req_id) + res = json_to_object(str(data, encoding="utf-8"), req_id=req_id) except Exception: raise if res.error is not None and res.error.code_n != 0: - raise MaasException( + raise MaasError( res.error.code_n, res.error.code, res.error.message, @@ -79,7 +72,7 @@ class MaasService(Service): yield res return iter_fn() - except MaasException: + except MaasError: raise except Exception as e: raise new_client_sdk_request_error(str(e)) @@ -95,29 +88,28 @@ class MaasService(Service): apikey = self._apikey try: - res = self._call(endpoint_id, api, req_id, params, - json.dumps(req).encode("utf-8"), apikey) + res = self._call(endpoint_id, api, req_id, params, json.dumps(req).encode("utf-8"), apikey) resp = dict_to_object(res.json()) if resp and isinstance(resp, dict): resp["req_id"] = req_id return resp - except MaasException as e: + except MaasError as e: raise e except Exception as e: raise new_client_sdk_request_error(str(e), req_id) def _validate(self, api, req_id): credentials_exist = ( - self.service_info.credentials is not None and - self.service_info.credentials.sk is not None and - self.service_info.credentials.ak is not None + self.service_info.credentials is not None + and self.service_info.credentials.sk is not None + and self.service_info.credentials.ak is not None ) if not self._apikey and not credentials_exist: raise new_client_sdk_request_error("no valid credential", req_id) - if not (api in self.api_info): + if api not in self.api_info: raise new_client_sdk_request_error("no such api", req_id) def _call(self, endpoint_id, api, req_id, params, body, apikey=None, stream=False): @@ -150,22 +142,19 @@ class MaasService(Service): raw = res.text.encode() res.close() try: - resp = json_to_object( - str(raw, encoding="utf-8"), req_id=req_id) + resp = json_to_object(str(raw, encoding="utf-8"), req_id=req_id) except Exception: raise new_client_sdk_request_error(raw, req_id) if resp.error: - raise MaasException( - resp.error.code_n, resp.error.code, resp.error.message, req_id - ) + raise MaasError(resp.error.code_n, resp.error.code, resp.error.message, req_id) else: raise new_client_sdk_request_error(resp, req_id) return res -class MaasException(Exception): +class MaasError(Exception): def __init__(self, code_n, code, message, req_id): self.code_n = code_n self.code = code @@ -173,15 +162,17 @@ class MaasException(Exception): self.req_id = req_id def __str__(self): - return ("Detailed exception information is listed below.\n" + - "req_id: {}\n" + - "code_n: {}\n" + - "code: {}\n" + - "message: {}").format(self.req_id, self.code_n, self.code, self.message) + return ( + "Detailed exception information is listed below.\n" + + "req_id: {}\n" + + "code_n: {}\n" + + "code: {}\n" + + "message: {}" + ).format(self.req_id, self.code_n, self.code, self.message) def new_client_sdk_request_error(raw, req_id=""): - return MaasException(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id) + return MaasError(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id) class BinaryResponseContent: @@ -189,25 +180,19 @@ class BinaryResponseContent: self.response = response self.request_id = request_id - def stream_to_file( - self, - file: str - ) -> None: + def stream_to_file(self, file: str) -> None: is_first = True - error_bytes = b'' + error_bytes = b"" with open(file, mode="wb") as f: for data in self.response: - if len(error_bytes) > 0 or (is_first and "\"error\":" in str(data)): + if len(error_bytes) > 0 or (is_first and '"error":' in str(data)): error_bytes += data else: f.write(data) if len(error_bytes) > 0: - resp = json_to_object( - str(error_bytes, encoding="utf-8"), req_id=self.request_id) - raise MaasException( - resp.error.code_n, resp.error.code, resp.error.message, self.request_id - ) + resp = json_to_object(str(error_bytes, encoding="utf-8"), req_id=self.request_id) + raise MaasError(resp.error.code_n, resp.error.code, resp.error.message, self.request_id) def iter_bytes(self) -> Iterator[bytes]: yield from self.response diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py index 996c66e604..c25851fc45 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py @@ -35,7 +35,7 @@ from core.model_runtime.model_providers.volcengine_maas.legacy.errors import ( AuthErrors, BadRequestErrors, ConnectionErrors, - MaasException, + MaasError, RateLimitErrors, ServerUnavailableErrors, ) @@ -49,10 +49,17 @@ logger = logging.getLogger(__name__) class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: if ArkClientV3.is_legacy(credentials): return self._generate_v2(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) return self._generate_v3(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -71,27 +78,36 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): try: client.chat( { - 'max_new_tokens': 16, - 'temperature': 0.7, - 'top_p': 0.9, - 'top_k': 15, + "max_new_tokens": 16, + "temperature": 0.7, + "top_p": 0.9, + "top_k": 15, }, - [UserPromptMessage(content='ping\nAnswer: ')], + [UserPromptMessage(content="ping\nAnswer: ")], ) - except MaasException as e: + except MaasError as e: raise CredentialsValidateFailedError(e.message) @staticmethod def _validate_credentials_v3(credentials: dict) -> None: client = ArkClientV3.from_credentials(credentials) try: - client.chat(max_tokens=16, temperature=0.7, top_p=0.9, - messages=[UserPromptMessage(content='ping\nAnswer: ')], ) + client.chat( + max_tokens=16, + temperature=0.7, + top_p=0.9, + messages=[UserPromptMessage(content="ping\nAnswer: ")], + ) except Exception as e: raise CredentialsValidateFailedError(e) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: if ArkClientV3.is_legacy(credentials): return self._get_num_tokens_v2(prompt_messages) return self._get_num_tokens_v3(prompt_messages) @@ -100,8 +116,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): if len(messages) == 0: return 0 num_tokens = 0 - messages_dict = [ - MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages] + messages_dict = [MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages] for message in messages_dict: for key, value in message.items(): num_tokens += self._get_num_tokens_by_gpt2(str(key)) @@ -113,8 +128,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): if len(messages) == 0: return 0 num_tokens = 0 - messages_dict = [ - ArkClientV3.convert_prompt_message(m) for m in messages] + messages_dict = [ArkClientV3.convert_prompt_message(m) for m in messages] for message in messages_dict: for key, value in message.items(): num_tokens += self._get_num_tokens_by_gpt2(str(key)) @@ -122,97 +136,108 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): return num_tokens - def _generate_v2(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - + def _generate_v2( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: client = MaaSClient.from_credential(credentials) req_params = get_v2_req_params(credentials, model_parameters, stop) extra_model_kwargs = {} if tools: - extra_model_kwargs['tools'] = [ - MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools - ] - resp = MaaSClient.wrap_exception( - lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs)) + extra_model_kwargs["tools"] = [MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools] + resp = MaaSClient.wrap_exception(lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs)) def _handle_stream_chat_response() -> Generator: for index, r in enumerate(resp): - choices = r['choices'] + choices = r["choices"] if not choices: continue choice = choices[0] - message = choice['message'] + message = choice["message"] usage = None - if r.get('usage'): - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=r['usage']['prompt_tokens'], - completion_tokens=r['usage']['completion_tokens'] - ) + if r.get("usage"): + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=r["usage"]["prompt_tokens"], + completion_tokens=r["usage"]["completion_tokens"], + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, message=AssistantPromptMessage( - content=message['content'] if message['content'] else '', - tool_calls=[] + content=message["content"] if message["content"] else "", tool_calls=[] ), usage=usage, - finish_reason=choice.get('finish_reason'), + finish_reason=choice.get("finish_reason"), ), ) def _handle_chat_response() -> LLMResult: - choices = resp['choices'] + choices = resp["choices"] if not choices: raise ValueError("No choices found") choice = choices[0] - message = choice['message'] + message = choice["message"] # parse tool calls tool_calls = [] - if message['tool_calls']: - for call in message['tool_calls']: + if message["tool_calls"]: + for call in message["tool_calls"]: tool_call = AssistantPromptMessage.ToolCall( - id=call['function']['name'], - type=call['type'], + id=call["function"]["name"], + type=call["type"], function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=call['function']['name'], - arguments=call['function']['arguments'] - ) + name=call["function"]["name"], arguments=call["function"]["arguments"] + ), ) tool_calls.append(tool_call) - usage = resp['usage'] + usage = resp["usage"] return LLMResult( model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage( - content=message['content'] if message['content'] else '', + content=message["content"] if message["content"] else "", tool_calls=tool_calls, ), - usage=self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=usage['prompt_tokens'], - completion_tokens=usage['completion_tokens'] - ), + usage=self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=usage["prompt_tokens"], + completion_tokens=usage["completion_tokens"], + ), ) if not stream: return _handle_chat_response() return _handle_stream_chat_response() - def _generate_v3(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - + def _generate_v3( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: client = ArkClientV3.from_credentials(credentials) req_params = get_v3_req_params(credentials, model_parameters, stop) if tools: - req_params['tools'] = tools + req_params["tools"] = tools def _handle_stream_chat_response(chunks: Generator[ChatCompletionChunk]) -> Generator: for chunk in chunks: @@ -225,14 +250,15 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=choice.index, - message=AssistantPromptMessage( - content=choice.delta.content, - tool_calls=[] - ), - usage=self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=chunk.usage.prompt_tokens, - completion_tokens=chunk.usage.completion_tokens - ) if chunk.usage else None, + message=AssistantPromptMessage(content=choice.delta.content, tool_calls=[]), + usage=self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=chunk.usage.prompt_tokens, + completion_tokens=chunk.usage.completion_tokens, + ) + if chunk.usage + else None, finish_reason=choice.finish_reason, ), ) @@ -248,9 +274,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): id=call.id, type=call.type, function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=call.function.name, - arguments=call.function.arguments - ) + name=call.function.name, arguments=call.function.arguments + ), ) tool_calls.append(tool_call) @@ -262,10 +287,12 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): content=message.content if message.content else "", tool_calls=tool_calls, ), - usage=self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=usage.prompt_tokens, - completion_tokens=usage.completion_tokens - ), + usage=self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens, + ), ) if not stream: @@ -277,72 +304,56 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ model_config = get_model_config(credentials) rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ) + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='top_k', - type=ParameterType.INT, - min=1, - default=1, - label=I18nObject( - zh_Hans='Top K', - en_US='Top K' - ) + name="top_k", type=ParameterType.INT, min=1, default=1, label=I18nObject(zh_Hans="Top K", en_US="Top K") ), ParameterRule( - name='presence_penalty', + name="presence_penalty", type=ParameterType.FLOAT, - use_template='presence_penalty', + use_template="presence_penalty", label=I18nObject( - en_US='Presence Penalty', - zh_Hans='存在惩罚', + en_US="Presence Penalty", + zh_Hans="存在惩罚", ), min=-2.0, max=2.0, ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", type=ParameterType.FLOAT, - use_template='frequency_penalty', + use_template="frequency_penalty", label=I18nObject( - en_US='Frequency Penalty', - zh_Hans='频率惩罚', + en_US="Frequency Penalty", + zh_Hans="频率惩罚", ), min=-2.0, max=2.0, ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, max=model_config.properties.max_tokens, default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), ), ] @@ -352,9 +363,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, model_properties=model_properties, diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py index a882f68a36..d8be14b024 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py @@ -16,138 +16,127 @@ class ModelConfig(BaseModel): configs: dict[str, ModelConfig] = { - 'Doubao-pro-4k': ModelConfig( + "Doubao-pro-4k": ModelConfig( properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Doubao-lite-4k': ModelConfig( + "Doubao-lite-4k": ModelConfig( properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Doubao-pro-32k': ModelConfig( + "Doubao-pro-32k": ModelConfig( properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Doubao-lite-32k': ModelConfig( + "Doubao-lite-32k": ModelConfig( properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Doubao-pro-128k': ModelConfig( + "Doubao-pro-128k": ModelConfig( properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Doubao-lite-128k': ModelConfig( - properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), - features=[] + "Doubao-lite-128k": ModelConfig( + properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), features=[] ), - 'Skylark2-pro-4k': ModelConfig( - properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), - features=[] + "Skylark2-pro-4k": ModelConfig( + properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), features=[] ), - 'Llama3-8B': ModelConfig( - properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), - features=[] + "Llama3-8B": ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), features=[] ), - 'Llama3-70B': ModelConfig( - properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), - features=[] + "Llama3-70B": ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), features=[] ), - 'Moonshot-v1-8k': ModelConfig( + "Moonshot-v1-8k": ModelConfig( properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Moonshot-v1-32k': ModelConfig( + "Moonshot-v1-32k": ModelConfig( properties=ModelProperties(context_size=32768, max_tokens=16384, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Moonshot-v1-128k': ModelConfig( + "Moonshot-v1-128k": ModelConfig( properties=ModelProperties(context_size=131072, max_tokens=65536, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'GLM3-130B': ModelConfig( + "GLM3-130B": ModelConfig( properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'GLM3-130B-Fin': ModelConfig( + "GLM3-130B-Fin": ModelConfig( properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], + ), + "Mistral-7B": ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT), features=[] ), - 'Mistral-7B': ModelConfig( - properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT), - features=[] - ) } def get_model_config(credentials: dict) -> ModelConfig: - base_model = credentials.get('base_model_name', '') + base_model = credentials.get("base_model_name", "") model_configs = configs.get(base_model) if not model_configs: return ModelConfig( properties=ModelProperties( - context_size=int(credentials.get('context_size', 0)), - max_tokens=int(credentials.get('max_tokens', 0)), - mode=LLMMode.value_of(credentials.get('mode', 'chat')), + context_size=int(credentials.get("context_size", 0)), + max_tokens=int(credentials.get("max_tokens", 0)), + mode=LLMMode.value_of(credentials.get("mode", "chat")), ), - features=[] + features=[], ) return model_configs -def get_v2_req_params(credentials: dict, model_parameters: dict, - stop: list[str] | None = None): +def get_v2_req_params(credentials: dict, model_parameters: dict, stop: list[str] | None = None): req_params = {} # predefined properties model_configs = get_model_config(credentials) if model_configs: - req_params['max_prompt_tokens'] = model_configs.properties.context_size - req_params['max_new_tokens'] = model_configs.properties.max_tokens + req_params["max_prompt_tokens"] = model_configs.properties.context_size + req_params["max_new_tokens"] = model_configs.properties.max_tokens # model parameters - if model_parameters.get('max_tokens'): - req_params['max_new_tokens'] = model_parameters.get('max_tokens') - if model_parameters.get('temperature'): - req_params['temperature'] = model_parameters.get('temperature') - if model_parameters.get('top_p'): - req_params['top_p'] = model_parameters.get('top_p') - if model_parameters.get('top_k'): - req_params['top_k'] = model_parameters.get('top_k') - if model_parameters.get('presence_penalty'): - req_params['presence_penalty'] = model_parameters.get( - 'presence_penalty') - if model_parameters.get('frequency_penalty'): - req_params['frequency_penalty'] = model_parameters.get( - 'frequency_penalty') + if model_parameters.get("max_tokens"): + req_params["max_new_tokens"] = model_parameters.get("max_tokens") + if model_parameters.get("temperature"): + req_params["temperature"] = model_parameters.get("temperature") + if model_parameters.get("top_p"): + req_params["top_p"] = model_parameters.get("top_p") + if model_parameters.get("top_k"): + req_params["top_k"] = model_parameters.get("top_k") + if model_parameters.get("presence_penalty"): + req_params["presence_penalty"] = model_parameters.get("presence_penalty") + if model_parameters.get("frequency_penalty"): + req_params["frequency_penalty"] = model_parameters.get("frequency_penalty") if stop: - req_params['stop'] = stop + req_params["stop"] = stop return req_params -def get_v3_req_params(credentials: dict, model_parameters: dict, - stop: list[str] | None = None): +def get_v3_req_params(credentials: dict, model_parameters: dict, stop: list[str] | None = None): req_params = {} # predefined properties model_configs = get_model_config(credentials) if model_configs: - req_params['max_tokens'] = model_configs.properties.max_tokens + req_params["max_tokens"] = model_configs.properties.max_tokens # model parameters - if model_parameters.get('max_tokens'): - req_params['max_tokens'] = model_parameters.get('max_tokens') - if model_parameters.get('temperature'): - req_params['temperature'] = model_parameters.get('temperature') - if model_parameters.get('top_p'): - req_params['top_p'] = model_parameters.get('top_p') - if model_parameters.get('presence_penalty'): - req_params['presence_penalty'] = model_parameters.get( - 'presence_penalty') - if model_parameters.get('frequency_penalty'): - req_params['frequency_penalty'] = model_parameters.get( - 'frequency_penalty') + if model_parameters.get("max_tokens"): + req_params["max_tokens"] = model_parameters.get("max_tokens") + if model_parameters.get("temperature"): + req_params["temperature"] = model_parameters.get("temperature") + if model_parameters.get("top_p"): + req_params["top_p"] = model_parameters.get("top_p") + if model_parameters.get("presence_penalty"): + req_params["presence_penalty"] = model_parameters.get("presence_penalty") + if model_parameters.get("frequency_penalty"): + req_params["frequency_penalty"] = model_parameters.get("frequency_penalty") if stop: - req_params['stop'] = stop + req_params["stop"] = stop return req_params diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py index 74cf26247c..ce4f0c3ab1 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py @@ -11,20 +11,18 @@ class ModelConfig(BaseModel): ModelConfigs = { - 'Doubao-embedding': ModelConfig( - properties=ModelProperties(context_size=4096, max_chunks=32) - ), + "Doubao-embedding": ModelConfig(properties=ModelProperties(context_size=4096, max_chunks=32)), } def get_model_config(credentials: dict) -> ModelConfig: - base_model = credentials.get('base_model_name', '') + base_model = credentials.get("base_model_name", "") model_configs = ModelConfigs.get(base_model) if not model_configs: return ModelConfig( properties=ModelProperties( - context_size=int(credentials.get('context_size', 0)), - max_chunks=int(credentials.get('max_chunks', 0)), + context_size=int(credentials.get("context_size", 0)), + max_chunks=int(credentials.get("max_chunks", 0)), ) ) return model_configs diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py index d54aeeb0b1..9cba2cb879 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py @@ -28,7 +28,7 @@ from core.model_runtime.model_providers.volcengine_maas.legacy.errors import ( AuthErrors, BadRequestErrors, ConnectionErrors, - MaasException, + MaasError, RateLimitErrors, ServerUnavailableErrors, ) @@ -40,9 +40,9 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): Model class for VolcengineMaaS text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -57,37 +57,27 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): return self._generate_v3(model, credentials, texts, user) - def _generate_v2(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _generate_v2( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: client = MaaSClient.from_credential(credentials) resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts)) - usage = self._calc_response_usage( - model=model, credentials=credentials, tokens=resp['usage']['total_tokens']) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=resp["usage"]["total_tokens"]) - result = TextEmbeddingResult( - model=model, - embeddings=[v['embedding'] for v in resp['data']], - usage=usage - ) + result = TextEmbeddingResult(model=model, embeddings=[v["embedding"] for v in resp["data"]], usage=usage) return result - def _generate_v3(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _generate_v3( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: client = ArkClientV3.from_credentials(credentials) resp = client.embeddings(texts) - usage = self._calc_response_usage( - model=model, credentials=credentials, tokens=resp.usage.total_tokens) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=resp.usage.total_tokens) - result = TextEmbeddingResult( - model=model, - embeddings=[v.embedding for v in resp.data], - usage=usage - ) + result = TextEmbeddingResult(model=model, embeddings=[v.embedding for v in resp.data], usage=usage) return result @@ -120,13 +110,13 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): def _validate_credentials_v2(self, model: str, credentials: dict) -> None: try: - self._invoke(model=model, credentials=credentials, texts=['ping']) - except MaasException as e: + self._invoke(model=model, credentials=credentials, texts=["ping"]) + except MaasError as e: raise CredentialsValidateFailedError(e.message) def _validate_credentials_v3(self, model: str, credentials: dict) -> None: try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except Exception as e: raise CredentialsValidateFailedError(e) @@ -150,12 +140,12 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ model_config = get_model_config(credentials) model_properties = { ModelPropertyKey.CONTEXT_SIZE: model_config.properties.context_size, - ModelPropertyKey.MAX_CHUNKS: model_config.properties.max_chunks + ModelPropertyKey.MAX_CHUNKS: model_config.properties.max_chunks, } entity = AIModelEntity( model=model, @@ -165,10 +155,10 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): model_properties=model_properties, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity @@ -184,10 +174,7 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -198,7 +185,7 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml index a00c1b7994..13e00da76f 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml +++ b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml @@ -30,8 +30,28 @@ model_credential_schema: en_US: Enter your Model Name zh_Hans: 输入模型名称 credential_form_schemas: + - variable: auth_method + required: true + label: + en_US: Authentication Method + zh_Hans: 鉴权方式 + type: select + default: aksk + options: + - label: + en_US: API Key + value: api_key + - label: + en_US: Access Key / Secret Access Key + value: aksk + placeholder: + en_US: Enter your Authentication Method + zh_Hans: 选择鉴权方式 - variable: volc_access_key_id required: true + show_on: + - variable: auth_method + value: aksk label: en_US: Access Key zh_Hans: Access Key @@ -41,6 +61,9 @@ model_credential_schema: zh_Hans: 输入您的 Access Key - variable: volc_secret_access_key required: true + show_on: + - variable: auth_method + value: aksk label: en_US: Secret Access Key zh_Hans: Secret Access Key @@ -48,6 +71,17 @@ model_credential_schema: placeholder: en_US: Enter your Secret Access Key zh_Hans: 输入您的 Secret Access Key + - variable: volc_api_key + required: true + show_on: + - variable: auth_method + value: api_key + label: + en_US: API Key + type: secret-input + placeholder: + en_US: Enter your API Key + zh_Hans: 输入您的 API Key - variable: volc_region required: true label: @@ -64,7 +98,7 @@ model_credential_schema: en_US: API Endpoint Host zh_Hans: API Endpoint Host type: text-input - default: maas-api.ml-platform-cn-beijing.volces.com + default: https://ark.cn-beijing.volces.com/api/v3 placeholder: en_US: Enter your API Endpoint Host zh_Hans: 输入 API Endpoint Host diff --git a/api/core/model_runtime/model_providers/wenxin/_common.py b/api/core/model_runtime/model_providers/wenxin/_common.py index 0230c78b75..d72d1bd83a 100644 --- a/api/core/model_runtime/model_providers/wenxin/_common.py +++ b/api/core/model_runtime/model_providers/wenxin/_common.py @@ -11,7 +11,7 @@ from core.model_runtime.model_providers.wenxin.wenxin_errors import ( RateLimitReachedError, ) -baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {} +baidu_access_tokens: dict[str, "BaiduAccessToken"] = {} baidu_access_tokens_lock = Lock() @@ -22,49 +22,46 @@ class BaiduAccessToken: def __init__(self, api_key: str) -> None: self.api_key = api_key - self.access_token = '' + self.access_token = "" self.expires = datetime.now() + timedelta(days=3) @staticmethod def _get_access_token(api_key: str, secret_key: str) -> str: """ - request access token from Baidu + request access token from Baidu """ try: response = post( - url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}', - headers={ - 'Content-Type': 'application/json', - 'Accept': 'application/json' - }, + url=f"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}", + headers={"Content-Type": "application/json", "Accept": "application/json"}, ) except Exception as e: - raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}') + raise InvalidAuthenticationError(f"Failed to get access token from Baidu: {e}") resp = response.json() - if 'error' in resp: - if resp['error'] == 'invalid_client': + if "error" in resp: + if resp["error"] == "invalid_client": raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}') - elif resp['error'] == 'unknown_error': + elif resp["error"] == "unknown_error": raise InternalServerError(f'Internal server error: {resp["error_description"]}') - elif resp['error'] == 'invalid_request': + elif resp["error"] == "invalid_request": raise BadRequestError(f'Bad request: {resp["error_description"]}') - elif resp['error'] == 'rate_limit_exceeded': + elif resp["error"] == "rate_limit_exceeded": raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}') else: raise Exception(f'Unknown error: {resp["error_description"]}') - return resp['access_token'] + return resp["access_token"] @staticmethod - def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken': + def get_access_token(api_key: str, secret_key: str) -> "BaiduAccessToken": """ - LLM from Baidu requires access token to invoke the API. - however, we have api_key and secret_key, and access token is valid for 30 days. - so we can cache the access token for 3 days. (avoid memory leak) + LLM from Baidu requires access token to invoke the API. + however, we have api_key and secret_key, and access token is valid for 30 days. + so we can cache the access token for 3 days. (avoid memory leak) - it may be more efficient to use a ticker to refresh access token, but it will cause - more complexity, so we just refresh access tokens when get_access_token is called. + it may be more efficient to use a ticker to refresh access token, but it will cause + more complexity, so we just refresh access tokens when get_access_token is called. """ # loop up cache, remove expired access token @@ -79,11 +76,13 @@ class BaiduAccessToken: # if access token not in cache, request it token = BaiduAccessToken(api_key) baidu_access_tokens[api_key] = token - # release it to enhance performance - # btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock - baidu_access_tokens_lock.release() - # try to get access token - token_str = BaiduAccessToken._get_access_token(api_key, secret_key) + try: + # try to get access token + token_str = BaiduAccessToken._get_access_token(api_key, secret_key) + finally: + # release it to enhance performance + # btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock + baidu_access_tokens_lock.release() token.access_token = token_str token.expires = now + timedelta(days=3) return token @@ -96,49 +95,49 @@ class BaiduAccessToken: class _CommonWenxin: api_bases = { - 'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', - 'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', - 'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', - 'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', - 'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', - 'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205', - 'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222', - 'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', - 'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k', - 'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', - 'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', - 'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed', - 'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k', - 'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas', - 'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', - 'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k', - 'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', - 'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', - 'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k', - 'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview', - 'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat', - 'embedding-v1': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1', - 'bge-large-en': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en', - 'bge-large-zh': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh', - 'tao-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k', + "ernie-bot": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205", + "ernie-bot-4": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro", + "ernie-bot-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions", + "ernie-bot-turbo": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant", + "ernie-3.5-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions", + "ernie-3.5-8k-0205": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205", + "ernie-3.5-8k-1222": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222", + "ernie-3.5-4k-0205": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205", + "ernie-3.5-128k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k", + "ernie-4.0-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro", + "ernie-4.0-8k-latest": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro", + "ernie-speed-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed", + "ernie-speed-128k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k", + "ernie-speed-appbuilder": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas", + "ernie-lite-8k-0922": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant", + "ernie-lite-8k-0308": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k", + "ernie-character-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k", + "ernie-character-8k-0321": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k", + "ernie-4.0-turbo-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k", + "ernie-4.0-turbo-8k-preview": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview", + "yi_34b_chat": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat", + "embedding-v1": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1", + "bge-large-en": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en", + "bge-large-zh": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh", + "tao-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k", } function_calling_supports = [ - 'ernie-bot', - 'ernie-bot-8k', - 'ernie-3.5-8k', - 'ernie-3.5-8k-0205', - 'ernie-3.5-8k-1222', - 'ernie-3.5-4k-0205', - 'ernie-3.5-128k', - 'ernie-4.0-8k', - 'ernie-4.0-turbo-8k', - 'ernie-4.0-turbo-8k-preview', - 'yi_34b_chat' + "ernie-bot", + "ernie-bot-8k", + "ernie-3.5-8k", + "ernie-3.5-8k-0205", + "ernie-3.5-8k-1222", + "ernie-3.5-4k-0205", + "ernie-3.5-128k", + "ernie-4.0-8k", + "ernie-4.0-turbo-8k", + "ernie-4.0-turbo-8k-preview", + "yi_34b_chat", ] - api_key: str = '' - secret_key: str = '' + api_key: str = "" + secret_key: str = "" def __init__(self, api_key: str, secret_key: str): self.api_key = api_key @@ -146,10 +145,7 @@ class _CommonWenxin: @staticmethod def _to_credential_kwargs(credentials: dict) -> dict: - credentials_kwargs = { - "api_key": credentials['api_key'], - "secret_key": credentials['secret_key'] - } + credentials_kwargs = {"api_key": credentials["api_key"], "secret_key": credentials["secret_key"]} return credentials_kwargs def _handle_error(self, code: int, msg: str): @@ -185,13 +181,13 @@ class _CommonWenxin: 336105: BadRequestError, 336200: InternalServerError, 336303: BadRequestError, - 337006: BadRequestError + 337006: BadRequestError, } if code in error_map: raise error_map[code](msg) else: - raise InternalServerError(f'Unknown error: {msg}') + raise InternalServerError(f"Unknown error: {msg}") def _get_access_token(self) -> str: token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key) diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py index 8109949b1d..07b970f810 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py @@ -15,33 +15,39 @@ from core.model_runtime.model_providers.wenxin.wenxin_errors import ( class ErnieMessage: class Role(Enum): - USER = 'user' - ASSISTANT = 'assistant' - FUNCTION = 'function' - SYSTEM = 'system' + USER = "user" + ASSISTANT = "assistant" + FUNCTION = "function" + SYSTEM = "system" role: str = Role.USER.value content: str usage: dict[str, int] = None - stop_reason: str = '' + stop_reason: str = "" def to_dict(self) -> dict[str, Any]: return { - 'role': self.role, - 'content': self.content, + "role": self.role, + "content": self.content, } - def __init__(self, content: str, role: str = 'user') -> None: + def __init__(self, content: str, role: str = "user") -> None: self.content = content self.role = role + class ErnieBotModel(_CommonWenxin): - - def generate(self, model: str, stream: bool, messages: list[ErnieMessage], - parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \ - stop: list[str], user: str) \ - -> Union[Generator[ErnieMessage, None, None], ErnieMessage]: - + def generate( + self, + model: str, + stream: bool, + messages: list[ErnieMessage], + parameters: dict[str, Any], + timeout: int, + tools: list[PromptMessageTool], + stop: list[str], + user: str, + ) -> Union[Generator[ErnieMessage, None, None], ErnieMessage]: # check parameters self._check_parameters(model, parameters, tools, stop) @@ -49,22 +55,23 @@ class ErnieBotModel(_CommonWenxin): access_token = self._get_access_token() # generate request body - url = f'{self.api_bases[model]}?access_token={access_token}' + url = f"{self.api_bases[model]}?access_token={access_token}" # clone messages messages_cloned = self._copy_messages(messages=messages) # build body - body = self._build_request_body(model, messages=messages_cloned, stream=stream, - parameters=parameters, tools=tools, stop=stop, user=user) + body = self._build_request_body( + model, messages=messages_cloned, stream=stream, parameters=parameters, tools=tools, stop=stop, user=user + ) headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } resp = post(url=url, data=dumps(body), headers=headers, stream=stream) if resp.status_code != 200: - raise InternalServerError(f'Failed to invoke ernie bot: {resp.text}') + raise InternalServerError(f"Failed to invoke ernie bot: {resp.text}") if stream: return self._handle_chat_stream_generate_response(resp) @@ -73,10 +80,11 @@ class ErnieBotModel(_CommonWenxin): def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]: return [ErnieMessage(message.content, message.role) for message in messages] - def _check_parameters(self, model: str, parameters: dict[str, Any], - tools: list[PromptMessageTool], stop: list[str]) -> None: + def _check_parameters( + self, model: str, parameters: dict[str, Any], tools: list[PromptMessageTool], stop: list[str] + ) -> None: if model not in self.api_bases: - raise BadRequestError(f'Invalid model: {model}') + raise BadRequestError(f"Invalid model: {model}") # if model not in self.function_calling_supports and tools is not None and len(tools) > 0: # raise BadRequestError(f'Model {model} does not support calling function.') @@ -85,86 +93,106 @@ class ErnieBotModel(_CommonWenxin): # so, we just disable function calling for now. if tools is not None and len(tools) > 0: - raise BadRequestError('function calling is not supported yet.') + raise BadRequestError("function calling is not supported yet.") if stop is not None: if len(stop) > 4: - raise BadRequestError('stop list should not exceed 4 items.') + raise BadRequestError("stop list should not exceed 4 items.") for s in stop: if len(s) > 20: - raise BadRequestError('stop item should not exceed 20 characters.') + raise BadRequestError("stop item should not exceed 20 characters.") - def _build_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, parameters: dict[str, Any], - tools: list[PromptMessageTool], stop: list[str], user: str) -> dict[str, Any]: + def _build_request_body( + self, + model: str, + messages: list[ErnieMessage], + stream: bool, + parameters: dict[str, Any], + tools: list[PromptMessageTool], + stop: list[str], + user: str, + ) -> dict[str, Any]: # if model in self.function_calling_supports: # return self._build_function_calling_request_body(model, messages, parameters, tools, stop, user) return self._build_chat_request_body(model, messages, stream, parameters, stop, user) - def _build_function_calling_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, - parameters: dict[str, Any], tools: list[PromptMessageTool], - stop: list[str], user: str) \ - -> dict[str, Any]: + def _build_function_calling_request_body( + self, + model: str, + messages: list[ErnieMessage], + stream: bool, + parameters: dict[str, Any], + tools: list[PromptMessageTool], + stop: list[str], + user: str, + ) -> dict[str, Any]: if len(messages) % 2 == 0: - raise BadRequestError('The number of messages should be odd.') - if messages[0].role == 'function': - raise BadRequestError('The first message should be user message.') + raise BadRequestError("The number of messages should be odd.") + if messages[0].role == "function": + raise BadRequestError("The first message should be user message.") """ TODO: implement function calling """ - def _build_chat_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, - parameters: dict[str, Any], stop: list[str], user: str) \ - -> dict[str, Any]: + def _build_chat_request_body( + self, + model: str, + messages: list[ErnieMessage], + stream: bool, + parameters: dict[str, Any], + stop: list[str], + user: str, + ) -> dict[str, Any]: if len(messages) == 0: - raise BadRequestError('The number of messages should not be zero.') + raise BadRequestError("The number of messages should not be zero.") # check if the first element is system, shift it - system_message = '' - if messages[0].role == 'system': + system_message = "" + if messages[0].role == "system": message = messages.pop(0) system_message = message.content if len(messages) % 2 == 0: - raise BadRequestError('The number of messages should be odd.') - if messages[0].role != 'user': - raise BadRequestError('The first message should be user message.') + raise BadRequestError("The number of messages should be odd.") + if messages[0].role != "user": + raise BadRequestError("The first message should be user message.") body = { - 'messages': [message.to_dict() for message in messages], - 'stream': stream, - 'stop': stop, - 'user_id': user, - **parameters + "messages": [message.to_dict() for message in messages], + "stream": stream, + "stop": stop, + "user_id": user, + **parameters, } - if 'max_tokens' in parameters and type(parameters['max_tokens']) == int: - body['max_output_tokens'] = parameters['max_tokens'] + if "max_tokens" in parameters and type(parameters["max_tokens"]) == int: + body["max_output_tokens"] = parameters["max_tokens"] - if 'presence_penalty' in parameters and type(parameters['presence_penalty']) == float: - body['penalty_score'] = parameters['presence_penalty'] + if "presence_penalty" in parameters and type(parameters["presence_penalty"]) == float: + body["penalty_score"] = parameters["presence_penalty"] if system_message: - body['system'] = system_message + body["system"] = system_message return body def _handle_chat_generate_response(self, response: Response) -> ErnieMessage: data = response.json() - if 'error_code' in data: - code = data['error_code'] - msg = data['error_msg'] + if "error_code" in data: + code = data["error_code"] + msg = data["error_msg"] # raise error self._handle_error(code, msg) - result = data['result'] - usage = data['usage'] + result = data["result"] + usage = data["usage"] - message = ErnieMessage(content=result, role='assistant') + message = ErnieMessage(content=result, role="assistant") message.usage = { - 'prompt_tokens': usage['prompt_tokens'], - 'completion_tokens': usage['completion_tokens'], - 'total_tokens': usage['total_tokens'] + "prompt_tokens": usage["prompt_tokens"], + "completion_tokens": usage["completion_tokens"], + "total_tokens": usage["total_tokens"], } return message @@ -173,19 +201,19 @@ class ErnieBotModel(_CommonWenxin): for line in response.iter_lines(): if len(line) == 0: continue - line = line.decode('utf-8') - if line[0] == '{': + line = line.decode("utf-8") + if line[0] == "{": try: data = loads(line) - if 'error_code' in data: - code = data['error_code'] - msg = data['error_msg'] + if "error_code" in data: + code = data["error_code"] + msg = data["error_msg"] # raise error self._handle_error(code, msg) except Exception as e: - raise InternalServerError(f'Failed to parse response: {e}') + raise InternalServerError(f"Failed to parse response: {e}") - if line.startswith('data:'): + if line.startswith("data:"): line = line[5:].strip() else: continue @@ -195,23 +223,23 @@ class ErnieBotModel(_CommonWenxin): try: data = loads(line) except Exception as e: - raise InternalServerError(f'Failed to parse response: {e}') + raise InternalServerError(f"Failed to parse response: {e}") - result = data['result'] - is_end = data['is_end'] + result = data["result"] + is_end = data["is_end"] if is_end: - usage = data['usage'] - finish_reason = data.get('finish_reason', None) - message = ErnieMessage(content=result, role='assistant') + usage = data["usage"] + finish_reason = data.get("finish_reason", None) + message = ErnieMessage(content=result, role="assistant") message.usage = { - 'prompt_tokens': usage['prompt_tokens'], - 'completion_tokens': usage['completion_tokens'], - 'total_tokens': usage['total_tokens'] + "prompt_tokens": usage["prompt_tokens"], + "completion_tokens": usage["completion_tokens"], + "total_tokens": usage["total_tokens"], } message.stop_reason = finish_reason yield message else: - message = ErnieMessage(content=result, role='assistant') + message = ErnieMessage(content=result, role="assistant") yield message diff --git a/api/core/model_runtime/model_providers/wenxin/llm/llm.py b/api/core/model_runtime/model_providers/wenxin/llm/llm.py index 140606298c..1ff0ac7ad2 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/llm.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/llm.py @@ -30,42 +30,82 @@ if you are not sure about the structure. You should also complete the text started with ``` but not tell ``` directly. """ -class ErnieBotLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, - model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) - def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: +class ErnieBotLargeLanguageModel(LargeLanguageModel): + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: + return self._generate( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ) + + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: list[Callback] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper for invoking large language model """ - if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']: - response_format = model_parameters['response_format'] + if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: + response_format = model_parameters["response_format"] stop = stop or [] - self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, response_format) - model_parameters.pop('response_format') + self._transform_json_prompts( + model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, response_format + ) + model_parameters.pop("response_format") if stream: return self._code_block_mode_stream_processor( model=model, prompt_messages=prompt_messages, - input_generator=self._invoke(model=model, credentials=credentials, prompt_messages=prompt_messages, - model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) + input_generator=self._invoke( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ), ) - + return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def _transform_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ - -> None: + def _transform_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ Transform json prompts to model prompts """ @@ -74,34 +114,44 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): # override the system message prompt_messages[0] = SystemPromptMessage( - content=ERNIE_BOT_BLOCK_MODE_PROMPT - .replace("{{instructions}}", prompt_messages[0].content) - .replace("{{block}}", response_format) + content=ERNIE_BOT_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace( + "{{block}}", response_format + ) ) else: # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=ERNIE_BOT_BLOCK_MODE_PROMPT - .replace("{{instructions}}", f"Please output a valid {response_format} object.") - .replace("{{block}}", response_format) - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=ERNIE_BOT_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", f"Please output a valid {response_format} object." + ).replace("{{block}}", response_format) + ), + ) if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage): # add ```JSON\n to the last message prompt_messages[-1].content += "\n```JSON\n{\n" else: # append a user message - prompt_messages.append(UserPromptMessage( - content="```JSON\n{\n" - )) + prompt_messages.append(UserPromptMessage(content="```JSON\n{\n")) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: # tools is not supported yet return self._num_tokens_from_messages(prompt_messages) - def _num_tokens_from_messages(self, messages: list[PromptMessage],) -> int: + def _num_tokens_from_messages( + self, + messages: list[PromptMessage], + ) -> int: """Calculate num tokens for baichuan model""" + def tokens(text: str): return self._get_num_tokens_by_gpt2(text) @@ -113,10 +163,10 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -126,36 +176,53 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): return num_tokens def validate_credentials(self, model: str, credentials: dict) -> None: - api_key = credentials['api_key'] - secret_key = credentials['secret_key'] + api_key = credentials["api_key"] + secret_key = credentials["secret_key"] try: BaiduAccessToken.get_access_token(api_key, secret_key) except Exception as e: - raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: instance = ErnieBotModel( - api_key=credentials['api_key'], - secret_key=credentials['secret_key'], + api_key=credentials["api_key"], + secret_key=credentials["secret_key"], ) - user = user if user else 'ErnieBotDefault' + user = user if user else "ErnieBotDefault" # convert prompt messages to baichuan messages messages = [ ErnieMessage( - content=message.content if isinstance(message.content, str) else ''.join([ - content.data for content in message.content - ]), - role=message.role.value - ) for message in prompt_messages + content=message.content + if isinstance(message.content, str) + else "".join([content.data for content in message.content]), + role=message.role.value, + ) + for message in prompt_messages ] # invoke model - response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters, timeout=60, tools=tools, stop=stop, user=user) + response = instance.generate( + model=model, + stream=stream, + messages=messages, + parameters=model_parameters, + timeout=60, + tools=tools, + stop=stop, + user=user, + ) if stream: return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response) @@ -180,41 +247,47 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): message_dict = {"role": "system", "content": message.content} else: raise ValueError(f"Unknown message type {type(message)}") - + return message_dict - def _handle_chat_generate_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: ErnieMessage) -> LLMResult: + def _handle_chat_generate_response( + self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: ErnieMessage + ) -> LLMResult: # convert baichuan message to llm result - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=response.usage['prompt_tokens'], completion_tokens=response.usage['completion_tokens']) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=response.usage["prompt_tokens"], + completion_tokens=response.usage["completion_tokens"], + ) return LLMResult( model=model, prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=response.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=response.content, tool_calls=[]), usage=usage, ) - def _handle_chat_generate_stream_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Generator[ErnieMessage, None, None]) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Generator[ErnieMessage, None, None], + ) -> Generator: for message in response: if message.usage: - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=message.usage['prompt_tokens'], completion_tokens=message.usage['completion_tokens']) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=message.usage["prompt_tokens"], + completion_tokens=message.usage["completion_tokens"], + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=message.content, tool_calls=[]), usage=usage, finish_reason=message.stop_reason if message.stop_reason else None, ), @@ -225,10 +298,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=message.content, tool_calls=[]), finish_reason=message.stop_reason if message.stop_reason else None, ), ) diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py index 10ac1a1861..db323ae4c1 100644 --- a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py @@ -29,38 +29,38 @@ class TextEmbedding: class WenxinTextEmbedding(_CommonWenxin, TextEmbedding): def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): access_token = self._get_access_token() - url = f'{self.api_bases[model]}?access_token={access_token}' + url = f"{self.api_bases[model]}?access_token={access_token}" body = self._build_embed_request_body(model, texts, user) headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } resp = post(url, data=dumps(body), headers=headers) if resp.status_code != 200: - raise InternalServerError(f'Failed to invoke ernie bot: {resp.text}') + raise InternalServerError(f"Failed to invoke ernie bot: {resp.text}") return self._handle_embed_response(model, resp) def _build_embed_request_body(self, model: str, texts: list[str], user: str) -> dict[str, Any]: if len(texts) == 0: - raise BadRequestError('The number of texts should not be zero.') + raise BadRequestError("The number of texts should not be zero.") body = { - 'input': texts, - 'user_id': user, + "input": texts, + "user_id": user, } return body def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int): data = response.json() - if 'error_code' in data: - code = data['error_code'] - msg = data['error_msg'] + if "error_code" in data: + code = data["error_code"] + msg = data["error_msg"] # raise error self._handle_error(code, msg) - embeddings = [v['embedding'] for v in data['data']] - _usage = data['usage'] - tokens = _usage['prompt_tokens'] - total_tokens = _usage['total_tokens'] + embeddings = [v["embedding"] for v in data["data"]] + _usage = data["usage"] + tokens = _usage["prompt_tokens"] + total_tokens = _usage["total_tokens"] return embeddings, tokens, total_tokens @@ -69,22 +69,23 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel): def _create_text_embedding(self, api_key: str, secret_key: str) -> TextEmbedding: return WenxinTextEmbedding(api_key, secret_key) - def _invoke(self, model: str, credentials: dict, texts: list[str], - user: Optional[str] = None) -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ - Invoke text embedding model + Invoke text embedding model - :param model: model name - :param credentials: model credentials - :param texts: texts to embed - :param user: unique user id - :return: embeddings result - """ + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ - api_key = credentials['api_key'] - secret_key = credentials['secret_key'] + api_key = credentials["api_key"] + secret_key = credentials["secret_key"] embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key) - user = user if user else 'ErnieBotDefault' + user = user if user else "ErnieBotDefault" context_size = self._get_context_size(model, credentials) max_chunks = self._get_max_chunks(model, credentials) @@ -94,7 +95,6 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel): used_total_tokens = 0 for i, text in enumerate(texts): - # Here token count is only an approximation based on the GPT2 tokenizer num_tokens = self._get_num_tokens_by_gpt2(text) @@ -110,9 +110,8 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel): _iter = range(0, len(inputs), max_chunks) for i in _iter: embeddings_batch, _used_tokens, _total_used_tokens = embedding.embed_documents( - model, - inputs[i: i + max_chunks], - user) + model, inputs[i : i + max_chunks], user + ) used_tokens += _used_tokens used_total_tokens += _total_used_tokens batched_embeddings += embeddings_batch @@ -142,12 +141,12 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel): return total_num_tokens def validate_credentials(self, model: str, credentials: Mapping) -> None: - api_key = credentials['api_key'] - secret_key = credentials['secret_key'] + api_key = credentials["api_key"] + secret_key = credentials["secret_key"] try: BaiduAccessToken.get_access_token(api_key, secret_key) except Exception as e: - raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: @@ -164,10 +163,7 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -178,7 +174,7 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/wenxin/wenxin.py b/api/core/model_runtime/model_providers/wenxin/wenxin.py index 04845d06bc..895af20bc8 100644 --- a/api/core/model_runtime/model_providers/wenxin/wenxin.py +++ b/api/core/model_runtime/model_providers/wenxin/wenxin.py @@ -6,6 +6,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) + class WenxinProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -19,12 +20,9 @@ class WenxinProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `ernie-bot` model for validate, - model_instance.validate_credentials( - model='ernie-bot', - credentials=credentials - ) + model_instance.validate_credentials(model="ernie-bot", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py b/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py index 0fbd0f55ec..bd074e0477 100644 --- a/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py +++ b/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py @@ -18,40 +18,37 @@ def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]: :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass -class InsufficientAccountBalance(Exception): + +class InsufficientAccountBalanceError(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 4760e8f118..b2c837dee1 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -65,88 +65,108 @@ from core.model_runtime.utils import helper class XinferenceAILargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - invoke LLM + invoke LLM - see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` + see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` """ - if 'temperature' in model_parameters: - if model_parameters['temperature'] < 0.01: - model_parameters['temperature'] = 0.01 - elif model_parameters['temperature'] > 1.0: - model_parameters['temperature'] = 0.99 + if "temperature" in model_parameters: + if model_parameters["temperature"] < 0.01: + model_parameters["temperature"] = 0.01 + elif model_parameters["temperature"] > 1.0: + model_parameters["temperature"] = 0.99 return self._generate( - model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, - tools=tools, stop=stop, stream=stream, user=user, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'], - api_key=credentials.get('api_key'), - ) + server_url=credentials["server_url"], + model_uid=credentials["model_uid"], + api_key=credentials.get("api_key"), + ), ) def validate_credentials(self, model: str, credentials: dict) -> None: """ - validate credentials + validate credentials - credentials should be like: - { - 'model_type': 'text-generation', - 'server_url': 'server url', - 'model_uid': 'model uid', - } + credentials should be like: + { + 'model_type': 'text-generation', + 'server_url': 'server url', + 'model_uid': 'model uid', + } """ try: - if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") extra_param = XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'], - api_key=credentials.get('api_key') + server_url=credentials["server_url"], + model_uid=credentials["model_uid"], + api_key=credentials.get("api_key"), ) - if 'completion_type' not in credentials: - if 'chat' in extra_param.model_ability: - credentials['completion_type'] = 'chat' - elif 'generate' in extra_param.model_ability: - credentials['completion_type'] = 'completion' + if "completion_type" not in credentials: + if "chat" in extra_param.model_ability: + credentials["completion_type"] = "chat" + elif "generate" in extra_param.model_ability: + credentials["completion_type"] = "completion" else: raise ValueError( - f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type') + f"xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type" + ) if extra_param.support_function_call: - credentials['support_function_call'] = True + credentials["support_function_call"] = True if extra_param.support_vision: - credentials['support_vision'] = True + credentials["support_vision"] = True if extra_param.context_length: - credentials['context_length'] = extra_param.context_length + credentials["context_length"] = extra_param.context_length except RuntimeError as e: - raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') + raise CredentialsValidateFailedError(f"Xinference credentials validate failed: {e}") except KeyError as e: - raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') + raise CredentialsValidateFailedError(f"Xinference credentials validate failed: {e}") except Exception as e: raise e - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: """ - get number of tokens + get number of tokens - cause XinferenceAI LLM is a customized model, we could net detect which tokenizer to use - so we just take the GPT2 tokenizer as default + cause XinferenceAI LLM is a customized model, we could net detect which tokenizer to use + so we just take the GPT2 tokenizer as default """ return self._num_tokens_from_messages(prompt_messages, tools) - def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool], - is_completion_model: bool = False) -> int: + def _num_tokens_from_messages( + self, messages: list[PromptMessage], tools: list[PromptMessageTool], is_completion_model: bool = False + ) -> int: def tokens(text: str): return self._get_num_tokens_by_gpt2(text) @@ -162,10 +182,10 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -217,30 +237,30 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): num_tokens = 0 for tool in tools: # calculate num tokens for function object - num_tokens += tokens('name') + num_tokens += tokens("name") num_tokens += tokens(tool.name) - num_tokens += tokens('description') + num_tokens += tokens("description") num_tokens += tokens(tool.description) parameters = tool.parameters - num_tokens += tokens('parameters') - num_tokens += tokens('type') + num_tokens += tokens("parameters") + num_tokens += tokens("type") num_tokens += tokens(parameters.get("type")) - if 'properties' in parameters: - num_tokens += tokens('properties') - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += tokens("properties") + for key, value in parameters.get("properties").items(): num_tokens += tokens(key) for field_key, field_value in value.items(): num_tokens += tokens(field_key) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += tokens(enum_field) else: num_tokens += tokens(field_key) num_tokens += tokens(str(field_value)) - if 'required' in parameters: - num_tokens += tokens('required') - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += tokens("required") + for required_field in parameters["required"]: num_tokens += 3 num_tokens += tokens(required_field) @@ -248,9 +268,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str: """ - convert prompt message to text + convert prompt message to text """ - text = '' + text = "" for item in message: if isinstance(item, UserPromptMessage): text += item.content @@ -259,7 +279,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): elif isinstance(item, AssistantPromptMessage): text += item.content else: - raise NotImplementedError(f'PromptMessage type {type(item)} is not supported') + raise NotImplementedError(f"PromptMessage type {type(item)} is not supported") return text def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: @@ -275,19 +295,13 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) message_dict = {"role": "user", "content": sub_messages} @@ -297,7 +311,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): if message.tool_calls and len(message.tool_calls) > 0: message_dict["function_call"] = { "name": message.tool_calls[0].function.name, - "arguments": message.tool_calls[0].function.arguments + "arguments": message.tool_calls[0].function.arguments, } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) @@ -312,151 +326,144 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ), + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, - max=credentials.get('context_length', 2048), + max=credentials.get("context_length", 2048), default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), ), ParameterRule( name=DefaultParameterName.PRESENCE_PENALTY, use_template=DefaultParameterName.PRESENCE_PENALTY, type=ParameterType.FLOAT, label=I18nObject( - en_US='Presence Penalty', - zh_Hans='存在惩罚', + en_US="Presence Penalty", + zh_Hans="存在惩罚", ), required=False, help=I18nObject( - en_US='Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they ' - 'appear in the text so far, increasing the model\'s likelihood to talk about new topics.', - zh_Hans='介于 -2.0 和 2.0 之间的数字。正值会根据新词是否已出现在文本中对其进行惩罚,从而增加模型谈论新话题的可能性。' + en_US="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they " + "appear in the text so far, increasing the model's likelihood to talk about new topics.", + zh_Hans="介于 -2.0 和 2.0 之间的数字。正值会根据新词是否已出现在文本中对其进行惩罚,从而增加模型谈论新话题的可能性。", ), default=0.0, min=-2.0, max=2.0, - precision=2 + precision=2, ), ParameterRule( name=DefaultParameterName.FREQUENCY_PENALTY, use_template=DefaultParameterName.FREQUENCY_PENALTY, type=ParameterType.FLOAT, label=I18nObject( - en_US='Frequency Penalty', - zh_Hans='频率惩罚', + en_US="Frequency Penalty", + zh_Hans="频率惩罚", ), required=False, help=I18nObject( - en_US='Number between -2.0 and 2.0. Positive values penalize new tokens based on their ' - 'existing frequency in the text so far, decreasing the model\'s likelihood to repeat the ' - 'same line verbatim.', - zh_Hans='介于 -2.0 和 2.0 之间的数字。正值会根据新词在文本中的现有频率对其进行惩罚,从而降低模型逐字重复相同内容的可能性。' + en_US="Number between -2.0 and 2.0. Positive values penalize new tokens based on their " + "existing frequency in the text so far, decreasing the model's likelihood to repeat the " + "same line verbatim.", + zh_Hans="介于 -2.0 和 2.0 之间的数字。正值会根据新词在文本中的现有频率对其进行惩罚,从而降低模型逐字重复相同内容的可能性。", ), default=0.0, min=-2.0, max=2.0, - precision=2 - ) + precision=2, + ), ] completion_type = None - if 'completion_type' in credentials: - if credentials['completion_type'] == 'chat': + if "completion_type" in credentials: + if credentials["completion_type"] == "chat": completion_type = LLMMode.CHAT.value - elif credentials['completion_type'] == 'completion': + elif credentials["completion_type"] == "completion": completion_type = LLMMode.COMPLETION.value else: raise ValueError(f'completion_type {credentials["completion_type"]} is not supported') else: extra_args = XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'], - api_key=credentials.get('api_key') + server_url=credentials["server_url"], + model_uid=credentials["model_uid"], + api_key=credentials.get("api_key"), ) - if 'chat' in extra_args.model_ability: + if "chat" in extra_args.model_ability: completion_type = LLMMode.CHAT.value - elif 'generate' in extra_args.model_ability: + elif "generate" in extra_args.model_ability: completion_type = LLMMode.COMPLETION.value else: - raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported') + raise ValueError(f"xinference model ability {extra_args.model_ability} is not supported") features = [] - support_function_call = credentials.get('support_function_call', False) + support_function_call = credentials.get("support_function_call", False) if support_function_call: features.append(ModelFeature.TOOL_CALL) - support_vision = credentials.get('support_vision', False) + support_vision = credentials.get("support_vision", False) if support_vision: features.append(ModelFeature.VISION) - context_length = credentials.get('context_length', 2048) + context_length = credentials.get("context_length", 2048) entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, features=features, - model_properties={ - ModelPropertyKey.MODE: completion_type, - ModelPropertyKey.CONTEXT_SIZE: context_length - }, - parameter_rules=rules + model_properties={ModelPropertyKey.MODE: completion_type, ModelPropertyKey.CONTEXT_SIZE: context_length}, + parameter_rules=rules, ) return entity - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter, - tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + extra_model_kwargs: XinferenceModelExtraParameter, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - generate text from LLM + generate text from LLM - see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate` + see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate` - extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter` + extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter` """ - if 'server_url' not in credentials: - raise CredentialsValidateFailedError('server_url is required in credentials') + if "server_url" not in credentials: + raise CredentialsValidateFailedError("server_url is required in credentials") - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + if credentials["server_url"].endswith("/"): + credentials["server_url"] = credentials["server_url"][:-1] - api_key = credentials.get('api_key') or "abc" + api_key = credentials.get("api_key") or "abc" client = OpenAI( base_url=f'{credentials["server_url"]}/v1', @@ -466,34 +473,29 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ) xinference_client = Client( - base_url=credentials['server_url'], - api_key=credentials.get('api_key'), + base_url=credentials["server_url"], + api_key=credentials.get("api_key"), ) - xinference_model = xinference_client.get_model(credentials['model_uid']) + xinference_model = xinference_client.get_model(credentials["model_uid"]) generate_config = { - 'temperature': model_parameters.get('temperature', 1.0), - 'top_p': model_parameters.get('top_p', 0.7), - 'max_tokens': model_parameters.get('max_tokens', 512), - 'presence_penalty': model_parameters.get('presence_penalty', 0.0), - 'frequency_penalty': model_parameters.get('frequency_penalty', 0.0), + "temperature": model_parameters.get("temperature", 1.0), + "top_p": model_parameters.get("top_p", 0.7), + "max_tokens": model_parameters.get("max_tokens", 512), + "presence_penalty": model_parameters.get("presence_penalty", 0.0), + "frequency_penalty": model_parameters.get("frequency_penalty", 0.0), } if stop: - generate_config['stop'] = stop + generate_config["stop"] = stop if tools and len(tools) > 0: - generate_config['tools'] = [ - { - 'type': 'function', - 'function': helper.dump_model(tool) - } for tool in tools - ] - vision = credentials.get('support_vision', False) + generate_config["tools"] = [{"type": "function", "function": helper.dump_model(tool)} for tool in tools] + vision = credentials.get("support_vision", False) if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle): resp = client.chat.completions.create( - model=credentials['model_uid'], + model=credentials["model_uid"], messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages], stream=stream, user=user, @@ -501,34 +503,34 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ) if stream: if tools and len(tools) > 0: - raise InvokeBadRequestError('xinference tool calls does not support stream mode') - return self._handle_chat_stream_response(model=model, credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, resp=resp) - return self._handle_chat_generate_response(model=model, credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, resp=resp) + raise InvokeBadRequestError("xinference tool calls does not support stream mode") + return self._handle_chat_stream_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=resp + ) + return self._handle_chat_generate_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=resp + ) elif isinstance(xinference_model, RESTfulGenerateModelHandle): resp = client.completions.create( - model=credentials['model_uid'], + model=credentials["model_uid"], prompt=self._convert_prompt_message_to_text(prompt_messages), stream=stream, user=user, **generate_config, ) if stream: - return self._handle_completion_stream_response(model=model, credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, resp=resp) - return self._handle_completion_generate_response(model=model, credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, resp=resp) + return self._handle_completion_stream_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=resp + ) + return self._handle_completion_generate_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=resp + ) else: - raise NotImplementedError(f'xinference model handle type {type(xinference_model)} is not supported') + raise NotImplementedError(f"xinference model handle type {type(xinference_model)} is not supported") - def _extract_response_tool_calls(self, - response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls( + self, response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -539,21 +541,19 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function + id=response_tool_call.id, type=response_tool_call.type, function=function ) tool_calls.append(tool_call) return tool_calls - def _extract_response_function_call(self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \ - -> AssistantPromptMessage.ToolCall: + def _extract_response_function_call( + self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall + ) -> AssistantPromptMessage.ToolCall: """ Extract function call from response @@ -563,23 +563,25 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): tool_call = None if response_function_call: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function_call.name, - arguments=response_function_call.arguments + name=response_function_call.name, arguments=response_function_call.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call.name, - type="function", - function=function + id=response_function_call.name, type="function", function=function ) return tool_call - def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: ChatCompletion) -> LLMResult: + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: ChatCompletion, + ) -> LLMResult: """ - handle normal chat generate response + handle normal chat generate response """ if len(resp.choices) == 0: raise InvokeServerUnavailableError("Empty response") @@ -595,15 +597,15 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=assistant_prompt_message_tool_calls + content=assistant_message.content, tool_calls=assistant_prompt_message_tool_calls ) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) response = LLMResult( model=model, @@ -615,13 +617,18 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): return response - def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Iterator[ChatCompletionChunk]) -> Generator: + def _handle_chat_stream_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Iterator[ChatCompletionChunk], + ) -> Generator: """ - handle stream chat generate response + handle stream chat generate response """ - full_response = '' + full_response = "" for chunk in resp: if len(chunk.choices) == 0: @@ -629,7 +636,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ""): continue # check if there is a tool call in the response @@ -646,32 +653,31 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=assistant_message_tool_calls + content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls ) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=assistant_message_tool_calls + content=full_response, tool_calls=assistant_message_tool_calls ) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, system_fingerprint=chunk.system_fingerprint, delta=LLMResultChunkDelta( - index=0, - message=assistant_prompt_message, - finish_reason=delta.finish_reason, - usage=usage + index=0, message=assistant_prompt_message, finish_reason=delta.finish_reason, usage=usage ), ) else: @@ -687,11 +693,16 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): full_response += delta.delta.content - def _handle_completion_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Completion) -> LLMResult: + def _handle_completion_generate_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Completion, + ) -> LLMResult: """ - handle normal completion generate response + handle normal completion generate response """ if len(resp.choices) == 0: raise InvokeServerUnavailableError("Empty response") @@ -699,14 +710,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): assistant_message = resp.choices[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message, - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message, tool_calls=[]) - prompt_tokens = self._get_num_tokens_by_gpt2( - self._convert_prompt_message_to_text(prompt_messages) - ) + prompt_tokens = self._get_num_tokens_by_gpt2(self._convert_prompt_message_to_text(prompt_messages)) completion_tokens = self._num_tokens_from_messages( messages=[assistant_prompt_message], tools=[], is_completion_model=True ) @@ -724,13 +730,18 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): return response - def _handle_completion_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Iterator[Completion]) -> Generator: + def _handle_completion_stream_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Iterator[Completion], + ) -> Generator: """ - handle stream completion generate response + handle stream completion generate response """ - full_response = '' + full_response = "" for chunk in resp: if len(chunk.choices) == 0: @@ -739,40 +750,33 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): delta = chunk.choices[0] # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=delta.text if delta.text else '', - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=delta.text if delta.text else "", tool_calls=[]) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage - temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=[] - ) + temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[]) - prompt_tokens = self._get_num_tokens_by_gpt2( - self._convert_prompt_message_to_text(prompt_messages) - ) + prompt_tokens = self._get_num_tokens_by_gpt2(self._convert_prompt_message_to_text(prompt_messages)) completion_tokens = self._num_tokens_from_messages( messages=[temp_assistant_prompt_message], tools=[], is_completion_model=True ) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, system_fingerprint=chunk.system_fingerprint, delta=LLMResultChunkDelta( - index=0, - message=assistant_prompt_message, - finish_reason=delta.finish_reason, - usage=usage + index=0, message=assistant_prompt_message, finish_reason=delta.finish_reason, usage=usage ), ) else: - if delta.text is None or delta.text == '': + if delta.text is None or delta.text == "": continue yield LLMResultChunk( @@ -807,15 +811,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ConflictError, NotFoundError, UnprocessableEntityError, - PermissionDeniedError + PermissionDeniedError, ], - InvokeRateLimitError: [ - RateLimitError - ], - InvokeAuthorizationError: [ - AuthenticationError - ], - InvokeBadRequestError: [ - ValueError - ] + InvokeRateLimitError: [RateLimitError], + InvokeAuthorizationError: [AuthenticationError], + InvokeBadRequestError: [ValueError], } diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index d809537479..1582fe43b9 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -22,10 +22,16 @@ class XinferenceRerankModel(RerankModel): Model class for Xinference rerank model. """ - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -39,24 +45,16 @@ class XinferenceRerankModel(RerankModel): :return: rerank result """ if len(docs) == 0: - return RerankResult( - model=model, - docs=[] - ) + return RerankResult(model=model, docs=[]) - server_url = credentials['server_url'] - model_uid = credentials['model_uid'] - api_key = credentials.get('api_key') - if server_url.endswith('/'): + server_url = credentials["server_url"] + model_uid = credentials["model_uid"] + api_key = credentials.get("api_key") + if server_url.endswith("/"): server_url = server_url[:-1] - auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} - params = { - 'documents': docs, - 'query': query, - 'top_n': top_n, - 'return_documents': True - } + params = {"documents": docs, "query": query, "top_n": top_n, "return_documents": True} try: handle = RESTfulRerankModelHandle(model_uid, server_url, auth_headers) response = handle.rerank(**params) @@ -69,27 +67,24 @@ class XinferenceRerankModel(RerankModel): response = handle.rerank(**params) rerank_documents = [] - for idx, result in enumerate(response['results']): + for idx, result in enumerate(response["results"]): # format document - index = result['index'] - page_content = result['document'] if isinstance(result['document'], str) else result['document']['text'] + index = result["index"] + page_content = result["document"] if isinstance(result["document"], str) else result["document"]["text"] rerank_document = RerankDocument( index=index, text=page_content, - score=result['relevance_score'], + score=result["relevance_score"], ) # score threshold check if score_threshold is not None: - if result['relevance_score'] >= score_threshold: + if result["relevance_score"] >= score_threshold: rerank_documents.append(rerank_document) else: rerank_documents.append(rerank_document) - return RerankResult( - model=model, - docs=rerank_documents - ) + return RerankResult(model=model, docs=rerank_documents) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -100,34 +95,35 @@ class XinferenceRerankModel(RerankModel): :return: """ try: - if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + if credentials["server_url"].endswith("/"): + credentials["server_url"] = credentials["server_url"][:-1] # initialize client client = Client( - base_url=credentials['server_url'], - api_key=credentials.get('api_key'), + base_url=credentials["server_url"], + api_key=credentials.get("api_key"), ) - xinference_client = client.get_model(model_uid=credentials['model_uid']) + xinference_client = client.get_model(model_uid=credentials["model_uid"]) if not isinstance(xinference_client, RESTfulRerankModelHandle): raise InvokeBadRequestError( - 'please check model type, the model you want to invoke is not a rerank model') + "please check model type, the model you want to invoke is not a rerank model" + ) self.invoke( model=model, credentials=credentials, query="Whose kasumi", docs=[ - "Kasumi is a girl's name of Japanese origin meaning \"mist\".", + 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", - "and she leads a team named PopiParty." + "and she leads a team named PopiParty.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -143,53 +139,38 @@ class XinferenceRerankModel(RerankModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.RERANK, model_properties={}, - parameter_rules=[] + parameter_rules=[], ) return entity class RESTfulRerankModelHandleWithoutExtraParameter(RESTfulRerankModelHandle): - def rerank( - self, - documents: list[str], - query: str, - top_n: Optional[int] = None, - max_chunks_per_doc: Optional[int] = None, - return_documents: Optional[bool] = None, - **kwargs + self, + documents: list[str], + query: str, + top_n: Optional[int] = None, + max_chunks_per_doc: Optional[int] = None, + return_documents: Optional[bool] = None, + **kwargs, ): url = f"{self._base_url}/v1/rerank" request_body = { @@ -205,8 +186,6 @@ class RESTfulRerankModelHandleWithoutExtraParameter(RESTfulRerankModelHandle): response = requests.post(url, json=request_body, headers=self.auth_headers) if response.status_code != 200: - raise InvokeServerUnavailableError( - f"Failed to rerank documents, detail: {response.json()['detail']}" - ) + raise InvokeServerUnavailableError(f"Failed to rerank documents, detail: {response.json()['detail']}") response_data = response.json() return response_data diff --git a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py index 62b77f22e5..54c8b51654 100644 --- a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py @@ -21,9 +21,7 @@ class XinferenceSpeech2TextModel(Speech2TextModel): Model class for Xinference speech to text model. """ - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model @@ -44,27 +42,28 @@ class XinferenceSpeech2TextModel(Speech2TextModel): :return: """ try: - if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + if credentials["server_url"].endswith("/"): + credentials["server_url"] = credentials["server_url"][:-1] # initialize client client = Client( - base_url=credentials['server_url'], - api_key=credentials.get('api_key'), + base_url=credentials["server_url"], + api_key=credentials.get("api_key"), ) - xinference_client = client.get_model(model_uid=credentials['model_uid']) + xinference_client = client.get_model(model_uid=credentials["model_uid"]) if not isinstance(xinference_client, RESTfulAudioModelHandle): raise InvokeBadRequestError( - 'please check model type, the model you want to invoke is not a audio model') + "please check model type, the model you want to invoke is not a audio model" + ) audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self.invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -80,23 +79,11 @@ class XinferenceSpeech2TextModel(Speech2TextModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def _speech2text_invoke( @@ -122,21 +109,17 @@ class XinferenceSpeech2TextModel(Speech2TextModel): :param temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output mor e random,while lower values like 0.2 will make it more focused and deterministic.If set to 0, the model wi ll use log probability to automatically increase the temperature until certain thresholds are hit. :return: text for given audio file """ - server_url = credentials['server_url'] - model_uid = credentials['model_uid'] - api_key = credentials.get('api_key') - if server_url.endswith('/'): + server_url = credentials["server_url"] + model_uid = credentials["model_uid"] + api_key = credentials.get("api_key") + if server_url.endswith("/"): server_url = server_url[:-1] - auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} try: handle = RESTfulAudioModelHandle(model_uid, server_url, auth_headers) response = handle.transcriptions( - audio=file, - language=language, - prompt=prompt, - response_format=response_format, - temperature=temperature + audio=file, language=language, prompt=prompt, response_format=response_format, temperature=temperature ) except RuntimeError as e: raise InvokeServerUnavailableError(str(e)) @@ -145,17 +128,15 @@ class XinferenceSpeech2TextModel(Speech2TextModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.SPEECH2TEXT, - model_properties={ }, - parameter_rules=[] + model_properties={}, + parameter_rules=[], ) return entity diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index 3a8d704c25..ac704e7de8 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -23,9 +23,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): """ Model class for Xinference text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -41,12 +42,12 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - server_url = credentials['server_url'] - model_uid = credentials['model_uid'] - api_key = credentials.get('api_key') - if server_url.endswith('/'): + server_url = credentials["server_url"] + model_uid = credentials["model_uid"] + api_key = credentials.get("api_key") + if server_url.endswith("/"): server_url = server_url[:-1] - auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} try: handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers) @@ -70,13 +71,11 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): embedding: List[float] """ - usage = embeddings['usage'] - usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens']) + usage = embeddings["usage"] + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"]) result = TextEmbeddingResult( - model=model, - embeddings=[embedding['embedding'] for embedding in embeddings['data']], - usage=usage + model=model, embeddings=[embedding["embedding"] for embedding in embeddings["data"]], usage=usage ) return result @@ -105,12 +104,12 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - server_url = credentials['server_url'] - model_uid = credentials['model_uid'] - api_key = credentials.get('api_key') + server_url = credentials["server_url"] + model_uid = credentials["model_uid"] + api_key = credentials.get("api_key") extra_args = XinferenceHelper.get_xinference_extra_parameter( server_url=server_url, model_uid=model_uid, @@ -118,8 +117,8 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): ) if extra_args.max_tokens: - credentials['max_tokens'] = extra_args.max_tokens - if server_url.endswith('/'): + credentials["max_tokens"] = extra_args.max_tokens + if server_url.endswith("/"): server_url = server_url[:-1] client = Client( @@ -133,32 +132,24 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): raise InvokeAuthorizationError(e) if not isinstance(handle, RESTfulEmbeddingModelHandle): - raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model') + raise InvokeBadRequestError( + "please check model type, the model you want to invoke is not a text embedding model" + ) - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeAuthorizationError as e: - raise CredentialsValidateFailedError(f'Failed to validate credentials for model {model}: {e}') + raise CredentialsValidateFailedError(f"Failed to validate credentials for model {model}: {e}") except RuntimeError as e: raise CredentialsValidateFailedError(e) @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: @@ -172,10 +163,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -186,28 +174,26 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ ModelPropertyKey.MAX_CHUNKS: 1, - ModelPropertyKey.CONTEXT_SIZE: 'max_tokens' in credentials and credentials['max_tokens'] or 512, + ModelPropertyKey.CONTEXT_SIZE: "max_tokens" in credentials and credentials["max_tokens"] or 512, }, - parameter_rules=[] + parameter_rules=[], ) return entity diff --git a/api/core/model_runtime/model_providers/xinference/tts/tts.py b/api/core/model_runtime/model_providers/xinference/tts/tts.py index bfa752df8c..60db151302 100644 --- a/api/core/model_runtime/model_providers/xinference/tts/tts.py +++ b/api/core/model_runtime/model_providers/xinference/tts/tts.py @@ -19,92 +19,91 @@ from core.model_runtime.model_providers.xinference.xinference_helper import Xinf class XinferenceText2SpeechModel(TTSModel): - def __init__(self): # preset voices, need support custom voice self.model_voices = { - '__default': { - 'all': [ - {'name': 'Default', 'value': 'default'}, + "__default": { + "all": [ + {"name": "Default", "value": "default"}, ] }, - 'ChatTTS': { - 'all': [ - {'name': 'Alloy', 'value': 'alloy'}, - {'name': 'Echo', 'value': 'echo'}, - {'name': 'Fable', 'value': 'fable'}, - {'name': 'Onyx', 'value': 'onyx'}, - {'name': 'Nova', 'value': 'nova'}, - {'name': 'Shimmer', 'value': 'shimmer'}, + "ChatTTS": { + "all": [ + {"name": "Alloy", "value": "alloy"}, + {"name": "Echo", "value": "echo"}, + {"name": "Fable", "value": "fable"}, + {"name": "Onyx", "value": "onyx"}, + {"name": "Nova", "value": "nova"}, + {"name": "Shimmer", "value": "shimmer"}, ] }, - 'CosyVoice': { - 'zh-Hans': [ - {'name': '中文男', 'value': '中文男'}, - {'name': '中文女', 'value': '中文女'}, - {'name': '粤语女', 'value': '粤语女'}, + "CosyVoice": { + "zh-Hans": [ + {"name": "中文男", "value": "中文男"}, + {"name": "中文女", "value": "中文女"}, + {"name": "粤语女", "value": "粤语女"}, ], - 'zh-Hant': [ - {'name': '中文男', 'value': '中文男'}, - {'name': '中文女', 'value': '中文女'}, - {'name': '粤语女', 'value': '粤语女'}, + "zh-Hant": [ + {"name": "中文男", "value": "中文男"}, + {"name": "中文女", "value": "中文女"}, + {"name": "粤语女", "value": "粤语女"}, ], - 'en-US': [ - {'name': '英文男', 'value': '英文男'}, - {'name': '英文女', 'value': '英文女'}, + "en-US": [ + {"name": "英文男", "value": "英文男"}, + {"name": "英文女", "value": "英文女"}, ], - 'ja-JP': [ - {'name': '日语男', 'value': '日语男'}, + "ja-JP": [ + {"name": "日语男", "value": "日语男"}, ], - 'ko-KR': [ - {'name': '韩语女', 'value': '韩语女'}, - ] - } + "ko-KR": [ + {"name": "韩语女", "value": "韩语女"}, + ], + }, } def validate_credentials(self, model: str, credentials: dict) -> None: """ - Validate model credentials + Validate model credentials - :param model: model name - :param credentials: model credentials - :return: - """ + :param model: model name + :param credentials: model credentials + :return: + """ try: - if ("/" in credentials['model_uid'] or - "?" in credentials['model_uid'] or - "#" in credentials['model_uid']): + if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + if credentials["server_url"].endswith("/"): + credentials["server_url"] = credentials["server_url"][:-1] extra_param = XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'], - api_key=credentials.get('api_key'), + server_url=credentials["server_url"], + model_uid=credentials["model_uid"], + api_key=credentials.get("api_key"), ) - if 'text-to-audio' not in extra_param.model_ability: + if "text-to-audio" not in extra_param.model_ability: raise InvokeBadRequestError( - 'please check model type, the model you want to invoke is not a text-to-audio model') + "please check model type, the model you want to invoke is not a text-to-audio model" + ) if extra_param.model_family and extra_param.model_family in self.model_voices: - credentials['audio_model_name'] = extra_param.model_family + credentials["audio_model_name"] = extra_param.model_family else: - credentials['audio_model_name'] = '__default' + credentials["audio_model_name"] = "__default" self._tts_invoke_streaming( model=model, credentials=credentials, - content_text='Hello Dify!', + content_text="Hello Dify!", voice=self._get_model_default_voice(model, credentials), ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, - user: Optional[str] = None): + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ): """ _invoke text2speech model @@ -120,18 +119,16 @@ class XinferenceText2SpeechModel(TTSModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TTS, model_properties={}, - parameter_rules=[] + parameter_rules=[], ) return entity @@ -147,35 +144,28 @@ class XinferenceText2SpeechModel(TTSModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: - audio_model_name = credentials.get('audio_model_name', '__default') + audio_model_name = credentials.get("audio_model_name", "__default") for key, voices in self.model_voices.items(): if key in audio_model_name: if language and language in voices: return voices[language] - elif 'all' in voices: - return voices['all'] + elif "all" in voices: + return voices["all"] + else: + all_voices = [] + for lang, lang_voices in voices.items(): + all_voices.extend(lang_voices) + return all_voices - return self.model_voices['__default']['all'] + return self.model_voices["__default"]["all"] def _get_model_default_voice(self, model: str, credentials: dict) -> any: return "" @@ -189,8 +179,7 @@ class XinferenceText2SpeechModel(TTSModel): def _get_model_workers_limit(self, model: str, credentials: dict) -> int: return 5 - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, - voice: str) -> any: + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: """ _tts_invoke_streaming text2speech model @@ -200,48 +189,42 @@ class XinferenceText2SpeechModel(TTSModel): :param voice: model timbre :return: text translated to audio file """ - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + if credentials["server_url"].endswith("/"): + credentials["server_url"] = credentials["server_url"][:-1] try: - api_key = credentials.get('api_key') - auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + api_key = credentials.get("api_key") + auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} handle = RESTfulAudioModelHandle( - credentials['model_uid'], credentials['server_url'], auth_headers=auth_headers + credentials["model_uid"], credentials["server_url"], auth_headers=auth_headers ) - model_support_voice = [x.get("value") for x in - self.get_tts_model_voices(model=model, credentials=credentials)] + model_support_voice = [ + x.get("value") for x in self.get_tts_model_voices(model=model, credentials=credentials) + ] if not voice or voice not in model_support_voice: voice = self._get_model_default_voice(model, credentials) word_limit = self._get_model_word_limit(model, credentials) if len(content_text) > word_limit: sentences = self._split_text_into_sentences(content_text, max_length=word_limit) executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences))) - futures = [executor.submit( - handle.speech, - input=sentences[i], - voice=voice, - response_format="mp3", - speed=1.0, - stream=False - ) - for i in range(len(sentences))] + futures = [ + executor.submit( + handle.speech, input=sentences[i], voice=voice, response_format="mp3", speed=1.0, stream=False + ) + for i in range(len(sentences)) + ] for index, future in enumerate(futures): response = future.result() for i in range(0, len(response), 1024): - yield response[i:i + 1024] + yield response[i : i + 1024] else: response = handle.speech( - input=content_text.strip(), - voice=voice, - response_format="mp3", - speed=1.0, - stream=False + input=content_text.strip(), voice=voice, response_format="mp3", speed=1.0, stream=False ) for i in range(0, len(response), 1024): - yield response[i:i + 1024] + yield response[i : i + 1024] except Exception as ex: raise InvokeBadRequestError(str(ex)) diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index 75161ad376..6ad10e690d 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -18,9 +18,17 @@ class XinferenceModelExtraParameter: support_vision: bool = False model_family: Optional[str] - def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str], - support_function_call: bool, support_vision: bool, max_tokens: int, context_length: int, - model_family: Optional[str]) -> None: + def __init__( + self, + model_format: str, + model_handle_type: str, + model_ability: list[str], + support_function_call: bool, + support_vision: bool, + max_tokens: int, + context_length: int, + model_family: Optional[str], + ) -> None: self.model_format = model_format self.model_handle_type = model_handle_type self.model_ability = model_ability @@ -30,9 +38,11 @@ class XinferenceModelExtraParameter: self.context_length = context_length self.model_family = model_family + cache = {} cache_lock = Lock() + class XinferenceHelper: @staticmethod def get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: @@ -40,16 +50,16 @@ class XinferenceHelper: with cache_lock: if model_uid not in cache: cache[model_uid] = { - 'expires': time() + 300, - 'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid, api_key) + "expires": time() + 300, + "value": XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid, api_key), } - return cache[model_uid]['value'] + return cache[model_uid]["value"] @staticmethod def _clean_cache() -> None: try: with cache_lock: - expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()] + expired_keys = [model_uid for model_uid, model in cache.items() if model["expires"] < time()] for model_uid in expired_keys: del cache[model_uid] except RuntimeError as e: @@ -58,55 +68,57 @@ class XinferenceHelper: @staticmethod def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: """ - get xinference model extra parameter like model_format and model_handle_type + get xinference model extra parameter like model_format and model_handle_type """ if not model_uid or not model_uid.strip() or not server_url or not server_url.strip(): - raise RuntimeError('model_uid is empty') + raise RuntimeError("model_uid is empty") - url = str(URL(server_url) / 'v1' / 'models' / model_uid) + url = str(URL(server_url) / "v1" / "models" / model_uid) # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 session = Session() - session.mount('http://', HTTPAdapter(max_retries=3)) - session.mount('https://', HTTPAdapter(max_retries=3)) - headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + session.mount("http://", HTTPAdapter(max_retries=3)) + session.mount("https://", HTTPAdapter(max_retries=3)) + headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} try: response = session.get(url, headers=headers, timeout=10) except (MissingSchema, ConnectionError, Timeout) as e: - raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}') + raise RuntimeError(f"get xinference model extra parameter failed, url: {url}, error: {e}") if response.status_code != 200: - raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}') + raise RuntimeError( + f"get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}" + ) response_json = response.json() - model_format = response_json.get('model_format', 'ggmlv3') - model_ability = response_json.get('model_ability', []) - model_family = response_json.get('model_family', None) + model_format = response_json.get("model_format", "ggmlv3") + model_ability = response_json.get("model_ability", []) + model_family = response_json.get("model_family", None) - if response_json.get('model_type') == 'embedding': - model_handle_type = 'embedding' - elif response_json.get('model_type') == 'audio': - model_handle_type = 'audio' - if model_family and model_family in ['ChatTTS', 'CosyVoice']: - model_ability.append('text-to-audio') + if response_json.get("model_type") == "embedding": + model_handle_type = "embedding" + elif response_json.get("model_type") == "audio": + model_handle_type = "audio" + if model_family and model_family in ["ChatTTS", "CosyVoice", "FishAudio"]: + model_ability.append("text-to-audio") else: - model_ability.append('audio-to-text') - elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']: - model_handle_type = 'chatglm' - elif 'generate' in model_ability: - model_handle_type = 'generate' - elif 'chat' in model_ability: - model_handle_type = 'chat' + model_ability.append("audio-to-text") + elif model_format == "ggmlv3" and "chatglm" in response_json["model_name"]: + model_handle_type = "chatglm" + elif "generate" in model_ability: + model_handle_type = "generate" + elif "chat" in model_ability: + model_handle_type = "chat" else: - raise NotImplementedError('xinference model handle type is not supported') + raise NotImplementedError("xinference model handle type is not supported") - support_function_call = 'tools' in model_ability - support_vision = 'vision' in model_ability - max_tokens = response_json.get('max_tokens', 512) + support_function_call = "tools" in model_ability + support_vision = "vision" in model_ability + max_tokens = response_json.get("max_tokens", 512) - context_length = response_json.get('context_length', 2048) + context_length = response_json.get("context_length", 2048) return XinferenceModelExtraParameter( model_format=model_format, @@ -116,5 +128,5 @@ class XinferenceHelper: support_vision=support_vision, max_tokens=max_tokens, context_length=context_length, - model_family=model_family + model_family=model_family, ) diff --git a/api/core/model_runtime/model_providers/yi/llm/llm.py b/api/core/model_runtime/model_providers/yi/llm/llm.py index d33f38333b..5ab7fd126e 100644 --- a/api/core/model_runtime/model_providers/yi/llm/llm.py +++ b/api/core/model_runtime/model_providers/yi/llm/llm.py @@ -14,11 +14,17 @@ from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguag class YiLargeLanguageModel(OpenAILargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) # yi-vl-plus not support system prompt yet. @@ -27,7 +33,9 @@ class YiLargeLanguageModel(OpenAILargeLanguageModel): for message in prompt_messages: if not isinstance(message, SystemPromptMessage): prompt_message_except_system.append(message) - return super()._invoke(model, credentials, prompt_message_except_system, model_parameters, tools, stop, stream) + return super()._invoke( + model, credentials, prompt_message_except_system, model_parameters, tools, stop, stream + ) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -36,8 +44,7 @@ class YiLargeLanguageModel(OpenAILargeLanguageModel): super().validate_credentials(model, credentials) # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_string(self, model: str, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -55,8 +62,9 @@ class YiLargeLanguageModel(OpenAILargeLanguageModel): return num_tokens # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/ @@ -76,10 +84,10 @@ class YiLargeLanguageModel(OpenAILargeLanguageModel): # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -110,10 +118,10 @@ class YiLargeLanguageModel(OpenAILargeLanguageModel): @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['openai_api_key']=credentials['api_key'] - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - credentials['openai_api_base']='https://api.lingyiwanwu.com' + credentials["mode"] = "chat" + credentials["openai_api_key"] = credentials["api_key"] + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + credentials["openai_api_base"] = "https://api.lingyiwanwu.com" else: - parsed_url = urlparse(credentials['endpoint_url']) - credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}" + parsed_url = urlparse(credentials["endpoint_url"]) + credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}" diff --git a/api/core/model_runtime/model_providers/yi/yi.py b/api/core/model_runtime/model_providers/yi/yi.py index 691c7aa371..9599acb22b 100644 --- a/api/core/model_runtime/model_providers/yi/yi.py +++ b/api/core/model_runtime/model_providers/yi/yi.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class YiProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -21,12 +20,9 @@ class YiProvider(ModelProvider): # Use `yi-34b-chat-0205` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='yi-34b-chat-0205', - credentials=credentials - ) + model_instance.validate_credentials(model="yi-34b-chat-0205", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/zhinao/llm/llm.py b/api/core/model_runtime/model_providers/zhinao/llm/llm.py index 6930a5ed01..befc3de021 100644 --- a/api/core/model_runtime/model_providers/zhinao/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhinao/llm/llm.py @@ -7,11 +7,17 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class ZhinaoLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -21,5 +27,5 @@ class ZhinaoLargeLanguageModel(OAIAPICompatLargeLanguageModel): @classmethod def _add_custom_parameters(cls, credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.360.cn/v1' + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.360.cn/v1" diff --git a/api/core/model_runtime/model_providers/zhinao/zhinao.py b/api/core/model_runtime/model_providers/zhinao/zhinao.py index 44b36c9f51..2a263292f9 100644 --- a/api/core/model_runtime/model_providers/zhinao/zhinao.py +++ b/api/core/model_runtime/model_providers/zhinao/zhinao.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class ZhinaoProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -21,12 +20,9 @@ class ZhinaoProvider(ModelProvider): # Use `360gpt-turbo` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='360gpt-turbo', - credentials=credentials - ) + model_instance.validate_credentials(model="360gpt-turbo", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/zhipuai/_common.py b/api/core/model_runtime/model_providers/zhipuai/_common.py index 3412d8100f..fa95232f71 100644 --- a/api/core/model_runtime/model_providers/zhipuai/_common.py +++ b/api/core/model_runtime/model_providers/zhipuai/_common.py @@ -17,8 +17,7 @@ class _CommonZhipuaiAI: :return: """ credentials_kwargs = { - "api_key": credentials['api_key'] if 'api_key' in credentials else - credentials.get("zhipuai_api_key"), + "api_key": credentials["api_key"] if "api_key" in credentials else credentials.get("zhipuai_api_key"), } return credentials_kwargs @@ -38,5 +37,5 @@ class _CommonZhipuaiAI: InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], - InvokeBadRequestError: [] + InvokeBadRequestError: [], } diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/chatglm_turbo.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/chatglm_turbo.yaml index 8f51f80967..fcd5c5ef64 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/chatglm_turbo.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/chatglm_turbo.yaml @@ -19,15 +19,24 @@ parameter_rules: help: zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. - - name: incremental + - name: do_sample label: - zh_Hans: 增量返回 - en_US: Incremental + zh_Hans: 采样策略 + en_US: Sampling strategy type: boolean help: - zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。 - en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return. - required: false + zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。 + en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true. + default: true + - name: stream + label: + zh_Hans: 流处理 + en_US: Event Stream + type: boolean + help: + zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。 + en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts. + default: false - name: return_type label: zh_Hans: 回复类型 diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-0520.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-0520.yaml index 8391278e4f..b1f9b7485c 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-0520.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-0520.yaml @@ -23,20 +23,29 @@ parameter_rules: help: zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. - - name: incremental + - name: do_sample label: - zh_Hans: 增量返回 - en_US: Incremental + zh_Hans: 采样策略 + en_US: Sampling strategy type: boolean help: - zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。 - en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return. - required: false + zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。 + en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true. + default: true + - name: stream + label: + zh_Hans: 流处理 + en_US: Event Stream + type: boolean + help: + zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。 + en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts. + default: false - name: max_tokens use_template: max_tokens default: 1024 min: 1 - max: 8192 + max: 4095 pricing: input: '0.1' output: '0.1' diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-air.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-air.yaml index 7caebd3e4b..4e7d5fd3cc 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-air.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-air.yaml @@ -23,20 +23,29 @@ parameter_rules: help: zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. - - name: incremental + - name: do_sample label: - zh_Hans: 增量返回 - en_US: Incremental + zh_Hans: 采样策略 + en_US: Sampling strategy type: boolean help: - zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。 - en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return. - required: false + zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。 + en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true. + default: true + - name: stream + label: + zh_Hans: 流处理 + en_US: Event Stream + type: boolean + help: + zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。 + en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts. + default: false - name: max_tokens use_template: max_tokens default: 1024 min: 1 - max: 8192 + max: 4095 pricing: input: '0.001' output: '0.001' diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-airx.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-airx.yaml index dc123913de..14f17db5d6 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-airx.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-airx.yaml @@ -23,20 +23,29 @@ parameter_rules: help: zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. - - name: incremental + - name: do_sample label: - zh_Hans: 增量返回 - en_US: Incremental + zh_Hans: 采样策略 + en_US: Sampling strategy type: boolean help: - zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。 - en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return. - required: false + zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。 + en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true. + default: true + - name: stream + label: + zh_Hans: 流处理 + en_US: Event Stream + type: boolean + help: + zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。 + en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts. + default: false - name: max_tokens use_template: max_tokens default: 1024 min: 1 - max: 8192 + max: 4095 pricing: input: '0.01' output: '0.01' diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flash.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flash.yaml index 1b1d499ba7..3361474d73 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flash.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flash.yaml @@ -23,22 +23,31 @@ parameter_rules: help: zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. - - name: incremental + - name: do_sample label: - zh_Hans: 增量返回 - en_US: Incremental + zh_Hans: 采样策略 + en_US: Sampling strategy type: boolean help: - zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。 - en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return. - required: false + zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。 + en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true. + default: true + - name: stream + label: + zh_Hans: 流处理 + en_US: Event Stream + type: boolean + help: + zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。 + en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts. + default: false - name: max_tokens use_template: max_tokens default: 1024 min: 1 - max: 8192 + max: 4095 pricing: - input: '0.0001' - output: '0.0001' + input: '0' + output: '0' unit: '0.001' currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml index 5bdb442840..bf0135d198 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml @@ -23,17 +23,31 @@ parameter_rules: help: zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. - - name: incremental + - name: do_sample label: - zh_Hans: 增量返回 - en_US: Incremental + zh_Hans: 采样策略 + en_US: Sampling strategy type: boolean help: - zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。 - en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return. - required: false + zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。 + en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true. + default: true + - name: stream + label: + zh_Hans: 流处理 + en_US: Event Stream + type: boolean + help: + zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。 + en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts. + default: false - name: max_tokens use_template: max_tokens default: 1024 min: 1 max: 8192 +pricing: + input: '0.001' + output: '0.001' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml index 6b5bcc5bcf..ab4b32dd82 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml @@ -23,17 +23,31 @@ parameter_rules: help: zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. - - name: incremental + - name: do_sample label: - zh_Hans: 增量返回 - en_US: Incremental + zh_Hans: 采样策略 + en_US: Sampling strategy type: boolean help: - zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。 - en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return. - required: false + zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。 + en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true. + default: true + - name: stream + label: + zh_Hans: 流处理 + en_US: Event Stream + type: boolean + help: + zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。 + en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts. + default: false - name: max_tokens use_template: max_tokens default: 1024 min: 1 - max: 8192 + max: 4095 +pricing: + input: '0.1' + output: '0.1' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_long.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_long.yaml index 9d92e58f6c..d1b01731f5 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_long.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_long.yaml @@ -26,8 +26,31 @@ parameter_rules: help: zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. + - name: do_sample + label: + zh_Hans: 采样策略 + en_US: Sampling strategy + type: boolean + help: + zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。 + en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true. + default: true + - name: stream + label: + zh_Hans: 流处理 + en_US: Event Stream + type: boolean + help: + zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。 + en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts. + default: false - name: max_tokens use_template: max_tokens default: 1024 min: 1 - max: 4096 + max: 4095 +pricing: + input: '0.001' + output: '0.001' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_plus.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_plus.yaml new file mode 100644 index 0000000000..9ede308f18 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_plus.yaml @@ -0,0 +1,53 @@ +model: glm-4-plus +label: + en_US: glm-4-plus +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat +parameter_rules: + - name: temperature + use_template: temperature + default: 0.95 + min: 0.0 + max: 1.0 + help: + zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 + en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. + - name: top_p + use_template: top_p + default: 0.7 + help: + zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 + en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. + - name: do_sample + label: + zh_Hans: 采样策略 + en_US: Sampling strategy + type: boolean + help: + zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。 + en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true. + default: true + - name: stream + label: + zh_Hans: 流处理 + en_US: Event Stream + type: boolean + help: + zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。 + en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts. + default: false + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 4095 +pricing: + input: '0.05' + output: '0.05' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml index ddea331c8e..28286580a7 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml @@ -17,21 +17,35 @@ parameter_rules: en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. - name: top_p use_template: top_p - default: 0.7 + default: 0.6 help: zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. - - name: incremental + - name: do_sample label: - zh_Hans: 增量返回 - en_US: Incremental + zh_Hans: 采样策略 + en_US: Sampling strategy type: boolean help: - zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。 - en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return. - required: false + zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。 + en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true. + default: true + - name: stream + label: + zh_Hans: 流处理 + en_US: Event Stream + type: boolean + help: + zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。 + en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts. + default: false - name: max_tokens use_template: max_tokens default: 1024 min: 1 - max: 8192 + max: 1024 +pricing: + input: '0.05' + output: '0.05' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_plus.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_plus.yaml new file mode 100644 index 0000000000..4c5fa24034 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_plus.yaml @@ -0,0 +1,51 @@ +model: glm-4v-plus +label: + en_US: glm-4v-plus +model_type: llm +model_properties: + mode: chat +features: + - vision +parameter_rules: + - name: temperature + use_template: temperature + default: 0.95 + min: 0.0 + max: 1.0 + help: + zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 + en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. + - name: top_p + use_template: top_p + default: 0.6 + help: + zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 + en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. + - name: do_sample + label: + zh_Hans: 采样策略 + en_US: Sampling strategy + type: boolean + help: + zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。 + en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true. + default: true + - name: stream + label: + zh_Hans: 流处理 + en_US: Event Stream + type: boolean + help: + zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。 + en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts. + default: false + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 1024 +pricing: + input: '0.01' + output: '0.01' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index 13d8f5e5c3..484ac088db 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -35,12 +35,17 @@ And you should always end the block with a "```" to indicate the end of the JSON class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -62,9 +67,9 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): # self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user) - # def _transform_json_prompts(self, model: str, credentials: dict, - # prompt_messages: list[PromptMessage], model_parameters: dict, - # tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + # def _transform_json_prompts(self, model: str, credentials: dict, + # prompt_messages: list[PromptMessage], model_parameters: dict, + # tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, # stream: bool = True, user: str | None = None) \ # -> None: # """ @@ -94,8 +99,13 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): # content="```JSON\n" # )) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -130,16 +140,22 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): "temperature": 0.5, }, tools=[], - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials_kwargs: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials_kwargs: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -153,15 +169,14 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): :return: full response or stream response chunk generator result """ extra_model_kwargs = {} - if stop: - extra_model_kwargs['stop'] = stop + # request to glm-4v-plus with stop words will always response "finish_reason":"network_error" + if stop and model != "glm-4v-plus": + extra_model_kwargs["stop"] = stop - client = ZhipuAI( - api_key=credentials_kwargs['api_key'] - ) + client = ZhipuAI(api_key=credentials_kwargs["api_key"]) if len(prompt_messages) == 0: - raise ValueError('At least one message is required') + raise ValueError("At least one message is required") if prompt_messages[0].role == PromptMessageRole.SYSTEM: if not prompt_messages[0].content: @@ -174,10 +189,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]: if isinstance(copy_prompt_message.content, list): # check if model is 'glm-4v' - if model != 'glm-4v': + if model not in ("glm-4v", "glm-4v-plus"): # not support list message continue - # get image and + # get image and if not isinstance(copy_prompt_message, UserPromptMessage): # not support system message continue @@ -187,8 +202,11 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): # not support image message continue - if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER and \ - copy_prompt_message.role == PromptMessageRole.USER: + if ( + new_prompt_messages + and new_prompt_messages[-1].role == PromptMessageRole.USER + and copy_prompt_message.role == PromptMessageRole.USER + ): new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content else: if copy_prompt_message.role == PromptMessageRole.USER: @@ -207,77 +225,66 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): else: new_prompt_messages.append(copy_prompt_message) - if model == 'glm-4v': + if model == "glm-4v" or model == "glm-4v-plus": params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters) else: - params = { - 'model': model, - 'messages': [], - **model_parameters - } + params = {"model": model, "messages": [], **model_parameters} # glm model - if not model.startswith('chatglm'): - + if not model.startswith("chatglm"): for prompt_message in new_prompt_messages: if prompt_message.role == PromptMessageRole.TOOL: - params['messages'].append({ - 'role': 'tool', - 'content': prompt_message.content, - 'tool_call_id': prompt_message.tool_call_id - }) + params["messages"].append( + { + "role": "tool", + "content": prompt_message.content, + "tool_call_id": prompt_message.tool_call_id, + } + ) elif isinstance(prompt_message, AssistantPromptMessage): if prompt_message.tool_calls: - params['messages'].append({ - 'role': 'assistant', - 'content': prompt_message.content, - 'tool_calls': [ - { - 'id': tool_call.id, - 'type': tool_call.type, - 'function': { - 'name': tool_call.function.name, - 'arguments': tool_call.function.arguments + params["messages"].append( + { + "role": "assistant", + "content": prompt_message.content, + "tool_calls": [ + { + "id": tool_call.id, + "type": tool_call.type, + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, } - } for tool_call in prompt_message.tool_calls - ] - }) + for tool_call in prompt_message.tool_calls + ], + } + ) else: - params['messages'].append({ - 'role': 'assistant', - 'content': prompt_message.content - }) + params["messages"].append({"role": "assistant", "content": prompt_message.content}) else: - params['messages'].append({ - 'role': prompt_message.role.value, - 'content': prompt_message.content - }) + params["messages"].append( + {"role": prompt_message.role.value, "content": prompt_message.content} + ) else: # chatglm model for prompt_message in new_prompt_messages: # merge system message to user message - if prompt_message.role == PromptMessageRole.SYSTEM or \ - prompt_message.role == PromptMessageRole.TOOL or \ - prompt_message.role == PromptMessageRole.USER: - if len(params['messages']) > 0 and params['messages'][-1]['role'] == 'user': - params['messages'][-1]['content'] += "\n\n" + prompt_message.content + if ( + prompt_message.role == PromptMessageRole.SYSTEM + or prompt_message.role == PromptMessageRole.TOOL + or prompt_message.role == PromptMessageRole.USER + ): + if len(params["messages"]) > 0 and params["messages"][-1]["role"] == "user": + params["messages"][-1]["content"] += "\n\n" + prompt_message.content else: - params['messages'].append({ - 'role': 'user', - 'content': prompt_message.content - }) + params["messages"].append({"role": "user", "content": prompt_message.content}) else: - params['messages'].append({ - 'role': prompt_message.role.value, - 'content': prompt_message.content - }) + params["messages"].append( + {"role": prompt_message.role.value, "content": prompt_message.content} + ) if tools and len(tools) > 0: - params['tools'] = [ - { - 'type': 'function', - 'function': helper.dump_model(tool) - } for tool in tools - ] + params["tools"] = [{"type": "function", "function": helper.dump_model(tool)} for tool in tools] if stream: response = client.chat.completions.create(stream=stream, **params, **extra_model_kwargs) @@ -286,47 +293,41 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): response = client.chat.completions.create(**params, **extra_model_kwargs) return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages) - def _construct_glm_4v_parameter(self, model: str, prompt_messages: list[PromptMessage], - model_parameters: dict): + def _construct_glm_4v_parameter(self, model: str, prompt_messages: list[PromptMessage], model_parameters: dict): messages = [ - { - 'role': message.role.value, - 'content': self._construct_glm_4v_messages(message.content) - } + {"role": message.role.value, "content": self._construct_glm_4v_messages(message.content)} for message in prompt_messages ] - params = { - 'model': model, - 'messages': messages, - **model_parameters - } + params = {"model": model, "messages": messages, **model_parameters} return params - def _construct_glm_4v_messages(self, prompt_message: Union[str | list[PromptMessageContent]]) -> list[dict]: + def _construct_glm_4v_messages(self, prompt_message: Union[str, list[PromptMessageContent]]) -> list[dict]: if isinstance(prompt_message, str): - return [{'type': 'text', 'text': prompt_message}] + return [{"type": "text", "text": prompt_message}] return [ - {'type': 'image_url', 'image_url': {'url': self._remove_image_header(item.data)}} - if item.type == PromptMessageContentType.IMAGE else - {'type': 'text', 'text': item.data} - + {"type": "image_url", "image_url": {"url": self._remove_image_header(item.data)}} + if item.type == PromptMessageContentType.IMAGE + else {"type": "text", "text": item.data} for item in prompt_message ] def _remove_image_header(self, image: str) -> str: - if image.startswith('data:image'): - return image.split(',')[1] + if image.startswith("data:image"): + return image.split(",")[1] return image - def _handle_generate_response(self, model: str, - credentials: dict, - tools: Optional[list[PromptMessageTool]], - response: Completion, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, + model: str, + credentials: dict, + tools: Optional[list[PromptMessageTool]], + response: Completion, + prompt_messages: list[PromptMessage], + ) -> LLMResult: """ Handle llm response @@ -335,12 +336,12 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response """ - text = '' + text = "" assistant_tool_calls: list[AssistantPromptMessage.ToolCall] = [] for choice in response.choices: if choice.message.tool_calls: for tool_call in choice.message.tool_calls: - if tool_call.type == 'function': + if tool_call.type == "function": assistant_tool_calls.append( AssistantPromptMessage.ToolCall( id=tool_call.id, @@ -348,11 +349,11 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=tool_call.function.name, arguments=tool_call.function.arguments, - ) + ), ) ) - text += choice.message.content or '' + text += choice.message.content or "" prompt_usage = response.usage.prompt_tokens completion_usage = response.usage.completion_tokens @@ -364,20 +365,20 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): result = LLMResult( model=model, prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=text, - tool_calls=assistant_tool_calls - ), + message=AssistantPromptMessage(content=text, tool_calls=assistant_tool_calls), usage=usage, ) return result - def _handle_generate_stream_response(self, model: str, - credentials: dict, - tools: Optional[list[PromptMessageTool]], - responses: Generator[ChatCompletionChunk, None, None], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + tools: Optional[list[PromptMessageTool]], + responses: Generator[ChatCompletionChunk, None, None], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -386,19 +387,19 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator result """ - full_assistant_content = '' + full_assistant_content = "" for chunk in responses: if len(chunk.choices) == 0: continue delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ""): continue assistant_tool_calls: list[AssistantPromptMessage.ToolCall] = [] for tool_call in delta.delta.tool_calls or []: - if tool_call.type == 'function': + if tool_call.type == "function": assistant_tool_calls.append( AssistantPromptMessage.ToolCall( id=tool_call.id, @@ -406,17 +407,16 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=tool_call.function.name, arguments=tool_call.function.arguments, - ) + ), ) ) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=assistant_tool_calls + content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_tool_calls ) - full_assistant_content += delta.delta.content if delta.delta.content else '' + full_assistant_content += delta.delta.content if delta.delta.content else "" if delta.finish_reason is not None and chunk.usage is not None: completion_tokens = chunk.usage.completion_tokens @@ -428,24 +428,22 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): yield LLMResultChunk( model=chunk.model, prompt_messages=prompt_messages, - system_fingerprint='', + system_fingerprint="", delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage - ) + usage=usage, + ), ) else: yield LLMResultChunk( model=chunk.model, prompt_messages=prompt_messages, - system_fingerprint='', + system_fingerprint="", delta=LLMResultChunkDelta( - index=delta.index, - message=assistant_prompt_message, - finish_reason=delta.finish_reason - ) + index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason + ), ) def _convert_one_message_to_text(self, message: PromptMessage) -> str: @@ -472,18 +470,16 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): return message_text - def _convert_messages_to_prompt(self, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> str: + def _convert_messages_to_prompt( + self, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> str: """ :param messages: List of PromptMessage to combine. :return: Combined string with necessary human_prompt and ai_prompt tags. """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) if tools and len(tools) > 0: text += "\n\nTools:" diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py index 0f9fecfc72..ee20954381 100644 --- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py @@ -14,9 +14,9 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): Model class for ZhipuAI text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -27,16 +27,14 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): :return: embeddings result """ credentials_kwargs = self._to_credential_kwargs(credentials) - client = ZhipuAI( - api_key=credentials_kwargs['api_key'] - ) + client = ZhipuAI(api_key=credentials_kwargs["api_key"]) embeddings, embedding_used_tokens = self.embed_documents(model, client, texts) return TextEmbeddingResult( embeddings=embeddings, usage=self._calc_response_usage(model, credentials_kwargs, embedding_used_tokens), - model=model + model=model, ) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: @@ -50,7 +48,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): """ if len(texts) == 0: return 0 - + total_num_tokens = 0 for text in texts: total_num_tokens += self._get_num_tokens_by_gpt2(text) @@ -68,15 +66,13 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): try: # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) - client = ZhipuAI( - api_key=credentials_kwargs['api_key'] - ) + client = ZhipuAI(api_key=credentials_kwargs["api_key"]) # call embedding model self.embed_documents( model=model, client=client, - texts=['ping'], + texts=["ping"], ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -100,7 +96,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): embedding_used_tokens += response.usage.total_tokens return [list(map(float, e)) for e in embeddings], embedding_used_tokens - + def embed_query(self, text: str) -> list[float]: """Call out to ZhipuAI's embedding endpoint. @@ -111,8 +107,8 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): Embeddings for the text. """ return self.embed_documents([text])[0] - - def _calc_response_usage(self, model: str,credentials: dict, tokens: int) -> EmbeddingUsage: + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -122,10 +118,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -136,7 +129,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai.py index c517d2dba5..e75aad6eb0 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai.py @@ -19,12 +19,9 @@ class ZhipuaiProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='glm-4', - credentials=credentials - ) + model_instance.validate_credentials(model="glm-4", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py index 4dcd03f551..bf9b093cb3 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py @@ -1,4 +1,3 @@ - from .__version__ import __version__ from ._client import ZhipuAI from .core._errors import ( diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py index eb0ad332ca..659f38d7ff 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py @@ -1,2 +1 @@ - -__version__ = 'v2.0.1' \ No newline at end of file +__version__ = "v2.0.1" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py index 6588d1dd68..df9e506095 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py @@ -20,14 +20,14 @@ class ZhipuAI(HttpClient): api_key: str def __init__( - self, - *, - api_key: str | None = None, - base_url: str | httpx.URL | None = None, - timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, - max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES, - http_client: httpx.Client | None = None, - custom_headers: Mapping[str, str] | None = None + self, + *, + api_key: str | None = None, + base_url: str | httpx.URL | None = None, + timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, + max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES, + http_client: httpx.Client | None = None, + custom_headers: Mapping[str, str] | None = None, ) -> None: if api_key is None: raise ZhipuAIError("No api_key provided, please provide it through parameters or environment variables") @@ -38,6 +38,7 @@ class ZhipuAI(HttpClient): if base_url is None: base_url = "https://open.bigmodel.cn/api/paas/v4" from .__version__ import __version__ + super().__init__( version=__version__, base_url=base_url, @@ -58,9 +59,7 @@ class ZhipuAI(HttpClient): return {"Authorization": f"{_jwt_token.generate_token(api_key)}"} def __del__(self) -> None: - if (not hasattr(self, "_has_custom_http_client") - or not hasattr(self, "close") - or not hasattr(self, "_client")): + if not hasattr(self, "_has_custom_http_client") or not hasattr(self, "close") or not hasattr(self, "_client"): # if the '__init__' method raised an error, self would not have client attr return diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py index dab6dac5fe..1f80119739 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py @@ -17,25 +17,24 @@ class AsyncCompletions(BaseAPI): def __init__(self, client: ZhipuAI) -> None: super().__init__(client) - def create( - self, - *, - model: str, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, - temperature: Optional[float] | NotGiven = NOT_GIVEN, - top_p: Optional[float] | NotGiven = NOT_GIVEN, - max_tokens: int | NotGiven = NOT_GIVEN, - seed: int | NotGiven = NOT_GIVEN, - messages: Union[str, list[str], list[int], list[list[int]], None], - stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, - sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, - tools: Optional[object] | NotGiven = NOT_GIVEN, - tool_choice: str | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + model: str, + request_id: Optional[str] | NotGiven = NOT_GIVEN, + do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, + max_tokens: int | NotGiven = NOT_GIVEN, + seed: int | NotGiven = NOT_GIVEN, + messages: Union[str, list[str], list[int], list[list[int]], None], + stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, + sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, + tools: Optional[object] | NotGiven = NOT_GIVEN, + tool_choice: str | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + disable_strict_validation: Optional[bool] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> AsyncTaskStatus: _cast_type = AsyncTaskStatus @@ -57,9 +56,7 @@ class AsyncCompletions(BaseAPI): "tools": tools, "tool_choice": tool_choice, }, - options=make_user_request_input( - extra_headers=extra_headers, timeout=timeout - ), + options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), cast_type=_cast_type, enable_stream=False, ) @@ -71,16 +68,11 @@ class AsyncCompletions(BaseAPI): disable_strict_validation: Optional[bool] | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> Union[AsyncCompletion, AsyncTaskStatus]: - _cast_type = Union[AsyncCompletion,AsyncTaskStatus] + _cast_type = Union[AsyncCompletion, AsyncTaskStatus] if disable_strict_validation: _cast_type = object return self._get( path=f"/async-result/{id}", cast_type=_cast_type, - options=make_user_request_input( - extra_headers=extra_headers, - timeout=timeout - ) + options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), ) - - diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py index 5c4ed4d1ba..ec29f33864 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py @@ -20,24 +20,24 @@ class Completions(BaseAPI): super().__init__(client) def create( - self, - *, - model: str, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, - stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, - temperature: Optional[float] | NotGiven = NOT_GIVEN, - top_p: Optional[float] | NotGiven = NOT_GIVEN, - max_tokens: int | NotGiven = NOT_GIVEN, - seed: int | NotGiven = NOT_GIVEN, - messages: Union[str, list[str], list[int], object, None], - stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, - sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, - tools: Optional[object] | NotGiven = NOT_GIVEN, - tool_choice: str | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + model: str, + request_id: Optional[str] | NotGiven = NOT_GIVEN, + do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, + max_tokens: int | NotGiven = NOT_GIVEN, + seed: int | NotGiven = NOT_GIVEN, + messages: Union[str, list[str], list[int], object, None], + stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, + sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, + tools: Optional[object] | NotGiven = NOT_GIVEN, + tool_choice: str | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + disable_strict_validation: Optional[bool] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> Completion | StreamResponse[ChatCompletionChunk]: _cast_type = Completion _stream_cls = StreamResponse[ChatCompletionChunk] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py index 35d54592fd..2308a20451 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py @@ -18,16 +18,16 @@ class Embeddings(BaseAPI): super().__init__(client) def create( - self, - *, - input: Union[str, list[str], list[int], list[list[int]]], - model: Union[str], - encoding_format: str | NotGiven = NOT_GIVEN, - user: str | NotGiven = NOT_GIVEN, - sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + input: Union[str, list[str], list[int], list[list[int]]], + model: Union[str], + encoding_format: str | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + disable_strict_validation: Optional[bool] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> EmbeddingsResponded: _cast_type = EmbeddingsResponded if disable_strict_validation: @@ -41,9 +41,7 @@ class Embeddings(BaseAPI): "user": user, "sensitive_word_check": sensitive_word_check, }, - options=make_user_request_input( - extra_headers=extra_headers, timeout=timeout - ), + options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), cast_type=_cast_type, enable_stream=False, ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py index 5deb8d08f3..f2ac74bffa 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py @@ -17,17 +17,16 @@ __all__ = ["Files"] class Files(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def create( - self, - *, - file: FileTypes, - purpose: str, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + file: FileTypes, + purpose: str, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> FileObject: if not is_file_content(file): prefix = f"Expected file input `{file!r}`" @@ -44,21 +43,19 @@ class Files(BaseAPI): "purpose": purpose, }, files=files, - options=make_user_request_input( - extra_headers=extra_headers, timeout=timeout - ), + options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), cast_type=FileObject, ) def list( - self, - *, - purpose: str | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, - after: str | NotGiven = NOT_GIVEN, - order: str | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + purpose: str | NotGiven = NOT_GIVEN, + limit: int | NotGiven = NOT_GIVEN, + after: str | NotGiven = NOT_GIVEN, + order: str | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ListOfFileObject: return self._get( "/files", diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py index dc54a9ca45..dc30bd33ed 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py @@ -13,4 +13,3 @@ class FineTuning(BaseAPI): def __init__(self, client: "ZhipuAI") -> None: super().__init__(client) self.jobs = Jobs(client) - diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py index b860de192a..3d2e9208a1 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py @@ -16,21 +16,20 @@ __all__ = ["Jobs"] class Jobs(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def create( - self, - *, - model: str, - training_file: str, - hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN, - suffix: Optional[str] | NotGiven = NOT_GIVEN, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - validation_file: Optional[str] | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + model: str, + training_file: str, + hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN, + suffix: Optional[str] | NotGiven = NOT_GIVEN, + request_id: Optional[str] | NotGiven = NOT_GIVEN, + validation_file: Optional[str] | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> FineTuningJob: return self._post( "/fine_tuning/jobs", @@ -42,34 +41,30 @@ class Jobs(BaseAPI): "validation_file": validation_file, "request_id": request_id, }, - options=make_user_request_input( - extra_headers=extra_headers, timeout=timeout - ), + options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), cast_type=FineTuningJob, ) def retrieve( - self, - fine_tuning_job_id: str, - *, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + fine_tuning_job_id: str, + *, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> FineTuningJob: return self._get( f"/fine_tuning/jobs/{fine_tuning_job_id}", - options=make_user_request_input( - extra_headers=extra_headers, timeout=timeout - ), + options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), cast_type=FineTuningJob, ) def list( - self, - *, - after: str | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + after: str | NotGiven = NOT_GIVEN, + limit: int | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ListOfFineTuningJob: return self._get( "/fine_tuning/jobs", @@ -93,7 +88,6 @@ class Jobs(BaseAPI): extra_headers: Headers | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> FineTuningJobEvent: - return self._get( f"/fine_tuning/jobs/{fine_tuning_job_id}/events", cast_type=FineTuningJobEvent, diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py index 8eae1216d0..2692b093af 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py @@ -18,21 +18,21 @@ class Images(BaseAPI): super().__init__(client) def generations( - self, - *, - prompt: str, - model: str | NotGiven = NOT_GIVEN, - n: Optional[int] | NotGiven = NOT_GIVEN, - quality: Optional[str] | NotGiven = NOT_GIVEN, - response_format: Optional[str] | NotGiven = NOT_GIVEN, - size: Optional[str] | NotGiven = NOT_GIVEN, - style: Optional[str] | NotGiven = NOT_GIVEN, - user: str | NotGiven = NOT_GIVEN, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + prompt: str, + model: str | NotGiven = NOT_GIVEN, + n: Optional[int] | NotGiven = NOT_GIVEN, + quality: Optional[str] | NotGiven = NOT_GIVEN, + response_format: Optional[str] | NotGiven = NOT_GIVEN, + size: Optional[str] | NotGiven = NOT_GIVEN, + style: Optional[str] | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + request_id: Optional[str] | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + disable_strict_validation: Optional[bool] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ImagesResponded: _cast_type = ImagesResponded if disable_strict_validation: @@ -50,11 +50,7 @@ class Images(BaseAPI): "user": user, "request_id": request_id, }, - options=make_user_request_input( - extra_headers=extra_headers, - extra_body=extra_body, - timeout=timeout - ), + options=make_user_request_input(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), cast_type=_cast_type, enable_stream=False, ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py index a2a438b8f3..1027c1bc5b 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py @@ -17,7 +17,10 @@ __all__ = [ class ZhipuAIError(Exception): - def __init__(self, message: str, ) -> None: + def __init__( + self, + message: str, + ) -> None: super().__init__(message) @@ -31,24 +34,19 @@ class APIStatusError(Exception): self.status_code = response.status_code -class APIRequestFailedError(APIStatusError): - ... +class APIRequestFailedError(APIStatusError): ... -class APIAuthenticationError(APIStatusError): - ... +class APIAuthenticationError(APIStatusError): ... -class APIReachLimitError(APIStatusError): - ... +class APIReachLimitError(APIStatusError): ... -class APIInternalError(APIStatusError): - ... +class APIInternalError(APIStatusError): ... -class APIServerFlowExceedError(APIStatusError): - ... +class APIServerFlowExceedError(APIStatusError): ... class APIResponseError(Exception): @@ -67,16 +65,11 @@ class APIResponseValidationError(APIResponseError): status_code: int response: httpx.Response - def __init__( - self, - response: httpx.Response, - json_data: object | None, *, - message: str | None = None - ) -> None: + def __init__(self, response: httpx.Response, json_data: object | None, *, message: str | None = None) -> None: super().__init__( message=message or "Data returned by API invalid for expected schema.", request=response.request, - json_data=json_data + json_data=json_data, ) self.response = response self.status_code = response.status_code diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py index 263fe82990..48eeb37c41 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py @@ -48,13 +48,13 @@ class HttpClient: _default_stream_cls: type[StreamResponse[Any]] | None = None def __init__( - self, - *, - version: str, - base_url: URL, - timeout: Union[float, Timeout, None], - custom_httpx_client: httpx.Client | None = None, - custom_headers: Mapping[str, str] | None = None, + self, + *, + version: str, + base_url: URL, + timeout: Union[float, Timeout, None], + custom_httpx_client: httpx.Client | None = None, + custom_headers: Mapping[str, str] | None = None, ) -> None: if timeout is None or isinstance(timeout, NotGiven): if custom_httpx_client and custom_httpx_client.timeout != HTTPX_DEFAULT_TIMEOUT: @@ -76,7 +76,6 @@ class HttpClient: self._custom_headers = custom_headers or {} def _prepare_url(self, url: str) -> URL: - sub_url = URL(url) if sub_url.is_relative_url: request_raw_url = self._base_url.raw_path + sub_url.raw_path.lstrip(b"/") @@ -86,16 +85,15 @@ class HttpClient: @property def _default_headers(self): - return \ - { - "Accept": "application/json", - "Content-Type": "application/json; charset=UTF-8", - "ZhipuAI-SDK-Ver": self._version, - "source_type": "zhipu-sdk-python", - "x-request-sdk": "zhipu-sdk-python", - **self._auth_headers, - **self._custom_headers, - } + return { + "Accept": "application/json", + "Content-Type": "application/json; charset=UTF-8", + "ZhipuAI-SDK-Ver": self._version, + "source_type": "zhipu-sdk-python", + "x-request-sdk": "zhipu-sdk-python", + **self._auth_headers, + **self._custom_headers, + } @property def _auth_headers(self): @@ -109,10 +107,7 @@ class HttpClient: return httpx_headers - def _prepare_request( - self, - request_param: ClientRequestParam - ) -> httpx.Request: + def _prepare_request(self, request_param: ClientRequestParam) -> httpx.Request: kwargs: dict[str, Any] = {} json_data = request_param.json_data headers = self._prepare_headers(request_param) @@ -135,16 +130,16 @@ class HttpClient: **kwargs, ) - def _object_to_formfata(self, key: str, value: Data | Mapping[object, object]) -> list[tuple[str, str]]: + def _object_to_formdata(self, key: str, value: Data | Mapping[object, object]) -> list[tuple[str, str]]: items = [] if isinstance(value, Mapping): for k, v in value.items(): - items.extend(self._object_to_formfata(f"{key}[{k}]", v)) + items.extend(self._object_to_formdata(f"{key}[{k}]", v)) return items if isinstance(value, list | tuple): for v in value: - items.extend(self._object_to_formfata(key + "[]", v)) + items.extend(self._object_to_formdata(key + "[]", v)) return items def _primitive_value_to_str(val) -> str: @@ -164,8 +159,7 @@ class HttpClient: return [(key, str_data)] def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]: - - items = flatten([self._object_to_formfata(k, v) for k, v in data.items()]) + items = flatten([self._object_to_formdata(k, v) for k, v in data.items()]) serialized: dict[str, object] = {} for key, value in items: @@ -175,30 +169,25 @@ class HttpClient: return serialized def _parse_response( - self, - *, - cast_type: type[ResponseT], - response: httpx.Response, - enable_stream: bool, - request_param: ClientRequestParam, - stream_cls: type[StreamResponse[Any]] | None = None, + self, + *, + cast_type: type[ResponseT], + response: httpx.Response, + enable_stream: bool, + request_param: ClientRequestParam, + stream_cls: type[StreamResponse[Any]] | None = None, ) -> HttpResponse: - http_response = HttpResponse( - raw_response=response, - cast_type=cast_type, - client=self, - enable_stream=enable_stream, - stream_cls=stream_cls + raw_response=response, cast_type=cast_type, client=self, enable_stream=enable_stream, stream_cls=stream_cls ) return http_response.parse() def _process_response_data( - self, - *, - data: object, - cast_type: type[ResponseT], - response: httpx.Response, + self, + *, + data: object, + cast_type: type[ResponseT], + response: httpx.Response, ) -> ResponseT: if data is None: return cast(ResponseT, None) @@ -225,12 +214,12 @@ class HttpClient: @retry(stop=stop_after_attempt(ZHIPUAI_DEFAULT_MAX_RETRIES)) def request( - self, - *, - cast_type: type[ResponseT], - params: ClientRequestParam, - enable_stream: bool = False, - stream_cls: type[StreamResponse[Any]] | None = None, + self, + *, + cast_type: type[ResponseT], + params: ClientRequestParam, + enable_stream: bool = False, + stream_cls: type[StreamResponse[Any]] | None = None, ) -> ResponseT | StreamResponse: request = self._prepare_request(params) @@ -259,81 +248,79 @@ class HttpClient: ) def get( - self, - path: str, - *, - cast_type: type[ResponseT], - options: UserRequestInput = {}, - enable_stream: bool = False, + self, + path: str, + *, + cast_type: type[ResponseT], + options: UserRequestInput = {}, + enable_stream: bool = False, ) -> ResponseT | StreamResponse: opts = ClientRequestParam.construct(method="get", url=path, **options) - return self.request( - cast_type=cast_type, params=opts, - enable_stream=enable_stream - ) + return self.request(cast_type=cast_type, params=opts, enable_stream=enable_stream) def post( - self, - path: str, - *, - body: Body | None = None, - cast_type: type[ResponseT], - options: UserRequestInput = {}, - files: RequestFiles | None = None, - enable_stream: bool = False, - stream_cls: type[StreamResponse[Any]] | None = None, + self, + path: str, + *, + body: Body | None = None, + cast_type: type[ResponseT], + options: UserRequestInput = {}, + files: RequestFiles | None = None, + enable_stream: bool = False, + stream_cls: type[StreamResponse[Any]] | None = None, ) -> ResponseT | StreamResponse: - opts = ClientRequestParam.construct(method="post", json_data=body, files=make_httpx_files(files), url=path, - **options) - - return self.request( - cast_type=cast_type, params=opts, - enable_stream=enable_stream, - stream_cls=stream_cls + opts = ClientRequestParam.construct( + method="post", json_data=body, files=make_httpx_files(files), url=path, **options ) + return self.request(cast_type=cast_type, params=opts, enable_stream=enable_stream, stream_cls=stream_cls) + def patch( - self, - path: str, - *, - body: Body | None = None, - cast_type: type[ResponseT], - options: UserRequestInput = {}, + self, + path: str, + *, + body: Body | None = None, + cast_type: type[ResponseT], + options: UserRequestInput = {}, ) -> ResponseT: opts = ClientRequestParam.construct(method="patch", url=path, json_data=body, **options) return self.request( - cast_type=cast_type, params=opts, + cast_type=cast_type, + params=opts, ) def put( - self, - path: str, - *, - body: Body | None = None, - cast_type: type[ResponseT], - options: UserRequestInput = {}, - files: RequestFiles | None = None, + self, + path: str, + *, + body: Body | None = None, + cast_type: type[ResponseT], + options: UserRequestInput = {}, + files: RequestFiles | None = None, ) -> ResponseT | StreamResponse: - opts = ClientRequestParam.construct(method="put", url=path, json_data=body, files=make_httpx_files(files), - **options) + opts = ClientRequestParam.construct( + method="put", url=path, json_data=body, files=make_httpx_files(files), **options + ) return self.request( - cast_type=cast_type, params=opts, + cast_type=cast_type, + params=opts, ) def delete( - self, - path: str, - *, - body: Body | None = None, - cast_type: type[ResponseT], - options: UserRequestInput = {}, + self, + path: str, + *, + body: Body | None = None, + cast_type: type[ResponseT], + options: UserRequestInput = {}, ) -> ResponseT | StreamResponse: opts = ClientRequestParam.construct(method="delete", url=path, json_data=body, **options) return self.request( - cast_type=cast_type, params=opts, + cast_type=cast_type, + params=opts, ) def _make_status_error(self, response) -> APIStatusError: @@ -355,11 +342,11 @@ class HttpClient: def make_user_request_input( - max_retries: int | None = None, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, - extra_headers: Headers = None, - extra_body: Body | None = None, - query: Query | None = None, + max_retries: int | None = None, + timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + extra_headers: Headers = None, + extra_body: Body | None = None, + query: Query | None = None, ) -> UserRequestInput: options: UserRequestInput = {} @@ -368,7 +355,7 @@ def make_user_request_input( if max_retries is not None: options["max_retries"] = max_retries if not isinstance(timeout, NotGiven): - options['timeout'] = timeout + options["timeout"] = timeout if query is not None: options["params"] = query if extra_body is not None: diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py index a3f49ba846..ac459151fc 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py @@ -35,17 +35,14 @@ class ClientRequestParam: @classmethod def construct( # type: ignore - cls, - _fields_set: set[str] | None = None, - **values: Unpack[UserRequestInput], - ) -> ClientRequestParam : - kwargs: dict[str, Any] = { - key: remove_notgiven_indict(value) for key, value in values.items() - } + cls, + _fields_set: set[str] | None = None, + **values: Unpack[UserRequestInput], + ) -> ClientRequestParam: + kwargs: dict[str, Any] = {key: remove_notgiven_indict(value) for key, value in values.items()} client = cls() client.__dict__.update(kwargs) return client model_construct = construct - diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py index 2f831b6fc9..56e60a7934 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py @@ -26,13 +26,13 @@ class HttpResponse(Generic[R]): http_response: httpx.Response def __init__( - self, - *, - raw_response: httpx.Response, - cast_type: type[R], - client: HttpClient, - enable_stream: bool = False, - stream_cls: type[StreamResponse[Any]] | None = None, + self, + *, + raw_response: httpx.Response, + cast_type: type[R], + client: HttpClient, + enable_stream: bool = False, + stream_cls: type[StreamResponse[Any]] | None = None, ) -> None: self._cast_type = cast_type self._client = client @@ -52,8 +52,8 @@ class HttpResponse(Generic[R]): self._stream_cls( cast_type=cast(type, get_args(self._stream_cls)[0]), response=self.http_response, - client=self._client - ) + client=self._client, + ), ) return self._parsed cast_type = self._cast_type diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py index 66afbfd107..3566c6b332 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py @@ -16,16 +16,15 @@ if TYPE_CHECKING: class StreamResponse(Generic[ResponseT]): - response: httpx.Response _cast_type: type[ResponseT] def __init__( - self, - *, - cast_type: type[ResponseT], - response: httpx.Response, - client: HttpClient, + self, + *, + cast_type: type[ResponseT], + response: httpx.Response, + client: HttpClient, ) -> None: self.response = response self._cast_type = cast_type @@ -39,7 +38,6 @@ class StreamResponse(Generic[ResponseT]): yield from self._stream_chunks def __stream__(self) -> Iterator[ResponseT]: - sse_line_parser = SSELineParser() iterator = sse_line_parser.iter_lines(self.response.iter_lines()) @@ -63,11 +61,7 @@ class StreamResponse(Generic[ResponseT]): class Event: def __init__( - self, - event: str | None = None, - data: str | None = None, - id: str | None = None, - retry: int | None = None + self, event: str | None = None, data: str | None = None, id: str | None = None, retry: int | None = None ): self._event = event self._data = data @@ -76,21 +70,28 @@ class Event: def __repr__(self): data_len = len(self._data) if self._data else 0 - return f"Event(event={self._event}, data={self._data} ,data_length={data_len}, id={self._id}, retry={self._retry}" + return ( + f"Event(event={self._event}, data={self._data} ,data_length={data_len}, id={self._id}, retry={self._retry}" + ) @property - def event(self): return self._event + def event(self): + return self._event @property - def data(self): return self._data + def data(self): + return self._data - def json_data(self): return json.loads(self._data) + def json_data(self): + return json.loads(self._data) @property - def id(self): return self._id + def id(self): + return self._id @property - def retry(self): return self._retry + def retry(self): + return self._retry class SSELineParser: @@ -107,19 +108,11 @@ class SSELineParser: def iter_lines(self, lines: Iterator[str]) -> Iterator[Event]: for line in lines: - line = line.rstrip('\n') + line = line.rstrip("\n") if not line: - if self._event is None and \ - not self._data and \ - self._id is None and \ - self._retry is None: + if self._event is None and not self._data and self._id is None and self._retry is None: continue - sse_event = Event( - event=self._event, - data='\n'.join(self._data), - id=self._id, - retry=self._retry - ) + sse_event = Event(event=self._event, data="\n".join(self._data), id=self._id, retry=self._retry) self._event = None self._data = [] self._id = None @@ -134,7 +127,7 @@ class SSELineParser: field, _p, value = line.partition(":") - if value.startswith(' '): + if value.startswith(" "): value = value[1:] if field == "data": self._data.append(value) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py index f22f32d251..a0645b0916 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py @@ -20,4 +20,4 @@ class AsyncCompletion(BaseModel): model: Optional[str] = None task_status: str choices: list[CompletionChoice] - usage: CompletionUsage \ No newline at end of file + usage: CompletionUsage diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py index b2a847c50c..4b3a929a2b 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py @@ -41,5 +41,3 @@ class Completion(BaseModel): request_id: Optional[str] = None id: Optional[str] = None usage: CompletionUsage - - diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py index 917bda7576..75f76fe969 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py @@ -6,7 +6,6 @@ __all__ = ["FileObject"] class FileObject(BaseModel): - id: Optional[str] = None bytes: Optional[int] = None created_at: Optional[int] = None @@ -18,7 +17,6 @@ class FileObject(BaseModel): class ListOfFileObject(BaseModel): - object: Optional[str] = None data: list[FileObject] has_more: Optional[bool] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py index 71c00eaff0..1d3930286b 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py @@ -2,7 +2,7 @@ from typing import Optional, Union from pydantic import BaseModel -__all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob" ] +__all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob"] class Error(BaseModel): diff --git a/api/core/model_runtime/schema_validators/common_validator.py b/api/core/model_runtime/schema_validators/common_validator.py index fe705d6943..e4f3541475 100644 --- a/api/core/model_runtime/schema_validators/common_validator.py +++ b/api/core/model_runtime/schema_validators/common_validator.py @@ -4,9 +4,9 @@ from core.model_runtime.entities.provider_entities import CredentialFormSchema, class CommonValidator: - def _validate_and_filter_credential_form_schemas(self, - credential_form_schemas: list[CredentialFormSchema], - credentials: dict) -> dict: + def _validate_and_filter_credential_form_schemas( + self, credential_form_schemas: list[CredentialFormSchema], credentials: dict + ) -> dict: need_validate_credential_form_schema_map = {} for credential_form_schema in credential_form_schemas: if not credential_form_schema.show_on: @@ -36,8 +36,9 @@ class CommonValidator: return validated_credentials - def _validate_credential_form_schema(self, credential_form_schema: CredentialFormSchema, credentials: dict) \ - -> Optional[str]: + def _validate_credential_form_schema( + self, credential_form_schema: CredentialFormSchema, credentials: dict + ) -> Optional[str]: """ Validate credential form schema @@ -49,7 +50,7 @@ class CommonValidator: if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]: # If required is True, an exception is thrown if credential_form_schema.required: - raise ValueError(f'Variable {credential_form_schema.variable} is required') + raise ValueError(f"Variable {credential_form_schema.variable} is required") else: # Get the value of default if credential_form_schema.default: @@ -65,23 +66,25 @@ class CommonValidator: # If max_length=0, no validation is performed if credential_form_schema.max_length: if len(value) > credential_form_schema.max_length: - raise ValueError(f'Variable {credential_form_schema.variable} length should not greater than {credential_form_schema.max_length}') + raise ValueError( + f"Variable {credential_form_schema.variable} length should not greater than {credential_form_schema.max_length}" + ) # check the type of value if not isinstance(value, str): - raise ValueError(f'Variable {credential_form_schema.variable} should be string') + raise ValueError(f"Variable {credential_form_schema.variable} should be string") if credential_form_schema.type in [FormType.SELECT, FormType.RADIO]: # If the value is in options, no validation is performed if credential_form_schema.options: if value not in [option.value for option in credential_form_schema.options]: - raise ValueError(f'Variable {credential_form_schema.variable} is not in options') + raise ValueError(f"Variable {credential_form_schema.variable} is not in options") if credential_form_schema.type == FormType.SWITCH: # If the value is not in ['true', 'false'], an exception is thrown - if value.lower() not in ['true', 'false']: - raise ValueError(f'Variable {credential_form_schema.variable} should be true or false') + if value.lower() not in ["true", "false"]: + raise ValueError(f"Variable {credential_form_schema.variable} should be true or false") - value = True if value.lower() == 'true' else False + value = True if value.lower() == "true" else False return value diff --git a/api/core/model_runtime/schema_validators/model_credential_schema_validator.py b/api/core/model_runtime/schema_validators/model_credential_schema_validator.py index c4786fad5d..7d1644d134 100644 --- a/api/core/model_runtime/schema_validators/model_credential_schema_validator.py +++ b/api/core/model_runtime/schema_validators/model_credential_schema_validator.py @@ -4,7 +4,6 @@ from core.model_runtime.schema_validators.common_validator import CommonValidato class ModelCredentialSchemaValidator(CommonValidator): - def __init__(self, model_type: ModelType, model_credential_schema: ModelCredentialSchema): self.model_type = model_type self.model_credential_schema = model_credential_schema diff --git a/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py b/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py index c945016534..6dff2428ca 100644 --- a/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py +++ b/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py @@ -3,7 +3,6 @@ from core.model_runtime.schema_validators.common_validator import CommonValidato class ProviderCredentialSchemaValidator(CommonValidator): - def __init__(self, provider_credential_schema: ProviderCredentialSchema): self.provider_credential_schema = provider_credential_schema diff --git a/api/core/model_runtime/utils/encoders.py b/api/core/model_runtime/utils/encoders.py index 5078f00bfa..ec1bad5698 100644 --- a/api/core/model_runtime/utils/encoders.py +++ b/api/core/model_runtime/utils/encoders.py @@ -18,11 +18,10 @@ from pydantic_core import Url from pydantic_extra_types.color import Color -def _model_dump( - model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any -) -> Any: +def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: return model.model_dump(mode=mode, **kwargs) + # Taken from Pydantic v1 as is def isoformat(o: Union[datetime.date, datetime.time]) -> str: return o.isoformat() @@ -82,11 +81,9 @@ ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = { def generate_encoders_by_class_tuples( - type_encoder_map: dict[Any, Callable[[Any], Any]] + type_encoder_map: dict[Any, Callable[[Any], Any]], ) -> dict[Callable[[Any], Any], tuple[Any, ...]]: - encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict( - tuple - ) + encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict(tuple) for type_, encoder in type_encoder_map.items(): encoders_by_class_tuples[encoder] += (type_,) return encoders_by_class_tuples @@ -149,17 +146,13 @@ def jsonable_encoder( if isinstance(obj, str | int | float | type(None)): return obj if isinstance(obj, Decimal): - return format(obj, 'f') + return format(obj, "f") if isinstance(obj, dict): encoded_dict = {} allowed_keys = set(obj.keys()) for key, value in obj.items(): if ( - ( - not sqlalchemy_safe - or (not isinstance(key, str)) - or (not key.startswith("_sa")) - ) + (not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa"))) and (value is not None or not exclude_none) and key in allowed_keys ): diff --git a/api/core/model_runtime/utils/helper.py b/api/core/model_runtime/utils/helper.py index c68a554471..2067092d80 100644 --- a/api/core/model_runtime/utils/helper.py +++ b/api/core/model_runtime/utils/helper.py @@ -3,7 +3,7 @@ from pydantic import BaseModel def dump_model(model: BaseModel) -> dict: - if hasattr(pydantic, 'model_dump'): + if hasattr(pydantic, "model_dump"): return pydantic.model_dump(model) else: return model.model_dump() diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index f96e2a1c21..094ad78636 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -44,32 +44,29 @@ class ApiModeration(Moderation): flagged = False preset_response = "" - if self.config['inputs_config']['enabled']: - params = ModerationInputParams( - app_id=self.app_id, - inputs=inputs, - query=query - ) + if self.config["inputs_config"]["enabled"]: + params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query) result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump()) return ModerationInputsResult(**result) - return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationInputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" - if self.config['outputs_config']['enabled']: - params = ModerationOutputParams( - app_id=self.app_id, - text=text - ) + if self.config["outputs_config"]["enabled"]: + params = ModerationOutputParams(app_id=self.app_id, text=text) result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump()) return ModerationOutputsResult(**result) - return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationOutputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict: extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id")) @@ -80,9 +77,10 @@ class ApiModeration(Moderation): @staticmethod def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: - extension = db.session.query(APIBasedExtension).filter( - APIBasedExtension.tenant_id == tenant_id, - APIBasedExtension.id == api_based_extension_id - ).first() + extension = ( + db.session.query(APIBasedExtension) + .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) + .first() + ) return extension diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 9a369a9f87..60898d5547 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -8,8 +8,8 @@ from core.extension.extensible import Extensible, ExtensionModule class ModerationAction(Enum): - DIRECT_OUTPUT = 'direct_output' - OVERRIDED = 'overrided' + DIRECT_OUTPUT = "direct_output" + OVERRIDDEN = "overridden" class ModerationInputsResult(BaseModel): @@ -31,6 +31,7 @@ class Moderation(Extensible, ABC): """ The base class of moderation. """ + module: ExtensionModule = ExtensionModule.MODERATION def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None) -> None: @@ -75,7 +76,7 @@ class Moderation(Extensible, ABC): raise NotImplementedError @classmethod - def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_required: bool) -> None: + def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool) -> None: # inputs_config inputs_config = config.get("inputs_config") if not isinstance(inputs_config, dict): @@ -110,5 +111,5 @@ class Moderation(Extensible, ABC): raise ValueError("outputs_config.preset_response must be less than 100 characters") -class ModerationException(Exception): +class ModerationError(Exception): pass diff --git a/api/core/moderation/input_moderation.py b/api/core/moderation/input_moderation.py index 8157b300b1..46d3963bd0 100644 --- a/api/core/moderation/input_moderation.py +++ b/api/core/moderation/input_moderation.py @@ -2,7 +2,7 @@ import logging from typing import Optional from core.app.app_config.entities import AppConfig -from core.moderation.base import ModerationAction, ModerationException +from core.moderation.base import ModerationAction, ModerationError from core.moderation.factory import ModerationFactory from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask @@ -13,13 +13,14 @@ logger = logging.getLogger(__name__) class InputModeration: def check( - self, app_id: str, + self, + app_id: str, tenant_id: str, app_config: AppConfig, inputs: dict, query: str, message_id: str, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> tuple[bool, dict, str]: """ Process sensitive_word_avoidance. @@ -39,10 +40,7 @@ class InputModeration: moderation_type = sensitive_word_avoidance_config.type moderation_factory = ModerationFactory( - name=moderation_type, - app_id=app_id, - tenant_id=tenant_id, - config=sensitive_word_avoidance_config.config + name=moderation_type, app_id=app_id, tenant_id=tenant_id, config=sensitive_word_avoidance_config.config ) with measure_time() as timer: @@ -55,7 +53,7 @@ class InputModeration: message_id=message_id, moderation_result=moderation_result, inputs=inputs, - timer=timer + timer=timer, ) ) @@ -63,8 +61,8 @@ class InputModeration: return False, inputs, query if moderation_result.action == ModerationAction.DIRECT_OUTPUT: - raise ModerationException(moderation_result.preset_response) - elif moderation_result.action == ModerationAction.OVERRIDED: + raise ModerationError(moderation_result.preset_response) + elif moderation_result.action == ModerationAction.OVERRIDDEN: inputs = moderation_result.inputs query = moderation_result.query diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index ca562ad987..17e48b8fbe 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -25,31 +25,35 @@ class KeywordsModeration(Moderation): flagged = False preset_response = "" - if self.config['inputs_config']['enabled']: - preset_response = self.config['inputs_config']['preset_response'] + if self.config["inputs_config"]["enabled"]: + preset_response = self.config["inputs_config"]["preset_response"] if query: - inputs['query__'] = query + inputs["query__"] = query # Filter out empty values - keywords_list = [keyword for keyword in self.config['keywords'].split('\n') if keyword] + keywords_list = [keyword for keyword in self.config["keywords"].split("\n") if keyword] flagged = self._is_violated(inputs, keywords_list) - return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationInputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" - if self.config['outputs_config']['enabled']: + if self.config["outputs_config"]["enabled"]: # Filter out empty values - keywords_list = [keyword for keyword in self.config['keywords'].split('\n') if keyword] + keywords_list = [keyword for keyword in self.config["keywords"].split("\n") if keyword] - flagged = self._is_violated({'text': text}, keywords_list) - preset_response = self.config['outputs_config']['preset_response'] + flagged = self._is_violated({"text": text}, keywords_list) + preset_response = self.config["outputs_config"]["preset_response"] - return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationOutputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def _is_violated(self, inputs: dict, keywords_list: list) -> bool: for value in inputs.values(): diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index fee51007eb..6465de23b9 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -21,37 +21,36 @@ class OpenAIModeration(Moderation): flagged = False preset_response = "" - if self.config['inputs_config']['enabled']: - preset_response = self.config['inputs_config']['preset_response'] + if self.config["inputs_config"]["enabled"]: + preset_response = self.config["inputs_config"]["preset_response"] if query: - inputs['query__'] = query + inputs["query__"] = query flagged = self._is_violated(inputs) - return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationInputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" - if self.config['outputs_config']['enabled']: - flagged = self._is_violated({'text': text}) - preset_response = self.config['outputs_config']['preset_response'] + if self.config["outputs_config"]["enabled"]: + flagged = self._is_violated({"text": text}) + preset_response = self.config["outputs_config"]["preset_response"] - return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationOutputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def _is_violated(self, inputs: dict): - text = '\n'.join(str(inputs.values())) + text = "\n".join(str(inputs.values())) model_manager = ModelManager() model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, - provider="openai", - model_type=ModelType.MODERATION, - model="text-moderation-stable" + tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="text-moderation-stable" ) - openai_moderation = model_instance.invoke_moderation( - text=text - ) + openai_moderation = model_instance.invoke_moderation(text=text) return openai_moderation diff --git a/api/core/moderation/output_moderation.py b/api/core/moderation/output_moderation.py index 9a4d8db4e2..d8d794be18 100644 --- a/api/core/moderation/output_moderation.py +++ b/api/core/moderation/output_moderation.py @@ -29,18 +29,18 @@ class OutputModeration(BaseModel): thread: Optional[threading.Thread] = None thread_running: bool = True - buffer: str = '' + buffer: str = "" is_final_chunk: bool = False final_output: Optional[str] = None model_config = ConfigDict(arbitrary_types_allowed=True) - def should_direct_output(self): + def should_direct_output(self) -> bool: return self.final_output is not None - def get_final_output(self): - return self.final_output + def get_final_output(self) -> str: + return self.final_output or "" - def append_new_token(self, token: str): + def append_new_token(self, token: str) -> None: self.buffer += token if not self.thread: @@ -50,11 +50,7 @@ class OutputModeration(BaseModel): self.buffer = completion self.is_final_chunk = True - result = self.moderation( - tenant_id=self.tenant_id, - app_id=self.app_id, - moderation_buffer=completion - ) + result = self.moderation(tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=completion) if not result or not result.flagged: return completion @@ -65,21 +61,19 @@ class OutputModeration(BaseModel): final_output = result.text if public_event: - self.queue_manager.publish( - QueueMessageReplaceEvent( - text=final_output - ), - PublishFrom.TASK_PIPELINE - ) + self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE) return final_output def start_thread(self) -> threading.Thread: buffer_size = dify_config.MODERATION_BUFFER_SIZE - thread = threading.Thread(target=self.worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'buffer_size': buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE - }) + thread = threading.Thread( + target=self.worker, + kwargs={ + "flask_app": current_app._get_current_object(), + "buffer_size": buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE, + }, + ) thread.start() @@ -104,9 +98,7 @@ class OutputModeration(BaseModel): current_length = buffer_length result = self.moderation( - tenant_id=self.tenant_id, - app_id=self.app_id, - moderation_buffer=moderation_buffer + tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=moderation_buffer ) if not result or not result.flagged: @@ -116,16 +108,11 @@ class OutputModeration(BaseModel): final_output = result.preset_response self.final_output = final_output else: - final_output = result.text + self.buffer[len(moderation_buffer):] + final_output = result.text + self.buffer[len(moderation_buffer) :] # trigger replace event if self.thread_running: - self.queue_manager.publish( - QueueMessageReplaceEvent( - text=final_output - ), - PublishFrom.TASK_PIPELINE - ) + self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE) if result.action == ModerationAction.DIRECT_OUTPUT: break @@ -133,10 +120,7 @@ class OutputModeration(BaseModel): def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]: try: moderation_factory = ModerationFactory( - name=self.rule.type, - app_id=app_id, - tenant_id=tenant_id, - config=self.rule.config + name=self.rule.type, app_id=app_id, tenant_id=tenant_id, config=self.rule.config ) result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer) diff --git a/api/core/ops/base_trace_instance.py b/api/core/ops/base_trace_instance.py index c7af8e2963..f7b882fc71 100644 --- a/api/core/ops/base_trace_instance.py +++ b/api/core/ops/base_trace_instance.py @@ -23,4 +23,4 @@ class BaseTraceInstance(ABC): Abstract method to trace activities. Subclasses must implement specific tracing logic for activities. """ - ... \ No newline at end of file + ... diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py index 221e6239ab..5c79867571 100644 --- a/api/core/ops/entities/config_entity.py +++ b/api/core/ops/entities/config_entity.py @@ -4,14 +4,15 @@ from pydantic import BaseModel, ValidationInfo, field_validator class TracingProviderEnum(Enum): - LANGFUSE = 'langfuse' - LANGSMITH = 'langsmith' + LANGFUSE = "langfuse" + LANGSMITH = "langsmith" class BaseTracingConfig(BaseModel): """ Base model class for tracing """ + ... @@ -19,16 +20,18 @@ class LangfuseConfig(BaseTracingConfig): """ Model class for Langfuse tracing config. """ + public_key: str secret_key: str - host: str = 'https://api.langfuse.com' + host: str = "https://api.langfuse.com" @field_validator("host") + @classmethod def set_value(cls, v, info: ValidationInfo): if v is None or v == "": - v = 'https://api.langfuse.com' - if not v.startswith('https://') and not v.startswith('http://'): - raise ValueError('host must start with https:// or http://') + v = "https://api.langfuse.com" + if not v.startswith("https://") and not v.startswith("http://"): + raise ValueError("host must start with https:// or http://") return v @@ -37,15 +40,17 @@ class LangSmithConfig(BaseTracingConfig): """ Model class for Langsmith tracing config. """ + api_key: str project: str - endpoint: str = 'https://api.smith.langchain.com' + endpoint: str = "https://api.smith.langchain.com" @field_validator("endpoint") + @classmethod def set_value(cls, v, info: ValidationInfo): if v is None or v == "": - v = 'https://api.smith.langchain.com' - if not v.startswith('https://'): - raise ValueError('endpoint must start with https://') + v = "https://api.smith.langchain.com" + if not v.startswith("https://"): + raise ValueError("endpoint must start with https://") return v diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index a1443f0691..f27a0af6e0 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -15,6 +15,7 @@ class BaseTraceInfo(BaseModel): metadata: dict[str, Any] @field_validator("inputs", "outputs") + @classmethod def ensure_type(cls, v): if v is None: return None @@ -23,6 +24,7 @@ class BaseTraceInfo(BaseModel): else: return "" + class WorkflowTraceInfo(BaseTraceInfo): workflow_data: Any conversation_id: Optional[str] = None @@ -98,23 +100,24 @@ class GenerateNameTraceInfo(BaseTraceInfo): conversation_id: Optional[str] = None tenant_id: str + trace_info_info_map = { - 'WorkflowTraceInfo': WorkflowTraceInfo, - 'MessageTraceInfo': MessageTraceInfo, - 'ModerationTraceInfo': ModerationTraceInfo, - 'SuggestedQuestionTraceInfo': SuggestedQuestionTraceInfo, - 'DatasetRetrievalTraceInfo': DatasetRetrievalTraceInfo, - 'ToolTraceInfo': ToolTraceInfo, - 'GenerateNameTraceInfo': GenerateNameTraceInfo, + "WorkflowTraceInfo": WorkflowTraceInfo, + "MessageTraceInfo": MessageTraceInfo, + "ModerationTraceInfo": ModerationTraceInfo, + "SuggestedQuestionTraceInfo": SuggestedQuestionTraceInfo, + "DatasetRetrievalTraceInfo": DatasetRetrievalTraceInfo, + "ToolTraceInfo": ToolTraceInfo, + "GenerateNameTraceInfo": GenerateNameTraceInfo, } class TraceTaskName(str, Enum): - CONVERSATION_TRACE = 'conversation' - WORKFLOW_TRACE = 'workflow' - MESSAGE_TRACE = 'message' - MODERATION_TRACE = 'moderation' - SUGGESTED_QUESTION_TRACE = 'suggested_question' - DATASET_RETRIEVAL_TRACE = 'dataset_retrieval' - TOOL_TRACE = 'tool' - GENERATE_NAME_TRACE = 'generate_conversation_name' + CONVERSATION_TRACE = "conversation" + WORKFLOW_TRACE = "workflow" + MESSAGE_TRACE = "message" + MODERATION_TRACE = "moderation" + SUGGESTED_QUESTION_TRACE = "suggested_question" + DATASET_RETRIEVAL_TRACE = "dataset_retrieval" + TOOL_TRACE = "tool" + GENERATE_NAME_TRACE = "generate_conversation_name" diff --git a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py index af7661f0af..447b799f1f 100644 --- a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py +++ b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py @@ -101,6 +101,7 @@ class LangfuseTrace(BaseModel): ) @field_validator("input", "output") + @classmethod def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name return validate_input_output(v, field_name) @@ -171,6 +172,7 @@ class LangfuseSpan(BaseModel): ) @field_validator("input", "output") + @classmethod def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name return validate_input_output(v, field_name) @@ -196,6 +198,7 @@ class GenerationUsage(BaseModel): totalCost: Optional[float] = None @field_validator("input", "output") + @classmethod def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name return validate_input_output(v, field_name) @@ -273,6 +276,7 @@ class LangfuseGeneration(BaseModel): model_config = ConfigDict(protected_namespaces=()) @field_validator("input", "output") + @classmethod def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name return validate_input_output(v, field_name) diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index a21c67ed50..a0f3ac7f86 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -204,6 +204,7 @@ class LangFuseDataTrace(BaseTraceInstance): node_generation_data = LangfuseGeneration( name="llm", trace_id=trace_id, + model=process_data.get("model_name"), parent_observation_id=node_execution_id, start_time=created_at, end_time=finished_at, diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py index f3fc46d99a..05c932fb99 100644 --- a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py +++ b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py @@ -35,49 +35,32 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): run_type: LangSmithRunType = Field(..., description="Type of the run") start_time: Optional[datetime | str] = Field(None, description="Start time of the run") end_time: Optional[datetime | str] = Field(None, description="End time of the run") - extra: Optional[dict[str, Any]] = Field( - None, description="Extra information of the run" - ) + extra: Optional[dict[str, Any]] = Field(None, description="Extra information of the run") error: Optional[str] = Field(None, description="Error message of the run") - serialized: Optional[dict[str, Any]] = Field( - None, description="Serialized data of the run" - ) + serialized: Optional[dict[str, Any]] = Field(None, description="Serialized data of the run") parent_run_id: Optional[str] = Field(None, description="Parent run ID") - events: Optional[list[dict[str, Any]]] = Field( - None, description="Events associated with the run" - ) + events: Optional[list[dict[str, Any]]] = Field(None, description="Events associated with the run") tags: Optional[list[str]] = Field(None, description="Tags associated with the run") - trace_id: Optional[str] = Field( - None, description="Trace ID associated with the run" - ) + trace_id: Optional[str] = Field(None, description="Trace ID associated with the run") dotted_order: Optional[str] = Field(None, description="Dotted order of the run") id: Optional[str] = Field(None, description="ID of the run") - session_id: Optional[str] = Field( - None, description="Session ID associated with the run" - ) - session_name: Optional[str] = Field( - None, description="Session name associated with the run" - ) - reference_example_id: Optional[str] = Field( - None, description="Reference example ID associated with the run" - ) - input_attachments: Optional[dict[str, Any]] = Field( - None, description="Input attachments of the run" - ) - output_attachments: Optional[dict[str, Any]] = Field( - None, description="Output attachments of the run" - ) + session_id: Optional[str] = Field(None, description="Session ID associated with the run") + session_name: Optional[str] = Field(None, description="Session name associated with the run") + reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run") + input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run") + output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run") @field_validator("inputs", "outputs") + @classmethod def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name values = info.data if v == {} or v is None: return v usage_metadata = { - "input_tokens": values.get('input_tokens', 0), - "output_tokens": values.get('output_tokens', 0), - "total_tokens": values.get('total_tokens', 0), + "input_tokens": values.get("input_tokens", 0), + "output_tokens": values.get("output_tokens", 0), + "total_tokens": values.get("total_tokens", 0), } file_list = values.get("file_list", []) if isinstance(v, str): @@ -133,6 +116,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): return v return v + @classmethod @field_validator("start_time", "end_time") def format_time(cls, v, info: ValidationInfo): if not isinstance(v, datetime): @@ -143,25 +127,15 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): class LangSmithRunUpdateModel(BaseModel): run_id: str = Field(..., description="ID of the run") - trace_id: Optional[str] = Field( - None, description="Trace ID associated with the run" - ) + trace_id: Optional[str] = Field(None, description="Trace ID associated with the run") dotted_order: Optional[str] = Field(None, description="Dotted order of the run") parent_run_id: Optional[str] = Field(None, description="Parent run ID") end_time: Optional[datetime | str] = Field(None, description="End time of the run") error: Optional[str] = Field(None, description="Error message of the run") inputs: Optional[dict[str, Any]] = Field(None, description="Inputs of the run") outputs: Optional[dict[str, Any]] = Field(None, description="Outputs of the run") - events: Optional[list[dict[str, Any]]] = Field( - None, description="Events associated with the run" - ) + events: Optional[list[dict[str, Any]]] = Field(None, description="Events associated with the run") tags: Optional[list[str]] = Field(None, description="Tags associated with the run") - extra: Optional[dict[str, Any]] = Field( - None, description="Extra information of the run" - ) - input_attachments: Optional[dict[str, Any]] = Field( - None, description="Input attachments of the run" - ) - output_attachments: Optional[dict[str, Any]] = Field( - None, description="Output attachments of the run" - ) + extra: Optional[dict[str, Any]] = Field(None, description="Extra information of the run") + input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run") + output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run") diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index fde8a06c61..eea7bb3535 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -1,9 +1,11 @@ import json import logging import os +import uuid from datetime import datetime, timedelta from langsmith import Client +from langsmith.schemas import RunBase from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import LangSmithConfig @@ -139,8 +141,7 @@ class LangSmithDataTrace(BaseTraceInstance): json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} ) node_total_tokens = execution_metadata.get("total_tokens", 0) - - metadata = json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} + metadata = execution_metadata.copy() metadata.update( { "workflow_run_id": trace_info.workflow_run_id, @@ -156,6 +157,12 @@ class LangSmithDataTrace(BaseTraceInstance): process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} if process_data and process_data.get("model_mode") == "chat": run_type = LangSmithRunType.llm + metadata.update( + { + "ls_provider": process_data.get("model_provider", ""), + "ls_model_name": process_data.get("model_name", ""), + } + ) elif node_type == "knowledge-retrieval": run_type = LangSmithRunType.retriever else: @@ -366,3 +373,22 @@ class LangSmithDataTrace(BaseTraceInstance): except Exception as e: logger.debug(f"LangSmith API check failed: {str(e)}") raise ValueError(f"LangSmith API check failed: {str(e)}") + + def get_project_url(self): + try: + run_data = RunBase( + id=uuid.uuid4(), + name="tool", + inputs={"input": "test"}, + outputs={"output": "test"}, + run_type=LangSmithRunType.tool, + start_time=datetime.now(), + ) + + project_url = self.langsmith_client.get_run_url( + run=run_data, project_id=self.project_id, project_name=self.project_name + ) + return project_url.split("/r/")[0] + except Exception as e: + logger.debug(f"LangSmith get run url failed: {str(e)}") + raise ValueError(f"LangSmith get run url failed: {str(e)}") diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 1416d6bd2d..d6156e479a 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -36,17 +36,17 @@ from tasks.ops_trace_task import process_trace_tasks provider_config_map = { TracingProviderEnum.LANGFUSE.value: { - 'config_class': LangfuseConfig, - 'secret_keys': ['public_key', 'secret_key'], - 'other_keys': ['host', 'project_key'], - 'trace_instance': LangFuseDataTrace + "config_class": LangfuseConfig, + "secret_keys": ["public_key", "secret_key"], + "other_keys": ["host", "project_key"], + "trace_instance": LangFuseDataTrace, }, TracingProviderEnum.LANGSMITH.value: { - 'config_class': LangSmithConfig, - 'secret_keys': ['api_key'], - 'other_keys': ['project', 'endpoint'], - 'trace_instance': LangSmithDataTrace - } + "config_class": LangSmithConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "endpoint"], + "trace_instance": LangSmithDataTrace, + }, } @@ -64,14 +64,17 @@ class OpsTraceManager: :return: encrypted tracing configuration """ # Get the configuration class and the keys that require encryption - config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys'] + config_class, secret_keys, other_keys = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["secret_keys"], + provider_config_map[tracing_provider]["other_keys"], + ) new_config = {} # Encrypt necessary keys for key in secret_keys: if key in tracing_config: - if '*' in tracing_config[key]: + if "*" in tracing_config[key]: # If the key contains '*', retain the original value from the current config new_config[key] = current_trace_config.get(key, tracing_config[key]) else: @@ -94,8 +97,11 @@ class OpsTraceManager: :param tracing_config: tracing config :return: """ - config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys'] + config_class, secret_keys, other_keys = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["secret_keys"], + provider_config_map[tracing_provider]["other_keys"], + ) new_config = {} for key in secret_keys: if key in tracing_config: @@ -114,8 +120,11 @@ class OpsTraceManager: :param decrypt_tracing_config: tracing config :return: """ - config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys'] + config_class, secret_keys, other_keys = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["secret_keys"], + provider_config_map[tracing_provider]["other_keys"], + ) new_config = {} for key in secret_keys: if key in decrypt_tracing_config: @@ -133,9 +142,11 @@ class OpsTraceManager: :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter( - TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider - ).first() + trace_config_data: TraceAppConfig = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) if not trace_config_data: return None @@ -164,21 +175,21 @@ class OpsTraceManager: if app_id is None: return None - app: App = db.session.query(App).filter( - App.id == app_id - ).first() + app: App = db.session.query(App).filter(App.id == app_id).first() app_ops_trace_config = json.loads(app.tracing) if app.tracing else None if app_ops_trace_config is not None: - tracing_provider = app_ops_trace_config.get('tracing_provider') + tracing_provider = app_ops_trace_config.get("tracing_provider") else: return None # decrypt_token decrypt_trace_config = cls.get_decrypted_tracing_config(app_id, tracing_provider) - if app_ops_trace_config.get('enabled'): - trace_instance, config_class = provider_config_map[tracing_provider]['trace_instance'], \ - provider_config_map[tracing_provider]['config_class'] + if app_ops_trace_config.get("enabled"): + trace_instance, config_class = ( + provider_config_map[tracing_provider]["trace_instance"], + provider_config_map[tracing_provider]["config_class"], + ) tracing_instance = trace_instance(config_class(**decrypt_trace_config)) return tracing_instance @@ -192,9 +203,11 @@ class OpsTraceManager: conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() if conversation_data.app_model_config_id: - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == conversation_data.app_model_config_id - ).first() + app_model_config = ( + db.session.query(AppModelConfig) + .filter(AppModelConfig.id == conversation_data.app_model_config_id) + .first() + ) elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs: app_model_config = conversation_data.override_model_configs @@ -231,10 +244,7 @@ class OpsTraceManager: """ app: App = db.session.query(App).filter(App.id == app_id).first() if not app.tracing: - return { - "enabled": False, - "tracing_provider": None - } + return {"enabled": False, "tracing_provider": None} app_trace_config = json.loads(app.tracing) return app_trace_config @@ -246,8 +256,10 @@ class OpsTraceManager: :param tracing_provider: tracing provider :return: """ - config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['trace_instance'] + config_type, trace_instance = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["trace_instance"], + ) tracing_config = config_type(**tracing_config) return trace_instance(tracing_config).api_check() @@ -259,11 +271,28 @@ class OpsTraceManager: :param tracing_provider: tracing provider :return: """ - config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['trace_instance'] + config_type, trace_instance = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["trace_instance"], + ) tracing_config = config_type(**tracing_config) return trace_instance(tracing_config).get_project_key() + @staticmethod + def get_trace_config_project_url(tracing_config: dict, tracing_provider: str): + """ + get trace config is project key + :param tracing_config: tracing config + :param tracing_provider: tracing provider + :return: + """ + config_type, trace_instance = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["trace_instance"], + ) + tracing_config = config_type(**tracing_config) + return trace_instance(tracing_config).get_project_url() + class TraceTask: def __init__( @@ -274,7 +303,7 @@ class TraceTask: conversation_id: Optional[str] = None, user_id: Optional[str] = None, timer: Optional[Any] = None, - **kwargs + **kwargs, ): self.trace_type = trace_type self.message_id = message_id @@ -297,9 +326,7 @@ class TraceTask: self.workflow_run, self.conversation_id, self.user_id ), TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(self.message_id), - TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace( - self.message_id, self.timer, **self.kwargs - ), + TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(self.message_id, self.timer, **self.kwargs), TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace( self.message_id, self.timer, **self.kwargs ), @@ -324,12 +351,8 @@ class TraceTask: workflow_run_id = workflow_run.id workflow_run_elapsed_time = workflow_run.elapsed_time workflow_run_status = workflow_run.status - workflow_run_inputs = ( - json.loads(workflow_run.inputs) if workflow_run.inputs else {} - ) - workflow_run_outputs = ( - json.loads(workflow_run.outputs) if workflow_run.outputs else {} - ) + workflow_run_inputs = json.loads(workflow_run.inputs) if workflow_run.inputs else {} + workflow_run_outputs = json.loads(workflow_run.outputs) if workflow_run.outputs else {} workflow_run_version = workflow_run.version error = workflow_run.error if workflow_run.error else "" @@ -339,17 +362,18 @@ class TraceTask: query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" # get workflow_app_log_id - workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by( - tenant_id=tenant_id, - app_id=workflow_run.app_id, - workflow_run_id=workflow_run.id - ).first() + workflow_app_log_data = ( + db.session.query(WorkflowAppLog) + .filter_by(tenant_id=tenant_id, app_id=workflow_run.app_id, workflow_run_id=workflow_run.id) + .first() + ) workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None # get message_id - message_data = db.session.query(Message.id).filter_by( - conversation_id=conversation_id, - workflow_run_id=workflow_run_id - ).first() + message_data = ( + db.session.query(Message.id) + .filter_by(conversation_id=conversation_id, workflow_run_id=workflow_run_id) + .first() + ) message_id = str(message_data.id) if message_data else None metadata = { @@ -457,9 +481,9 @@ class TraceTask: # get workflow_app_log_id workflow_app_log_id = None if message_data.workflow_run_id: - workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by( - workflow_run_id=message_data.workflow_run_id - ).first() + workflow_app_log_data = ( + db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first() + ) workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None moderation_trace_info = ModerationTraceInfo( @@ -497,9 +521,9 @@ class TraceTask: # get workflow_app_log_id workflow_app_log_id = None if message_data.workflow_run_id: - workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by( - workflow_run_id=message_data.workflow_run_id - ).first() + workflow_app_log_data = ( + db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first() + ) workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None suggested_question_trace_info = SuggestedQuestionTraceInfo( @@ -556,9 +580,9 @@ class TraceTask: return dataset_retrieval_trace_info def tool_trace(self, message_id, timer, **kwargs): - tool_name = kwargs.get('tool_name') - tool_inputs = kwargs.get('tool_inputs') - tool_outputs = kwargs.get('tool_outputs') + tool_name = kwargs.get("tool_name") + tool_inputs = kwargs.get("tool_inputs") + tool_outputs = kwargs.get("tool_outputs") message_data = get_message_data(message_id) if not message_data: return {} @@ -573,11 +597,11 @@ class TraceTask: if tool_name in agent_thought.tools: created_time = agent_thought.created_at tool_meta_data = agent_thought.tool_meta.get(tool_name, {}) - tool_config = tool_meta_data.get('tool_config', {}) - time_cost = tool_meta_data.get('time_cost', 0) + tool_config = tool_meta_data.get("tool_config", {}) + time_cost = tool_meta_data.get("time_cost", 0) end_time = created_time + timedelta(seconds=time_cost) - error = tool_meta_data.get('error', "") - tool_parameters = tool_meta_data.get('tool_parameters', {}) + error = tool_meta_data.get("error", "") + tool_parameters = tool_meta_data.get("tool_parameters", {}) metadata = { "message_id": message_id, "tool_name": tool_name, @@ -702,9 +726,7 @@ class TraceQueueManager: def start_timer(self): global trace_manager_timer if trace_manager_timer is None or not trace_manager_timer.is_alive(): - trace_manager_timer = threading.Timer( - trace_manager_interval, self.run - ) + trace_manager_timer = threading.Timer(trace_manager_interval, self.run) trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}" trace_manager_timer.daemon = False trace_manager_timer.start() diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py index 3b2e04abb7..498685b342 100644 --- a/api/core/ops/utils.py +++ b/api/core/ops/utils.py @@ -20,19 +20,19 @@ def get_message_data(message_id): @contextmanager def measure_time(): - timing_info = {'start': datetime.now(), 'end': None} + timing_info = {"start": datetime.now(), "end": None} try: yield timing_info finally: - timing_info['end'] = datetime.now() + timing_info["end"] = datetime.now() def replace_text_with_content(data): if isinstance(data, dict): new_data = {} for key, value in data.items(): - if key == 'text': - new_data['content'] = value + if key == "text": + new_data["content"] = value else: new_data[key] = replace_text_with_content(value) return new_data diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 22420fea2c..ce8038d14e 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -22,18 +22,22 @@ class AdvancedPromptTransform(PromptTransform): """ Advanced Prompt Transform for Workflow LLM Node. """ + def __init__(self, with_variable_tmpl: bool = False) -> None: self.with_variable_tmpl = with_variable_tmpl - def get_prompt(self, prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate], - inputs: dict, - query: str, - files: list[FileVar], - context: Optional[str], - memory_config: Optional[MemoryConfig], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity, - query_prompt_template: Optional[str] = None) -> list[PromptMessage]: + def get_prompt( + self, + prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate], + inputs: dict, + query: str, + files: list[FileVar], + context: Optional[str], + memory_config: Optional[MemoryConfig], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + query_prompt_template: Optional[str] = None, + ) -> list[PromptMessage]: inputs = {key: str(value) for key, value in inputs.items()} prompt_messages = [] @@ -48,7 +52,7 @@ class AdvancedPromptTransform(PromptTransform): context=context, memory_config=memory_config, memory=memory, - model_config=model_config + model_config=model_config, ) elif model_mode == ModelMode.CHAT: prompt_messages = self._get_chat_model_prompt_messages( @@ -60,20 +64,22 @@ class AdvancedPromptTransform(PromptTransform): context=context, memory_config=memory_config, memory=memory, - model_config=model_config + model_config=model_config, ) return prompt_messages - def _get_completion_model_prompt_messages(self, - prompt_template: CompletionModelPromptTemplate, - inputs: dict, - query: Optional[str], - files: list[FileVar], - context: Optional[str], - memory_config: Optional[MemoryConfig], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: + def _get_completion_model_prompt_messages( + self, + prompt_template: CompletionModelPromptTemplate, + inputs: dict, + query: Optional[str], + files: list[FileVar], + context: Optional[str], + memory_config: Optional[MemoryConfig], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> list[PromptMessage]: """ Get completion model prompt messages. """ @@ -81,7 +87,7 @@ class AdvancedPromptTransform(PromptTransform): prompt_messages = [] - if prompt_template.edition_type == 'basic' or not prompt_template.edition_type: + if prompt_template.edition_type == "basic" or not prompt_template.edition_type: prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} @@ -96,15 +102,13 @@ class AdvancedPromptTransform(PromptTransform): role_prefix=role_prefix, prompt_template=prompt_template, prompt_inputs=prompt_inputs, - model_config=model_config + model_config=model_config, ) if query: prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs) - prompt = prompt_template.format( - prompt_inputs - ) + prompt = prompt_template.format(prompt_inputs) else: prompt = raw_prompt prompt_inputs = inputs @@ -122,16 +126,18 @@ class AdvancedPromptTransform(PromptTransform): return prompt_messages - def _get_chat_model_prompt_messages(self, - prompt_template: list[ChatModelMessage], - inputs: dict, - query: Optional[str], - files: list[FileVar], - context: Optional[str], - memory_config: Optional[MemoryConfig], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity, - query_prompt_template: Optional[str] = None) -> list[PromptMessage]: + def _get_chat_model_prompt_messages( + self, + prompt_template: list[ChatModelMessage], + inputs: dict, + query: Optional[str], + files: list[FileVar], + context: Optional[str], + memory_config: Optional[MemoryConfig], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + query_prompt_template: Optional[str] = None, + ) -> list[PromptMessage]: """ Get chat model prompt messages. """ @@ -142,22 +148,20 @@ class AdvancedPromptTransform(PromptTransform): for prompt_item in raw_prompt_list: raw_prompt = prompt_item.text - if prompt_item.edition_type == 'basic' or not prompt_item.edition_type: + if prompt_item.edition_type == "basic" or not prompt_item.edition_type: prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) - prompt = prompt_template.format( - prompt_inputs - ) - elif prompt_item.edition_type == 'jinja2': + prompt = prompt_template.format(prompt_inputs) + elif prompt_item.edition_type == "jinja2": prompt = raw_prompt prompt_inputs = inputs prompt = Jinja2Formatter.format(prompt, prompt_inputs) else: - raise ValueError(f'Invalid edition type: {prompt_item.edition_type}') + raise ValueError(f"Invalid edition type: {prompt_item.edition_type}") if prompt_item.role == PromptMessageRole.USER: prompt_messages.append(UserPromptMessage(content=prompt)) @@ -168,17 +172,14 @@ class AdvancedPromptTransform(PromptTransform): if query and query_prompt_template: prompt_template = PromptTemplateParser( - template=query_prompt_template, - with_variable_tmpl=self.with_variable_tmpl + template=query_prompt_template, with_variable_tmpl=self.with_variable_tmpl ) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - prompt_inputs['#sys.query#'] = query + prompt_inputs["#sys.query#"] = query prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) - query = prompt_template.format( - prompt_inputs - ) + query = prompt_template.format(prompt_inputs) if memory and memory_config: prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) @@ -203,7 +204,7 @@ class AdvancedPromptTransform(PromptTransform): last_message.content = prompt_message_contents else: - prompt_message_contents = [TextPromptMessageContent(data='')] # not for query + prompt_message_contents = [TextPromptMessageContent(data="")] # not for query for file in files: prompt_message_contents.append(file.prompt_message_content) @@ -220,38 +221,39 @@ class AdvancedPromptTransform(PromptTransform): return prompt_messages def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict: - if '#context#' in prompt_template.variable_keys: + if "#context#" in prompt_template.variable_keys: if context: - prompt_inputs['#context#'] = context + prompt_inputs["#context#"] = context else: - prompt_inputs['#context#'] = '' + prompt_inputs["#context#"] = "" return prompt_inputs def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict: - if '#query#' in prompt_template.variable_keys: + if "#query#" in prompt_template.variable_keys: if query: - prompt_inputs['#query#'] = query + prompt_inputs["#query#"] = query else: - prompt_inputs['#query#'] = '' + prompt_inputs["#query#"] = "" return prompt_inputs - def _set_histories_variable(self, memory: TokenBufferMemory, - memory_config: MemoryConfig, - raw_prompt: str, - role_prefix: MemoryConfig.RolePrefix, - prompt_template: PromptTemplateParser, - prompt_inputs: dict, - model_config: ModelConfigWithCredentialsEntity) -> dict: - if '#histories#' in prompt_template.variable_keys: + def _set_histories_variable( + self, + memory: TokenBufferMemory, + memory_config: MemoryConfig, + raw_prompt: str, + role_prefix: MemoryConfig.RolePrefix, + prompt_template: PromptTemplateParser, + prompt_inputs: dict, + model_config: ModelConfigWithCredentialsEntity, + ) -> dict: + if "#histories#" in prompt_template.variable_keys: if memory: - inputs = {'#histories#': '', **prompt_inputs} + inputs = {"#histories#": "", **prompt_inputs} prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - tmp_human_message = UserPromptMessage( - content=prompt_template.format(prompt_inputs) - ) + tmp_human_message = UserPromptMessage(content=prompt_template.format(prompt_inputs)) rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) @@ -260,10 +262,10 @@ class AdvancedPromptTransform(PromptTransform): memory_config=memory_config, max_token_limit=rest_tokens, human_prefix=role_prefix.user, - ai_prefix=role_prefix.assistant + ai_prefix=role_prefix.assistant, ) - prompt_inputs['#histories#'] = histories + prompt_inputs["#histories#"] = histories else: - prompt_inputs['#histories#'] = '' + prompt_inputs["#histories#"] = "" return prompt_inputs diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index af0075ea91..caa1793ea8 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -17,12 +17,14 @@ class AgentHistoryPromptTransform(PromptTransform): """ History Prompt Transform for Agent App """ - def __init__(self, - model_config: ModelConfigWithCredentialsEntity, - prompt_messages: list[PromptMessage], - history_messages: list[PromptMessage], - memory: Optional[TokenBufferMemory] = None, - ): + + def __init__( + self, + model_config: ModelConfigWithCredentialsEntity, + prompt_messages: list[PromptMessage], + history_messages: list[PromptMessage], + memory: Optional[TokenBufferMemory] = None, + ): self.model_config = model_config self.prompt_messages = prompt_messages self.history_messages = history_messages @@ -45,9 +47,7 @@ class AgentHistoryPromptTransform(PromptTransform): model_type_instance = cast(LargeLanguageModel, model_type_instance) curr_message_tokens = model_type_instance.get_num_tokens( - self.memory.model_instance.model, - self.memory.model_instance.credentials, - self.history_messages + self.memory.model_instance.model, self.memory.model_instance.credentials, self.history_messages ) if curr_message_tokens <= max_token_limit: return self.history_messages @@ -63,9 +63,7 @@ class AgentHistoryPromptTransform(PromptTransform): # a message is start with UserPromptMessage if isinstance(prompt_message, UserPromptMessage): curr_message_tokens = model_type_instance.get_num_tokens( - self.memory.model_instance.model, - self.memory.model_instance.credentials, - prompt_messages + self.memory.model_instance.model, self.memory.model_instance.credentials, prompt_messages ) # if current message token is overflow, drop all the prompts in current message and break if curr_message_tokens > max_token_limit: diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py index 61df69163c..c8e7b414df 100644 --- a/api/core/prompt/entities/advanced_prompt_entities.py +++ b/api/core/prompt/entities/advanced_prompt_entities.py @@ -9,27 +9,31 @@ class ChatModelMessage(BaseModel): """ Chat Message. """ + text: str role: PromptMessageRole - edition_type: Optional[Literal['basic', 'jinja2']] = None + edition_type: Optional[Literal["basic", "jinja2"]] = None class CompletionModelPromptTemplate(BaseModel): """ Completion Model Prompt Template. """ + text: str - edition_type: Optional[Literal['basic', 'jinja2']] = None + edition_type: Optional[Literal["basic", "jinja2"]] = None class MemoryConfig(BaseModel): """ Memory Config. """ + class RolePrefix(BaseModel): """ Role Prefix. """ + user: str assistant: str @@ -37,6 +41,7 @@ class MemoryConfig(BaseModel): """ Window Config. """ + enabled: bool size: Optional[int] = None diff --git a/api/core/prompt/prompt_templates/advanced_prompt_templates.py b/api/core/prompt/prompt_templates/advanced_prompt_templates.py index da40534d99..e4b3a61cb4 100644 --- a/api/core/prompt/prompt_templates/advanced_prompt_templates.py +++ b/api/core/prompt/prompt_templates/advanced_prompt_templates.py @@ -7,39 +7,18 @@ CHAT_APP_COMPLETION_PROMPT_CONFIG = { "prompt": { "text": "{{#pre_prompt#}}\nHere is the chat histories between human and assistant, inside XML tags.\n\n\n{{#histories#}}\n\n\n\nHuman: {{#query#}}\n\nAssistant: " }, - "conversation_histories_role": { - "user_prefix": "Human", - "assistant_prefix": "Assistant" - } + "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, }, - "stop": ["Human:"] + "stop": ["Human:"], } -CHAT_APP_CHAT_PROMPT_CONFIG = { - "chat_prompt_config": { - "prompt": [{ - "role": "system", - "text": "{{#pre_prompt#}}" - }] - } -} +CHAT_APP_CHAT_PROMPT_CONFIG = {"chat_prompt_config": {"prompt": [{"role": "system", "text": "{{#pre_prompt#}}"}]}} -COMPLETION_APP_CHAT_PROMPT_CONFIG = { - "chat_prompt_config": { - "prompt": [{ - "role": "user", - "text": "{{#pre_prompt#}}" - }] - } -} +COMPLETION_APP_CHAT_PROMPT_CONFIG = {"chat_prompt_config": {"prompt": [{"role": "user", "text": "{{#pre_prompt#}}"}]}} COMPLETION_APP_COMPLETION_PROMPT_CONFIG = { - "completion_prompt_config": { - "prompt": { - "text": "{{#pre_prompt#}}" - } - }, - "stop": ["Human:"] + "completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}}, + "stop": ["Human:"], } BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = { @@ -47,37 +26,20 @@ BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = { "prompt": { "text": "{{#pre_prompt#}}\n\n用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n\n\n用户:{{#query#}}" }, - "conversation_histories_role": { - "user_prefix": "用户", - "assistant_prefix": "助手" - } + "conversation_histories_role": {"user_prefix": "用户", "assistant_prefix": "助手"}, }, - "stop": ["用户:"] + "stop": ["用户:"], } -BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = { - "chat_prompt_config": { - "prompt": [{ - "role": "system", - "text": "{{#pre_prompt#}}" - }] - } +BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = { + "chat_prompt_config": {"prompt": [{"role": "system", "text": "{{#pre_prompt#}}"}]} } BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG = { - "chat_prompt_config": { - "prompt": [{ - "role": "user", - "text": "{{#pre_prompt#}}" - }] - } + "chat_prompt_config": {"prompt": [{"role": "user", "text": "{{#pre_prompt#}}"}]} } BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = { - "completion_prompt_config": { - "prompt": { - "text": "{{#pre_prompt#}}" - } - }, - "stop": ["用户:"] + "completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}}, + "stop": ["用户:"], } diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index b86d3fa815..87acdb3c49 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -9,75 +9,78 @@ from core.prompt.entities.advanced_prompt_entities import MemoryConfig class PromptTransform: - def _append_chat_histories(self, memory: TokenBufferMemory, - memory_config: MemoryConfig, - prompt_messages: list[PromptMessage], - model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: + def _append_chat_histories( + self, + memory: TokenBufferMemory, + memory_config: MemoryConfig, + prompt_messages: list[PromptMessage], + model_config: ModelConfigWithCredentialsEntity, + ) -> list[PromptMessage]: rest_tokens = self._calculate_rest_token(prompt_messages, model_config) histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens) prompt_messages.extend(histories) return prompt_messages - def _calculate_rest_token(self, prompt_messages: list[PromptMessage], - model_config: ModelConfigWithCredentialsEntity) -> int: + def _calculate_rest_token( + self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity + ) -> int: rest_tokens = 2000 model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, - model=model_config.model + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) - curr_message_tokens = model_instance.get_llm_num_tokens( - prompt_messages - ) + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) max_tokens = 0 for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template) + ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens rest_tokens = max(rest_tokens, 0) return rest_tokens - def _get_history_messages_from_memory(self, memory: TokenBufferMemory, - memory_config: MemoryConfig, - max_token_limit: int, - human_prefix: Optional[str] = None, - ai_prefix: Optional[str] = None) -> str: + def _get_history_messages_from_memory( + self, + memory: TokenBufferMemory, + memory_config: MemoryConfig, + max_token_limit: int, + human_prefix: Optional[str] = None, + ai_prefix: Optional[str] = None, + ) -> str: """Get memory messages.""" - kwargs = { - "max_token_limit": max_token_limit - } + kwargs = {"max_token_limit": max_token_limit} if human_prefix: - kwargs['human_prefix'] = human_prefix + kwargs["human_prefix"] = human_prefix if ai_prefix: - kwargs['ai_prefix'] = ai_prefix + kwargs["ai_prefix"] = ai_prefix if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0: - kwargs['message_limit'] = memory_config.window.size + kwargs["message_limit"] = memory_config.window.size - return memory.get_history_prompt_text( - **kwargs - ) + return memory.get_history_prompt_text(**kwargs) - def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory, - memory_config: MemoryConfig, - max_token_limit: int) -> list[PromptMessage]: + def _get_history_messages_list_from_memory( + self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int + ) -> list[PromptMessage]: """Get memory messages.""" return memory.get_history_prompt_messages( max_token_limit=max_token_limit, message_limit=memory_config.window.size - if (memory_config.window.enabled - and memory_config.window.size is not None - and memory_config.window.size > 0) - else None + if ( + memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0 + ) + else None, ) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index fd7ed0181b..13e5c5253e 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -22,11 +22,11 @@ if TYPE_CHECKING: class ModelMode(enum.Enum): - COMPLETION = 'completion' - CHAT = 'chat' + COMPLETION = "completion" + CHAT = "chat" @classmethod - def value_of(cls, value: str) -> 'ModelMode': + def value_of(cls, value: str) -> "ModelMode": """ Get value of given mode. @@ -36,7 +36,7 @@ class ModelMode(enum.Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") prompt_file_contents = {} @@ -47,16 +47,17 @@ class SimplePromptTransform(PromptTransform): Simple Prompt Transform for Chatbot App Basic Mode. """ - def get_prompt(self, - app_mode: AppMode, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - query: str, - files: list["FileVar"], - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) -> \ - tuple[list[PromptMessage], Optional[list[str]]]: + def get_prompt( + self, + app_mode: AppMode, + prompt_template_entity: PromptTemplateEntity, + inputs: dict, + query: str, + files: list["FileVar"], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: inputs = {key: str(value) for key, value in inputs.items()} model_mode = ModelMode.value_of(model_config.mode) @@ -69,7 +70,7 @@ class SimplePromptTransform(PromptTransform): files=files, context=context, memory=memory, - model_config=model_config + model_config=model_config, ) else: prompt_messages, stops = self._get_completion_model_prompt_messages( @@ -80,19 +81,21 @@ class SimplePromptTransform(PromptTransform): files=files, context=context, memory=memory, - model_config=model_config + model_config=model_config, ) return prompt_messages, stops - def get_prompt_str_and_rules(self, app_mode: AppMode, - model_config: ModelConfigWithCredentialsEntity, - pre_prompt: str, - inputs: dict, - query: Optional[str] = None, - context: Optional[str] = None, - histories: Optional[str] = None, - ) -> tuple[str, dict]: + def get_prompt_str_and_rules( + self, + app_mode: AppMode, + model_config: ModelConfigWithCredentialsEntity, + pre_prompt: str, + inputs: dict, + query: Optional[str] = None, + context: Optional[str] = None, + histories: Optional[str] = None, + ) -> tuple[str, dict]: # get prompt template prompt_template_config = self.get_prompt_template( app_mode=app_mode, @@ -101,74 +104,75 @@ class SimplePromptTransform(PromptTransform): pre_prompt=pre_prompt, has_context=context is not None, query_in_prompt=query is not None, - with_memory_prompt=histories is not None + with_memory_prompt=histories is not None, ) - variables = {k: inputs[k] for k in prompt_template_config['custom_variable_keys'] if k in inputs} + variables = {k: inputs[k] for k in prompt_template_config["custom_variable_keys"] if k in inputs} - for v in prompt_template_config['special_variable_keys']: + for v in prompt_template_config["special_variable_keys"]: # support #context#, #query# and #histories# - if v == '#context#': - variables['#context#'] = context if context else '' - elif v == '#query#': - variables['#query#'] = query if query else '' - elif v == '#histories#': - variables['#histories#'] = histories if histories else '' + if v == "#context#": + variables["#context#"] = context if context else "" + elif v == "#query#": + variables["#query#"] = query if query else "" + elif v == "#histories#": + variables["#histories#"] = histories if histories else "" - prompt_template = prompt_template_config['prompt_template'] + prompt_template = prompt_template_config["prompt_template"] prompt = prompt_template.format(variables) - return prompt, prompt_template_config['prompt_rules'] + return prompt, prompt_template_config["prompt_rules"] - def get_prompt_template(self, app_mode: AppMode, - provider: str, - model: str, - pre_prompt: str, - has_context: bool, - query_in_prompt: bool, - with_memory_prompt: bool = False) -> dict: - prompt_rules = self._get_prompt_rule( - app_mode=app_mode, - provider=provider, - model=model - ) + def get_prompt_template( + self, + app_mode: AppMode, + provider: str, + model: str, + pre_prompt: str, + has_context: bool, + query_in_prompt: bool, + with_memory_prompt: bool = False, + ) -> dict: + prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model) custom_variable_keys = [] special_variable_keys = [] - prompt = '' - for order in prompt_rules['system_prompt_orders']: - if order == 'context_prompt' and has_context: - prompt += prompt_rules['context_prompt'] - special_variable_keys.append('#context#') - elif order == 'pre_prompt' and pre_prompt: - prompt += pre_prompt + '\n' + prompt = "" + for order in prompt_rules["system_prompt_orders"]: + if order == "context_prompt" and has_context: + prompt += prompt_rules["context_prompt"] + special_variable_keys.append("#context#") + elif order == "pre_prompt" and pre_prompt: + prompt += pre_prompt + "\n" pre_prompt_template = PromptTemplateParser(template=pre_prompt) custom_variable_keys = pre_prompt_template.variable_keys - elif order == 'histories_prompt' and with_memory_prompt: - prompt += prompt_rules['histories_prompt'] - special_variable_keys.append('#histories#') + elif order == "histories_prompt" and with_memory_prompt: + prompt += prompt_rules["histories_prompt"] + special_variable_keys.append("#histories#") if query_in_prompt: - prompt += prompt_rules.get('query_prompt', '{{#query#}}') - special_variable_keys.append('#query#') + prompt += prompt_rules.get("query_prompt", "{{#query#}}") + special_variable_keys.append("#query#") return { "prompt_template": PromptTemplateParser(template=prompt), "custom_variable_keys": custom_variable_keys, "special_variable_keys": special_variable_keys, - "prompt_rules": prompt_rules + "prompt_rules": prompt_rules, } - def _get_chat_model_prompt_messages(self, app_mode: AppMode, - pre_prompt: str, - inputs: dict, - query: str, - context: Optional[str], - files: list["FileVar"], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) \ - -> tuple[list[PromptMessage], Optional[list[str]]]: + def _get_chat_model_prompt_messages( + self, + app_mode: AppMode, + pre_prompt: str, + inputs: dict, + query: str, + context: Optional[str], + files: list["FileVar"], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: prompt_messages = [] # get prompt @@ -178,7 +182,7 @@ class SimplePromptTransform(PromptTransform): pre_prompt=pre_prompt, inputs=inputs, query=None, - context=context + context=context, ) if prompt and query: @@ -193,7 +197,7 @@ class SimplePromptTransform(PromptTransform): ) ), prompt_messages=prompt_messages, - model_config=model_config + model_config=model_config, ) if query: @@ -203,15 +207,17 @@ class SimplePromptTransform(PromptTransform): return prompt_messages, None - def _get_completion_model_prompt_messages(self, app_mode: AppMode, - pre_prompt: str, - inputs: dict, - query: str, - context: Optional[str], - files: list["FileVar"], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) \ - -> tuple[list[PromptMessage], Optional[list[str]]]: + def _get_completion_model_prompt_messages( + self, + app_mode: AppMode, + pre_prompt: str, + inputs: dict, + query: str, + context: Optional[str], + files: list["FileVar"], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: # get prompt prompt, prompt_rules = self.get_prompt_str_and_rules( app_mode=app_mode, @@ -219,13 +225,11 @@ class SimplePromptTransform(PromptTransform): pre_prompt=pre_prompt, inputs=inputs, query=query, - context=context + context=context, ) if memory: - tmp_human_message = UserPromptMessage( - content=prompt - ) + tmp_human_message = UserPromptMessage(content=prompt) rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) histories = self._get_history_messages_from_memory( @@ -236,8 +240,8 @@ class SimplePromptTransform(PromptTransform): ) ), max_token_limit=rest_tokens, - human_prefix=prompt_rules.get('human_prefix', 'Human'), - ai_prefix=prompt_rules.get('assistant_prefix', 'Assistant') + human_prefix=prompt_rules.get("human_prefix", "Human"), + ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"), ) # get prompt @@ -248,10 +252,10 @@ class SimplePromptTransform(PromptTransform): inputs=inputs, query=query, context=context, - histories=histories + histories=histories, ) - stops = prompt_rules.get('stops') + stops = prompt_rules.get("stops") if stops is not None and len(stops) == 0: stops = None @@ -277,22 +281,18 @@ class SimplePromptTransform(PromptTransform): :param model: model name :return: """ - prompt_file_name = self._prompt_file_name( - app_mode=app_mode, - provider=provider, - model=model - ) + prompt_file_name = self._prompt_file_name(app_mode=app_mode, provider=provider, model=model) # Check if the prompt file is already loaded if prompt_file_name in prompt_file_contents: return prompt_file_contents[prompt_file_name] # Get the absolute path of the subdirectory - prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'prompt_templates') - json_file_path = os.path.join(prompt_path, f'{prompt_file_name}.json') + prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates") + json_file_path = os.path.join(prompt_path, f"{prompt_file_name}.json") # Open the JSON file and read its content - with open(json_file_path, encoding='utf-8') as json_file: + with open(json_file_path, encoding="utf-8") as json_file: content = json.load(json_file) # Store the content of the prompt file @@ -303,21 +303,21 @@ class SimplePromptTransform(PromptTransform): def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str: # baichuan is_baichuan = False - if provider == 'baichuan': + if provider == "baichuan": is_baichuan = True else: baichuan_supported_providers = ["huggingface_hub", "openllm", "xinference"] - if provider in baichuan_supported_providers and 'baichuan' in model.lower(): + if provider in baichuan_supported_providers and "baichuan" in model.lower(): is_baichuan = True if is_baichuan: if app_mode == AppMode.COMPLETION: - return 'baichuan_completion' + return "baichuan_completion" else: - return 'baichuan_chat' + return "baichuan_chat" # common if app_mode == AppMode.COMPLETION: - return 'common_completion' + return "common_completion" else: - return 'common_chat' + return "common_chat" diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index befdceeda5..29494db221 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -25,26 +25,29 @@ class PromptMessageUtil: tool_calls = [] for prompt_message in prompt_messages: if prompt_message.role == PromptMessageRole.USER: - role = 'user' + role = "user" elif prompt_message.role == PromptMessageRole.ASSISTANT: - role = 'assistant' + role = "assistant" if isinstance(prompt_message, AssistantPromptMessage): - tool_calls = [{ - 'id': tool_call.id, - 'type': 'function', - 'function': { - 'name': tool_call.function.name, - 'arguments': tool_call.function.arguments, + tool_calls = [ + { + "id": tool_call.id, + "type": "function", + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, } - } for tool_call in prompt_message.tool_calls] + for tool_call in prompt_message.tool_calls + ] elif prompt_message.role == PromptMessageRole.SYSTEM: - role = 'system' + role = "system" elif prompt_message.role == PromptMessageRole.TOOL: - role = 'tool' + role = "tool" else: continue - text = '' + text = "" files = [] if isinstance(prompt_message.content, list): for content in prompt_message.content: @@ -53,27 +56,25 @@ class PromptMessageUtil: text += content.data else: content = cast(ImagePromptMessageContent, content) - files.append({ - "type": 'image', - "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], - "detail": content.detail.value - }) + files.append( + { + "type": "image", + "data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:], + "detail": content.detail.value, + } + ) else: text = prompt_message.content - prompt = { - "role": role, - "text": text, - "files": files - } - + prompt = {"role": role, "text": text, "files": files} + if tool_calls: - prompt['tool_calls'] = tool_calls + prompt["tool_calls"] = tool_calls prompts.append(prompt) else: prompt_message = prompt_messages[0] - text = '' + text = "" files = [] if isinstance(prompt_message.content, list): for content in prompt_message.content: @@ -82,21 +83,23 @@ class PromptMessageUtil: text += content.data else: content = cast(ImagePromptMessageContent, content) - files.append({ - "type": 'image', - "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], - "detail": content.detail.value - }) + files.append( + { + "type": "image", + "data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:], + "detail": content.detail.value, + } + ) else: text = prompt_message.content params = { - "role": 'user', + "role": "user", "text": text, } if files: - params['files'] = files + params["files"] = files prompts.append(params) diff --git a/api/core/prompt/utils/prompt_template_parser.py b/api/core/prompt/utils/prompt_template_parser.py index 3e68492df2..8111559675 100644 --- a/api/core/prompt/utils/prompt_template_parser.py +++ b/api/core/prompt/utils/prompt_template_parser.py @@ -38,8 +38,8 @@ class PromptTemplateParser: return value prompt = re.sub(self.regex, replacer, self.template) - return re.sub(r'<\|.*?\|>', '', prompt) + return re.sub(r"<\|.*?\|>", "", prompt) @classmethod def remove_template_variables(cls, text: str, with_variable_tmpl: bool = False): - return re.sub(WITH_VARIABLE_TMPL_REGEX if with_variable_tmpl else REGEX, r'{\1}', text) + return re.sub(WITH_VARIABLE_TMPL_REGEX if with_variable_tmpl else REGEX, r"{\1}", text) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 67eee2c294..3a1fe300df 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -90,8 +90,7 @@ class ProviderManager: # Initialize trial provider records if not exist provider_name_to_provider_records_dict = self._init_trial_provider_records( - tenant_id, - provider_name_to_provider_records_dict + tenant_id, provider_name_to_provider_records_dict ) # Get all provider model records of the workspace @@ -107,22 +106,20 @@ class ProviderManager: provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id) # Get All load balancing configs - provider_name_to_provider_load_balancing_model_configs_dict \ - = self._get_all_provider_load_balancing_configs(tenant_id) - - provider_configurations = ProviderConfigurations( - tenant_id=tenant_id + provider_name_to_provider_load_balancing_model_configs_dict = self._get_all_provider_load_balancing_configs( + tenant_id ) + provider_configurations = ProviderConfigurations(tenant_id=tenant_id) + # Construct ProviderConfiguration objects for each provider for provider_entity in provider_entities: - # handle include, exclude if is_filtered( - include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, - exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, - data=provider_entity, - name_func=lambda x: x.provider, + include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, + exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, + data=provider_entity, + name_func=lambda x: x.provider, ): continue @@ -132,18 +129,11 @@ class ProviderManager: # Convert to custom configuration custom_configuration = self._to_custom_configuration( - tenant_id, - provider_entity, - provider_records, - provider_model_records + tenant_id, provider_entity, provider_records, provider_model_records ) # Convert to system configuration - system_configuration = self._to_system_configuration( - tenant_id, - provider_entity, - provider_records - ) + system_configuration = self._to_system_configuration(tenant_id, provider_entity, provider_records) # Get preferred provider type preferred_provider_type_record = provider_name_to_preferred_model_provider_records_dict.get(provider_name) @@ -173,14 +163,15 @@ class ProviderManager: provider_model_settings = provider_name_to_provider_model_settings_dict.get(provider_name) # Get provider load balancing configs - provider_load_balancing_configs \ - = provider_name_to_provider_load_balancing_model_configs_dict.get(provider_name) + provider_load_balancing_configs = provider_name_to_provider_load_balancing_model_configs_dict.get( + provider_name + ) # Convert to model settings model_settings = self._to_model_settings( provider_entity=provider_entity, provider_model_settings=provider_model_settings, - load_balancing_model_configs=provider_load_balancing_configs + load_balancing_model_configs=provider_load_balancing_configs, ) provider_configuration = ProviderConfiguration( @@ -190,7 +181,7 @@ class ProviderManager: using_provider_type=using_provider_type, system_configuration=system_configuration, custom_configuration=custom_configuration, - model_settings=model_settings + model_settings=model_settings, ) provider_configurations[provider_name] = provider_configuration @@ -219,7 +210,7 @@ class ProviderManager: return ProviderModelBundle( configuration=provider_configuration, provider_instance=provider_instance, - model_type_instance=model_type_instance + model_type_instance=model_type_instance, ) def get_default_model(self, tenant_id: str, model_type: ModelType) -> Optional[DefaultModelEntity]: @@ -231,11 +222,14 @@ class ProviderManager: :return: """ # Get the corresponding TenantDefaultModel record - default_model = db.session.query(TenantDefaultModel) \ + default_model = ( + db.session.query(TenantDefaultModel) .filter( - TenantDefaultModel.tenant_id == tenant_id, - TenantDefaultModel.model_type == model_type.to_origin_model_type() - ).first() + TenantDefaultModel.tenant_id == tenant_id, + TenantDefaultModel.model_type == model_type.to_origin_model_type(), + ) + .first() + ) # If it does not exist, get the first available provider model from get_configurations # and update the TenantDefaultModel record @@ -244,20 +238,18 @@ class ProviderManager: provider_configurations = self.get_configurations(tenant_id) # get available models from provider_configurations - available_models = provider_configurations.get_models( - model_type=model_type, - only_active=True - ) + available_models = provider_configurations.get_models(model_type=model_type, only_active=True) if available_models: - available_model = next((model for model in available_models if model.model == "gpt-4"), - available_models[0]) + available_model = next( + (model for model in available_models if model.model == "gpt-4"), available_models[0] + ) default_model = TenantDefaultModel( tenant_id=tenant_id, model_type=model_type.to_origin_model_type(), provider_name=available_model.provider.provider, - model_name=available_model.model + model_name=available_model.model, ) db.session.add(default_model) db.session.commit() @@ -276,8 +268,8 @@ class ProviderManager: label=provider_schema.label, icon_small=provider_schema.icon_small, icon_large=provider_schema.icon_large, - supported_model_types=provider_schema.supported_model_types - ) + supported_model_types=provider_schema.supported_model_types, + ), ) def get_first_provider_first_model(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]: @@ -291,15 +283,13 @@ class ProviderManager: provider_configurations = self.get_configurations(tenant_id) # get available models from provider_configurations - all_models = provider_configurations.get_models( - model_type=model_type, - only_active=False - ) + all_models = provider_configurations.get_models(model_type=model_type, only_active=False) return all_models[0].provider.provider, all_models[0].model - def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \ - -> TenantDefaultModel: + def update_default_model_record( + self, tenant_id: str, model_type: ModelType, provider: str, model: str + ) -> TenantDefaultModel: """ Update default model record. @@ -314,10 +304,7 @@ class ProviderManager: raise ValueError(f"Provider {provider} does not exist.") # get available models from provider_configurations - available_models = provider_configurations.get_models( - model_type=model_type, - only_active=True - ) + available_models = provider_configurations.get_models(model_type=model_type, only_active=True) # check if the model is exist in available models model_names = [model.model for model in available_models] @@ -325,11 +312,14 @@ class ProviderManager: raise ValueError(f"Model {model} does not exist.") # Get the list of available models from get_configurations and check if it is LLM - default_model = db.session.query(TenantDefaultModel) \ + default_model = ( + db.session.query(TenantDefaultModel) .filter( - TenantDefaultModel.tenant_id == tenant_id, - TenantDefaultModel.model_type == model_type.to_origin_model_type() - ).first() + TenantDefaultModel.tenant_id == tenant_id, + TenantDefaultModel.model_type == model_type.to_origin_model_type(), + ) + .first() + ) # create or update TenantDefaultModel record if default_model: @@ -358,11 +348,7 @@ class ProviderManager: :param tenant_id: workspace id :return: """ - providers = db.session.query(Provider) \ - .filter( - Provider.tenant_id == tenant_id, - Provider.is_valid == True - ).all() + providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all() provider_name_to_provider_records_dict = defaultdict(list) for provider in providers: @@ -379,11 +365,11 @@ class ProviderManager: :return: """ # Get all provider model records of the workspace - provider_models = db.session.query(ProviderModel) \ - .filter( - ProviderModel.tenant_id == tenant_id, - ProviderModel.is_valid == True - ).all() + provider_models = ( + db.session.query(ProviderModel) + .filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) + .all() + ) provider_name_to_provider_model_records_dict = defaultdict(list) for provider_model in provider_models: @@ -399,10 +385,11 @@ class ProviderManager: :param tenant_id: workspace id :return: """ - preferred_provider_types = db.session.query(TenantPreferredModelProvider) \ - .filter( - TenantPreferredModelProvider.tenant_id == tenant_id - ).all() + preferred_provider_types = ( + db.session.query(TenantPreferredModelProvider) + .filter(TenantPreferredModelProvider.tenant_id == tenant_id) + .all() + ) provider_name_to_preferred_provider_type_records_dict = { preferred_provider_type.provider_name: preferred_provider_type @@ -419,15 +406,17 @@ class ProviderManager: :param tenant_id: workspace id :return: """ - provider_model_settings = db.session.query(ProviderModelSetting) \ - .filter( - ProviderModelSetting.tenant_id == tenant_id - ).all() + provider_model_settings = ( + db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).all() + ) provider_name_to_provider_model_settings_dict = defaultdict(list) for provider_model_setting in provider_model_settings: - (provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name] - .append(provider_model_setting)) + ( + provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append( + provider_model_setting + ) + ) return provider_name_to_provider_model_settings_dict @@ -445,27 +434,30 @@ class ProviderManager: model_load_balancing_enabled = FeatureService.get_features(tenant_id).model_load_balancing_enabled redis_client.setex(cache_key, 120, str(model_load_balancing_enabled)) else: - cache_result = cache_result.decode('utf-8') - model_load_balancing_enabled = cache_result == 'True' + cache_result = cache_result.decode("utf-8") + model_load_balancing_enabled = cache_result == "True" if not model_load_balancing_enabled: return {} - provider_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ - .filter( - LoadBalancingModelConfig.tenant_id == tenant_id - ).all() + provider_load_balancing_configs = ( + db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).all() + ) provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list) for provider_load_balancing_config in provider_load_balancing_configs: - (provider_name_to_provider_load_balancing_model_configs_dict[provider_load_balancing_config.provider_name] - .append(provider_load_balancing_config)) + ( + provider_name_to_provider_load_balancing_model_configs_dict[ + provider_load_balancing_config.provider_name + ].append(provider_load_balancing_config) + ) return provider_name_to_provider_load_balancing_model_configs_dict @staticmethod - def _init_trial_provider_records(tenant_id: str, - provider_name_to_provider_records_dict: dict[str, list]) -> dict[str, list]: + def _init_trial_provider_records( + tenant_id: str, provider_name_to_provider_records_dict: dict[str, list] + ) -> dict[str, list]: """ Initialize trial provider records if not exists. @@ -489,8 +481,9 @@ class ProviderManager: if provider_record.provider_type != ProviderType.SYSTEM.value: continue - provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] \ - = provider_record + provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = ( + provider_record + ) for quota in configuration.quotas: if quota.quota_type == ProviderQuotaType.TRIAL: @@ -504,19 +497,22 @@ class ProviderManager: quota_type=ProviderQuotaType.TRIAL.value, quota_limit=quota.quota_limit, quota_used=0, - is_valid=True + is_valid=True, ) db.session.add(provider_record) db.session.commit() except IntegrityError: db.session.rollback() - provider_record = db.session.query(Provider) \ + provider_record = ( + db.session.query(Provider) .filter( - Provider.tenant_id == tenant_id, - Provider.provider_name == provider_name, - Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == ProviderQuotaType.TRIAL.value - ).first() + Provider.tenant_id == tenant_id, + Provider.provider_name == provider_name, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == ProviderQuotaType.TRIAL.value, + ) + .first() + ) if provider_record and not provider_record.is_valid: provider_record.is_valid = True @@ -526,11 +522,13 @@ class ProviderManager: return provider_name_to_provider_records_dict - def _to_custom_configuration(self, - tenant_id: str, - provider_entity: ProviderEntity, - provider_records: list[Provider], - provider_model_records: list[ProviderModel]) -> CustomConfiguration: + def _to_custom_configuration( + self, + tenant_id: str, + provider_entity: ProviderEntity, + provider_records: list[Provider], + provider_model_records: list[ProviderModel], + ) -> CustomConfiguration: """ Convert to custom configuration. @@ -543,7 +541,8 @@ class ProviderManager: # Get provider credential secret variables provider_credential_secret_variables = self._extract_secret_variables( provider_entity.provider_credential_schema.credential_form_schemas - if provider_entity.provider_credential_schema else [] + if provider_entity.provider_credential_schema + else [] ) # Get custom provider record @@ -563,7 +562,7 @@ class ProviderManager: provider_credentials_cache = ProviderCredentialsCache( tenant_id=tenant_id, identity_id=custom_provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER + cache_type=ProviderCredentialsCacheType.PROVIDER, ) # Get cached provider credentials @@ -572,11 +571,11 @@ class ProviderManager: if not cached_provider_credentials: try: # fix origin data - if (custom_provider_record.encrypted_config - and not custom_provider_record.encrypted_config.startswith("{")): - provider_credentials = { - "openai_api_key": custom_provider_record.encrypted_config - } + if ( + custom_provider_record.encrypted_config + and not custom_provider_record.encrypted_config.startswith("{") + ): + provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} else: provider_credentials = json.loads(custom_provider_record.encrypted_config) except JSONDecodeError: @@ -590,28 +589,23 @@ class ProviderManager: if variable in provider_credentials: try: provider_credentials[variable] = encrypter.decrypt_token_with_decoding( - provider_credentials.get(variable), - self.decoding_rsa_key, - self.decoding_cipher_rsa + provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa ) except ValueError: pass # cache provider credentials - provider_credentials_cache.set( - credentials=provider_credentials - ) + provider_credentials_cache.set(credentials=provider_credentials) else: provider_credentials = cached_provider_credentials - custom_provider_configuration = CustomProviderConfiguration( - credentials=provider_credentials - ) + custom_provider_configuration = CustomProviderConfiguration(credentials=provider_credentials) # Get provider model credential secret variables model_credential_secret_variables = self._extract_secret_variables( provider_entity.model_credential_schema.credential_form_schemas - if provider_entity.model_credential_schema else [] + if provider_entity.model_credential_schema + else [] ) # Get custom provider model credentials @@ -621,9 +615,7 @@ class ProviderManager: continue provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=tenant_id, - identity_id=provider_model_record.id, - cache_type=ProviderCredentialsCacheType.MODEL + tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL ) # Get cached provider model credentials @@ -645,15 +637,13 @@ class ProviderManager: provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_model_credentials.get(variable), self.decoding_rsa_key, - self.decoding_cipher_rsa + self.decoding_cipher_rsa, ) except ValueError: pass # cache provider model credentials - provider_model_credentials_cache.set( - credentials=provider_model_credentials - ) + provider_model_credentials_cache.set(credentials=provider_model_credentials) else: provider_model_credentials = cached_provider_model_credentials @@ -661,19 +651,15 @@ class ProviderManager: CustomModelConfiguration( model=provider_model_record.model_name, model_type=ModelType.value_of(provider_model_record.model_type), - credentials=provider_model_credentials + credentials=provider_model_credentials, ) ) - return CustomConfiguration( - provider=custom_provider_configuration, - models=custom_model_configurations - ) + return CustomConfiguration(provider=custom_provider_configuration, models=custom_model_configurations) - def _to_system_configuration(self, - tenant_id: str, - provider_entity: ProviderEntity, - provider_records: list[Provider]) -> SystemConfiguration: + def _to_system_configuration( + self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider] + ) -> SystemConfiguration: """ Convert to system configuration. @@ -685,11 +671,11 @@ class ProviderManager: # Get hosting configuration hosting_configuration = ext_hosting_provider.hosting_configuration - if provider_entity.provider not in hosting_configuration.provider_map \ - or not hosting_configuration.provider_map.get(provider_entity.provider).enabled: - return SystemConfiguration( - enabled=False - ) + if ( + provider_entity.provider not in hosting_configuration.provider_map + or not hosting_configuration.provider_map.get(provider_entity.provider).enabled + ): + return SystemConfiguration(enabled=False) provider_hosting_configuration = hosting_configuration.provider_map.get(provider_entity.provider) @@ -699,8 +685,9 @@ class ProviderManager: if provider_record.provider_type != ProviderType.SYSTEM.value: continue - quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] \ - = provider_record + quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = ( + provider_record + ) quota_configurations = [] for provider_quota in provider_hosting_configuration.quotas: @@ -712,7 +699,7 @@ class ProviderManager: quota_used=0, quota_limit=0, is_valid=False, - restrict_models=provider_quota.restrict_models + restrict_models=provider_quota.restrict_models, ) else: continue @@ -724,16 +711,15 @@ class ProviderManager: quota_unit=provider_hosting_configuration.quota_unit, quota_used=provider_record.quota_used, quota_limit=provider_record.quota_limit, - is_valid=provider_record.quota_limit > provider_record.quota_used or provider_record.quota_limit == -1, - restrict_models=provider_quota.restrict_models + is_valid=provider_record.quota_limit > provider_record.quota_used + or provider_record.quota_limit == -1, + restrict_models=provider_quota.restrict_models, ) quota_configurations.append(quota_configuration) if len(quota_configurations) == 0: - return SystemConfiguration( - enabled=False - ) + return SystemConfiguration(enabled=False) current_quota_type = self._choice_current_using_quota_type(quota_configurations) @@ -745,7 +731,7 @@ class ProviderManager: provider_credentials_cache = ProviderCredentialsCache( tenant_id=tenant_id, identity_id=provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER + cache_type=ProviderCredentialsCacheType.PROVIDER, ) # Get cached provider credentials @@ -760,7 +746,8 @@ class ProviderManager: # Get provider credential secret variables provider_credential_secret_variables = self._extract_secret_variables( provider_entity.provider_credential_schema.credential_form_schemas - if provider_entity.provider_credential_schema else [] + if provider_entity.provider_credential_schema + else [] ) # Get decoding rsa key and cipher for decrypting credentials @@ -771,9 +758,7 @@ class ProviderManager: if variable in provider_credentials: try: provider_credentials[variable] = encrypter.decrypt_token_with_decoding( - provider_credentials.get(variable), - self.decoding_rsa_key, - self.decoding_cipher_rsa + provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa ) except ValueError: pass @@ -781,9 +766,7 @@ class ProviderManager: current_using_credentials = provider_credentials # cache provider credentials - provider_credentials_cache.set( - credentials=current_using_credentials - ) + provider_credentials_cache.set(credentials=current_using_credentials) else: current_using_credentials = cached_provider_credentials else: @@ -794,7 +777,7 @@ class ProviderManager: enabled=True, current_quota_type=current_quota_type, quota_configurations=quota_configurations, - credentials=current_using_credentials + credentials=current_using_credentials, ) @staticmethod @@ -809,8 +792,7 @@ class ProviderManager: """ # convert to dict quota_type_to_quota_configuration_dict = { - quota_configuration.quota_type: quota_configuration - for quota_configuration in quota_configurations + quota_configuration.quota_type: quota_configuration for quota_configuration in quota_configurations } last_quota_configuration = None @@ -823,7 +805,7 @@ class ProviderManager: if last_quota_configuration: return last_quota_configuration.quota_type - raise ValueError('No quota type available') + raise ValueError("No quota type available") @staticmethod def _extract_secret_variables(credential_form_schemas: list[CredentialFormSchema]) -> list[str]: @@ -840,10 +822,12 @@ class ProviderManager: return secret_input_form_variables - def _to_model_settings(self, provider_entity: ProviderEntity, - provider_model_settings: Optional[list[ProviderModelSetting]] = None, - load_balancing_model_configs: Optional[list[LoadBalancingModelConfig]] = None) \ - -> list[ModelSettings]: + def _to_model_settings( + self, + provider_entity: ProviderEntity, + provider_model_settings: Optional[list[ProviderModelSetting]] = None, + load_balancing_model_configs: Optional[list[LoadBalancingModelConfig]] = None, + ) -> list[ModelSettings]: """ Convert to model settings. :param provider_entity: provider entity @@ -854,7 +838,8 @@ class ProviderManager: # Get provider model credential secret variables model_credential_secret_variables = self._extract_secret_variables( provider_entity.model_credential_schema.credential_form_schemas - if provider_entity.model_credential_schema else [] + if provider_entity.model_credential_schema + else [] ) model_settings = [] @@ -865,24 +850,28 @@ class ProviderManager: load_balancing_configs = [] if provider_model_setting.load_balancing_enabled and load_balancing_model_configs: for load_balancing_model_config in load_balancing_model_configs: - if (load_balancing_model_config.model_name == provider_model_setting.model_name - and load_balancing_model_config.model_type == provider_model_setting.model_type): + if ( + load_balancing_model_config.model_name == provider_model_setting.model_name + and load_balancing_model_config.model_type == provider_model_setting.model_type + ): if not load_balancing_model_config.enabled: continue if not load_balancing_model_config.encrypted_config: if load_balancing_model_config.name == "__inherit__": - load_balancing_configs.append(ModelLoadBalancingConfiguration( - id=load_balancing_model_config.id, - name=load_balancing_model_config.name, - credentials={} - )) + load_balancing_configs.append( + ModelLoadBalancingConfiguration( + id=load_balancing_model_config.id, + name=load_balancing_model_config.name, + credentials={}, + ) + ) continue provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=load_balancing_model_config.tenant_id, identity_id=load_balancing_model_config.id, - cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL + cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, ) # Get cached provider model credentials @@ -897,7 +886,8 @@ class ProviderManager: # Get decoding rsa key and cipher for decrypting credentials if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding( - load_balancing_model_config.tenant_id) + load_balancing_model_config.tenant_id + ) for variable in model_credential_secret_variables: if variable in provider_model_credentials: @@ -905,30 +895,30 @@ class ProviderManager: provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_model_credentials.get(variable), self.decoding_rsa_key, - self.decoding_cipher_rsa + self.decoding_cipher_rsa, ) except ValueError: pass # cache provider model credentials - provider_model_credentials_cache.set( - credentials=provider_model_credentials - ) + provider_model_credentials_cache.set(credentials=provider_model_credentials) else: provider_model_credentials = cached_provider_model_credentials - load_balancing_configs.append(ModelLoadBalancingConfiguration( - id=load_balancing_model_config.id, - name=load_balancing_model_config.name, - credentials=provider_model_credentials - )) + load_balancing_configs.append( + ModelLoadBalancingConfiguration( + id=load_balancing_model_config.id, + name=load_balancing_model_config.name, + credentials=provider_model_credentials, + ) + ) model_settings.append( ModelSettings( model=provider_model_setting.model_name, model_type=ModelType.value_of(provider_model_setting.model_type), enabled=provider_model_setting.enabled, - load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [] + load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [], ) ) diff --git a/api/core/rag/cleaner/clean_processor.py b/api/core/rag/cleaner/clean_processor.py index eaad0e0f4c..3c6ab2e4cf 100644 --- a/api/core/rag/cleaner/clean_processor.py +++ b/api/core/rag/cleaner/clean_processor.py @@ -2,37 +2,35 @@ import re class CleanProcessor: - @classmethod def clean(cls, text: str, process_rule: dict) -> str: # default clean # remove invalid symbol - text = re.sub(r'<\|', '<', text) - text = re.sub(r'\|>', '>', text) - text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text) + text = re.sub(r"<\|", "<", text) + text = re.sub(r"\|>", ">", text) + text = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]", "", text) # Unicode U+FFFE - text = re.sub('\uFFFE', '', text) + text = re.sub("\ufffe", "", text) - rules = process_rule['rules'] if process_rule else None - if 'pre_processing_rules' in rules: + rules = process_rule["rules"] if process_rule else None + if "pre_processing_rules" in rules: pre_processing_rules = rules["pre_processing_rules"] for pre_processing_rule in pre_processing_rules: if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True: # Remove extra spaces - pattern = r'\n{3,}' - text = re.sub(pattern, '\n\n', text) - pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}' - text = re.sub(pattern, ' ', text) + pattern = r"\n{3,}" + text = re.sub(pattern, "\n\n", text) + pattern = r"[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}" + text = re.sub(pattern, " ", text) elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True: # Remove email - pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)' - text = re.sub(pattern, '', text) + pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)" + text = re.sub(pattern, "", text) # Remove URL - pattern = r'https?://[^\s]+' - text = re.sub(pattern, '', text) + pattern = r"https?://[^\s]+" + text = re.sub(pattern, "", text) return text def filter_string(self, text): - return text diff --git a/api/core/rag/cleaner/cleaner_base.py b/api/core/rag/cleaner/cleaner_base.py index 523bd904f2..d3bc2f765e 100644 --- a/api/core/rag/cleaner/cleaner_base.py +++ b/api/core/rag/cleaner/cleaner_base.py @@ -1,12 +1,11 @@ """Abstract interface for document cleaner implementations.""" + from abc import ABC, abstractmethod class BaseCleaner(ABC): - """Interface for clean chunk content. - """ + """Interface for clean chunk content.""" @abstractmethod def clean(self, content: str): raise NotImplementedError - diff --git a/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py index 6a0b8c9046..167a919e69 100644 --- a/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py +++ b/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py @@ -1,9 +1,9 @@ """Abstract interface for document clean implementations.""" + from core.rag.cleaner.cleaner_base import BaseCleaner class UnstructuredNonAsciiCharsCleaner(BaseCleaner): - def clean(self, content) -> str: """clean document content.""" from unstructured.cleaners.core import clean_extra_whitespace diff --git a/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py index 6fc3a408da..9c682d29db 100644 --- a/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py +++ b/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py @@ -1,9 +1,9 @@ """Abstract interface for document clean implementations.""" + from core.rag.cleaner.cleaner_base import BaseCleaner class UnstructuredGroupBrokenParagraphsCleaner(BaseCleaner): - def clean(self, content) -> str: """clean document content.""" import re diff --git a/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py index ca1ae8dfd1..0cdbb171e1 100644 --- a/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py +++ b/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py @@ -1,12 +1,12 @@ """Abstract interface for document clean implementations.""" + from core.rag.cleaner.cleaner_base import BaseCleaner class UnstructuredNonAsciiCharsCleaner(BaseCleaner): - def clean(self, content) -> str: """clean document content.""" from unstructured.cleaners.core import clean_non_ascii_chars - # Returns "This text containsnon-ascii characters!" + # Returns "This text contains non-ascii characters!" return clean_non_ascii_chars(content) diff --git a/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py index 974a28fef1..9f42044a2d 100644 --- a/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py +++ b/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py @@ -1,11 +1,12 @@ """Abstract interface for document clean implementations.""" + from core.rag.cleaner.cleaner_base import BaseCleaner class UnstructuredNonAsciiCharsCleaner(BaseCleaner): - def clean(self, content) -> str: """Replaces unicode quote characters, such as the \x91 character in a string.""" from unstructured.cleaners.core import replace_unicode_quotes + return replace_unicode_quotes(content) diff --git a/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py index dfaf3a2787..32ae7217e8 100644 --- a/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py +++ b/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py @@ -1,9 +1,9 @@ """Abstract interface for document clean implementations.""" + from core.rag.cleaner.cleaner_base import BaseCleaner class UnstructuredTranslateTextCleaner(BaseCleaner): - def clean(self, content) -> str: """clean document content.""" from unstructured.cleaners.translate import translate_text diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index ad9ee4f7cf..b1d6f93cff 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -12,17 +12,27 @@ from core.rag.rerank.weight_rerank import WeightRerankRunner class DataPostProcessor: - """Interface for data post-processing document. - """ + """Interface for data post-processing document.""" - def __init__(self, tenant_id: str, reranking_mode: str, - reranking_model: Optional[dict] = None, weights: Optional[dict] = None, - reorder_enabled: bool = False): + def __init__( + self, + tenant_id: str, + reranking_mode: str, + reranking_model: Optional[dict] = None, + weights: Optional[dict] = None, + reorder_enabled: bool = False, + ): self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights) self.reorder_runner = self._get_reorder_runner(reorder_enabled) - def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None, - top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]: + def invoke( + self, + query: str, + documents: list[Document], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> list[Document]: if self.rerank_runner: documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user) @@ -31,21 +41,26 @@ class DataPostProcessor: return documents - def _get_rerank_runner(self, reranking_mode: str, tenant_id: str, reranking_model: Optional[dict] = None, - weights: Optional[dict] = None) -> Optional[RerankModelRunner | WeightRerankRunner]: + def _get_rerank_runner( + self, + reranking_mode: str, + tenant_id: str, + reranking_model: Optional[dict] = None, + weights: Optional[dict] = None, + ) -> Optional[RerankModelRunner | WeightRerankRunner]: if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights: return WeightRerankRunner( tenant_id, Weights( vector_setting=VectorSetting( - vector_weight=weights['vector_setting']['vector_weight'], - embedding_provider_name=weights['vector_setting']['embedding_provider_name'], - embedding_model_name=weights['vector_setting']['embedding_model_name'], + vector_weight=weights["vector_setting"]["vector_weight"], + embedding_provider_name=weights["vector_setting"]["embedding_provider_name"], + embedding_model_name=weights["vector_setting"]["embedding_model_name"], ), keyword_setting=KeywordSetting( - keyword_weight=weights['keyword_setting']['keyword_weight'], - ) - ) + keyword_weight=weights["keyword_setting"]["keyword_weight"], + ), + ), ) elif reranking_mode == RerankMode.RERANKING_MODEL.value: if reranking_model: @@ -53,9 +68,9 @@ class DataPostProcessor: model_manager = ModelManager() rerank_model_instance = model_manager.get_model_instance( tenant_id=tenant_id, - provider=reranking_model['reranking_provider_name'], + provider=reranking_model["reranking_provider_name"], model_type=ModelType.RERANK, - model=reranking_model['reranking_model_name'] + model=reranking_model["reranking_model_name"], ) except InvokeAuthorizationError: return None @@ -67,5 +82,3 @@ class DataPostProcessor: if reorder_enabled: return ReorderRunner() return None - - diff --git a/api/core/rag/data_post_processor/reorder.py b/api/core/rag/data_post_processor/reorder.py index 71297588a4..a9a0885241 100644 --- a/api/core/rag/data_post_processor/reorder.py +++ b/api/core/rag/data_post_processor/reorder.py @@ -2,7 +2,6 @@ from core.rag.models.document import Document class ReorderRunner: - def run(self, documents: list[Document]) -> list[Document]: # Retrieve elements from odd indices (0, 2, 4, etc.) of the documents list odd_elements = documents[::2] diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index a3714c2fd3..3073100746 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -24,37 +24,42 @@ class Jieba(BaseKeyword): self._config = KeywordTableConfig() def create(self, texts: list[Document], **kwargs) -> BaseKeyword: - lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) with redis_client.lock(lock_name, timeout=600): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() for text in texts: - keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) - self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) - keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) + keywords = keyword_table_handler.extract_keywords( + text.page_content, self._config.max_keywords_per_chunk + ) + self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) + keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords)) self._save_dataset_keyword_table(keyword_table) return self def add_texts(self, texts: list[Document], **kwargs): - lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) with redis_client.lock(lock_name, timeout=600): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() - keywords_list = kwargs.get('keywords_list', None) + keywords_list = kwargs.get("keywords_list", None) for i in range(len(texts)): text = texts[i] if keywords_list: keywords = keywords_list[i] if not keywords: - keywords = keyword_table_handler.extract_keywords(text.page_content, - self._config.max_keywords_per_chunk) + keywords = keyword_table_handler.extract_keywords( + text.page_content, self._config.max_keywords_per_chunk + ) else: - keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) - self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) - keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) + keywords = keyword_table_handler.extract_keywords( + text.page_content, self._config.max_keywords_per_chunk + ) + self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) + keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords)) self._save_dataset_keyword_table(keyword_table) @@ -63,97 +68,91 @@ class Jieba(BaseKeyword): return id in set.union(*keyword_table.values()) def delete_by_ids(self, ids: list[str]) -> None: - lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) with redis_client.lock(lock_name, timeout=600): keyword_table = self._get_dataset_keyword_table() keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) self._save_dataset_keyword_table(keyword_table) - def search( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search(self, query: str, **kwargs: Any) -> list[Document]: keyword_table = self._get_dataset_keyword_table() - k = kwargs.get('top_k', 4) + k = kwargs.get("top_k", 4) sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k) documents = [] for chunk_index in sorted_chunk_indices: - segment = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == self.dataset.id, - DocumentSegment.index_node_id == chunk_index - ).first() + segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index) + .first() + ) if segment: - - documents.append(Document( - page_content=segment.content, - metadata={ - "doc_id": chunk_index, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - } - )) + documents.append( + Document( + page_content=segment.content, + metadata={ + "doc_id": chunk_index, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + ) return documents def delete(self) -> None: - lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) with redis_client.lock(lock_name, timeout=600): dataset_keyword_table = self.dataset.dataset_keyword_table if dataset_keyword_table: db.session.delete(dataset_keyword_table) db.session.commit() - if dataset_keyword_table.data_source_type != 'database': - file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt' + if dataset_keyword_table.data_source_type != "database": + file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt" storage.delete(file_key) def _save_dataset_keyword_table(self, keyword_table): keyword_table_dict = { - '__type__': 'keyword_table', - '__data__': { - "index_id": self.dataset.id, - "summary": None, - "table": keyword_table - } + "__type__": "keyword_table", + "__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table}, } dataset_keyword_table = self.dataset.dataset_keyword_table keyword_data_source_type = dataset_keyword_table.data_source_type - if keyword_data_source_type == 'database': + if keyword_data_source_type == "database": dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder) db.session.commit() else: - file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt' + file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt" if storage.exists(file_key): storage.delete(file_key) - storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode('utf-8')) + storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode("utf-8")) def _get_dataset_keyword_table(self) -> Optional[dict]: dataset_keyword_table = self.dataset.dataset_keyword_table if dataset_keyword_table: keyword_table_dict = dataset_keyword_table.keyword_table_dict if keyword_table_dict: - return keyword_table_dict['__data__']['table'] + return keyword_table_dict["__data__"]["table"] else: keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE dataset_keyword_table = DatasetKeywordTable( dataset_id=self.dataset.id, - keyword_table='', + keyword_table="", data_source_type=keyword_data_source_type, ) - if keyword_data_source_type == 'database': - dataset_keyword_table.keyword_table = json.dumps({ - '__type__': 'keyword_table', - '__data__': { - "index_id": self.dataset.id, - "summary": None, - "table": {} - } - }, cls=SetEncoder) + if keyword_data_source_type == "database": + dataset_keyword_table.keyword_table = json.dumps( + { + "__type__": "keyword_table", + "__data__": {"index_id": self.dataset.id, "summary": None, "table": {}}, + }, + cls=SetEncoder, + ) db.session.add(dataset_keyword_table) db.session.commit() @@ -174,9 +173,7 @@ class Jieba(BaseKeyword): keywords_to_delete = set() for keyword, node_idxs in keyword_table.items(): if node_idxs_to_delete.intersection(node_idxs): - keyword_table[keyword] = node_idxs.difference( - node_idxs_to_delete - ) + keyword_table[keyword] = node_idxs.difference(node_idxs_to_delete) if not keyword_table[keyword]: keywords_to_delete.add(keyword) @@ -202,13 +199,14 @@ class Jieba(BaseKeyword): reverse=True, ) - return sorted_chunk_indices[: k] + return sorted_chunk_indices[:k] def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]): - document_segment = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == dataset_id, - DocumentSegment.index_node_id == node_id - ).first() + document_segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id) + .first() + ) if document_segment: document_segment.keywords = keywords db.session.add(document_segment) @@ -224,14 +222,14 @@ class Jieba(BaseKeyword): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() for pre_segment_data in pre_segment_data_list: - segment = pre_segment_data['segment'] - if pre_segment_data['keywords']: - segment.keywords = pre_segment_data['keywords'] - keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, - pre_segment_data['keywords']) + segment = pre_segment_data["segment"] + if pre_segment_data["keywords"]: + segment.keywords = pre_segment_data["keywords"] + keyword_table = self._add_text_to_keyword_table( + keyword_table, segment.index_node_id, pre_segment_data["keywords"] + ) else: - keywords = keyword_table_handler.extract_keywords(segment.content, - self._config.max_keywords_per_chunk) + keywords = keyword_table_handler.extract_keywords(segment.content, self._config.max_keywords_per_chunk) segment.keywords = list(keywords) keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords)) self._save_dataset_keyword_table(keyword_table) diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index ad669ef515..4b1ade8e3f 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -8,7 +8,6 @@ from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS class JiebaKeywordTableHandler: - def __init__(self): default_tfidf.stop_words = STOPWORDS @@ -30,4 +29,4 @@ class JiebaKeywordTableHandler: if len(sub_tokens) > 1: results.update({w for w in sub_tokens if w not in list(STOPWORDS)}) - return results \ No newline at end of file + return results diff --git a/api/core/rag/datasource/keyword/jieba/stopwords.py b/api/core/rag/datasource/keyword/jieba/stopwords.py index c616a15cf0..9abe78d6ef 100644 --- a/api/core/rag/datasource/keyword/jieba/stopwords.py +++ b/api/core/rag/datasource/keyword/jieba/stopwords.py @@ -1,90 +1,1380 @@ STOPWORDS = { - "during", "when", "but", "then", "further", "isn", "mustn't", "until", "own", "i", "couldn", "y", "only", "you've", - "ours", "who", "where", "ourselves", "has", "to", "was", "didn't", "themselves", "if", "against", "through", "her", - "an", "your", "can", "those", "didn", "about", "aren't", "shan't", "be", "not", "these", "again", "so", "t", - "theirs", "weren", "won't", "won", "itself", "just", "same", "while", "why", "doesn", "aren", "him", "haven", - "for", "you'll", "that", "we", "am", "d", "by", "having", "wasn't", "than", "weren't", "out", "from", "now", - "their", "too", "hadn", "o", "needn", "most", "it", "under", "needn't", "any", "some", "few", "ll", "hers", "which", - "m", "you're", "off", "other", "had", "she", "you'd", "do", "you", "does", "s", "will", "each", "wouldn't", "hasn't", - "such", "more", "whom", "she's", "my", "yours", "yourself", "of", "on", "very", "hadn't", "with", "yourselves", - "been", "ma", "them", "mightn't", "shan", "mustn", "they", "what", "both", "that'll", "how", "is", "he", "because", - "down", "haven't", "are", "no", "it's", "our", "being", "the", "or", "above", "myself", "once", "don't", "doesn't", - "as", "nor", "here", "herself", "hasn", "mightn", "have", "its", "all", "were", "ain", "this", "at", "after", - "over", "shouldn't", "into", "before", "don", "wouldn", "re", "couldn't", "wasn", "in", "should", "there", - "himself", "isn't", "should've", "doing", "ve", "shouldn", "a", "did", "and", "his", "between", "me", "up", "below", - "人民", "末##末", "啊", "阿", "哎", "哎呀", "哎哟", "唉", "俺", "俺们", "按", "按照", "吧", "吧哒", "把", "罢了", "被", "本", - "本着", "比", "比方", "比如", "鄙人", "彼", "彼此", "边", "别", "别的", "别说", "并", "并且", "不比", "不成", "不单", "不但", - "不独", "不管", "不光", "不过", "不仅", "不拘", "不论", "不怕", "不然", "不如", "不特", "不惟", "不问", "不只", "朝", "朝着", - "趁", "趁着", "乘", "冲", "除", "除此之外", "除非", "除了", "此", "此间", "此外", "从", "从而", "打", "待", "但", "但是", "当", - "当着", "到", "得", "的", "的话", "等", "等等", "地", "第", "叮咚", "对", "对于", "多", "多少", "而", "而况", "而且", "而是", - "而外", "而言", "而已", "尔后", "反过来", "反过来说", "反之", "非但", "非徒", "否则", "嘎", "嘎登", "该", "赶", "个", "各", - "各个", "各位", "各种", "各自", "给", "根据", "跟", "故", "故此", "固然", "关于", "管", "归", "果然", "果真", "过", "哈", - "哈哈", "呵", "和", "何", "何处", "何况", "何时", "嘿", "哼", "哼唷", "呼哧", "乎", "哗", "还是", "还有", "换句话说", "换言之", - "或", "或是", "或者", "极了", "及", "及其", "及至", "即", "即便", "即或", "即令", "即若", "即使", "几", "几时", "己", "既", - "既然", "既是", "继而", "加之", "假如", "假若", "假使", "鉴于", "将", "较", "较之", "叫", "接着", "结果", "借", "紧接着", - "进而", "尽", "尽管", "经", "经过", "就", "就是", "就是说", "据", "具体地说", "具体说来", "开始", "开外", "靠", "咳", "可", - "可见", "可是", "可以", "况且", "啦", "来", "来着", "离", "例如", "哩", "连", "连同", "两者", "了", "临", "另", "另外", - "另一方面", "论", "嘛", "吗", "慢说", "漫说", "冒", "么", "每", "每当", "们", "莫若", "某", "某个", "某些", "拿", "哪", - "哪边", "哪儿", "哪个", "哪里", "哪年", "哪怕", "哪天", "哪些", "哪样", "那", "那边", "那儿", "那个", "那会儿", "那里", "那么", - "那么些", "那么样", "那时", "那些", "那样", "乃", "乃至", "呢", "能", "你", "你们", "您", "宁", "宁可", "宁肯", "宁愿", "哦", - "呕", "啪达", "旁人", "呸", "凭", "凭借", "其", "其次", "其二", "其他", "其它", "其一", "其余", "其中", "起", "起见", "岂但", - "恰恰相反", "前后", "前者", "且", "然而", "然后", "然则", "让", "人家", "任", "任何", "任凭", "如", "如此", "如果", "如何", - "如其", "如若", "如上所述", "若", "若非", "若是", "啥", "上下", "尚且", "设若", "设使", "甚而", "甚么", "甚至", "省得", "时候", - "什么", "什么样", "使得", "是", "是的", "首先", "谁", "谁知", "顺", "顺着", "似的", "虽", "虽然", "虽说", "虽则", "随", "随着", - "所", "所以", "他", "他们", "他人", "它", "它们", "她", "她们", "倘", "倘或", "倘然", "倘若", "倘使", "腾", "替", "通过", "同", - "同时", "哇", "万一", "往", "望", "为", "为何", "为了", "为什么", "为着", "喂", "嗡嗡", "我", "我们", "呜", "呜呼", "乌乎", - "无论", "无宁", "毋宁", "嘻", "吓", "相对而言", "像", "向", "向着", "嘘", "呀", "焉", "沿", "沿着", "要", "要不", "要不然", - "要不是", "要么", "要是", "也", "也罢", "也好", "一", "一般", "一旦", "一方面", "一来", "一切", "一样", "一则", "依", "依照", - "矣", "以", "以便", "以及", "以免", "以至", "以至于", "以致", "抑或", "因", "因此", "因而", "因为", "哟", "用", "由", - "由此可见", "由于", "有", "有的", "有关", "有些", "又", "于", "于是", "于是乎", "与", "与此同时", "与否", "与其", "越是", - "云云", "哉", "再说", "再者", "在", "在下", "咱", "咱们", "则", "怎", "怎么", "怎么办", "怎么样", "怎样", "咋", "照", "照着", - "者", "这", "这边", "这儿", "这个", "这会儿", "这就是说", "这里", "这么", "这么点儿", "这么些", "这么样", "这时", "这些", "这样", - "正如", "吱", "之", "之类", "之所以", "之一", "只是", "只限", "只要", "只有", "至", "至于", "诸位", "着", "着呢", "自", "自从", - "自个儿", "自各儿", "自己", "自家", "自身", "综上所述", "总的来看", "总的来说", "总的说来", "总而言之", "总之", "纵", "纵令", - "纵然", "纵使", "遵照", "作为", "兮", "呃", "呗", "咚", "咦", "喏", "啐", "喔唷", "嗬", "嗯", "嗳", "~", "!", ".", ":", - "\"", "'", "(", ")", "*", "A", "白", "社会主义", "--", "..", ">>", " [", " ]", "", "<", ">", "/", "\\", "|", "-", "_", - "+", "=", "&", "^", "%", "#", "@", "`", ";", "$", "(", ")", "——", "—", "¥", "·", "...", "‘", "’", "〉", "〈", "…", - " ", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "二", - "三", "四", "五", "六", "七", "八", "九", "零", ">", "<", "@", "#", "$", "%", "︿", "&", "*", "+", "~", "|", "[", - "]", "{", "}", "啊哈", "啊呀", "啊哟", "挨次", "挨个", "挨家挨户", "挨门挨户", "挨门逐户", "挨着", "按理", "按期", "按时", - "按说", "暗地里", "暗中", "暗自", "昂然", "八成", "白白", "半", "梆", "保管", "保险", "饱", "背地里", "背靠背", "倍感", "倍加", - "本人", "本身", "甭", "比起", "比如说", "比照", "毕竟", "必", "必定", "必将", "必须", "便", "别人", "并非", "并肩", "并没", - "并没有", "并排", "并无", "勃然", "不", "不必", "不常", "不大", "不但...而且", "不得", "不得不", "不得了", "不得已", "不迭", - "不定", "不对", "不妨", "不管怎样", "不会", "不仅...而且", "不仅仅", "不仅仅是", "不经意", "不可开交", "不可抗拒", "不力", "不了", - "不料", "不满", "不免", "不能不", "不起", "不巧", "不然的话", "不日", "不少", "不胜", "不时", "不是", "不同", "不能", "不要", - "不外", "不外乎", "不下", "不限", "不消", "不已", "不亦乐乎", "不由得", "不再", "不择手段", "不怎么", "不曾", "不知不觉", "不止", - "不止一次", "不至于", "才", "才能", "策略地", "差不多", "差一点", "常", "常常", "常言道", "常言说", "常言说得好", "长此下去", - "长话短说", "长期以来", "长线", "敞开儿", "彻夜", "陈年", "趁便", "趁机", "趁热", "趁势", "趁早", "成年", "成年累月", "成心", - "乘机", "乘胜", "乘势", "乘隙", "乘虚", "诚然", "迟早", "充分", "充其极", "充其量", "抽冷子", "臭", "初", "出", "出来", "出去", - "除此", "除此而外", "除此以外", "除开", "除去", "除却", "除外", "处处", "川流不息", "传", "传说", "传闻", "串行", "纯", "纯粹", - "此后", "此中", "次第", "匆匆", "从不", "从此", "从此以后", "从古到今", "从古至今", "从今以后", "从宽", "从来", "从轻", "从速", - "从头", "从未", "从无到有", "从小", "从新", "从严", "从优", "从早到晚", "从中", "从重", "凑巧", "粗", "存心", "达旦", "打从", - "打开天窗说亮话", "大", "大不了", "大大", "大抵", "大都", "大多", "大凡", "大概", "大家", "大举", "大略", "大面儿上", "大事", - "大体", "大体上", "大约", "大张旗鼓", "大致", "呆呆地", "带", "殆", "待到", "单", "单纯", "单单", "但愿", "弹指之间", "当场", - "当儿", "当即", "当口儿", "当然", "当庭", "当头", "当下", "当真", "当中", "倒不如", "倒不如说", "倒是", "到处", "到底", "到了儿", - "到目前为止", "到头", "到头来", "得起", "得天独厚", "的确", "等到", "叮当", "顶多", "定", "动不动", "动辄", "陡然", "都", "独", - "独自", "断然", "顿时", "多次", "多多", "多多少少", "多多益善", "多亏", "多年来", "多年前", "而后", "而论", "而又", "尔等", - "二话不说", "二话没说", "反倒", "反倒是", "反而", "反手", "反之亦然", "反之则", "方", "方才", "方能", "放量", "非常", "非得", - "分期", "分期分批", "分头", "奋勇", "愤然", "风雨无阻", "逢", "弗", "甫", "嘎嘎", "该当", "概", "赶快", "赶早不赶晚", "敢", - "敢情", "敢于", "刚", "刚才", "刚好", "刚巧", "高低", "格外", "隔日", "隔夜", "个人", "各式", "更", "更加", "更进一步", "更为", - "公然", "共", "共总", "够瞧的", "姑且", "古来", "故而", "故意", "固", "怪", "怪不得", "惯常", "光", "光是", "归根到底", - "归根结底", "过于", "毫不", "毫无", "毫无保留地", "毫无例外", "好在", "何必", "何尝", "何妨", "何苦", "何乐而不为", "何须", - "何止", "很", "很多", "很少", "轰然", "后来", "呼啦", "忽地", "忽然", "互", "互相", "哗啦", "话说", "还", "恍然", "会", "豁然", - "活", "伙同", "或多或少", "或许", "基本", "基本上", "基于", "极", "极大", "极度", "极端", "极力", "极其", "极为", "急匆匆", - "即将", "即刻", "即是说", "几度", "几番", "几乎", "几经", "既...又", "继之", "加上", "加以", "间或", "简而言之", "简言之", - "简直", "见", "将才", "将近", "将要", "交口", "较比", "较为", "接连不断", "接下来", "皆可", "截然", "截至", "藉以", "借此", - "借以", "届时", "仅", "仅仅", "谨", "进来", "进去", "近", "近几年来", "近来", "近年来", "尽管如此", "尽可能", "尽快", "尽量", - "尽然", "尽如人意", "尽心竭力", "尽心尽力", "尽早", "精光", "经常", "竟", "竟然", "究竟", "就此", "就地", "就算", "居然", "局外", - "举凡", "据称", "据此", "据实", "据说", "据我所知", "据悉", "具体来说", "决不", "决非", "绝", "绝不", "绝顶", "绝对", "绝非", - "均", "喀", "看", "看来", "看起来", "看上去", "看样子", "可好", "可能", "恐怕", "快", "快要", "来不及", "来得及", "来讲", - "来看", "拦腰", "牢牢", "老", "老大", "老老实实", "老是", "累次", "累年", "理当", "理该", "理应", "历", "立", "立地", "立刻", - "立马", "立时", "联袂", "连连", "连日", "连日来", "连声", "连袂", "临到", "另方面", "另行", "另一个", "路经", "屡", "屡次", - "屡次三番", "屡屡", "缕缕", "率尔", "率然", "略", "略加", "略微", "略为", "论说", "马上", "蛮", "满", "没", "没有", "每逢", - "每每", "每时每刻", "猛然", "猛然间", "莫", "莫不", "莫非", "莫如", "默默地", "默然", "呐", "那末", "奈", "难道", "难得", "难怪", - "难说", "内", "年复一年", "凝神", "偶而", "偶尔", "怕", "砰", "碰巧", "譬如", "偏偏", "乒", "平素", "颇", "迫于", "扑通", - "其后", "其实", "奇", "齐", "起初", "起来", "起首", "起头", "起先", "岂", "岂非", "岂止", "迄", "恰逢", "恰好", "恰恰", "恰巧", - "恰如", "恰似", "千", "千万", "千万千万", "切", "切不可", "切莫", "切切", "切勿", "窃", "亲口", "亲身", "亲手", "亲眼", "亲自", - "顷", "顷刻", "顷刻间", "顷刻之间", "请勿", "穷年累月", "取道", "去", "权时", "全都", "全力", "全年", "全然", "全身心", "然", - "人人", "仍", "仍旧", "仍然", "日复一日", "日见", "日渐", "日益", "日臻", "如常", "如此等等", "如次", "如今", "如期", "如前所述", - "如上", "如下", "汝", "三番两次", "三番五次", "三天两头", "瑟瑟", "沙沙", "上", "上来", "上去", "一个", "月", "日", "\n" + "during", + "when", + "but", + "then", + "further", + "isn", + "mustn't", + "until", + "own", + "i", + "couldn", + "y", + "only", + "you've", + "ours", + "who", + "where", + "ourselves", + "has", + "to", + "was", + "didn't", + "themselves", + "if", + "against", + "through", + "her", + "an", + "your", + "can", + "those", + "didn", + "about", + "aren't", + "shan't", + "be", + "not", + "these", + "again", + "so", + "t", + "theirs", + "weren", + "won't", + "won", + "itself", + "just", + "same", + "while", + "why", + "doesn", + "aren", + "him", + "haven", + "for", + "you'll", + "that", + "we", + "am", + "d", + "by", + "having", + "wasn't", + "than", + "weren't", + "out", + "from", + "now", + "their", + "too", + "hadn", + "o", + "needn", + "most", + "it", + "under", + "needn't", + "any", + "some", + "few", + "ll", + "hers", + "which", + "m", + "you're", + "off", + "other", + "had", + "she", + "you'd", + "do", + "you", + "does", + "s", + "will", + "each", + "wouldn't", + "hasn't", + "such", + "more", + "whom", + "she's", + "my", + "yours", + "yourself", + "of", + "on", + "very", + "hadn't", + "with", + "yourselves", + "been", + "ma", + "them", + "mightn't", + "shan", + "mustn", + "they", + "what", + "both", + "that'll", + "how", + "is", + "he", + "because", + "down", + "haven't", + "are", + "no", + "it's", + "our", + "being", + "the", + "or", + "above", + "myself", + "once", + "don't", + "doesn't", + "as", + "nor", + "here", + "herself", + "hasn", + "mightn", + "have", + "its", + "all", + "were", + "ain", + "this", + "at", + "after", + "over", + "shouldn't", + "into", + "before", + "don", + "wouldn", + "re", + "couldn't", + "wasn", + "in", + "should", + "there", + "himself", + "isn't", + "should've", + "doing", + "ve", + "shouldn", + "a", + "did", + "and", + "his", + "between", + "me", + "up", + "below", + "人民", + "末##末", + "啊", + "阿", + "哎", + "哎呀", + "哎哟", + "唉", + "俺", + "俺们", + "按", + "按照", + "吧", + "吧哒", + "把", + "罢了", + "被", + "本", + "本着", + "比", + "比方", + "比如", + "鄙人", + "彼", + "彼此", + "边", + "别", + "别的", + "别说", + "并", + "并且", + "不比", + "不成", + "不单", + "不但", + "不独", + "不管", + "不光", + "不过", + "不仅", + "不拘", + "不论", + "不怕", + "不然", + "不如", + "不特", + "不惟", + "不问", + "不只", + "朝", + "朝着", + "趁", + "趁着", + "乘", + "冲", + "除", + "除此之外", + "除非", + "除了", + "此", + "此间", + "此外", + "从", + "从而", + "打", + "待", + "但", + "但是", + "当", + "当着", + "到", + "得", + "的", + "的话", + "等", + "等等", + "地", + "第", + "叮咚", + "对", + "对于", + "多", + "多少", + "而", + "而况", + "而且", + "而是", + "而外", + "而言", + "而已", + "尔后", + "反过来", + "反过来说", + "反之", + "非但", + "非徒", + "否则", + "嘎", + "嘎登", + "该", + "赶", + "个", + "各", + "各个", + "各位", + "各种", + "各自", + "给", + "根据", + "跟", + "故", + "故此", + "固然", + "关于", + "管", + "归", + "果然", + "果真", + "过", + "哈", + "哈哈", + "呵", + "和", + "何", + "何处", + "何况", + "何时", + "嘿", + "哼", + "哼唷", + "呼哧", + "乎", + "哗", + "还是", + "还有", + "换句话说", + "换言之", + "或", + "或是", + "或者", + "极了", + "及", + "及其", + "及至", + "即", + "即便", + "即或", + "即令", + "即若", + "即使", + "几", + "几时", + "己", + "既", + "既然", + "既是", + "继而", + "加之", + "假如", + "假若", + "假使", + "鉴于", + "将", + "较", + "较之", + "叫", + "接着", + "结果", + "借", + "紧接着", + "进而", + "尽", + "尽管", + "经", + "经过", + "就", + "就是", + "就是说", + "据", + "具体地说", + "具体说来", + "开始", + "开外", + "靠", + "咳", + "可", + "可见", + "可是", + "可以", + "况且", + "啦", + "来", + "来着", + "离", + "例如", + "哩", + "连", + "连同", + "两者", + "了", + "临", + "另", + "另外", + "另一方面", + "论", + "嘛", + "吗", + "慢说", + "漫说", + "冒", + "么", + "每", + "每当", + "们", + "莫若", + "某", + "某个", + "某些", + "拿", + "哪", + "哪边", + "哪儿", + "哪个", + "哪里", + "哪年", + "哪怕", + "哪天", + "哪些", + "哪样", + "那", + "那边", + "那儿", + "那个", + "那会儿", + "那里", + "那么", + "那么些", + "那么样", + "那时", + "那些", + "那样", + "乃", + "乃至", + "呢", + "能", + "你", + "你们", + "您", + "宁", + "宁可", + "宁肯", + "宁愿", + "哦", + "呕", + "啪达", + "旁人", + "呸", + "凭", + "凭借", + "其", + "其次", + "其二", + "其他", + "其它", + "其一", + "其余", + "其中", + "起", + "起见", + "岂但", + "恰恰相反", + "前后", + "前者", + "且", + "然而", + "然后", + "然则", + "让", + "人家", + "任", + "任何", + "任凭", + "如", + "如此", + "如果", + "如何", + "如其", + "如若", + "如上所述", + "若", + "若非", + "若是", + "啥", + "上下", + "尚且", + "设若", + "设使", + "甚而", + "甚么", + "甚至", + "省得", + "时候", + "什么", + "什么样", + "使得", + "是", + "是的", + "首先", + "谁", + "谁知", + "顺", + "顺着", + "似的", + "虽", + "虽然", + "虽说", + "虽则", + "随", + "随着", + "所", + "所以", + "他", + "他们", + "他人", + "它", + "它们", + "她", + "她们", + "倘", + "倘或", + "倘然", + "倘若", + "倘使", + "腾", + "替", + "通过", + "同", + "同时", + "哇", + "万一", + "往", + "望", + "为", + "为何", + "为了", + "为什么", + "为着", + "喂", + "嗡嗡", + "我", + "我们", + "呜", + "呜呼", + "乌乎", + "无论", + "无宁", + "毋宁", + "嘻", + "吓", + "相对而言", + "像", + "向", + "向着", + "嘘", + "呀", + "焉", + "沿", + "沿着", + "要", + "要不", + "要不然", + "要不是", + "要么", + "要是", + "也", + "也罢", + "也好", + "一", + "一般", + "一旦", + "一方面", + "一来", + "一切", + "一样", + "一则", + "依", + "依照", + "矣", + "以", + "以便", + "以及", + "以免", + "以至", + "以至于", + "以致", + "抑或", + "因", + "因此", + "因而", + "因为", + "哟", + "用", + "由", + "由此可见", + "由于", + "有", + "有的", + "有关", + "有些", + "又", + "于", + "于是", + "于是乎", + "与", + "与此同时", + "与否", + "与其", + "越是", + "云云", + "哉", + "再说", + "再者", + "在", + "在下", + "咱", + "咱们", + "则", + "怎", + "怎么", + "怎么办", + "怎么样", + "怎样", + "咋", + "照", + "照着", + "者", + "这", + "这边", + "这儿", + "这个", + "这会儿", + "这就是说", + "这里", + "这么", + "这么点儿", + "这么些", + "这么样", + "这时", + "这些", + "这样", + "正如", + "吱", + "之", + "之类", + "之所以", + "之一", + "只是", + "只限", + "只要", + "只有", + "至", + "至于", + "诸位", + "着", + "着呢", + "自", + "自从", + "自个儿", + "自各儿", + "自己", + "自家", + "自身", + "综上所述", + "总的来看", + "总的来说", + "总的说来", + "总而言之", + "总之", + "纵", + "纵令", + "纵然", + "纵使", + "遵照", + "作为", + "兮", + "呃", + "呗", + "咚", + "咦", + "喏", + "啐", + "喔唷", + "嗬", + "嗯", + "嗳", + "~", + "!", + ".", + ":", + '"', + "'", + "(", + ")", + "*", + "A", + "白", + "社会主义", + "--", + "..", + ">>", + " [", + " ]", + "", + "<", + ">", + "/", + "\\", + "|", + "-", + "_", + "+", + "=", + "&", + "^", + "%", + "#", + "@", + "`", + ";", + "$", + "(", + ")", + "——", + "—", + "¥", + "·", + "...", + "‘", + "’", + "〉", + "〈", + "…", + " ", + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "二", + "三", + "四", + "五", + "六", + "七", + "八", + "九", + "零", + ">", + "<", + "@", + "#", + "$", + "%", + "︿", + "&", + "*", + "+", + "~", + "|", + "[", + "]", + "{", + "}", + "啊哈", + "啊呀", + "啊哟", + "挨次", + "挨个", + "挨家挨户", + "挨门挨户", + "挨门逐户", + "挨着", + "按理", + "按期", + "按时", + "按说", + "暗地里", + "暗中", + "暗自", + "昂然", + "八成", + "白白", + "半", + "梆", + "保管", + "保险", + "饱", + "背地里", + "背靠背", + "倍感", + "倍加", + "本人", + "本身", + "甭", + "比起", + "比如说", + "比照", + "毕竟", + "必", + "必定", + "必将", + "必须", + "便", + "别人", + "并非", + "并肩", + "并没", + "并没有", + "并排", + "并无", + "勃然", + "不", + "不必", + "不常", + "不大", + "不但...而且", + "不得", + "不得不", + "不得了", + "不得已", + "不迭", + "不定", + "不对", + "不妨", + "不管怎样", + "不会", + "不仅...而且", + "不仅仅", + "不仅仅是", + "不经意", + "不可开交", + "不可抗拒", + "不力", + "不了", + "不料", + "不满", + "不免", + "不能不", + "不起", + "不巧", + "不然的话", + "不日", + "不少", + "不胜", + "不时", + "不是", + "不同", + "不能", + "不要", + "不外", + "不外乎", + "不下", + "不限", + "不消", + "不已", + "不亦乐乎", + "不由得", + "不再", + "不择手段", + "不怎么", + "不曾", + "不知不觉", + "不止", + "不止一次", + "不至于", + "才", + "才能", + "策略地", + "差不多", + "差一点", + "常", + "常常", + "常言道", + "常言说", + "常言说得好", + "长此下去", + "长话短说", + "长期以来", + "长线", + "敞开儿", + "彻夜", + "陈年", + "趁便", + "趁机", + "趁热", + "趁势", + "趁早", + "成年", + "成年累月", + "成心", + "乘机", + "乘胜", + "乘势", + "乘隙", + "乘虚", + "诚然", + "迟早", + "充分", + "充其极", + "充其量", + "抽冷子", + "臭", + "初", + "出", + "出来", + "出去", + "除此", + "除此而外", + "除此以外", + "除开", + "除去", + "除却", + "除外", + "处处", + "川流不息", + "传", + "传说", + "传闻", + "串行", + "纯", + "纯粹", + "此后", + "此中", + "次第", + "匆匆", + "从不", + "从此", + "从此以后", + "从古到今", + "从古至今", + "从今以后", + "从宽", + "从来", + "从轻", + "从速", + "从头", + "从未", + "从无到有", + "从小", + "从新", + "从严", + "从优", + "从早到晚", + "从中", + "从重", + "凑巧", + "粗", + "存心", + "达旦", + "打从", + "打开天窗说亮话", + "大", + "大不了", + "大大", + "大抵", + "大都", + "大多", + "大凡", + "大概", + "大家", + "大举", + "大略", + "大面儿上", + "大事", + "大体", + "大体上", + "大约", + "大张旗鼓", + "大致", + "呆呆地", + "带", + "殆", + "待到", + "单", + "单纯", + "单单", + "但愿", + "弹指之间", + "当场", + "当儿", + "当即", + "当口儿", + "当然", + "当庭", + "当头", + "当下", + "当真", + "当中", + "倒不如", + "倒不如说", + "倒是", + "到处", + "到底", + "到了儿", + "到目前为止", + "到头", + "到头来", + "得起", + "得天独厚", + "的确", + "等到", + "叮当", + "顶多", + "定", + "动不动", + "动辄", + "陡然", + "都", + "独", + "独自", + "断然", + "顿时", + "多次", + "多多", + "多多少少", + "多多益善", + "多亏", + "多年来", + "多年前", + "而后", + "而论", + "而又", + "尔等", + "二话不说", + "二话没说", + "反倒", + "反倒是", + "反而", + "反手", + "反之亦然", + "反之则", + "方", + "方才", + "方能", + "放量", + "非常", + "非得", + "分期", + "分期分批", + "分头", + "奋勇", + "愤然", + "风雨无阻", + "逢", + "弗", + "甫", + "嘎嘎", + "该当", + "概", + "赶快", + "赶早不赶晚", + "敢", + "敢情", + "敢于", + "刚", + "刚才", + "刚好", + "刚巧", + "高低", + "格外", + "隔日", + "隔夜", + "个人", + "各式", + "更", + "更加", + "更进一步", + "更为", + "公然", + "共", + "共总", + "够瞧的", + "姑且", + "古来", + "故而", + "故意", + "固", + "怪", + "怪不得", + "惯常", + "光", + "光是", + "归根到底", + "归根结底", + "过于", + "毫不", + "毫无", + "毫无保留地", + "毫无例外", + "好在", + "何必", + "何尝", + "何妨", + "何苦", + "何乐而不为", + "何须", + "何止", + "很", + "很多", + "很少", + "轰然", + "后来", + "呼啦", + "忽地", + "忽然", + "互", + "互相", + "哗啦", + "话说", + "还", + "恍然", + "会", + "豁然", + "活", + "伙同", + "或多或少", + "或许", + "基本", + "基本上", + "基于", + "极", + "极大", + "极度", + "极端", + "极力", + "极其", + "极为", + "急匆匆", + "即将", + "即刻", + "即是说", + "几度", + "几番", + "几乎", + "几经", + "既...又", + "继之", + "加上", + "加以", + "间或", + "简而言之", + "简言之", + "简直", + "见", + "将才", + "将近", + "将要", + "交口", + "较比", + "较为", + "接连不断", + "接下来", + "皆可", + "截然", + "截至", + "藉以", + "借此", + "借以", + "届时", + "仅", + "仅仅", + "谨", + "进来", + "进去", + "近", + "近几年来", + "近来", + "近年来", + "尽管如此", + "尽可能", + "尽快", + "尽量", + "尽然", + "尽如人意", + "尽心竭力", + "尽心尽力", + "尽早", + "精光", + "经常", + "竟", + "竟然", + "究竟", + "就此", + "就地", + "就算", + "居然", + "局外", + "举凡", + "据称", + "据此", + "据实", + "据说", + "据我所知", + "据悉", + "具体来说", + "决不", + "决非", + "绝", + "绝不", + "绝顶", + "绝对", + "绝非", + "均", + "喀", + "看", + "看来", + "看起来", + "看上去", + "看样子", + "可好", + "可能", + "恐怕", + "快", + "快要", + "来不及", + "来得及", + "来讲", + "来看", + "拦腰", + "牢牢", + "老", + "老大", + "老老实实", + "老是", + "累次", + "累年", + "理当", + "理该", + "理应", + "历", + "立", + "立地", + "立刻", + "立马", + "立时", + "联袂", + "连连", + "连日", + "连日来", + "连声", + "连袂", + "临到", + "另方面", + "另行", + "另一个", + "路经", + "屡", + "屡次", + "屡次三番", + "屡屡", + "缕缕", + "率尔", + "率然", + "略", + "略加", + "略微", + "略为", + "论说", + "马上", + "蛮", + "满", + "没", + "没有", + "每逢", + "每每", + "每时每刻", + "猛然", + "猛然间", + "莫", + "莫不", + "莫非", + "莫如", + "默默地", + "默然", + "呐", + "那末", + "奈", + "难道", + "难得", + "难怪", + "难说", + "内", + "年复一年", + "凝神", + "偶而", + "偶尔", + "怕", + "砰", + "碰巧", + "譬如", + "偏偏", + "乒", + "平素", + "颇", + "迫于", + "扑通", + "其后", + "其实", + "奇", + "齐", + "起初", + "起来", + "起首", + "起头", + "起先", + "岂", + "岂非", + "岂止", + "迄", + "恰逢", + "恰好", + "恰恰", + "恰巧", + "恰如", + "恰似", + "千", + "千万", + "千万千万", + "切", + "切不可", + "切莫", + "切切", + "切勿", + "窃", + "亲口", + "亲身", + "亲手", + "亲眼", + "亲自", + "顷", + "顷刻", + "顷刻间", + "顷刻之间", + "请勿", + "穷年累月", + "取道", + "去", + "权时", + "全都", + "全力", + "全年", + "全然", + "全身心", + "然", + "人人", + "仍", + "仍旧", + "仍然", + "日复一日", + "日见", + "日渐", + "日益", + "日臻", + "如常", + "如此等等", + "如次", + "如今", + "如期", + "如前所述", + "如上", + "如下", + "汝", + "三番两次", + "三番五次", + "三天两头", + "瑟瑟", + "沙沙", + "上", + "上来", + "上去", + "一个", + "月", + "日", + "\n", } diff --git a/api/core/rag/datasource/keyword/keyword_base.py b/api/core/rag/datasource/keyword/keyword_base.py index b77c6562b2..27e4f383ad 100644 --- a/api/core/rag/datasource/keyword/keyword_base.py +++ b/api/core/rag/datasource/keyword/keyword_base.py @@ -8,7 +8,6 @@ from models.dataset import Dataset class BaseKeyword(ABC): - def __init__(self, dataset: Dataset): self.dataset = dataset @@ -31,15 +30,12 @@ class BaseKeyword(ABC): def delete(self) -> None: raise NotImplementedError - def search( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search(self, query: str, **kwargs: Any) -> list[Document]: raise NotImplementedError def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts[:]: - doc_id = text.metadata['doc_id'] + doc_id = text.metadata["doc_id"] exists_duplicate_node = self.text_exists(doc_id) if exists_duplicate_node: texts.remove(text) @@ -47,4 +43,4 @@ class BaseKeyword(ABC): return texts def _get_uuids(self, texts: list[Document]) -> list[str]: - return [text.metadata['doc_id'] for text in texts] + return [text.metadata["doc_id"] for text in texts] diff --git a/api/core/rag/datasource/keyword/keyword_factory.py b/api/core/rag/datasource/keyword/keyword_factory.py index 6ac610f82b..3c99f33be6 100644 --- a/api/core/rag/datasource/keyword/keyword_factory.py +++ b/api/core/rag/datasource/keyword/keyword_factory.py @@ -20,9 +20,7 @@ class Keyword: raise ValueError("Keyword store must be specified.") if keyword_type == "jieba": - return Jieba( - dataset=self._dataset - ) + return Jieba(dataset=self._dataset) else: raise ValueError(f"Keyword store {keyword_type} is not supported.") @@ -41,10 +39,7 @@ class Keyword: def delete(self) -> None: self._keyword_processor.delete() - def search( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search(self, query: str, **kwargs: Any) -> list[Document]: return self._keyword_processor.search(query, **kwargs) def __getattr__(self, name): diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 3932e90042..afac1bf300 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -7,78 +7,88 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector from core.rag.rerank.constants.rerank_mode import RerankMode -from core.rag.retrieval.retrival_methods import RetrievalMethod +from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from models.dataset import Dataset default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } class RetrievalService: - @classmethod - def retrieve(cls, retrival_method: str, dataset_id: str, query: str, - top_k: int, score_threshold: Optional[float] = .0, - reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = 'reranking_model', - weights: Optional[dict] = None): - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + def retrieve( + cls, + retrieval_method: str, + dataset_id: str, + query: str, + top_k: int, + score_threshold: Optional[float] = 0.0, + reranking_model: Optional[dict] = None, + reranking_mode: Optional[str] = "reranking_model", + weights: Optional[dict] = None, + ): + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: return [] all_documents = [] threads = [] exceptions = [] # retrieval_model source with keyword - if retrival_method == 'keyword_search': - keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'top_k': top_k, - 'all_documents': all_documents, - 'exceptions': exceptions, - }) + if retrieval_method == "keyword_search": + keyword_thread = threading.Thread( + target=RetrievalService.keyword_search, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset_id, + "query": query, + "top_k": top_k, + "all_documents": all_documents, + "exceptions": exceptions, + }, + ) threads.append(keyword_thread) keyword_thread.start() # retrieval_model source with semantic - if RetrievalMethod.is_support_semantic_search(retrival_method): - embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'top_k': top_k, - 'score_threshold': score_threshold, - 'reranking_model': reranking_model, - 'all_documents': all_documents, - 'retrival_method': retrival_method, - 'exceptions': exceptions, - }) + if RetrievalMethod.is_support_semantic_search(retrieval_method): + embedding_thread = threading.Thread( + target=RetrievalService.embedding_search, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset_id, + "query": query, + "top_k": top_k, + "score_threshold": score_threshold, + "reranking_model": reranking_model, + "all_documents": all_documents, + "retrieval_method": retrieval_method, + "exceptions": exceptions, + }, + ) threads.append(embedding_thread) embedding_thread.start() # retrieval source with full text - if RetrievalMethod.is_support_fulltext_search(retrival_method): - full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'retrival_method': retrival_method, - 'score_threshold': score_threshold, - 'top_k': top_k, - 'reranking_model': reranking_model, - 'all_documents': all_documents, - 'exceptions': exceptions, - }) + if RetrievalMethod.is_support_fulltext_search(retrieval_method): + full_text_index_thread = threading.Thread( + target=RetrievalService.full_text_index_search, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset_id, + "query": query, + "retrieval_method": retrieval_method, + "score_threshold": score_threshold, + "top_k": top_k, + "reranking_model": reranking_model, + "all_documents": all_documents, + "exceptions": exceptions, + }, + ) threads.append(full_text_index_thread) full_text_index_thread.start() @@ -86,110 +96,117 @@ class RetrievalService: thread.join() if exceptions: - exception_message = ';\n'.join(exceptions) + exception_message = ";\n".join(exceptions) raise Exception(exception_message) - if retrival_method == RetrievalMethod.HYBRID_SEARCH.value: - data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode, - reranking_model, weights, False) + if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: + data_post_processor = DataPostProcessor( + str(dataset.tenant_id), reranking_mode, reranking_model, weights, False + ) all_documents = data_post_processor.invoke( - query=query, - documents=all_documents, - score_threshold=score_threshold, - top_n=top_k + query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k ) return all_documents @classmethod - def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str, - top_k: int, all_documents: list, exceptions: list): + def keyword_search( + cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list + ): with flask_app.app_context(): try: - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() - keyword = Keyword( - dataset=dataset - ) + keyword = Keyword(dataset=dataset) - documents = keyword.search( - cls.escape_query_for_search(query), - top_k=top_k - ) + documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k) all_documents.extend(documents) except Exception as e: exceptions.append(str(e)) @classmethod - def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str, - top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], - all_documents: list, retrival_method: str, exceptions: list): + def embedding_search( + cls, + flask_app: Flask, + dataset_id: str, + query: str, + top_k: int, + score_threshold: Optional[float], + reranking_model: Optional[dict], + all_documents: list, + retrieval_method: str, + exceptions: list, + ): with flask_app.app_context(): try: - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() - vector = Vector( - dataset=dataset - ) + vector = Vector(dataset=dataset) documents = vector.search_by_vector( cls.escape_query_for_search(query), - search_type='similarity_score_threshold', + search_type="similarity_score_threshold", top_k=top_k, score_threshold=score_threshold, - filter={ - 'group_id': [dataset.id] - } + filter={"group_id": [dataset.id]}, ) if documents: - if reranking_model and retrival_method == RetrievalMethod.SEMANTIC_SEARCH.value: - data_post_processor = DataPostProcessor(str(dataset.tenant_id), - RerankMode.RERANKING_MODEL.value, - reranking_model, None, False) - all_documents.extend(data_post_processor.invoke( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=len(documents) - )) + if ( + reranking_model + and reranking_model.get("reranking_model_name") + and reranking_model.get("reranking_provider_name") + and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value + ): + data_post_processor = DataPostProcessor( + str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False + ) + all_documents.extend( + data_post_processor.invoke( + query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents) + ) + ) else: all_documents.extend(documents) except Exception as e: exceptions.append(str(e)) @classmethod - def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str, - top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], - all_documents: list, retrival_method: str, exceptions: list): + def full_text_index_search( + cls, + flask_app: Flask, + dataset_id: str, + query: str, + top_k: int, + score_threshold: Optional[float], + reranking_model: Optional[dict], + all_documents: list, + retrieval_method: str, + exceptions: list, + ): with flask_app.app_context(): try: - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() vector_processor = Vector( dataset=dataset, ) - documents = vector_processor.search_by_full_text( - cls.escape_query_for_search(query), - top_k=top_k - ) + documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k) if documents: - if reranking_model and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH.value: - data_post_processor = DataPostProcessor(str(dataset.tenant_id), - RerankMode.RERANKING_MODEL.value, - reranking_model, None, False) - all_documents.extend(data_post_processor.invoke( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=len(documents) - )) + if ( + reranking_model + and reranking_model.get("reranking_model_name") + and reranking_model.get("reranking_provider_name") + and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value + ): + data_post_processor = DataPostProcessor( + str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False + ) + all_documents.extend( + data_post_processor.invoke( + query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents) + ) + ) else: all_documents.extend(documents) except Exception as e: @@ -197,4 +214,4 @@ class RetrievalService: @staticmethod def escape_query_for_search(query: str) -> str: - return query.replace('"', '\\"') \ No newline at end of file + return query.replace('"', '\\"') diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index b78e2a59b1..a9c0eefb78 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -29,6 +29,7 @@ class AnalyticdbConfig(BaseModel): namespace_password: str = (None,) metrics: str = ("cosine",) read_timeout: int = 60000 + def to_analyticdb_client_params(self): return { "access_key_id": self.access_key_id, @@ -37,6 +38,7 @@ class AnalyticdbConfig(BaseModel): "read_timeout": self.read_timeout, } + class AnalyticdbVector(BaseVector): _instance = None _init = False @@ -57,9 +59,7 @@ class AnalyticdbVector(BaseVector): except: raise ImportError(_import_err_msg) self.config = config - self._client_config = open_api_models.Config( - user_agent="dify", **config.to_analyticdb_client_params() - ) + self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params()) self._client = Client(self._client_config) self._initialize() AnalyticdbVector._init = True @@ -77,6 +77,7 @@ class AnalyticdbVector(BaseVector): def _initialize_vector_database(self) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + request = gpdb_20160503_models.InitVectorDatabaseRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -88,6 +89,7 @@ class AnalyticdbVector(BaseVector): def _create_namespace_if_not_exists(self) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from Tea.exceptions import TeaException + try: request = gpdb_20160503_models.DescribeNamespaceRequest( dbinstance_id=self.config.instance_id, @@ -109,13 +111,12 @@ class AnalyticdbVector(BaseVector): ) self._client.create_namespace(request) else: - raise ValueError( - f"failed to create namespace {self.config.namespace}: {e}" - ) + raise ValueError(f"failed to create namespace {self.config.namespace}: {e}") def _create_collection_if_not_exists(self, embedding_dimension: int): from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from Tea.exceptions import TeaException + cache_key = f"vector_indexing_{self._collection_name}" lock_name = f"{cache_key}_lock" with redis_client.lock(lock_name, timeout=20): @@ -149,9 +150,7 @@ class AnalyticdbVector(BaseVector): ) self._client.create_collection(request) else: - raise ValueError( - f"failed to create collection {self._collection_name}: {e}" - ) + raise ValueError(f"failed to create collection {self._collection_name}: {e}") redis_client.set(collection_exist_cache_key, 1, ex=3600) def get_type(self) -> str: @@ -162,10 +161,9 @@ class AnalyticdbVector(BaseVector): self._create_collection_if_not_exists(dimension) self.add_texts(texts, embeddings) - def add_texts( - self, documents: list[Document], embeddings: list[list[float]], **kwargs - ): + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = [] for doc, embedding in zip(documents, embeddings, strict=True): metadata = { @@ -191,6 +189,7 @@ class AnalyticdbVector(BaseVector): def text_exists(self, id: str) -> bool: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -202,13 +201,14 @@ class AnalyticdbVector(BaseVector): vector=None, content=None, top_k=1, - filter=f"ref_doc_id='{id}'" + filter=f"ref_doc_id='{id}'", ) response = self._client.query_collection_data(request) return len(response.body.matches.match) > 0 def delete_by_ids(self, ids: list[str]) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + ids_str = ",".join(f"'{id}'" for id in ids) ids_str = f"({ids_str})" request = gpdb_20160503_models.DeleteCollectionDataRequest( @@ -224,6 +224,7 @@ class AnalyticdbVector(BaseVector): def delete_by_metadata_field(self, key: str, value: str) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + request = gpdb_20160503_models.DeleteCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -235,15 +236,10 @@ class AnalyticdbVector(BaseVector): ) self._client.delete_collection_data(request) - def search_by_vector( - self, query_vector: list[float], **kwargs: Any - ) -> list[Document]: + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - score_threshold = ( - kwargs.get("score_threshold", 0.0) - if kwargs.get("score_threshold", 0.0) - else 0.0 - ) + + score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -270,11 +266,8 @@ class AnalyticdbVector(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - score_threshold = ( - kwargs.get("score_threshold", 0.0) - if kwargs.get("score_threshold", 0.0) - else 0.0 - ) + + score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -304,6 +297,7 @@ class AnalyticdbVector(BaseVector): def delete(self) -> None: try: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + request = gpdb_20160503_models.DeleteCollectionRequest( collection=self._collection_name, dbinstance_id=self.config.instance_id, @@ -315,19 +309,16 @@ class AnalyticdbVector(BaseVector): except Exception as e: raise e + class AnalyticdbVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings): if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict["vector_store"][ - "class_prefix" - ] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name) - ) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)) # handle optional params if dify_config.ANALYTICDB_KEY_ID is None: diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index 3629887b44..cb38cf94a9 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -27,21 +27,20 @@ class ChromaConfig(BaseModel): settings = Settings( # auth chroma_client_auth_provider=self.auth_provider, - chroma_client_auth_credentials=self.auth_credentials + chroma_client_auth_credentials=self.auth_credentials, ) return { - 'host': self.host, - 'port': self.port, - 'ssl': False, - 'tenant': self.tenant, - 'database': self.database, - 'settings': settings, + "host": self.host, + "port": self.port, + "ssl": False, + "tenant": self.tenant, + "database": self.database, + "settings": settings, } class ChromaVector(BaseVector): - def __init__(self, collection_name: str, config: ChromaConfig): super().__init__(collection_name) self._client_config = config @@ -58,9 +57,9 @@ class ChromaVector(BaseVector): self.add_texts(texts, embeddings, **kwargs) def create_collection(self, collection_name: str): - lock_name = 'vector_indexing_lock_{}'.format(collection_name) + lock_name = "vector_indexing_lock_{}".format(collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return self._client.get_or_create_collection(collection_name) @@ -76,7 +75,7 @@ class ChromaVector(BaseVector): def delete_by_metadata_field(self, key: str, value: str): collection = self._client.get_or_create_collection(self._collection_name) - collection.delete(where={key: {'$eq': value}}) + collection.delete(where={key: {"$eq": value}}) def delete(self): self._client.delete_collection(self._collection_name) @@ -93,26 +92,26 @@ class ChromaVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: collection = self._client.get_or_create_collection(self._collection_name) results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) - score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 - ids: list[str] = results['ids'][0] - documents: list[str] = results['documents'][0] - metadatas: dict[str, Any] = results['metadatas'][0] - distances: list[float] = results['distances'][0] + ids: list[str] = results["ids"][0] + documents: list[str] = results["documents"][0] + metadatas: dict[str, Any] = results["metadatas"][0] + distances: list[float] = results["distances"][0] docs = [] for index in range(len(ids)): distance = distances[index] metadata = metadatas[index] if distance >= score_threshold: - metadata['score'] = distance + metadata["score"] = distance doc = Document( page_content=documents[index], metadata=metadata, ) docs.append(doc) - # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True) + # Sort the documents by score in descending order + docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -123,15 +122,12 @@ class ChromaVector(BaseVector): class ChromaVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - index_struct_dict = { - "type": VectorType.CHROMA, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": VectorType.CHROMA, "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) return ChromaVector( diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index 233539756f..f13723b51f 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -26,15 +26,16 @@ class ElasticSearchConfig(BaseModel): username: str password: str - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: - if not values['host']: + if not values["host"]: raise ValueError("config HOST is required") - if not values['port']: + if not values["port"]: raise ValueError("config PORT is required") - if not values['username']: + if not values["username"]: raise ValueError("config USERNAME is required") - if not values['password']: + if not values["password"]: raise ValueError("config PASSWORD is required") return values @@ -50,10 +51,10 @@ class ElasticSearchVector(BaseVector): def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: try: parsed_url = urlparse(config.host) - if parsed_url.scheme in ['http', 'https']: - hosts = f'{config.host}:{config.port}' + if parsed_url.scheme in ["http", "https"]: + hosts = f"{config.host}:{config.port}" else: - hosts = f'http://{config.host}:{config.port}' + hosts = f"http://{config.host}:{config.port}" client = Elasticsearch( hosts=hosts, basic_auth=(config.username, config.password), @@ -68,25 +69,27 @@ class ElasticSearchVector(BaseVector): def _get_version(self) -> str: info = self._client.info() - return info['version']['number'] + return info["version"]["number"] def _check_version(self): - if self._version < '8.0.0': + if self._version < "8.0.0": raise ValueError("Elasticsearch vector database version must be greater than 8.0.0") def get_type(self) -> str: - return 'elasticsearch' + return "elasticsearch" def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): uuids = self._get_uuids(documents) for i in range(len(documents)): - self._client.index(index=self._collection_name, - id=uuids[i], - document={ - Field.CONTENT_KEY.value: documents[i].page_content, - Field.VECTOR.value: embeddings[i] if embeddings[i] else None, - Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {} - }) + self._client.index( + index=self._collection_name, + id=uuids[i], + document={ + Field.CONTENT_KEY.value: documents[i].page_content, + Field.VECTOR.value: embeddings[i] if embeddings[i] else None, + Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {}, + }, + ) self._client.indices.refresh(index=self._collection_name) return uuids @@ -98,15 +101,9 @@ class ElasticSearchVector(BaseVector): self._client.delete(index=self._collection_name, id=id) def delete_by_metadata_field(self, key: str, value: str) -> None: - query_str = { - 'query': { - 'match': { - f'metadata.{key}': f'{value}' - } - } - } + query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}} results = self._client.search(index=self._collection_name, body=query_str) - ids = [hit['_id'] for hit in results['hits']['hits']] + ids = [hit["_id"] for hit in results["hits"]["hits"]] if ids: self.delete_by_ids(ids) @@ -115,44 +112,44 @@ class ElasticSearchVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 10) - knn = { - "field": Field.VECTOR.value, - "query_vector": query_vector, - "k": top_k - } + knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k} results = self._client.search(index=self._collection_name, knn=knn, size=top_k) docs_and_scores = [] - for hit in results['hits']['hits']: + for hit in results["hits"]["hits"]: docs_and_scores.append( - (Document(page_content=hit['_source'][Field.CONTENT_KEY.value], - vector=hit['_source'][Field.VECTOR.value], - metadata=hit['_source'][Field.METADATA_KEY.value]), hit['_score'])) + ( + Document( + page_content=hit["_source"][Field.CONTENT_KEY.value], + vector=hit["_source"][Field.VECTOR.value], + metadata=hit["_source"][Field.METADATA_KEY.value], + ), + hit["_score"], + ) + ) docs = [] for doc, score in docs_and_scores: - score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 if score > score_threshold: - doc.metadata['score'] = score + doc.metadata["score"] = score docs.append(doc) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - query_str = { - "match": { - Field.CONTENT_KEY.value: query - } - } + query_str = {"match": {Field.CONTENT_KEY.value: query}} results = self._client.search(index=self._collection_name, query=query_str) docs = [] - for hit in results['hits']['hits']: - docs.append(Document( - page_content=hit['_source'][Field.CONTENT_KEY.value], - vector=hit['_source'][Field.VECTOR.value], - metadata=hit['_source'][Field.METADATA_KEY.value], - )) + for hit in results["hits"]["hits"]: + docs.append( + Document( + page_content=hit["_source"][Field.CONTENT_KEY.value], + vector=hit["_source"][Field.VECTOR.value], + metadata=hit["_source"][Field.METADATA_KEY.value], + ) + ) return docs @@ -162,11 +159,11 @@ class ElasticSearchVector(BaseVector): self.add_texts(texts, embeddings, **kwargs) def create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None ): - lock_name = f'vector_indexing_lock_{self._collection_name}' + lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = f'vector_indexing_{self._collection_name}' + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" if redis_client.get(collection_exist_cache_key): logger.info(f"Collection {self._collection_name} already exists.") return @@ -179,14 +176,14 @@ class ElasticSearchVector(BaseVector): Field.VECTOR.value: { # Make sure the dimension is correct here "type": "dense_vector", "dims": dim, - "similarity": "cosine" + "similarity": "cosine", }, Field.METADATA_KEY.value: { "type": "object", "properties": { "doc_id": {"type": "keyword"} # Map doc_id to keyword type - } - } + }, + }, } } self._client.indices.create(index=self._collection_name, mappings=mappings) @@ -197,22 +194,21 @@ class ElasticSearchVector(BaseVector): class ElasticSearchVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name)) config = current_app.config return ElasticSearchVector( index_name=collection_name, config=ElasticSearchConfig( - host=config.get('ELASTICSEARCH_HOST'), - port=config.get('ELASTICSEARCH_PORT'), - username=config.get('ELASTICSEARCH_USERNAME'), - password=config.get('ELASTICSEARCH_PASSWORD'), + host=config.get("ELASTICSEARCH_HOST"), + port=config.get("ELASTICSEARCH_PORT"), + username=config.get("ELASTICSEARCH_USERNAME"), + password=config.get("ELASTICSEARCH_PASSWORD"), ), - attributes=[] + attributes=[], ) diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index cfc533ed33..d6d7136282 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -1,10 +1,9 @@ import json import logging from typing import Any, Optional -from uuid import uuid4 from pydantic import BaseModel, model_validator -from pymilvus import MilvusClient, MilvusException, connections +from pymilvus import MilvusClient, MilvusException from pymilvus.milvus_client import IndexParams from configs import dify_config @@ -21,55 +20,47 @@ logger = logging.getLogger(__name__) class MilvusConfig(BaseModel): - host: str - port: int + uri: str + token: Optional[str] = None user: str password: str - secure: bool = False batch_size: int = 100 database: str = "default" - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: - if not values.get('host'): - raise ValueError("config MILVUS_HOST is required") - if not values.get('port'): - raise ValueError("config MILVUS_PORT is required") - if not values.get('user'): + if not values.get("uri"): + raise ValueError("config MILVUS_URI is required") + if not values.get("user"): raise ValueError("config MILVUS_USER is required") - if not values.get('password'): + if not values.get("password"): raise ValueError("config MILVUS_PASSWORD is required") return values def to_milvus_params(self): return { - 'host': self.host, - 'port': self.port, - 'user': self.user, - 'password': self.password, - 'secure': self.secure, - 'db_name': self.database, + "uri": self.uri, + "token": self.token, + "user": self.user, + "password": self.password, + "db_name": self.database, } class MilvusVector(BaseVector): - def __init__(self, collection_name: str, config: MilvusConfig): super().__init__(collection_name) self._client_config = config self._client = self._init_client(config) - self._consistency_level = 'Session' + self._consistency_level = "Session" self._fields = [] def get_type(self) -> str: return VectorType.MILVUS def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - index_params = { - 'metric_type': 'IP', - 'index_type': "HNSW", - 'params': {"M": 8, "efConstruction": 64} - } + index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}} metadatas = [d.metadata for d in texts] self.create_collection(embeddings, metadatas, index_params) self.add_texts(texts, embeddings) @@ -80,7 +71,7 @@ class MilvusVector(BaseVector): insert_dict = { Field.CONTENT_KEY.value: documents[i].page_content, Field.VECTOR.value: embeddings[i], - Field.METADATA_KEY.value: documents[i].metadata + Field.METADATA_KEY.value: documents[i].metadata, } insert_dict_list.append(insert_dict) # Total insert count @@ -89,111 +80,70 @@ class MilvusVector(BaseVector): pks: list[str] = [] for i in range(0, total_count, 1000): - batch_insert_list = insert_dict_list[i:i + 1000] + batch_insert_list = insert_dict_list[i : i + 1000] # Insert into the collection. try: ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list) pks.extend(ids) except MilvusException as e: - logger.error( - "Failed to insert batch starting at entity: %s/%s", i, total_count - ) + logger.error("Failed to insert batch starting at entity: %s/%s", i, total_count) raise e return pks def get_ids_by_metadata_field(self, key: str, value: str): - result = self._client.query(collection_name=self._collection_name, - filter=f'metadata["{key}"] == "{value}"', - output_fields=["id"]) + result = self._client.query( + collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"] + ) if result: return [item["id"] for item in result] else: return None def delete_by_metadata_field(self, key: str, value: str): - alias = uuid4().hex - if self._client_config.secure: - uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) - else: - uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) - connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password, - db_name=self._client_config.database) - - from pymilvus import utility - if utility.has_collection(self._collection_name, using=alias): - + if self._client.has_collection(self._collection_name): ids = self.get_ids_by_metadata_field(key, value) if ids: self._client.delete(collection_name=self._collection_name, pks=ids) def delete_by_ids(self, ids: list[str]) -> None: - alias = uuid4().hex - if self._client_config.secure: - uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) - else: - uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) - connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password, - db_name=self._client_config.database) - - from pymilvus import utility - if utility.has_collection(self._collection_name, using=alias): - - result = self._client.query(collection_name=self._collection_name, - filter=f'metadata["doc_id"] in {ids}', - output_fields=["id"]) + if self._client.has_collection(self._collection_name): + result = self._client.query( + collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"] + ) if result: ids = [item["id"] for item in result] self._client.delete(collection_name=self._collection_name, pks=ids) def delete(self) -> None: - alias = uuid4().hex - if self._client_config.secure: - uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) - else: - uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) - connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password, - db_name=self._client_config.database) - - from pymilvus import utility - if utility.has_collection(self._collection_name, using=alias): - utility.drop_collection(self._collection_name, None, using=alias) + if self._client.has_collection(self._collection_name): + self._client.drop_collection(self._collection_name, None) def text_exists(self, id: str) -> bool: - alias = uuid4().hex - if self._client_config.secure: - uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) - else: - uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) - connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password, - db_name=self._client_config.database) - - from pymilvus import utility - if not utility.has_collection(self._collection_name, using=alias): + if not self._client.has_collection(self._collection_name): return False - result = self._client.query(collection_name=self._collection_name, - filter=f'metadata["doc_id"] == "{id}"', - output_fields=["id"]) + result = self._client.query( + collection_name=self._collection_name, filter=f'metadata["doc_id"] == "{id}"', output_fields=["id"] + ) return len(result) > 0 def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: - # Set search parameters. - results = self._client.search(collection_name=self._collection_name, - data=[query_vector], - limit=kwargs.get('top_k', 4), - output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], - ) + results = self._client.search( + collection_name=self._collection_name, + data=[query_vector], + limit=kwargs.get("top_k", 4), + output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], + ) # Organize results. docs = [] for result in results[0]: - metadata = result['entity'].get(Field.METADATA_KEY.value) - metadata['score'] = result['distance'] - score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 - if result['distance'] > score_threshold: - doc = Document(page_content=result['entity'].get(Field.CONTENT_KEY.value), - metadata=metadata) + metadata = result["entity"].get(Field.METADATA_KEY.value) + metadata["score"] = result["distance"] + score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 + if result["distance"] > score_threshold: + doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata) docs.append(doc) return docs @@ -202,23 +152,15 @@ class MilvusVector(BaseVector): return [] def create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None ): - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return # Grab the existing collection if it exists - from pymilvus import utility - alias = uuid4().hex - if self._client_config.secure: - uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) - else: - uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) - connections.connect(alias=alias, uri=uri, user=self._client_config.user, - password=self._client_config.password, db_name=self._client_config.database) - if not utility.has_collection(self._collection_name, using=alias): + if not self._client.has_collection(self._collection_name): from pymilvus import CollectionSchema, DataType, FieldSchema from pymilvus.orm.types import infer_dtype_bydata @@ -229,19 +171,11 @@ class MilvusVector(BaseVector): fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535)) # Create the text field - fields.append( - FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535) - ) + fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)) # Create the primary key field - fields.append( - FieldSchema( - Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True - ) - ) + fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True)) # Create the vector field, supports binary or float vectors - fields.append( - FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim) - ) + fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)) # Create the schema for the collection schema = CollectionSchema(fields) @@ -257,39 +191,36 @@ class MilvusVector(BaseVector): # Create the collection collection_name = self._collection_name - self._client.create_collection(collection_name=collection_name, - schema=schema, index_params=index_params_obj, - consistency_level=self._consistency_level) + self._client.create_collection( + collection_name=collection_name, + schema=schema, + index_params=index_params_obj, + consistency_level=self._consistency_level, + ) redis_client.set(collection_exist_cache_key, 1, ex=3600) def _init_client(self, config) -> MilvusClient: - if config.secure: - uri = "https://" + str(config.host) + ":" + str(config.port) - else: - uri = "http://" + str(config.host) + ":" + str(config.port) - client = MilvusClient(uri=uri, user=config.user, password=config.password, db_name=config.database) + client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database) return client class MilvusVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.MILVUS, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MILVUS, collection_name)) return MilvusVector( collection_name=collection_name, config=MilvusConfig( - host=dify_config.MILVUS_HOST, - port=dify_config.MILVUS_PORT, + uri=dify_config.MILVUS_URI, + token=dify_config.MILVUS_TOKEN, user=dify_config.MILVUS_USER, password=dify_config.MILVUS_PASSWORD, - secure=dify_config.MILVUS_SECURE, database=dify_config.MILVUS_DATABASE, - ) + ), ) diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index 05e75effef..90464ac42a 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -31,7 +31,6 @@ class SortOrder(Enum): class MyScaleVector(BaseVector): - def __init__(self, collection_name: str, config: MyScaleConfig, metric: str = "Cosine"): super().__init__(collection_name) self._config = config @@ -80,7 +79,7 @@ class MyScaleVector(BaseVector): doc_id, self.escape_str(doc.page_content), embeddings[i], - json.dumps(doc.metadata) if doc.metadata else {} + json.dumps(doc.metadata) if doc.metadata else {}, ) values.append(str(row)) ids.append(doc_id) @@ -101,7 +100,8 @@ class MyScaleVector(BaseVector): def delete_by_ids(self, ids: list[str]) -> None: self._client.command( - f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}") + f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}" + ) def get_ids_by_metadata_field(self, key: str, value: str): rows = self._client.query( @@ -122,9 +122,12 @@ class MyScaleVector(BaseVector): def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) - score_threshold = kwargs.get('score_threshold') or 0.0 - where_str = f"WHERE dist < {1 - score_threshold}" if \ - self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else "" + score_threshold = kwargs.get("score_threshold") or 0.0 + where_str = ( + f"WHERE dist < {1 - score_threshold}" + if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 + else "" + ) sql = f""" SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name} {where_str} ORDER BY dist {order.value} LIMIT {top_k} @@ -133,7 +136,7 @@ class MyScaleVector(BaseVector): return [ Document( page_content=r["text"], - vector=r['vector'], + vector=r["vector"], metadata=r["metadata"], ) for r in self._client.query(sql).named_results() @@ -149,13 +152,12 @@ class MyScaleVector(BaseVector): class MyScaleVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MyScaleVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.MYSCALE, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MYSCALE, collection_name)) return MyScaleVector( collection_name=collection_name, diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index c95d202173..7c0f620956 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -28,11 +28,12 @@ class OpenSearchConfig(BaseModel): password: Optional[str] = None secure: bool = False - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: - if not values.get('host'): + if not values.get("host"): raise ValueError("config OPENSEARCH_HOST is required") - if not values.get('port'): + if not values.get("port"): raise ValueError("config OPENSEARCH_PORT is required") return values @@ -44,19 +45,18 @@ class OpenSearchConfig(BaseModel): def to_opensearch_params(self) -> dict[str, Any]: params = { - 'hosts': [{'host': self.host, 'port': self.port}], - 'use_ssl': self.secure, - 'verify_certs': self.secure, + "hosts": [{"host": self.host, "port": self.port}], + "use_ssl": self.secure, + "verify_certs": self.secure, } if self.user and self.password: - params['http_auth'] = (self.user, self.password) + params["http_auth"] = (self.user, self.password) if self.secure: - params['ssl_context'] = self.create_ssl_context() + params["ssl_context"] = self.create_ssl_context() return params class OpenSearchVector(BaseVector): - def __init__(self, collection_name: str, config: OpenSearchConfig): super().__init__(collection_name) self._client_config = config @@ -81,7 +81,7 @@ class OpenSearchVector(BaseVector): Field.CONTENT_KEY.value: documents[i].page_content, Field.VECTOR.value: embeddings[i], # Make sure you pass an array here Field.METADATA_KEY.value: documents[i].metadata, - } + }, } actions.append(action) @@ -90,8 +90,8 @@ class OpenSearchVector(BaseVector): def get_ids_by_metadata_field(self, key: str, value: str): query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}} response = self._client.search(index=self._collection_name.lower(), body=query) - if response['hits']['hits']: - return [hit['_id'] for hit in response['hits']['hits']] + if response["hits"]["hits"]: + return [hit["_id"] for hit in response["hits"]["hits"]] else: return None @@ -110,7 +110,7 @@ class OpenSearchVector(BaseVector): actual_ids = [] for doc_id in ids: - es_ids = self.get_ids_by_metadata_field('doc_id', doc_id) + es_ids = self.get_ids_by_metadata_field("doc_id", doc_id) if es_ids: actual_ids.extend(es_ids) else: @@ -122,9 +122,9 @@ class OpenSearchVector(BaseVector): helpers.bulk(self._client, actions) except BulkIndexError as e: for error in e.errors: - delete_error = error.get('delete', {}) - status = delete_error.get('status') - doc_id = delete_error.get('_id') + delete_error = error.get("delete", {}) + status = delete_error.get("status") + doc_id = delete_error.get("_id") if status == 404: logger.warning(f"Document not found for deletion: {doc_id}") @@ -151,15 +151,8 @@ class OpenSearchVector(BaseVector): raise ValueError("All elements in query_vector should be floats") query = { - "size": kwargs.get('top_k', 4), - "query": { - "knn": { - Field.VECTOR.value: { - Field.VECTOR.value: query_vector, - "k": kwargs.get('top_k', 4) - } - } - } + "size": kwargs.get("top_k", 4), + "query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}}, } try: @@ -169,17 +162,17 @@ class OpenSearchVector(BaseVector): raise docs = [] - for hit in response['hits']['hits']: - metadata = hit['_source'].get(Field.METADATA_KEY.value, {}) + for hit in response["hits"]["hits"]: + metadata = hit["_source"].get(Field.METADATA_KEY.value, {}) # Make sure metadata is a dictionary if metadata is None: metadata = {} - metadata['score'] = hit['_score'] - score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 - if hit['_score'] > score_threshold: - doc = Document(page_content=hit['_source'].get(Field.CONTENT_KEY.value), metadata=metadata) + metadata["score"] = hit["_score"] + score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 + if hit["_score"] > score_threshold: + doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata) docs.append(doc) return docs @@ -190,32 +183,28 @@ class OpenSearchVector(BaseVector): response = self._client.search(index=self._collection_name.lower(), body=full_text_query) docs = [] - for hit in response['hits']['hits']: - metadata = hit['_source'].get(Field.METADATA_KEY.value) - vector = hit['_source'].get(Field.VECTOR.value) - page_content = hit['_source'].get(Field.CONTENT_KEY.value) + for hit in response["hits"]["hits"]: + metadata = hit["_source"].get(Field.METADATA_KEY.value) + vector = hit["_source"].get(Field.VECTOR.value) + page_content = hit["_source"].get(Field.CONTENT_KEY.value) doc = Document(page_content=page_content, vector=vector, metadata=metadata) docs.append(doc) return docs def create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None ): - lock_name = f'vector_indexing_lock_{self._collection_name.lower()}' + lock_name = f"vector_indexing_lock_{self._collection_name.lower()}" with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = f'vector_indexing_{self._collection_name.lower()}' + collection_exist_cache_key = f"vector_indexing_{self._collection_name.lower()}" if redis_client.get(collection_exist_cache_key): logger.info(f"Collection {self._collection_name.lower()} already exists.") return if not self._client.indices.exists(index=self._collection_name.lower()): index_body = { - "settings": { - "index": { - "knn": True - } - }, + "settings": {"index": {"knn": True}}, "mappings": { "properties": { Field.CONTENT_KEY.value: {"type": "text"}, @@ -226,20 +215,17 @@ class OpenSearchVector(BaseVector): "name": "hnsw", "space_type": "l2", "engine": "faiss", - "parameters": { - "ef_construction": 64, - "m": 8 - } - } + "parameters": {"ef_construction": 64, "m": 8}, + }, }, Field.METADATA_KEY.value: { "type": "object", "properties": { "doc_id": {"type": "keyword"} # Map doc_id to keyword type - } - } + }, + }, } - } + }, } self._client.indices.create(index=self._collection_name.lower(), body=index_body) @@ -248,17 +234,14 @@ class OpenSearchVector(BaseVector): class OpenSearchVectorFactory(AbstractVectorFactory): - def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OpenSearchVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name)) - + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name)) open_search_config = OpenSearchConfig( host=dify_config.OPENSEARCH_HOST, @@ -268,7 +251,4 @@ class OpenSearchVectorFactory(AbstractVectorFactory): secure=dify_config.OPENSEARCH_SECURE, ) - return OpenSearchVector( - collection_name=collection_name, - config=open_search_config - ) + return OpenSearchVector(collection_name=collection_name, config=open_search_config) diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index aa2c6171c3..06c20ceb5f 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -31,7 +31,8 @@ class OracleVectorConfig(BaseModel): password: str database: str - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: if not values["host"]: raise ValueError("config ORACLE_HOST is required") @@ -103,9 +104,16 @@ class OracleVector(BaseVector): arraysize=cursor.arraysize, outconverter=self.numpy_converter_out, ) - def _create_connection_pool(self, config: OracleVectorConfig): - return oracledb.create_pool(user=config.user, password=config.password, dsn="{}:{}/{}".format(config.host, config.port, config.database), min=1, max=50, increment=1) + def _create_connection_pool(self, config: OracleVectorConfig): + return oracledb.create_pool( + user=config.user, + password=config.password, + dsn="{}:{}/{}".format(config.host, config.port, config.database), + min=1, + max=50, + increment=1, + ) @contextmanager def _get_cursor(self): @@ -136,13 +144,15 @@ class OracleVector(BaseVector): doc_id, doc.page_content, json.dumps(doc.metadata), - #array.array("f", embeddings[i]), + # array.array("f", embeddings[i]), numpy.array(embeddings[i]), ) ) - #print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)") + # print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)") with self._get_cursor() as cur: - cur.executemany(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values) + cur.executemany( + f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values + ) return pks def text_exists(self, id: str) -> bool: @@ -157,7 +167,8 @@ class OracleVector(BaseVector): for record in cur: docs.append(Document(page_content=record[1], metadata=record[0])) return docs - #def get_ids_by_metadata_field(self, key: str, value: str): + + # def get_ids_by_metadata_field(self, key: str, value: str): # with self._get_cursor() as cur: # cur.execute(f"SELECT id FROM {self.table_name} d WHERE d.meta.{key}='{value}'" ) # idss = [] @@ -184,7 +195,8 @@ class OracleVector(BaseVector): top_k = kwargs.get("top_k", 5) with self._get_cursor() as cur: cur.execute( - f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name} ORDER BY distance fetch first {top_k} rows only" ,[numpy.array(query_vector)] + f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name} ORDER BY distance fetch first {top_k} rows only", + [numpy.array(query_vector)], ) docs = [] score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 @@ -202,7 +214,7 @@ class OracleVector(BaseVector): score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 if len(query) > 0: # Check which language the query is in - zh_pattern = re.compile('[\u4e00-\u9fa5]+') + zh_pattern = re.compile("[\u4e00-\u9fa5]+") match = zh_pattern.search(query) entities = [] # match: query condition maybe is a chinese sentence, so using Jieba split,else using nltk split @@ -210,7 +222,15 @@ class OracleVector(BaseVector): words = pseg.cut(query) current_entity = "" for word, pos in words: - if pos == 'nr' or pos == 'Ng' or pos == 'eng' or pos == 'nz' or pos == 'n' or pos == 'ORG' or pos == 'v': # nr: 人名, ns: 地名, nt: 机构名 + if ( + pos == "nr" + or pos == "Ng" + or pos == "eng" + or pos == "nz" + or pos == "n" + or pos == "ORG" + or pos == "v" + ): # nr: 人名, ns: 地名, nt: 机构名 current_entity += word else: if current_entity: @@ -220,22 +240,22 @@ class OracleVector(BaseVector): entities.append(current_entity) else: try: - nltk.data.find('tokenizers/punkt') - nltk.data.find('corpora/stopwords') + nltk.data.find("tokenizers/punkt") + nltk.data.find("corpora/stopwords") except LookupError: - nltk.download('punkt') - nltk.download('stopwords') + nltk.download("punkt") + nltk.download("stopwords") print("run download") - e_str = re.sub(r'[^\w ]', '', query) + e_str = re.sub(r"[^\w ]", "", query) all_tokens = nltk.word_tokenize(e_str) - stop_words = stopwords.words('english') + stop_words = stopwords.words("english") for token in all_tokens: if token not in stop_words: entities.append(token) with self._get_cursor() as cur: cur.execute( f"select meta, text, embedding FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only", - [" ACCUM ".join(entities)] + [" ACCUM ".join(entities)], ) docs = [] for record in cur: @@ -273,8 +293,7 @@ class OracleVectorFactory(AbstractVectorFactory): else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.ORACLE, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ORACLE, collection_name)) return OracleVector( collection_name=collection_name, diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index a48224070f..24b391d63a 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -31,27 +31,29 @@ class PgvectoRSConfig(BaseModel): password: str database: str - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: - if not values['host']: + if not values["host"]: raise ValueError("config PGVECTO_RS_HOST is required") - if not values['port']: + if not values["port"]: raise ValueError("config PGVECTO_RS_PORT is required") - if not values['user']: + if not values["user"]: raise ValueError("config PGVECTO_RS_USER is required") - if not values['password']: + if not values["password"]: raise ValueError("config PGVECTO_RS_PASSWORD is required") - if not values['database']: + if not values["database"]: raise ValueError("config PGVECTO_RS_DATABASE is required") return values class PGVectoRS(BaseVector): - def __init__(self, collection_name: str, config: PgvectoRSConfig, dim: int): super().__init__(collection_name) self._client_config = config - self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" + self._url = ( + f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" + ) self._client = create_engine(self._url) with Session(self._client) as session: session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors")) @@ -80,9 +82,9 @@ class PGVectoRS(BaseVector): self.add_texts(texts, embeddings) def create_collection(self, dimension: int): - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return index_name = f"{self._collection_name}_embedding_index" @@ -133,9 +135,7 @@ class PGVectoRS(BaseVector): def get_ids_by_metadata_field(self, key: str, value: str): result = None with Session(self._client) as session: - select_statement = sql_text( - f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; " - ) + select_statement = sql_text(f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; ") result = session.execute(select_statement).fetchall() if result: return [item[0] for item in result] @@ -143,12 +143,11 @@ class PGVectoRS(BaseVector): return None def delete_by_metadata_field(self, key: str, value: str): - ids = self.get_ids_by_metadata_field(key, value) if ids: with Session(self._client) as session: select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)") - session.execute(select_statement, {'ids': ids}) + session.execute(select_statement, {"ids": ids}) session.commit() def delete_by_ids(self, ids: list[str]) -> None: @@ -156,13 +155,13 @@ class PGVectoRS(BaseVector): select_statement = sql_text( f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = ANY (:doc_ids); " ) - result = session.execute(select_statement, {'doc_ids': ids}).fetchall() + result = session.execute(select_statement, {"doc_ids": ids}).fetchall() if result: ids = [item[0] for item in result] if ids: with Session(self._client) as session: select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)") - session.execute(select_statement, {'ids': ids}) + session.execute(select_statement, {"ids": ids}) session.commit() def delete(self) -> None: @@ -187,7 +186,7 @@ class PGVectoRS(BaseVector): query_vector, ).label("distance"), ) - .limit(kwargs.get('top_k', 2)) + .limit(kwargs.get("top_k", 2)) .order_by("distance") ) res = session.execute(stmt) @@ -198,11 +197,10 @@ class PGVectoRS(BaseVector): for record, dis in results: metadata = record.meta score = 1 - dis - metadata['score'] = score - score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 + metadata["score"] = score + score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 if score > score_threshold: - doc = Document(page_content=record.text, - metadata=metadata) + doc = Document(page_content=record.text, metadata=metadata) docs.append(doc) return docs @@ -225,13 +223,12 @@ class PGVectoRS(BaseVector): class PGVectoRSFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVectoRS: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) dim = len(embeddings.embed_query("pgvecto_rs")) return PGVectoRS( @@ -243,5 +240,5 @@ class PGVectoRSFactory(AbstractVectorFactory): password=dify_config.PGVECTO_RS_PASSWORD, database=dify_config.PGVECTO_RS_DATABASE, ), - dim=dim + dim=dim, ) diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index c9f2f35af0..38dfd24b56 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -24,7 +24,8 @@ class PGVectorConfig(BaseModel): password: str database: str - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: if not values["host"]: raise ValueError("config PGVECTOR_HOST is required") @@ -201,8 +202,7 @@ class PGVectorFactory(AbstractVectorFactory): else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name)) return PGVector( collection_name=collection_name, diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 297bff928e..83d561819c 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -48,28 +48,25 @@ class QdrantConfig(BaseModel): prefer_grpc: bool = False def to_qdrant_params(self): - if self.endpoint and self.endpoint.startswith('path:'): - path = self.endpoint.replace('path:', '') + if self.endpoint and self.endpoint.startswith("path:"): + path = self.endpoint.replace("path:", "") if not os.path.isabs(path): path = os.path.join(self.root_path, path) - return { - 'path': path - } + return {"path": path} else: return { - 'url': self.endpoint, - 'api_key': self.api_key, - 'timeout': self.timeout, - 'verify': self.endpoint.startswith('https'), - 'grpc_port': self.grpc_port, - 'prefer_grpc': self.prefer_grpc + "url": self.endpoint, + "api_key": self.api_key, + "timeout": self.timeout, + "verify": self.endpoint.startswith("https"), + "grpc_port": self.grpc_port, + "prefer_grpc": self.prefer_grpc, } class QdrantVector(BaseVector): - - def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = 'Cosine'): + def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = "Cosine"): super().__init__(collection_name) self._client_config = config self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params()) @@ -80,10 +77,7 @@ class QdrantVector(BaseVector): return VectorType.QDRANT def to_index_struct(self) -> dict: - return { - "type": self.get_type(), - "vector_store": {"class_prefix": self._collection_name} - } + return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): if texts: @@ -97,9 +91,9 @@ class QdrantVector(BaseVector): self.add_texts(texts, embeddings, **kwargs) def create_collection(self, collection_name: str, vector_size: int): - lock_name = 'vector_indexing_lock_{}'.format(collection_name) + lock_name = "vector_indexing_lock_{}".format(collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return collection_name = collection_name or uuid.uuid4().hex @@ -110,12 +104,19 @@ class QdrantVector(BaseVector): all_collection_name.append(collection.name) if collection_name not in all_collection_name: from qdrant_client.http import models as rest + vectors_config = rest.VectorParams( size=vector_size, distance=rest.Distance[self._distance_func], ) - hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, - max_indexing_threads=0, on_disk=False) + hnsw_config = HnswConfigDiff( + m=0, + payload_m=16, + ef_construct=100, + full_scan_threshold=10000, + max_indexing_threads=0, + on_disk=False, + ) self._client.recreate_collection( collection_name=collection_name, vectors_config=vectors_config, @@ -124,21 +125,24 @@ class QdrantVector(BaseVector): ) # create group_id payload index - self._client.create_payload_index(collection_name, Field.GROUP_KEY.value, - field_schema=PayloadSchemaType.KEYWORD) + self._client.create_payload_index( + collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD + ) # create doc_id payload index - self._client.create_payload_index(collection_name, Field.DOC_ID.value, - field_schema=PayloadSchemaType.KEYWORD) + self._client.create_payload_index( + collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD + ) # create full text index text_index_params = TextIndexParams( type=TextIndexType.TEXT, tokenizer=TokenizerType.MULTILINGUAL, min_token_len=2, max_token_len=20, - lowercase=True + lowercase=True, + ) + self._client.create_payload_index( + collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params ) - self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value, - field_schema=text_index_params) redis_client.set(collection_exist_cache_key, 1, ex=3600) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): @@ -147,26 +151,23 @@ class QdrantVector(BaseVector): metadatas = [d.metadata for d in documents] added_ids = [] - for batch_ids, points in self._generate_rest_batches( - texts, embeddings, metadatas, uuids, 64, self._group_id - ): - self._client.upsert( - collection_name=self._collection_name, points=points - ) + for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id): + self._client.upsert(collection_name=self._collection_name, points=points) added_ids.extend(batch_ids) return added_ids def _generate_rest_batches( - self, - texts: Iterable[str], - embeddings: list[list[float]], - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, - batch_size: int = 64, - group_id: Optional[str] = None, + self, + texts: Iterable[str], + embeddings: list[list[float]], + metadatas: Optional[list[dict]] = None, + ids: Optional[Sequence[str]] = None, + batch_size: int = 64, + group_id: Optional[str] = None, ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]: from qdrant_client.http import models as rest + texts_iterator = iter(texts) embeddings_iterator = iter(embeddings) metadatas_iterator = iter(metadatas or []) @@ -203,13 +204,13 @@ class QdrantVector(BaseVector): @classmethod def _build_payloads( - cls, - texts: Iterable[str], - metadatas: Optional[list[dict]], - content_payload_key: str, - metadata_payload_key: str, - group_id: str, - group_payload_key: str + cls, + texts: Iterable[str], + metadatas: Optional[list[dict]], + content_payload_key: str, + metadata_payload_key: str, + group_id: str, + group_payload_key: str, ) -> list[dict]: payloads = [] for i, text in enumerate(texts): @@ -219,18 +220,11 @@ class QdrantVector(BaseVector): "calling .from_texts or .add_texts on Qdrant instance." ) metadata = metadatas[i] if metadatas is not None else None - payloads.append( - { - content_payload_key: text, - metadata_payload_key: metadata, - group_payload_key: group_id - } - ) + payloads.append({content_payload_key: text, metadata_payload_key: metadata, group_payload_key: group_id}) return payloads def delete_by_metadata_field(self, key: str, value: str): - from qdrant_client.http import models from qdrant_client.http.exceptions import UnexpectedResponse @@ -248,9 +242,7 @@ class QdrantVector(BaseVector): self._client.delete( collection_name=self._collection_name, - points_selector=FilterSelector( - filter=filter - ), + points_selector=FilterSelector(filter=filter), ) except UnexpectedResponse as e: # Collection does not exist, so return @@ -275,9 +267,7 @@ class QdrantVector(BaseVector): ) self._client.delete( collection_name=self._collection_name, - points_selector=FilterSelector( - filter=filter - ), + points_selector=FilterSelector(filter=filter), ) except UnexpectedResponse as e: # Collection does not exist, so return @@ -288,7 +278,6 @@ class QdrantVector(BaseVector): raise e def delete_by_ids(self, ids: list[str]) -> None: - from qdrant_client.http import models from qdrant_client.http.exceptions import UnexpectedResponse @@ -304,9 +293,7 @@ class QdrantVector(BaseVector): ) self._client.delete( collection_name=self._collection_name, - points_selector=FilterSelector( - filter=filter - ), + points_selector=FilterSelector(filter=filter), ) except UnexpectedResponse as e: # Collection does not exist, so return @@ -324,15 +311,13 @@ class QdrantVector(BaseVector): all_collection_name.append(collection.name) if self._collection_name not in all_collection_name: return False - response = self._client.retrieve( - collection_name=self._collection_name, - ids=[id] - ) + response = self._client.retrieve(collection_name=self._collection_name, ids=[id]) return len(response) > 0 def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: from qdrant_client.http import models + filter = models.Filter( must=[ models.FieldCondition( @@ -348,22 +333,22 @@ class QdrantVector(BaseVector): limit=kwargs.get("top_k", 4), with_payload=True, with_vectors=True, - score_threshold=kwargs.get("score_threshold", .0) + score_threshold=kwargs.get("score_threshold", 0.0), ) docs = [] for result in results: metadata = result.payload.get(Field.METADATA_KEY.value) or {} # duplicate check score threshold - score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 if result.score > score_threshold: - metadata['score'] = result.score + metadata["score"] = result.score doc = Document( page_content=result.payload.get(Field.CONTENT_KEY.value), metadata=metadata, ) docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -372,6 +357,7 @@ class QdrantVector(BaseVector): List of documents most similar to the query text and distance for each. """ from qdrant_client.http import models + scroll_filter = models.Filter( must=[ models.FieldCondition( @@ -381,24 +367,21 @@ class QdrantVector(BaseVector): models.FieldCondition( key="page_content", match=models.MatchText(text=query), - ) + ), ] ) response = self._client.scroll( collection_name=self._collection_name, scroll_filter=scroll_filter, - limit=kwargs.get('top_k', 2), + limit=kwargs.get("top_k", 2), with_payload=True, - with_vectors=True - + with_vectors=True, ) results = response[0] documents = [] for result in results: if result: - document = self._document_from_scored_point( - result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value - ) + document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value) documents.append(document) return documents @@ -410,10 +393,10 @@ class QdrantVector(BaseVector): @classmethod def _document_from_scored_point( - cls, - scored_point: Any, - content_payload_key: str, - metadata_payload_key: str, + cls, + scored_point: Any, + content_payload_key: str, + metadata_payload_key: str, ) -> Document: return Document( page_content=scored_point.payload.get(content_payload_key), @@ -425,24 +408,25 @@ class QdrantVector(BaseVector): class QdrantVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector: if dataset.collection_binding_id: - dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ - filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \ - one_or_none() + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter(DatasetCollectionBinding.id == dataset.collection_binding_id) + .one_or_none() + ) if dataset_collection_binding: collection_name = dataset_collection_binding.collection_name else: - raise ValueError('Dataset Collection Bindings is not exist!') + raise ValueError("Dataset Collection Bindings is not exist!") else: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) if not dataset.index_struct_dict: - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.QDRANT, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.QDRANT, collection_name)) config = current_app.config return QdrantVector( @@ -454,6 +438,6 @@ class QdrantVectorFactory(AbstractVectorFactory): root_path=config.root_path, timeout=dify_config.QDRANT_CLIENT_TIMEOUT, grpc_port=dify_config.QDRANT_GRPC_PORT, - prefer_grpc=dify_config.QDRANT_GRPC_ENABLED - ) + prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, + ), ) diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index 63ad0682d7..0c9d3b343d 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -33,28 +33,30 @@ class RelytConfig(BaseModel): password: str database: str - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: - if not values['host']: + if not values["host"]: raise ValueError("config RELYT_HOST is required") - if not values['port']: + if not values["port"]: raise ValueError("config RELYT_PORT is required") - if not values['user']: + if not values["user"]: raise ValueError("config RELYT_USER is required") - if not values['password']: + if not values["password"]: raise ValueError("config RELYT_PASSWORD is required") - if not values['database']: + if not values["database"]: raise ValueError("config RELYT_DATABASE is required") return values class RelytVector(BaseVector): - def __init__(self, collection_name: str, config: RelytConfig, group_id: str): super().__init__(collection_name) self.embedding_dimension = 1536 self._client_config = config - self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" + self._url = ( + f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" + ) self.client = create_engine(self._url) self._fields = [] self._group_id = group_id @@ -70,9 +72,9 @@ class RelytVector(BaseVector): self.add_texts(texts, embeddings) def create_collection(self, dimension: int): - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return index_name = f"{self._collection_name}_embedding_index" @@ -110,7 +112,7 @@ class RelytVector(BaseVector): ids = [str(uuid.uuid1()) for _ in documents] metadatas = [d.metadata for d in documents] for metadata in metadatas: - metadata['group_id'] = self._group_id + metadata["group_id"] = self._group_id texts = [d.page_content for d in documents] # Define the table schema @@ -127,9 +129,7 @@ class RelytVector(BaseVector): chunks_table_data = [] with self.client.connect() as conn: with conn.begin(): - for document, metadata, chunk_id, embedding in zip( - texts, metadatas, ids, embeddings - ): + for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings): chunks_table_data.append( { "id": chunk_id, @@ -196,15 +196,13 @@ class RelytVector(BaseVector): return False def delete_by_metadata_field(self, key: str, value: str): - ids = self.get_ids_by_metadata_field(key, value) if ids: self.delete_by_uuids(ids) def delete_by_ids(self, ids: list[str]) -> None: - with Session(self.client) as session: - ids_str = ','.join(f"'{doc_id}'" for doc_id in ids) + ids_str = ",".join(f"'{doc_id}'" for doc_id in ids) select_statement = sql_text( f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """ ) @@ -228,38 +226,34 @@ class RelytVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: results = self.similarity_search_with_score_by_vector( - k=int(kwargs.get('top_k')), - embedding=query_vector, - filter=kwargs.get('filter') + k=int(kwargs.get("top_k")), embedding=query_vector, filter=kwargs.get("filter") ) # Organize results. docs = [] for document, score in results: - score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 + score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 if 1 - score > score_threshold: docs.append(document) return docs def similarity_search_with_score_by_vector( - self, - embedding: list[float], - k: int = 4, - filter: Optional[dict] = None, + self, + embedding: list[float], + k: int = 4, + filter: Optional[dict] = None, ) -> list[tuple[Document, float]]: # Add the filter if provided try: from sqlalchemy.engine import Row except ImportError: - raise ImportError( - "Could not import Row from sqlalchemy.engine. " - "Please 'pip install sqlalchemy>=1.4'." - ) + raise ImportError("Could not import Row from sqlalchemy.engine. " "Please 'pip install sqlalchemy>=1.4'.") filter_condition = "" if filter is not None: conditions = [ - f"metadata->>{key!r} in ({', '.join(map(repr, value))})" if len(value) > 1 + f"metadata->>{key!r} in ({', '.join(map(repr, value))})" + if len(value) > 1 else f"metadata->>{key!r} = {value[0]!r}" for key, value in filter.items() ] @@ -305,13 +299,12 @@ class RelytVector(BaseVector): class RelytVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> RelytVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.RELYT, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.RELYT, collection_name)) return RelytVector( collection_name=collection_name, @@ -322,5 +315,5 @@ class RelytVectorFactory(AbstractVectorFactory): password=dify_config.RELYT_PASSWORD, database=dify_config.RELYT_DATABASE, ), - group_id=dataset.id + group_id=dataset.id, ) diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 3325a1028e..ada0c5cf46 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -25,16 +25,11 @@ class TencentConfig(BaseModel): database: Optional[str] index_type: str = "HNSW" metric_type: str = "L2" - shard: int = 1, - replicas: int = 2, + shard: int = (1,) + replicas: int = (2,) def to_tencent_params(self): - return { - 'url': self.url, - 'username': self.username, - 'key': self.api_key, - 'timeout': self.timeout - } + return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout} class TencentVector(BaseVector): @@ -61,13 +56,10 @@ class TencentVector(BaseVector): return self._client.create_database(database_name=self._client_config.database) def get_type(self) -> str: - return 'tencent' + return "tencent" def to_index_struct(self) -> dict: - return { - "type": self.get_type(), - "vector_store": {"class_prefix": self._collection_name} - } + return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} def _has_collection(self) -> bool: collections = self._db.list_collections() @@ -77,9 +69,9 @@ class TencentVector(BaseVector): return False def _create_collection(self, dimension: int) -> None: - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return @@ -101,9 +93,7 @@ class TencentVector(BaseVector): raise ValueError("unsupported metric_type") params = vdb_index.HNSWParams(m=16, efconstruction=200) index = vdb_index.Index( - vdb_index.FilterIndex( - self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY - ), + vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY), vdb_index.VectorIndex( self.field_vector, dimension, @@ -111,12 +101,8 @@ class TencentVector(BaseVector): metric_type, params, ), - vdb_index.FilterIndex( - self.field_text, enum.FieldType.String, enum.IndexType.FILTER - ), - vdb_index.FilterIndex( - self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER - ), + vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER), + vdb_index.FilterIndex(self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER), ) self._db.create_collection( @@ -163,15 +149,14 @@ class TencentVector(BaseVector): self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value]))) def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: - - res = self._db.collection(self._collection_name).search(vectors=[query_vector], - params=document.HNSWSearchParams( - ef=kwargs.get("ef", 10)), - retrieve_vector=False, - limit=kwargs.get('top_k', 4), - timeout=self._client_config.timeout, - ) - score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + res = self._db.collection(self._collection_name).search( + vectors=[query_vector], + params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)), + retrieve_vector=False, + limit=kwargs.get("top_k", 4), + timeout=self._client_config.timeout, + ) + score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 return self._get_search_res(res, score_threshold) def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -200,15 +185,13 @@ class TencentVector(BaseVector): class TencentVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TencentVector: - if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.TENCENT, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TENCENT, collection_name)) return TencentVector( collection_name=collection_name, @@ -220,5 +203,5 @@ class TencentVectorFactory(AbstractVectorFactory): database=dify_config.TENCENT_VECTOR_DB_DATABASE, shard=dify_config.TENCENT_VECTOR_DB_SHARD, replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS, - ) + ), ) diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index d3685c0991..e1ac9d596c 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -28,47 +28,57 @@ class TiDBVectorConfig(BaseModel): database: str program_name: str - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: - if not values['host']: + if not values["host"]: raise ValueError("config TIDB_VECTOR_HOST is required") - if not values['port']: + if not values["port"]: raise ValueError("config TIDB_VECTOR_PORT is required") - if not values['user']: + if not values["user"]: raise ValueError("config TIDB_VECTOR_USER is required") - if not values['password']: + if not values["password"]: raise ValueError("config TIDB_VECTOR_PASSWORD is required") - if not values['database']: + if not values["database"]: raise ValueError("config TIDB_VECTOR_DATABASE is required") - if not values['program_name']: + if not values["program_name"]: raise ValueError("config APPLICATION_NAME is required") return values class TiDBVector(BaseVector): - def get_type(self) -> str: return VectorType.TIDB_VECTOR def _table(self, dim: int) -> Table: from tidb_vector.sqlalchemy import VectorType + return Table( self._collection_name, self._orm_base.metadata, - Column('id', String(36), primary_key=True, nullable=False), - Column("vector", VectorType(dim), nullable=False, comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})"), + Column("id", String(36), primary_key=True, nullable=False), + Column( + "vector", + VectorType(dim), + nullable=False, + comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})", + ), Column("text", TEXT, nullable=False), Column("meta", JSON, nullable=False), Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")), - Column("update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")), - extend_existing=True + Column( + "update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP") + ), + extend_existing=True, ) - def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = 'cosine'): + def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = "cosine"): super().__init__(collection_name) self._client_config = config - self._url = (f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?" - f"ssl_verify_cert=true&ssl_verify_identity=true&program_name={config.program_name}") + self._url = ( + f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?" + f"ssl_verify_cert=true&ssl_verify_identity=true&program_name={config.program_name}" + ) self._distance_func = distance_func.lower() self._engine = create_engine(self._url) self._orm_base = declarative_base() @@ -83,9 +93,9 @@ class TiDBVector(BaseVector): def _create_collection(self, dimension: int): logger.info("_create_collection, collection_name " + self._collection_name) - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return with Session(self._engine) as session: @@ -116,9 +126,7 @@ class TiDBVector(BaseVector): chunks_table_data = [] with self._engine.connect() as conn: with conn.begin(): - for id, text, meta, embedding in zip( - ids, texts, metas, embeddings - ): + for id, text, meta, embedding in zip(ids, texts, metas, embeddings): chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta}) # Execute the batch insert when the batch size is reached @@ -133,12 +141,12 @@ class TiDBVector(BaseVector): return ids def text_exists(self, id: str) -> bool: - result = self.get_ids_by_metadata_field('doc_id', id) + result = self.get_ids_by_metadata_field("doc_id", id) return bool(result) def delete_by_ids(self, ids: list[str]) -> None: with Session(self._engine) as session: - ids_str = ','.join(f"'{doc_id}'" for doc_id in ids) + ids_str = ",".join(f"'{doc_id}'" for doc_id in ids) select_statement = sql_text( f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.doc_id' in ({ids_str}); """ ) @@ -180,20 +188,22 @@ class TiDBVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 - filter = kwargs.get('filter') + filter = kwargs.get("filter") distance = 1 - score_threshold query_vector_str = ", ".join(format(x) for x in query_vector) query_vector_str = "[" + query_vector_str + "]" - logger.debug(f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}") + logger.debug( + f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}" + ) docs = [] - if self._distance_func == 'l2': - tidb_func = 'Vec_l2_distance' - elif self._distance_func == 'cosine': - tidb_func = 'Vec_Cosine_distance' + if self._distance_func == "l2": + tidb_func = "Vec_l2_distance" + elif self._distance_func == "cosine": + tidb_func = "Vec_Cosine_distance" else: - tidb_func = 'Vec_Cosine_distance' + tidb_func = "Vec_Cosine_distance" with Session(self._engine) as session: select_statement = sql_text( @@ -208,7 +218,7 @@ class TiDBVector(BaseVector): results = [(row[0], row[1], row[2]) for row in res] for meta, text, distance in results: metadata = json.loads(meta) - metadata['score'] = 1 - distance + metadata["score"] = 1 - distance docs.append(Document(page_content=text, metadata=metadata)) return docs @@ -224,15 +234,13 @@ class TiDBVector(BaseVector): class TiDBVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TiDBVector: - if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name)) return TiDBVector( collection_name=collection_name, diff --git a/api/core/rag/datasource/vdb/vector_base.py b/api/core/rag/datasource/vdb/vector_base.py index 3f70e8b608..fb80cdec87 100644 --- a/api/core/rag/datasource/vdb/vector_base.py +++ b/api/core/rag/datasource/vdb/vector_base.py @@ -7,7 +7,6 @@ from core.rag.models.document import Document class BaseVector(ABC): - def __init__(self, collection_name: str): self._collection_name = collection_name @@ -39,18 +38,11 @@ class BaseVector(ABC): raise NotImplementedError @abstractmethod - def search_by_vector( - self, - query_vector: list[float], - **kwargs: Any - ) -> list[Document]: + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: raise NotImplementedError @abstractmethod - def search_by_full_text( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: raise NotImplementedError def delete(self) -> None: @@ -58,7 +50,7 @@ class BaseVector(ABC): def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts[:]: - doc_id = text.metadata['doc_id'] + doc_id = text.metadata["doc_id"] exists_duplicate_node = self.text_exists(doc_id) if exists_duplicate_node: texts.remove(text) @@ -66,7 +58,7 @@ class BaseVector(ABC): return texts def _get_uuids(self, texts: list[Document]) -> list[str]: - return [text.metadata['doc_id'] for text in texts] + return [text.metadata["doc_id"] for text in texts] @property def collection_name(self): diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 3e9ca8e1fe..bb24143a41 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -20,17 +20,14 @@ class AbstractVectorFactory(ABC): @staticmethod def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> dict: - index_struct_dict = { - "type": vector_type, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}} return index_struct_dict class Vector: def __init__(self, dataset: Dataset, attributes: list = None): if attributes is None: - attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] + attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"] self._dataset = dataset self._embeddings = self._get_embeddings() self._attributes = attributes @@ -39,7 +36,7 @@ class Vector: def _init_vector(self) -> BaseVector: vector_type = dify_config.VECTOR_STORE if self._dataset.index_struct_dict: - vector_type = self._dataset.index_struct_dict['type'] + vector_type = self._dataset.index_struct_dict["type"] if not vector_type: raise ValueError("Vector store must be specified.") @@ -52,45 +49,59 @@ class Vector: match vector_type: case VectorType.CHROMA: from core.rag.datasource.vdb.chroma.chroma_vector import ChromaVectorFactory + return ChromaVectorFactory case VectorType.MILVUS: from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory + return MilvusVectorFactory case VectorType.MYSCALE: from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory + return MyScaleVectorFactory case VectorType.PGVECTOR: from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory + return PGVectorFactory case VectorType.PGVECTO_RS: from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory + return PGVectoRSFactory case VectorType.QDRANT: from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory + return QdrantVectorFactory case VectorType.RELYT: from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory + return RelytVectorFactory case VectorType.ELASTICSEARCH: from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory + return ElasticSearchVectorFactory case VectorType.TIDB_VECTOR: from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory + return TiDBVectorFactory case VectorType.WEAVIATE: from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory + return WeaviateVectorFactory case VectorType.TENCENT: from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory + return TencentVectorFactory case VectorType.ORACLE: from core.rag.datasource.vdb.oracle.oraclevector import OracleVectorFactory + return OracleVectorFactory case VectorType.OPENSEARCH: from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory + return OpenSearchVectorFactory case VectorType.ANALYTICDB: from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory + return AnalyticdbVectorFactory case _: raise ValueError(f"Vector store {vector_type} is not supported.") @@ -98,21 +109,14 @@ class Vector: def create(self, texts: list = None, **kwargs): if texts: embeddings = self._embeddings.embed_documents([document.page_content for document in texts]) - self._vector_processor.create( - texts=texts, - embeddings=embeddings, - **kwargs - ) + self._vector_processor.create(texts=texts, embeddings=embeddings, **kwargs) def add_texts(self, documents: list[Document], **kwargs): - if kwargs.get('duplicate_check', False): + if kwargs.get("duplicate_check", False): documents = self._filter_duplicate_texts(documents) + embeddings = self._embeddings.embed_documents([document.page_content for document in documents]) - self._vector_processor.create( - texts=documents, - embeddings=embeddings, - **kwargs - ) + self._vector_processor.create(texts=documents, embeddings=embeddings, **kwargs) def text_exists(self, id: str) -> bool: return self._vector_processor.text_exists(id) @@ -123,24 +127,18 @@ class Vector: def delete_by_metadata_field(self, key: str, value: str) -> None: self._vector_processor.delete_by_metadata_field(key, value) - def search_by_vector( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search_by_vector(self, query: str, **kwargs: Any) -> list[Document]: query_vector = self._embeddings.embed_query(query) return self._vector_processor.search_by_vector(query_vector, **kwargs) - def search_by_full_text( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return self._vector_processor.search_by_full_text(query, **kwargs) def delete(self) -> None: self._vector_processor.delete() # delete collection redis cache if self._vector_processor.collection_name: - collection_exist_cache_key = 'vector_indexing_{}'.format(self._vector_processor.collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._vector_processor.collection_name) redis_client.delete(collection_exist_cache_key) def _get_embeddings(self) -> Embeddings: @@ -150,14 +148,13 @@ class Vector: tenant_id=self._dataset.tenant_id, provider=self._dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=self._dataset.embedding_model - + model=self._dataset.embedding_model, ) return CacheEmbedding(embedding_model) def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts[:]: - doc_id = text.metadata['doc_id'] + doc_id = text.metadata["doc_id"] exists_duplicate_node = self.text_exists(doc_id) if exists_duplicate_node: texts.remove(text) diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index 317ca6abc8..ba04ea879d 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -2,17 +2,17 @@ from enum import Enum class VectorType(str, Enum): - ANALYTICDB = 'analyticdb' - CHROMA = 'chroma' - MILVUS = 'milvus' - MYSCALE = 'myscale' - PGVECTOR = 'pgvector' - PGVECTO_RS = 'pgvecto-rs' - QDRANT = 'qdrant' - RELYT = 'relyt' - TIDB_VECTOR = 'tidb_vector' - WEAVIATE = 'weaviate' - OPENSEARCH = 'opensearch' - TENCENT = 'tencent' - ORACLE = 'oracle' - ELASTICSEARCH = 'elasticsearch' + ANALYTICDB = "analyticdb" + CHROMA = "chroma" + MILVUS = "milvus" + MYSCALE = "myscale" + PGVECTOR = "pgvector" + PGVECTO_RS = "pgvecto-rs" + QDRANT = "qdrant" + RELYT = "relyt" + TIDB_VECTOR = "tidb_vector" + WEAVIATE = "weaviate" + OPENSEARCH = "opensearch" + TENCENT = "tencent" + ORACLE = "oracle" + ELASTICSEARCH = "elasticsearch" diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 205fe850c3..ca1123c6a0 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -22,15 +22,15 @@ class WeaviateConfig(BaseModel): api_key: Optional[str] = None batch_size: int = 100 - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: - if not values['endpoint']: + if not values["endpoint"]: raise ValueError("config WEAVIATE_ENDPOINT is required") return values class WeaviateVector(BaseVector): - def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list): super().__init__(collection_name) self._client = self._init_client(config) @@ -43,10 +43,7 @@ class WeaviateVector(BaseVector): try: client = weaviate.Client( - url=config.endpoint, - auth_client_secret=auth_config, - timeout_config=(5, 60), - startup_period=None + url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None ) except requests.exceptions.ConnectionError: raise ConnectionError("Vector database connection error") @@ -68,10 +65,10 @@ class WeaviateVector(BaseVector): def get_collection_name(self, dataset: Dataset) -> str: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] - if not class_prefix.endswith('_Node'): + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + if not class_prefix.endswith("_Node"): # original class_prefix - class_prefix += '_Node' + class_prefix += "_Node" return class_prefix @@ -79,10 +76,7 @@ class WeaviateVector(BaseVector): return Dataset.gen_collection_name_by_id(dataset_id) def to_index_struct(self) -> dict: - return { - "type": self.get_type(), - "vector_store": {"class_prefix": self._collection_name} - } + return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): # create collection @@ -91,9 +85,9 @@ class WeaviateVector(BaseVector): self.add_texts(texts, embeddings) def _create_collection(self): - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return schema = self._default_schema(self._collection_name) @@ -129,17 +123,9 @@ class WeaviateVector(BaseVector): # check whether the index already exists schema = self._default_schema(self._collection_name) if self._client.schema.contains(schema): - where_filter = { - "operator": "Equal", - "path": [key], - "valueText": value - } + where_filter = {"operator": "Equal", "path": [key], "valueText": value} - self._client.batch.delete_objects( - class_name=self._collection_name, - where=where_filter, - output='minimal' - ) + self._client.batch.delete_objects(class_name=self._collection_name, where=where_filter, output="minimal") def delete(self): # check whether the index already exists @@ -154,11 +140,19 @@ class WeaviateVector(BaseVector): # check whether the index already exists if not self._client.schema.contains(schema): return False - result = self._client.query.get(collection_name).with_additional(["id"]).with_where({ - "path": ["doc_id"], - "operator": "Equal", - "valueText": id, - }).with_limit(1).do() + result = ( + self._client.query.get(collection_name) + .with_additional(["id"]) + .with_where( + { + "path": ["doc_id"], + "operator": "Equal", + "valueText": id, + } + ) + .with_limit(1) + .do() + ) if "errors" in result: raise ValueError(f"Error during query: {result['errors']}") @@ -211,13 +205,13 @@ class WeaviateVector(BaseVector): docs = [] for doc, score in docs_and_scores: - score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 # check score threshold if score > score_threshold: - doc.metadata['score'] = score + doc.metadata["score"] = score docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -240,15 +234,15 @@ class WeaviateVector(BaseVector): if kwargs.get("where_filter"): query_obj = query_obj.with_where(kwargs.get("where_filter")) query_obj = query_obj.with_additional(["vector"]) - properties = ['text'] - result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do() + properties = ["text"] + result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 2)).do() if "errors" in result: raise ValueError(f"Error during query: {result['errors']}") docs = [] for res in result["data"]["Get"][collection_name]: text = res.pop(Field.TEXT_KEY.value) - additional = res.pop('_additional') - docs.append(Document(page_content=text, vector=additional['vector'], metadata=res)) + additional = res.pop("_additional") + docs.append(Document(page_content=text, vector=additional["vector"], metadata=res)) return docs def _default_schema(self, index_name: str) -> dict: @@ -271,20 +265,19 @@ class WeaviateVector(BaseVector): class WeaviateVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> WeaviateVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) return WeaviateVector( collection_name=collection_name, config=WeaviateConfig( endpoint=dify_config.WEAVIATE_ENDPOINT, api_key=dify_config.WEAVIATE_API_KEY, - batch_size=dify_config.WEAVIATE_BATCH_SIZE + batch_size=dify_config.WEAVIATE_BATCH_SIZE, ), - attributes=attributes + attributes=attributes, ) diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 96a15be742..0d4dff5b89 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -12,10 +12,10 @@ from models.dataset import Dataset, DocumentSegment class DatasetDocumentStore: def __init__( - self, - dataset: Dataset, - user_id: str, - document_id: Optional[str] = None, + self, + dataset: Dataset, + user_id: str, + document_id: Optional[str] = None, ): self._dataset = dataset self._user_id = user_id @@ -41,9 +41,9 @@ class DatasetDocumentStore: @property def docs(self) -> dict[str, Document]: - document_segments = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == self._dataset.id - ).all() + document_segments = ( + db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == self._dataset.id).all() + ) output = {} for document_segment in document_segments: @@ -55,48 +55,45 @@ class DatasetDocumentStore: "doc_hash": document_segment.index_node_hash, "document_id": document_segment.document_id, "dataset_id": document_segment.dataset_id, - } + }, ) return output - def add_documents( - self, docs: Sequence[Document], allow_update: bool = True - ) -> None: - max_position = db.session.query(func.max(DocumentSegment.position)).filter( - DocumentSegment.document_id == self._document_id - ).scalar() + def add_documents(self, docs: Sequence[Document], allow_update: bool = True) -> None: + max_position = ( + db.session.query(func.max(DocumentSegment.position)) + .filter(DocumentSegment.document_id == self._document_id) + .scalar() + ) if max_position is None: max_position = 0 embedding_model = None - if self._dataset.indexing_technique == 'high_quality': + if self._dataset.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=self._dataset.tenant_id, provider=self._dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=self._dataset.embedding_model + model=self._dataset.embedding_model, ) for doc in docs: if not isinstance(doc, Document): raise ValueError("doc must be a Document") - segment_document = self.get_document_segment(doc_id=doc.metadata['doc_id']) + segment_document = self.get_document_segment(doc_id=doc.metadata["doc_id"]) # NOTE: doc could already exist in the store, but we overwrite it if not allow_update and segment_document: raise ValueError( - f"doc_id {doc.metadata['doc_id']} already exists. " - "Set allow_update to True to overwrite." + f"doc_id {doc.metadata['doc_id']} already exists. " "Set allow_update to True to overwrite." ) # calc embedding use tokens if embedding_model: - tokens = embedding_model.get_text_embedding_num_tokens( - texts=[doc.page_content] - ) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[doc.page_content]) else: tokens = 0 @@ -107,8 +104,8 @@ class DatasetDocumentStore: tenant_id=self._dataset.tenant_id, dataset_id=self._dataset.id, document_id=self._document_id, - index_node_id=doc.metadata['doc_id'], - index_node_hash=doc.metadata['doc_hash'], + index_node_id=doc.metadata["doc_id"], + index_node_hash=doc.metadata["doc_hash"], position=max_position, content=doc.page_content, word_count=len(doc.page_content), @@ -116,15 +113,15 @@ class DatasetDocumentStore: enabled=False, created_by=self._user_id, ) - if doc.metadata.get('answer'): - segment_document.answer = doc.metadata.pop('answer', '') + if doc.metadata.get("answer"): + segment_document.answer = doc.metadata.pop("answer", "") db.session.add(segment_document) else: segment_document.content = doc.page_content - if doc.metadata.get('answer'): - segment_document.answer = doc.metadata.pop('answer', '') - segment_document.index_node_hash = doc.metadata['doc_hash'] + if doc.metadata.get("answer"): + segment_document.answer = doc.metadata.pop("answer", "") + segment_document.index_node_hash = doc.metadata["doc_hash"] segment_document.word_count = len(doc.page_content) segment_document.tokens = tokens @@ -135,9 +132,7 @@ class DatasetDocumentStore: result = self.get_document_segment(doc_id) return result is not None - def get_document( - self, doc_id: str, raise_error: bool = True - ) -> Optional[Document]: + def get_document(self, doc_id: str, raise_error: bool = True) -> Optional[Document]: document_segment = self.get_document_segment(doc_id) if document_segment is None: @@ -153,7 +148,7 @@ class DatasetDocumentStore: "doc_hash": document_segment.index_node_hash, "document_id": document_segment.document_id, "dataset_id": document_segment.dataset_id, - } + }, ) def delete_document(self, doc_id: str, raise_error: bool = True) -> None: @@ -188,9 +183,10 @@ class DatasetDocumentStore: return document_segment.index_node_hash def get_document_segment(self, doc_id: str) -> DocumentSegment: - document_segment = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == self._dataset.id, - DocumentSegment.index_node_id == doc_id - ).first() + document_segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id) + .first() + ) return document_segment diff --git a/api/core/rag/extractor/blod/blod.py b/api/core/rag/extractor/blob/blob.py similarity index 99% rename from api/core/rag/extractor/blod/blod.py rename to api/core/rag/extractor/blob/blob.py index abfdafcfa2..f4c7b4b5f7 100644 --- a/api/core/rag/extractor/blod/blod.py +++ b/api/core/rag/extractor/blob/blob.py @@ -4,6 +4,7 @@ The goal is to facilitate decoupling of content loading from content parsing cod In addition, content loading code should provide a lazy loading interface by default. """ + from __future__ import annotations import contextlib diff --git a/api/core/rag/extractor/csv_extractor.py b/api/core/rag/extractor/csv_extractor.py index 0470569f39..5b67403902 100644 --- a/api/core/rag/extractor/csv_extractor.py +++ b/api/core/rag/extractor/csv_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + import csv from typing import Optional @@ -18,12 +19,12 @@ class CSVExtractor(BaseExtractor): """ def __init__( - self, - file_path: str, - encoding: Optional[str] = None, - autodetect_encoding: bool = False, - source_column: Optional[str] = None, - csv_args: Optional[dict] = None, + self, + file_path: str, + encoding: Optional[str] = None, + autodetect_encoding: bool = False, + source_column: Optional[str] = None, + csv_args: Optional[dict] = None, ): """Initialize with file path.""" self._file_path = file_path @@ -57,7 +58,7 @@ class CSVExtractor(BaseExtractor): docs = [] try: # load csv file into pandas dataframe - df = pd.read_csv(csvfile, on_bad_lines='skip', **self.csv_args) + df = pd.read_csv(csvfile, on_bad_lines="skip", **self.csv_args) # check source column exists if self.source_column and self.source_column not in df.columns: @@ -67,7 +68,7 @@ class CSVExtractor(BaseExtractor): for i, row in df.iterrows(): content = ";".join(f"{col.strip()}: {str(row[col]).strip()}" for col in df.columns) - source = row[self.source_column] if self.source_column else '' + source = row[self.source_column] if self.source_column else "" metadata = {"source": source, "row": i} doc = Document(page_content=content, metadata=metadata) docs.append(doc) diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index 7479b1d97b..3692b5d19d 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -10,6 +10,7 @@ class NotionInfo(BaseModel): """ Notion import info. """ + notion_workspace_id: str notion_obj_id: str notion_page_type: str @@ -25,6 +26,7 @@ class WebsiteInfo(BaseModel): """ website import info. """ + provider: str job_id: str url: str @@ -43,6 +45,7 @@ class ExtractSetting(BaseModel): """ Model class for provider response. """ + datasource_type: str upload_file: Optional[UploadFile] = None notion_info: Optional[NotionInfo] = None diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index f0c302a619..fc33165719 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + import os from typing import Optional @@ -17,59 +18,60 @@ class ExcelExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - encoding: Optional[str] = None, - autodetect_encoding: bool = False - ): + def __init__(self, file_path: str, encoding: Optional[str] = None, autodetect_encoding: bool = False): """Initialize with file path.""" self._file_path = file_path self._encoding = encoding self._autodetect_encoding = autodetect_encoding def extract(self) -> list[Document]: - """ Load from Excel file in xls or xlsx format using Pandas and openpyxl.""" + """Load from Excel file in xls or xlsx format using Pandas and openpyxl.""" documents = [] file_extension = os.path.splitext(self._file_path)[-1].lower() - if file_extension == '.xlsx': + if file_extension == ".xlsx": wb = load_workbook(self._file_path, data_only=True) for sheet_name in wb.sheetnames: sheet = wb[sheet_name] data = sheet.values - cols = next(data) + try: + cols = next(data) + except StopIteration: + continue df = pd.DataFrame(data, columns=cols) - df.dropna(how='all', inplace=True) + df.dropna(how="all", inplace=True) for index, row in df.iterrows(): page_content = [] for col_index, (k, v) in enumerate(row.items()): if pd.notna(v): - cell = sheet.cell(row=index + 2, - column=col_index + 1) # +2 to account for header and 1-based index + cell = sheet.cell( + row=index + 2, column=col_index + 1 + ) # +2 to account for header and 1-based index if cell.hyperlink: value = f"[{v}]({cell.hyperlink.target})" page_content.append(f'"{k}":"{value}"') else: page_content.append(f'"{k}":"{v}"') - documents.append(Document(page_content=';'.join(page_content), - metadata={'source': self._file_path})) + documents.append( + Document(page_content=";".join(page_content), metadata={"source": self._file_path}) + ) - elif file_extension == '.xls': - excel_file = pd.ExcelFile(self._file_path, engine='xlrd') + elif file_extension == ".xls": + excel_file = pd.ExcelFile(self._file_path, engine="xlrd") for sheet_name in excel_file.sheet_names: df = excel_file.parse(sheet_name=sheet_name) - df.dropna(how='all', inplace=True) + df.dropna(how="all", inplace=True) for _, row in df.iterrows(): page_content = [] for k, v in row.items(): if pd.notna(v): page_content.append(f'"{k}":"{v}"') - documents.append(Document(page_content=';'.join(page_content), - metadata={'source': self._file_path})) + documents.append( + Document(page_content=";".join(page_content), metadata={"source": self._file_path}) + ) else: raise ValueError(f"Unsupported file extension: {file_extension}") diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index f7a08135f5..a00b3cba53 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -29,61 +29,60 @@ from core.rag.models.document import Document from extensions.ext_storage import storage from models.model import UploadFile -SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain', 'application/json'] +SUPPORT_URL_CONTENT_TYPES = ["application/pdf", "text/plain", "application/json"] USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" class ExtractProcessor: @classmethod - def load_from_upload_file(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) \ - -> Union[list[Document], str]: + def load_from_upload_file( + cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False + ) -> Union[list[Document], str]: extract_setting = ExtractSetting( - datasource_type="upload_file", - upload_file=upload_file, - document_model='text_model' + datasource_type="upload_file", upload_file=upload_file, document_model="text_model" ) if return_text: - delimiter = '\n' + delimiter = "\n" return delimiter.join([document.page_content for document in cls.extract(extract_setting, is_automatic)]) else: return cls.extract(extract_setting, is_automatic) @classmethod def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]: - response = ssrf_proxy.get(url, headers={ - "User-Agent": USER_AGENT - }) + response = ssrf_proxy.get(url, headers={"User-Agent": USER_AGENT}) with tempfile.TemporaryDirectory() as temp_dir: suffix = Path(url).suffix - if not suffix and suffix != '.': + if not suffix and suffix != ".": # get content-type - if response.headers.get('Content-Type'): - suffix = '.' + response.headers.get('Content-Type').split('/')[-1] + if response.headers.get("Content-Type"): + suffix = "." + response.headers.get("Content-Type").split("/")[-1] else: - content_disposition = response.headers.get('Content-Disposition') + content_disposition = response.headers.get("Content-Disposition") filename_match = re.search(r'filename="([^"]+)"', content_disposition) if filename_match: filename = unquote(filename_match.group(1)) - suffix = '.' + re.search(r'\.(\w+)$', filename).group(1) + suffix = "." + re.search(r"\.(\w+)$", filename).group(1) file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" - with open(file_path, 'wb') as file: + with open(file_path, "wb") as file: file.write(response.content) - extract_setting = ExtractSetting( - datasource_type="upload_file", - document_model='text_model' - ) + extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model") if return_text: - delimiter = '\n' - return delimiter.join([document.page_content for document in cls.extract( - extract_setting=extract_setting, file_path=file_path)]) + delimiter = "\n" + return delimiter.join( + [ + document.page_content + for document in cls.extract(extract_setting=extract_setting, file_path=file_path) + ] + ) else: return cls.extract(extract_setting=extract_setting, file_path=file_path) @classmethod - def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False, - file_path: str = None) -> list[Document]: + def extract( + cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str = None + ) -> list[Document]: if extract_setting.datasource_type == DatasourceType.FILE.value: with tempfile.TemporaryDirectory() as temp_dir: if not file_path: @@ -96,50 +95,56 @@ class ExtractProcessor: etl_type = dify_config.ETL_TYPE unstructured_api_url = dify_config.UNSTRUCTURED_API_URL unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY - if etl_type == 'Unstructured': - if file_extension == '.xlsx' or file_extension == '.xls': + if etl_type == "Unstructured": + if file_extension == ".xlsx" or file_extension == ".xls": extractor = ExcelExtractor(file_path) - elif file_extension == '.pdf': + elif file_extension == ".pdf": extractor = PdfExtractor(file_path) - elif file_extension in ['.md', '.markdown']: - extractor = UnstructuredMarkdownExtractor(file_path, unstructured_api_url) if is_automatic \ + elif file_extension in [".md", ".markdown"]: + extractor = ( + UnstructuredMarkdownExtractor(file_path, unstructured_api_url) + if is_automatic else MarkdownExtractor(file_path, autodetect_encoding=True) - elif file_extension in ['.htm', '.html']: + ) + elif file_extension in [".htm", ".html"]: extractor = HtmlExtractor(file_path) - elif file_extension in ['.docx']: + elif file_extension in [".docx"]: extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) - elif file_extension == '.csv': + elif file_extension == ".csv": extractor = CSVExtractor(file_path, autodetect_encoding=True) - elif file_extension == '.msg': + elif file_extension == ".msg": extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url) - elif file_extension == '.eml': + elif file_extension == ".eml": extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url) - elif file_extension == '.ppt': + elif file_extension == ".ppt": extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url, unstructured_api_key) - elif file_extension == '.pptx': + elif file_extension == ".pptx": extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url) - elif file_extension == '.xml': + elif file_extension == ".xml": extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url) - elif file_extension == 'epub': + elif file_extension == "epub": extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url) else: # txt - extractor = UnstructuredTextExtractor(file_path, unstructured_api_url) if is_automatic \ + extractor = ( + UnstructuredTextExtractor(file_path, unstructured_api_url) + if is_automatic else TextExtractor(file_path, autodetect_encoding=True) + ) else: - if file_extension == '.xlsx' or file_extension == '.xls': + if file_extension == ".xlsx" or file_extension == ".xls": extractor = ExcelExtractor(file_path) - elif file_extension == '.pdf': + elif file_extension == ".pdf": extractor = PdfExtractor(file_path) - elif file_extension in ['.md', '.markdown']: + elif file_extension in [".md", ".markdown"]: extractor = MarkdownExtractor(file_path, autodetect_encoding=True) - elif file_extension in ['.htm', '.html']: + elif file_extension in [".htm", ".html"]: extractor = HtmlExtractor(file_path) - elif file_extension in ['.docx']: + elif file_extension in [".docx"]: extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) - elif file_extension == '.csv': + elif file_extension == ".csv": extractor = CSVExtractor(file_path, autodetect_encoding=True) - elif file_extension == 'epub': + elif file_extension == "epub": extractor = UnstructuredEpubExtractor(file_path) else: # txt @@ -155,13 +160,13 @@ class ExtractProcessor: ) return extractor.extract() elif extract_setting.datasource_type == DatasourceType.WEBSITE.value: - if extract_setting.website_info.provider == 'firecrawl': + if extract_setting.website_info.provider == "firecrawl": extractor = FirecrawlWebExtractor( url=extract_setting.website_info.url, job_id=extract_setting.website_info.job_id, tenant_id=extract_setting.website_info.tenant_id, mode=extract_setting.website_info.mode, - only_main_content=extract_setting.website_info.only_main_content + only_main_content=extract_setting.website_info.only_main_content, ) return extractor.extract() else: diff --git a/api/core/rag/extractor/extractor_base.py b/api/core/rag/extractor/extractor_base.py index c490e59332..582eca94df 100644 --- a/api/core/rag/extractor/extractor_base.py +++ b/api/core/rag/extractor/extractor_base.py @@ -1,12 +1,11 @@ """Abstract interface for document loader implementations.""" + from abc import ABC, abstractmethod class BaseExtractor(ABC): - """Interface for extract files. - """ + """Interface for extract files.""" @abstractmethod def extract(self): raise NotImplementedError - diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index 2b85ad9739..054ce5f4b2 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -9,108 +9,98 @@ from extensions.ext_storage import storage class FirecrawlApp: def __init__(self, api_key=None, base_url=None): self.api_key = api_key - self.base_url = base_url or 'https://api.firecrawl.dev' - if self.api_key is None and self.base_url == 'https://api.firecrawl.dev': - raise ValueError('No API key provided') + self.base_url = base_url or "https://api.firecrawl.dev" + if self.api_key is None and self.base_url == "https://api.firecrawl.dev": + raise ValueError("No API key provided") def scrape_url(self, url, params=None) -> dict: - headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}' - } - json_data = {'url': url} + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} + json_data = {"url": url} if params: json_data.update(params) - response = requests.post( - f'{self.base_url}/v0/scrape', - headers=headers, - json=json_data - ) + response = requests.post(f"{self.base_url}/v0/scrape", headers=headers, json=json_data) if response.status_code == 200: response = response.json() - if response['success'] == True: - data = response['data'] + if response["success"] == True: + data = response["data"] return { - 'title': data.get('metadata').get('title'), - 'description': data.get('metadata').get('description'), - 'source_url': data.get('metadata').get('sourceURL'), - 'markdown': data.get('markdown') + "title": data.get("metadata").get("title"), + "description": data.get("metadata").get("description"), + "source_url": data.get("metadata").get("sourceURL"), + "markdown": data.get("markdown"), } else: raise Exception(f'Failed to scrape URL. Error: {response["error"]}') elif response.status_code in [402, 409, 500]: - error_message = response.json().get('error', 'Unknown error occurred') - raise Exception(f'Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}') + error_message = response.json().get("error", "Unknown error occurred") + raise Exception(f"Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}") else: - raise Exception(f'Failed to scrape URL. Status code: {response.status_code}') + raise Exception(f"Failed to scrape URL. Status code: {response.status_code}") def crawl_url(self, url, params=None) -> str: headers = self._prepare_headers() - json_data = {'url': url} + json_data = {"url": url} if params: json_data.update(params) - response = self._post_request(f'{self.base_url}/v0/crawl', json_data, headers) + response = self._post_request(f"{self.base_url}/v0/crawl", json_data, headers) if response.status_code == 200: - job_id = response.json().get('jobId') + job_id = response.json().get("jobId") return job_id else: - self._handle_error(response, 'start crawl job') + self._handle_error(response, "start crawl job") def check_crawl_status(self, job_id) -> dict: headers = self._prepare_headers() - response = self._get_request(f'{self.base_url}/v0/crawl/status/{job_id}', headers) + response = self._get_request(f"{self.base_url}/v0/crawl/status/{job_id}", headers) if response.status_code == 200: crawl_status_response = response.json() - if crawl_status_response.get('status') == 'completed': - total = crawl_status_response.get('total', 0) + if crawl_status_response.get("status") == "completed": + total = crawl_status_response.get("total", 0) if total == 0: - raise Exception('Failed to check crawl status. Error: No page found') - data = crawl_status_response.get('data', []) + raise Exception("Failed to check crawl status. Error: No page found") + data = crawl_status_response.get("data", []) url_data_list = [] for item in data: - if isinstance(item, dict) and 'metadata' in item and 'markdown' in item: + if isinstance(item, dict) and "metadata" in item and "markdown" in item: url_data = { - 'title': item.get('metadata').get('title'), - 'description': item.get('metadata').get('description'), - 'source_url': item.get('metadata').get('sourceURL'), - 'markdown': item.get('markdown') + "title": item.get("metadata").get("title"), + "description": item.get("metadata").get("description"), + "source_url": item.get("metadata").get("sourceURL"), + "markdown": item.get("markdown"), } url_data_list.append(url_data) if url_data_list: - file_key = 'website_files/' + job_id + '.txt' + file_key = "website_files/" + job_id + ".txt" if storage.exists(file_key): storage.delete(file_key) - storage.save(file_key, json.dumps(url_data_list).encode('utf-8')) + storage.save(file_key, json.dumps(url_data_list).encode("utf-8")) return { - 'status': 'completed', - 'total': crawl_status_response.get('total'), - 'current': crawl_status_response.get('current'), - 'data': url_data_list + "status": "completed", + "total": crawl_status_response.get("total"), + "current": crawl_status_response.get("current"), + "data": url_data_list, } else: return { - 'status': crawl_status_response.get('status'), - 'total': crawl_status_response.get('total'), - 'current': crawl_status_response.get('current'), - 'data': [] + "status": crawl_status_response.get("status"), + "total": crawl_status_response.get("total"), + "current": crawl_status_response.get("current"), + "data": [], } else: - self._handle_error(response, 'check crawl status') + self._handle_error(response, "check crawl status") def _prepare_headers(self): - return { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}' - } + return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5): for attempt in range(retries): response = requests.post(url, headers=headers, json=data) if response.status_code == 502: - time.sleep(backoff_factor * (2 ** attempt)) + time.sleep(backoff_factor * (2**attempt)) else: return response return response @@ -119,13 +109,11 @@ class FirecrawlApp: for attempt in range(retries): response = requests.get(url, headers=headers) if response.status_code == 502: - time.sleep(backoff_factor * (2 ** attempt)) + time.sleep(backoff_factor * (2**attempt)) else: return response return response def _handle_error(self, response, action): - error_message = response.json().get('error', 'Unknown error occurred') - raise Exception(f'Failed to {action}. Status code: {response.status_code}. Error: {error_message}') - - + error_message = response.json().get("error", "Unknown error occurred") + raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") diff --git a/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py b/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py index 8e2f107e5e..b33ce167c2 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py @@ -5,7 +5,7 @@ from services.website_service import WebsiteService class FirecrawlWebExtractor(BaseExtractor): """ - Crawl and scrape websites and return content in clean llm-ready markdown. + Crawl and scrape websites and return content in clean llm-ready markdown. Args: @@ -15,14 +15,7 @@ class FirecrawlWebExtractor(BaseExtractor): mode: The mode of operation. Defaults to 'scrape'. Options are 'crawl', 'scrape' and 'crawl_return_urls'. """ - def __init__( - self, - url: str, - job_id: str, - tenant_id: str, - mode: str = 'crawl', - only_main_content: bool = False - ): + def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = False): """Initialize with url, api_key, base_url and mode.""" self._url = url self.job_id = job_id @@ -33,28 +26,31 @@ class FirecrawlWebExtractor(BaseExtractor): def extract(self) -> list[Document]: """Extract content from the URL.""" documents = [] - if self.mode == 'crawl': - crawl_data = WebsiteService.get_crawl_url_data(self.job_id, 'firecrawl', self._url, self.tenant_id) + if self.mode == "crawl": + crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "firecrawl", self._url, self.tenant_id) if crawl_data is None: return [] - document = Document(page_content=crawl_data.get('markdown', ''), - metadata={ - 'source_url': crawl_data.get('source_url'), - 'description': crawl_data.get('description'), - 'title': crawl_data.get('title') - } - ) + document = Document( + page_content=crawl_data.get("markdown", ""), + metadata={ + "source_url": crawl_data.get("source_url"), + "description": crawl_data.get("description"), + "title": crawl_data.get("title"), + }, + ) documents.append(document) - elif self.mode == 'scrape': - scrape_data = WebsiteService.get_scrape_url_data('firecrawl', self._url, self.tenant_id, - self.only_main_content) + elif self.mode == "scrape": + scrape_data = WebsiteService.get_scrape_url_data( + "firecrawl", self._url, self.tenant_id, self.only_main_content + ) - document = Document(page_content=scrape_data.get('markdown', ''), - metadata={ - 'source_url': scrape_data.get('source_url'), - 'description': scrape_data.get('description'), - 'title': scrape_data.get('title') - } - ) + document = Document( + page_content=scrape_data.get("markdown", ""), + metadata={ + "source_url": scrape_data.get("source_url"), + "description": scrape_data.get("description"), + "title": scrape_data.get("title"), + }, + ) documents.append(document) return documents diff --git a/api/core/rag/extractor/helpers.py b/api/core/rag/extractor/helpers.py index 0c17a47b32..9a21d4272a 100644 --- a/api/core/rag/extractor/helpers.py +++ b/api/core/rag/extractor/helpers.py @@ -37,9 +37,7 @@ def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding try: encodings = future.result(timeout=timeout) except concurrent.futures.TimeoutError: - raise TimeoutError( - f"Timeout reached while detecting encoding for {file_path}" - ) + raise TimeoutError(f"Timeout reached while detecting encoding for {file_path}") if all(encoding["encoding"] is None for encoding in encodings): raise RuntimeError(f"Could not detect encoding for {file_path}") diff --git a/api/core/rag/extractor/html_extractor.py b/api/core/rag/extractor/html_extractor.py index ceb5306255..560c2d1d84 100644 --- a/api/core/rag/extractor/html_extractor.py +++ b/api/core/rag/extractor/html_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + from bs4 import BeautifulSoup from core.rag.extractor.extractor_base import BaseExtractor @@ -6,7 +7,6 @@ from core.rag.models.document import Document class HtmlExtractor(BaseExtractor): - """ Load html files. @@ -15,10 +15,7 @@ class HtmlExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str - ): + def __init__(self, file_path: str): """Initialize with file path.""" self._file_path = file_path @@ -27,8 +24,8 @@ class HtmlExtractor(BaseExtractor): def _load_as_text(self) -> str: with open(self._file_path, "rb") as fp: - soup = BeautifulSoup(fp, 'html.parser') + soup = BeautifulSoup(fp, "html.parser") text = soup.get_text() - text = text.strip() if text else '' + text = text.strip() if text else "" - return text \ No newline at end of file + return text diff --git a/api/core/rag/extractor/markdown_extractor.py b/api/core/rag/extractor/markdown_extractor.py index b24cf2e170..ca125ecf55 100644 --- a/api/core/rag/extractor/markdown_extractor.py +++ b/api/core/rag/extractor/markdown_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + import re from typing import Optional, cast @@ -16,12 +17,12 @@ class MarkdownExtractor(BaseExtractor): """ def __init__( - self, - file_path: str, - remove_hyperlinks: bool = False, - remove_images: bool = False, - encoding: Optional[str] = None, - autodetect_encoding: bool = True, + self, + file_path: str, + remove_hyperlinks: bool = False, + remove_images: bool = False, + encoding: Optional[str] = None, + autodetect_encoding: bool = True, ): """Initialize with file path.""" self._file_path = file_path @@ -78,13 +79,10 @@ class MarkdownExtractor(BaseExtractor): if current_header is not None: # pass linting, assert keys are defined markdown_tups = [ - (re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value)) - for key, value in markdown_tups + (re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value)) for key, value in markdown_tups ] else: - markdown_tups = [ - (key, re.sub("\n", "", value)) for key, value in markdown_tups - ] + markdown_tups = [(key, re.sub("\n", "", value)) for key, value in markdown_tups] return markdown_tups diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 9535455909..b02e30de62 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -21,22 +21,21 @@ RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}" RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}" # if user want split by headings, use the corresponding splitter HEADING_SPLITTER = { - 'heading_1': '# ', - 'heading_2': '## ', - 'heading_3': '### ', + "heading_1": "# ", + "heading_2": "## ", + "heading_3": "### ", } + class NotionExtractor(BaseExtractor): - def __init__( - self, - notion_workspace_id: str, - notion_obj_id: str, - notion_page_type: str, - tenant_id: str, - document_model: Optional[DocumentModel] = None, - notion_access_token: Optional[str] = None, - + self, + notion_workspace_id: str, + notion_obj_id: str, + notion_page_type: str, + tenant_id: str, + document_model: Optional[DocumentModel] = None, + notion_access_token: Optional[str] = None, ): self._notion_access_token = None self._document_model = document_model @@ -46,46 +45,38 @@ class NotionExtractor(BaseExtractor): if notion_access_token: self._notion_access_token = notion_access_token else: - self._notion_access_token = self._get_access_token(tenant_id, - self._notion_workspace_id) + self._notion_access_token = self._get_access_token(tenant_id, self._notion_workspace_id) if not self._notion_access_token: integration_token = dify_config.NOTION_INTEGRATION_TOKEN if integration_token is None: raise ValueError( - "Must specify `integration_token` or set environment " - "variable `NOTION_INTEGRATION_TOKEN`." + "Must specify `integration_token` or set environment " "variable `NOTION_INTEGRATION_TOKEN`." ) self._notion_access_token = integration_token def extract(self) -> list[Document]: - self.update_last_edited_time( - self._document_model - ) + self.update_last_edited_time(self._document_model) text_docs = self._load_data_as_documents(self._notion_obj_id, self._notion_page_type) return text_docs - def _load_data_as_documents( - self, notion_obj_id: str, notion_page_type: str - ) -> list[Document]: + def _load_data_as_documents(self, notion_obj_id: str, notion_page_type: str) -> list[Document]: docs = [] - if notion_page_type == 'database': + if notion_page_type == "database": # get all the pages in the database page_text_documents = self._get_notion_database_data(notion_obj_id) docs.extend(page_text_documents) - elif notion_page_type == 'page': + elif notion_page_type == "page": page_text_list = self._get_notion_block_data(notion_obj_id) - docs.append(Document(page_content='\n'.join(page_text_list))) + docs.append(Document(page_content="\n".join(page_text_list))) else: raise ValueError("notion page type not supported") return docs - def _get_notion_database_data( - self, database_id: str, query_dict: dict[str, Any] = {} - ) -> list[Document]: + def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] = {}) -> list[Document]: """Get all the pages from a Notion database.""" res = requests.post( DATABASE_URL_TMPL.format(database_id=database_id), @@ -100,50 +91,50 @@ class NotionExtractor(BaseExtractor): data = res.json() database_content = [] - if 'results' not in data or data["results"] is None: + if "results" not in data or data["results"] is None: return [] for result in data["results"]: - properties = result['properties'] + properties = result["properties"] data = {} for property_name, property_value in properties.items(): - type = property_value['type'] - if type == 'multi_select': + type = property_value["type"] + if type == "multi_select": value = [] multi_select_list = property_value[type] for multi_select in multi_select_list: - value.append(multi_select['name']) - elif type == 'rich_text' or type == 'title': + value.append(multi_select["name"]) + elif type == "rich_text" or type == "title": if len(property_value[type]) > 0: - value = property_value[type][0]['plain_text'] + value = property_value[type][0]["plain_text"] else: - value = '' - elif type == 'select' or type == 'status': + value = "" + elif type == "select" or type == "status": if property_value[type]: - value = property_value[type]['name'] + value = property_value[type]["name"] else: - value = '' + value = "" else: value = property_value[type] data[property_name] = value row_dict = {k: v for k, v in data.items() if v} - row_content = '' + row_content = "" for key, value in row_dict.items(): if isinstance(value, dict): value_dict = {k: v for k, v in value.items() if v} - value_content = ''.join(f'{k}:{v} ' for k, v in value_dict.items()) - row_content = row_content + f'{key}:{value_content}\n' + value_content = "".join(f"{k}:{v} " for k, v in value_dict.items()) + row_content = row_content + f"{key}:{value_content}\n" else: - row_content = row_content + f'{key}:{value}\n' + row_content = row_content + f"{key}:{value}\n" database_content.append(row_content) - return [Document(page_content='\n'.join(database_content))] + return [Document(page_content="\n".join(database_content))] def _get_notion_block_data(self, page_id: str) -> list[str]: result_lines_arr = [] start_cursor = None block_url = BLOCK_CHILD_URL_TMPL.format(block_id=page_id) while True: - query_dict: dict[str, Any] = {} if not start_cursor else {'start_cursor': start_cursor} + query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor} res = requests.request( "GET", block_url, @@ -152,14 +143,14 @@ class NotionExtractor(BaseExtractor): "Content-Type": "application/json", "Notion-Version": "2022-06-28", }, - params=query_dict + params=query_dict, ) data = res.json() for result in data["results"]: result_type = result["type"] result_obj = result[result_type] cur_result_text_arr = [] - if result_type == 'table': + if result_type == "table": result_block_id = result["id"] text = self._read_table_rows(result_block_id) text += "\n\n" @@ -175,17 +166,15 @@ class NotionExtractor(BaseExtractor): result_block_id = result["id"] has_children = result["has_children"] block_type = result["type"] - if has_children and block_type != 'child_page': - children_text = self._read_block( - result_block_id, num_tabs=1 - ) + if has_children and block_type != "child_page": + children_text = self._read_block(result_block_id, num_tabs=1) cur_result_text_arr.append(children_text) cur_result_text = "\n".join(cur_result_text_arr) if result_type in HEADING_SPLITTER: result_lines_arr.append(f"{HEADING_SPLITTER[result_type]}{cur_result_text}") else: - result_lines_arr.append(cur_result_text + '\n\n') + result_lines_arr.append(cur_result_text + "\n\n") if data["next_cursor"] is None: break @@ -199,7 +188,7 @@ class NotionExtractor(BaseExtractor): start_cursor = None block_url = BLOCK_CHILD_URL_TMPL.format(block_id=block_id) while True: - query_dict: dict[str, Any] = {} if not start_cursor else {'start_cursor': start_cursor} + query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor} res = requests.request( "GET", @@ -209,16 +198,16 @@ class NotionExtractor(BaseExtractor): "Content-Type": "application/json", "Notion-Version": "2022-06-28", }, - params=query_dict + params=query_dict, ) data = res.json() - if 'results' not in data or data["results"] is None: + if "results" not in data or data["results"] is None: break for result in data["results"]: result_type = result["type"] result_obj = result[result_type] cur_result_text_arr = [] - if result_type == 'table': + if result_type == "table": result_block_id = result["id"] text = self._read_table_rows(result_block_id) result_lines_arr.append(text) @@ -233,17 +222,15 @@ class NotionExtractor(BaseExtractor): result_block_id = result["id"] has_children = result["has_children"] block_type = result["type"] - if has_children and block_type != 'child_page': - children_text = self._read_block( - result_block_id, num_tabs=num_tabs + 1 - ) + if has_children and block_type != "child_page": + children_text = self._read_block(result_block_id, num_tabs=num_tabs + 1) cur_result_text_arr.append(children_text) cur_result_text = "\n".join(cur_result_text_arr) if result_type in HEADING_SPLITTER: - result_lines_arr.append(f'{HEADING_SPLITTER[result_type]}{cur_result_text}') + result_lines_arr.append(f"{HEADING_SPLITTER[result_type]}{cur_result_text}") else: - result_lines_arr.append(cur_result_text + '\n\n') + result_lines_arr.append(cur_result_text + "\n\n") if data["next_cursor"] is None: break @@ -260,7 +247,7 @@ class NotionExtractor(BaseExtractor): start_cursor = None block_url = BLOCK_CHILD_URL_TMPL.format(block_id=block_id) while not done: - query_dict: dict[str, Any] = {} if not start_cursor else {'start_cursor': start_cursor} + query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor} res = requests.request( "GET", @@ -270,31 +257,36 @@ class NotionExtractor(BaseExtractor): "Content-Type": "application/json", "Notion-Version": "2022-06-28", }, - params=query_dict + params=query_dict, ) data = res.json() # get table headers text table_header_cell_texts = [] - tabel_header_cells = data["results"][0]['table_row']['cells'] - for tabel_header_cell in tabel_header_cells: - if tabel_header_cell: - for table_header_cell_text in tabel_header_cell: + table_header_cells = data["results"][0]["table_row"]["cells"] + for table_header_cell in table_header_cells: + if table_header_cell: + for table_header_cell_text in table_header_cell: text = table_header_cell_text["text"]["content"] table_header_cell_texts.append(text) - # get table columns text and format + else: + table_header_cell_texts.append("") + # Initialize Markdown table with headers + markdown_table = "| " + " | ".join(table_header_cell_texts) + " |\n" + markdown_table += "| " + " | ".join(["---"] * len(table_header_cell_texts)) + " |\n" + + # Process data to format each row in Markdown table format results = data["results"] for i in range(len(results) - 1): column_texts = [] - tabel_column_cells = data["results"][i + 1]['table_row']['cells'] - for j in range(len(tabel_column_cells)): - if tabel_column_cells[j]: - for table_column_cell_text in tabel_column_cells[j]: + table_column_cells = data["results"][i + 1]["table_row"]["cells"] + for j in range(len(table_column_cells)): + if table_column_cells[j]: + for table_column_cell_text in table_column_cells[j]: column_text = table_column_cell_text["text"]["content"] - column_texts.append(f'{table_header_cell_texts[j]}:{column_text}') - - cur_result_text = "\n".join(column_texts) - result_lines_arr.append(cur_result_text) - + column_texts.append(column_text) + # Add row to Markdown table + markdown_table += "| " + " | ".join(column_texts) + " |\n" + result_lines_arr.append(markdown_table) if data["next_cursor"] is None: done = True break @@ -310,10 +302,8 @@ class NotionExtractor(BaseExtractor): last_edited_time = self.get_notion_last_edited_time() data_source_info = document_model.data_source_info_dict - data_source_info['last_edited_time'] = last_edited_time - update_params = { - DocumentModel.data_source_info: json.dumps(data_source_info) - } + data_source_info["last_edited_time"] = last_edited_time + update_params = {DocumentModel.data_source_info: json.dumps(data_source_info)} DocumentModel.query.filter_by(id=document_model.id).update(update_params) db.session.commit() @@ -321,7 +311,7 @@ class NotionExtractor(BaseExtractor): def get_notion_last_edited_time(self) -> str: obj_id = self._notion_obj_id page_type = self._notion_page_type - if page_type == 'database': + if page_type == "database": retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=obj_id) else: retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id) @@ -336,7 +326,7 @@ class NotionExtractor(BaseExtractor): "Content-Type": "application/json", "Notion-Version": "2022-06-28", }, - json=query_dict + json=query_dict, ) data = res.json() @@ -347,14 +337,16 @@ class NotionExtractor(BaseExtractor): data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == tenant_id, - DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"' + DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"', ) ).first() if not data_source_binding: - raise Exception(f'No notion data source binding found for tenant {tenant_id} ' - f'and notion workspace {notion_workspace_id}') + raise Exception( + f"No notion data source binding found for tenant {tenant_id} " + f"and notion workspace {notion_workspace_id}" + ) return data_source_binding.access_token diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py index cbb2655390..57cb9610ba 100644 --- a/api/core/rag/extractor/pdf_extractor.py +++ b/api/core/rag/extractor/pdf_extractor.py @@ -1,8 +1,9 @@ """Abstract interface for document loader implementations.""" + from collections.abc import Iterator from typing import Optional -from core.rag.extractor.blod.blod import Blob +from core.rag.extractor.blob.blob import Blob from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from extensions.ext_storage import storage @@ -16,21 +17,17 @@ class PdfExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - file_cache_key: Optional[str] = None - ): + def __init__(self, file_path: str, file_cache_key: Optional[str] = None): """Initialize with file path.""" self._file_path = file_path self._file_cache_key = file_cache_key def extract(self) -> list[Document]: - plaintext_file_key = '' + plaintext_file_key = "" plaintext_file_exists = False if self._file_cache_key: try: - text = storage.load(self._file_cache_key).decode('utf-8') + text = storage.load(self._file_cache_key).decode("utf-8") plaintext_file_exists = True return [Document(page_content=text)] except FileNotFoundError: @@ -43,12 +40,12 @@ class PdfExtractor(BaseExtractor): # save plaintext file for caching if not plaintext_file_exists and plaintext_file_key: - storage.save(plaintext_file_key, text.encode('utf-8')) + storage.save(plaintext_file_key, text.encode("utf-8")) return documents def load( - self, + self, ) -> Iterator[Document]: """Lazy load given path as pages.""" blob = Blob.from_path(self._file_path) diff --git a/api/core/rag/extractor/text_extractor.py b/api/core/rag/extractor/text_extractor.py index ac5d0920cf..ed0ae41f51 100644 --- a/api/core/rag/extractor/text_extractor.py +++ b/api/core/rag/extractor/text_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor @@ -14,12 +15,7 @@ class TextExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - encoding: Optional[str] = None, - autodetect_encoding: bool = False - ): + def __init__(self, file_path: str, encoding: Optional[str] = None, autodetect_encoding: bool = False): """Initialize with file path.""" self._file_path = file_path self._encoding = encoding diff --git a/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py index 0323b14a4a..a525c9e9e3 100644 --- a/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py @@ -8,13 +8,12 @@ logger = logging.getLogger(__name__) class UnstructuredWordExtractor(BaseExtractor): - """Loader that uses unstructured to load word documents. - """ + """Loader that uses unstructured to load word documents.""" def __init__( - self, - file_path: str, - api_url: str, + self, + file_path: str, + api_url: str, ): """Initialize with file path.""" self._file_path = file_path @@ -24,9 +23,7 @@ class UnstructuredWordExtractor(BaseExtractor): from unstructured.__version__ import __version__ as __unstructured_version__ from unstructured.file_utils.filetype import FileType, detect_filetype - unstructured_version = tuple( - int(x) for x in __unstructured_version__.split(".") - ) + unstructured_version = tuple(int(x) for x in __unstructured_version__.split(".")) # check the file extension try: import magic # noqa: F401 @@ -53,6 +50,7 @@ class UnstructuredWordExtractor(BaseExtractor): elements = partition_docx(filename=self._file_path) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py index 2e704f187d..34c6811b67 100644 --- a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py @@ -26,6 +26,7 @@ class UnstructuredEmailExtractor(BaseExtractor): def extract(self) -> list[Document]: from unstructured.partition.email import partition_email + elements = partition_email(filename=self._file_path) # noinspection PyBroadException @@ -34,15 +35,16 @@ class UnstructuredEmailExtractor(BaseExtractor): element_text = element.text.strip() padding_needed = 4 - len(element_text) % 4 - element_text += '=' * padding_needed + element_text += "=" * padding_needed element_decode = base64.b64decode(element_text) - soup = BeautifulSoup(element_decode.decode('utf-8'), 'html.parser') + soup = BeautifulSoup(element_decode.decode("utf-8"), "html.parser") element.text = soup.get_text() except Exception: pass from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py index 44cf958ea2..fa50fa76b2 100644 --- a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py @@ -28,6 +28,7 @@ class UnstructuredEpubExtractor(BaseExtractor): elements = partition_epub(filename=self._file_path, xml_keep_tags=True) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py b/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py index 144b4e0c1d..fc3ff10693 100644 --- a/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py @@ -38,6 +38,7 @@ class UnstructuredMarkdownExtractor(BaseExtractor): elements = partition_md(filename=self._file_path) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py b/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py index ad09b79eb0..8091e83e85 100644 --- a/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py @@ -14,11 +14,7 @@ class UnstructuredMsgExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - api_url: str - ): + def __init__(self, file_path: str, api_url: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url @@ -28,6 +24,7 @@ class UnstructuredMsgExtractor(BaseExtractor): elements = partition_msg(filename=self._file_path) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py index d354b593ed..b69394b3b1 100644 --- a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py @@ -14,12 +14,7 @@ class UnstructuredPPTExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - api_url: str, - api_key: str - ): + def __init__(self, file_path: str, api_url: str, api_key: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py index 6fcbb5feb9..6ed4a0dfb3 100644 --- a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py @@ -14,11 +14,7 @@ class UnstructuredPPTXExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - api_url: str - ): + def __init__(self, file_path: str, api_url: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_text_extractor.py b/api/core/rag/extractor/unstructured/unstructured_text_extractor.py index f4a4adbc16..22dfdd2075 100644 --- a/api/core/rag/extractor/unstructured/unstructured_text_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_text_extractor.py @@ -14,11 +14,7 @@ class UnstructuredTextExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - api_url: str - ): + def __init__(self, file_path: str, api_url: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url @@ -28,6 +24,7 @@ class UnstructuredTextExtractor(BaseExtractor): elements = partition_text(filename=self._file_path) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py index 6aef8e0f7e..3bffc01fbf 100644 --- a/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py @@ -14,11 +14,7 @@ class UnstructuredXmlExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - api_url: str - ): + def __init__(self, file_path: str, api_url: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url @@ -28,6 +24,7 @@ class UnstructuredXmlExtractor(BaseExtractor): elements = partition_xml(filename=self._file_path, xml_keep_tags=True) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index c3f0b75cfb..2db00d161b 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + import datetime import logging import mimetypes @@ -21,6 +22,7 @@ from models.model import UploadFile logger = logging.getLogger(__name__) + class WordExtractor(BaseExtractor): """Load docx files. @@ -43,9 +45,7 @@ class WordExtractor(BaseExtractor): r = requests.get(self.file_path) if r.status_code != 200: - raise ValueError( - f"Check the url of your file; returned status code {r.status_code}" - ) + raise ValueError(f"Check the url of your file; returned status code {r.status_code}") self.web_path = self.file_path self.temp_file = tempfile.NamedTemporaryFile() @@ -60,11 +60,13 @@ class WordExtractor(BaseExtractor): def extract(self) -> list[Document]: """Load given path as single page.""" - content = self.parse_docx(self.file_path, 'storage') - return [Document( - page_content=content, - metadata={"source": self.file_path}, - )] + content = self.parse_docx(self.file_path, "storage") + return [ + Document( + page_content=content, + metadata={"source": self.file_path}, + ) + ] @staticmethod def _is_valid_url(url: str) -> bool: @@ -84,18 +86,18 @@ class WordExtractor(BaseExtractor): url = rel.reltype response = requests.get(url, stream=True) if response.status_code == 200: - image_ext = mimetypes.guess_extension(response.headers['Content-Type']) + image_ext = mimetypes.guess_extension(response.headers["Content-Type"]) file_uuid = str(uuid.uuid4()) - file_key = 'image_files/' + self.tenant_id + '/' + file_uuid + '.' + image_ext + file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext mime_type, _ = mimetypes.guess_type(file_key) storage.save(file_key, response.content) else: continue else: - image_ext = rel.target_ref.split('.')[-1] + image_ext = rel.target_ref.split(".")[-1] # user uuid as file name file_uuid = str(uuid.uuid4()) - file_key = 'image_files/' + self.tenant_id + '/' + file_uuid + '.' + image_ext + file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext mime_type, _ = mimetypes.guess_type(file_key) storage.save(file_key, rel.target_part.blob) @@ -112,12 +114,14 @@ class WordExtractor(BaseExtractor): created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), used=True, used_by=self.user_id, - used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), ) db.session.add(upload_file) db.session.commit() - image_map[rel.target_part] = f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/image-preview)" + image_map[rel.target_part] = ( + f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/image-preview)" + ) return image_map @@ -167,9 +171,11 @@ class WordExtractor(BaseExtractor): def _parse_cell_paragraph(self, paragraph, image_map): paragraph_content = [] for run in paragraph.runs: - if run.element.xpath('.//a:blip'): - for blip in run.element.xpath('.//a:blip'): + if run.element.xpath(".//a:blip"): + for blip in run.element.xpath(".//a:blip"): image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") + if not image_id: + continue image_part = paragraph.part.rels[image_id].target_part if image_part in image_map: @@ -182,16 +188,16 @@ class WordExtractor(BaseExtractor): def _parse_paragraph(self, paragraph, image_map): paragraph_content = [] for run in paragraph.runs: - if run.element.xpath('.//a:blip'): - for blip in run.element.xpath('.//a:blip'): - embed_id = blip.get('{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed') + if run.element.xpath(".//a:blip"): + for blip in run.element.xpath(".//a:blip"): + embed_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") if embed_id: rel_target = run.part.rels[embed_id].target_ref if rel_target in image_map: paragraph_content.append(image_map[rel_target]) if run.text.strip(): paragraph_content.append(run.text.strip()) - return ' '.join(paragraph_content) if paragraph_content else '' + return " ".join(paragraph_content) if paragraph_content else "" def parse_docx(self, docx_path, image_folder): doc = DocxDocument(docx_path) @@ -202,60 +208,59 @@ class WordExtractor(BaseExtractor): image_map = self._extract_images_from_docx(doc, image_folder) hyperlinks_url = None - url_pattern = re.compile(r'http://[^\s+]+//|https://[^\s+]+') + url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+") for para in doc.paragraphs: for run in para.runs: if run.text and hyperlinks_url: - result = f' [{run.text}]({hyperlinks_url}) ' + result = f" [{run.text}]({hyperlinks_url}) " run.text = result hyperlinks_url = None - if 'HYPERLINK' in run.element.xml: + if "HYPERLINK" in run.element.xml: try: xml = ET.XML(run.element.xml) x_child = [c for c in xml.iter() if c is not None] for x in x_child: if x_child is None: continue - if x.tag.endswith('instrText'): + if x.tag.endswith("instrText"): for i in url_pattern.findall(x.text): hyperlinks_url = str(i) except Exception as e: logger.error(e) - - - def parse_paragraph(paragraph): paragraph_content = [] for run in paragraph.runs: - if hasattr(run.element, 'tag') and isinstance(element.tag, str) and run.element.tag.endswith('r'): + if hasattr(run.element, "tag") and isinstance(element.tag, str) and run.element.tag.endswith("r"): drawing_elements = run.element.findall( - './/{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing') + ".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing" + ) for drawing in drawing_elements: blip_elements = drawing.findall( - './/{http://schemas.openxmlformats.org/drawingml/2006/main}blip') + ".//{http://schemas.openxmlformats.org/drawingml/2006/main}blip" + ) for blip in blip_elements: embed_id = blip.get( - '{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed') + "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed" + ) if embed_id: image_part = doc.part.related_parts.get(embed_id) if image_part in image_map: paragraph_content.append(image_map[image_part]) if run.text.strip(): paragraph_content.append(run.text.strip()) - return ''.join(paragraph_content) if paragraph_content else '' + return "".join(paragraph_content) if paragraph_content else "" paragraphs = doc.paragraphs.copy() tables = doc.tables.copy() for element in doc.element.body: - if hasattr(element, 'tag'): - if isinstance(element.tag, str) and element.tag.endswith('p'): # paragraph + if hasattr(element, "tag"): + if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph para = paragraphs.pop(0) parsed_paragraph = parse_paragraph(para) if parsed_paragraph: content.append(parsed_paragraph) - elif isinstance(element.tag, str) and element.tag.endswith('tbl'): # table + elif isinstance(element.tag, str) and element.tag.endswith("tbl"): # table table = tables.pop(0) - content.append(self._table_to_markdown(table,image_map)) - return '\n'.join(content) - + content.append(self._table_to_markdown(table, image_map)) + return "\n".join(content) diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 176d0c1ed6..be857bd122 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + from abc import ABC, abstractmethod from typing import Optional @@ -15,8 +16,7 @@ from models.dataset import Dataset, DatasetProcessRule class BaseIndexProcessor(ABC): - """Interface for extract files. - """ + """Interface for extract files.""" @abstractmethod def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: @@ -34,18 +34,24 @@ class BaseIndexProcessor(ABC): raise NotImplementedError @abstractmethod - def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int, - score_threshold: float, reranking_model: dict) -> list[Document]: + def retrieve( + self, + retrieval_method: str, + query: str, + dataset: Dataset, + top_k: int, + score_threshold: float, + reranking_model: dict, + ) -> list[Document]: raise NotImplementedError - def _get_splitter(self, processing_rule: dict, - embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: + def _get_splitter(self, processing_rule: dict, embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: """ Get the NodeParser object according to the processing rule. """ - if processing_rule['mode'] == "custom": + if processing_rule["mode"] == "custom": # The user-defined segmentation rule - rules = processing_rule['rules'] + rules = processing_rule["rules"] segmentation = rules["segmentation"] max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: @@ -53,22 +59,22 @@ class BaseIndexProcessor(ABC): separator = segmentation["separator"] if separator: - separator = separator.replace('\\n', '\n') + separator = separator.replace("\\n", "\n") character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( chunk_size=segmentation["max_tokens"], - chunk_overlap=segmentation.get('chunk_overlap', 0) or 0, + chunk_overlap=segmentation.get("chunk_overlap", 0) or 0, fixed_separator=separator, separators=["\n\n", "。", ". ", " ", ""], - embedding_model_instance=embedding_model_instance + embedding_model_instance=embedding_model_instance, ) else: # Automatic segmentation character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( - chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'], - chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'], + chunk_size=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["max_tokens"], + chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["chunk_overlap"], separators=["\n\n", "。", ". ", " ", ""], - embedding_model_instance=embedding_model_instance + embedding_model_instance=embedding_model_instance, ) return character_splitter diff --git a/api/core/rag/index_processor/index_processor_factory.py b/api/core/rag/index_processor/index_processor_factory.py index df43a64910..9b855ece2c 100644 --- a/api/core/rag/index_processor/index_processor_factory.py +++ b/api/core/rag/index_processor/index_processor_factory.py @@ -7,8 +7,7 @@ from core.rag.index_processor.processor.qa_index_processor import QAIndexProcess class IndexProcessorFactory: - """IndexProcessorInit. - """ + """IndexProcessorInit.""" def __init__(self, index_type: str): self._index_type = index_type @@ -22,7 +21,6 @@ class IndexProcessorFactory: if self._index_type == IndexType.PARAGRAPH_INDEX.value: return ParagraphIndexProcessor() elif self._index_type == IndexType.QA_INDEX.value: - return QAIndexProcessor() else: raise ValueError(f"Index type {self._index_type} is not supported.") diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 5fbc319fd6..ed5712220f 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -1,4 +1,5 @@ """Paragraph index processor.""" + import uuid from typing import Optional @@ -15,34 +16,33 @@ from models.dataset import Dataset class ParagraphIndexProcessor(BaseIndexProcessor): - def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: - - text_docs = ExtractProcessor.extract(extract_setting=extract_setting, - is_automatic=kwargs.get('process_rule_mode') == "automatic") + text_docs = ExtractProcessor.extract( + extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic" + ) return text_docs def transform(self, documents: list[Document], **kwargs) -> list[Document]: # Split the text documents into nodes. - splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), - embedding_model_instance=kwargs.get('embedding_model_instance')) + splitter = self._get_splitter( + processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance") + ) all_documents = [] for document in documents: # document clean - document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule')) + document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule")) document.page_content = document_text # parse document to nodes document_nodes = splitter.split_documents([document]) split_documents = [] for document_node in document_nodes: - if document_node.page_content.strip(): doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata['doc_id'] = doc_id - document_node.metadata['doc_hash'] = hash - # delete Spliter character + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash + # delete Splitter character page_content = document_node.page_content if page_content.startswith(".") or page_content.startswith("。"): page_content = page_content[1:].strip() @@ -55,7 +55,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): return all_documents def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) if with_keywords: @@ -63,7 +63,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): keyword.create(documents) def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": vector = Vector(dataset) if node_ids: vector.delete_by_ids(node_ids) @@ -76,17 +76,29 @@ class ParagraphIndexProcessor(BaseIndexProcessor): else: keyword.delete() - def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int, - score_threshold: float, reranking_model: dict) -> list[Document]: + def retrieve( + self, + retrieval_method: str, + query: str, + dataset: Dataset, + top_k: int, + score_threshold: float, + reranking_model: dict, + ) -> list[Document]: # Set search parameters. - results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query, - top_k=top_k, score_threshold=score_threshold, - reranking_model=reranking_model) + results = RetrievalService.retrieve( + retrieval_method=retrieval_method, + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + ) # Organize results. docs = [] for result in results: metadata = result.metadata - metadata['score'] = result.score + metadata["score"] = result.score if result.score > score_threshold: doc = Document(page_content=result.page_content, metadata=metadata) docs.append(doc) diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 139bfe15f3..1dbc473281 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -1,4 +1,5 @@ """Paragraph index processor.""" + import logging import re import threading @@ -23,34 +24,34 @@ from models.dataset import Dataset class QAIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: - - text_docs = ExtractProcessor.extract(extract_setting=extract_setting, - is_automatic=kwargs.get('process_rule_mode') == "automatic") + text_docs = ExtractProcessor.extract( + extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic" + ) return text_docs def transform(self, documents: list[Document], **kwargs) -> list[Document]: - splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), - embedding_model_instance=kwargs.get('embedding_model_instance')) + splitter = self._get_splitter( + processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance") + ) # Split the text documents into nodes. all_documents = [] all_qa_documents = [] for document in documents: # document clean - document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule')) + document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule")) document.page_content = document_text # parse document to nodes document_nodes = splitter.split_documents([document]) split_documents = [] for document_node in document_nodes: - if document_node.page_content.strip(): doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata['doc_id'] = doc_id - document_node.metadata['doc_hash'] = hash - # delete Spliter character + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash + # delete Splitter character page_content = document_node.page_content if page_content.startswith(".") or page_content.startswith("。"): page_content = page_content[1:] @@ -61,14 +62,18 @@ class QAIndexProcessor(BaseIndexProcessor): all_documents.extend(split_documents) for i in range(0, len(all_documents), 10): threads = [] - sub_documents = all_documents[i:i + 10] + sub_documents = all_documents[i : i + 10] for doc in sub_documents: - document_format_thread = threading.Thread(target=self._format_qa_document, kwargs={ - 'flask_app': current_app._get_current_object(), - 'tenant_id': kwargs.get('tenant_id'), - 'document_node': doc, - 'all_qa_documents': all_qa_documents, - 'document_language': kwargs.get('doc_language', 'English')}) + document_format_thread = threading.Thread( + target=self._format_qa_document, + kwargs={ + "flask_app": current_app._get_current_object(), + "tenant_id": kwargs.get("tenant_id"), + "document_node": doc, + "all_qa_documents": all_qa_documents, + "document_language": kwargs.get("doc_language", "English"), + }, + ) threads.append(document_format_thread) document_format_thread.start() for thread in threads: @@ -76,9 +81,8 @@ class QAIndexProcessor(BaseIndexProcessor): return all_qa_documents def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: - # check file type - if not file.filename.endswith('.csv'): + if not file.filename.endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") try: @@ -86,7 +90,7 @@ class QAIndexProcessor(BaseIndexProcessor): df = pd.read_csv(file) text_docs = [] for index, row in df.iterrows(): - data = Document(page_content=row[0], metadata={'answer': row[1]}) + data = Document(page_content=row[0], metadata={"answer": row[1]}) text_docs.append(data) if len(text_docs) == 0: raise ValueError("The CSV file is empty.") @@ -96,7 +100,7 @@ class QAIndexProcessor(BaseIndexProcessor): return text_docs def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) @@ -107,17 +111,29 @@ class QAIndexProcessor(BaseIndexProcessor): else: vector.delete() - def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int, - score_threshold: float, reranking_model: dict): + def retrieve( + self, + retrieval_method: str, + query: str, + dataset: Dataset, + top_k: int, + score_threshold: float, + reranking_model: dict, + ): # Set search parameters. - results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query, - top_k=top_k, score_threshold=score_threshold, - reranking_model=reranking_model) + results = RetrievalService.retrieve( + retrieval_method=retrieval_method, + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + ) # Organize results. docs = [] for result in results: metadata = result.metadata - metadata['score'] = result.score + metadata["score"] = result.score if result.score > score_threshold: doc = Document(page_content=result.page_content, metadata=metadata) docs.append(doc) @@ -134,12 +150,12 @@ class QAIndexProcessor(BaseIndexProcessor): document_qa_list = self._format_split_text(response) qa_documents = [] for result in document_qa_list: - qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy()) + qa_document = Document(page_content=result["question"], metadata=document_node.metadata.copy()) doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(result['question']) - qa_document.metadata['answer'] = result['answer'] - qa_document.metadata['doc_id'] = doc_id - qa_document.metadata['doc_hash'] = hash + hash = helper.generate_text_hash(result["question"]) + qa_document.metadata["answer"] = result["answer"] + qa_document.metadata["doc_id"] = doc_id + qa_document.metadata["doc_hash"] = hash qa_documents.append(qa_document) format_documents.extend(qa_documents) except Exception as e: @@ -151,10 +167,4 @@ class QAIndexProcessor(BaseIndexProcessor): regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" matches = re.findall(regex, text, re.UNICODE) - return [ - { - "question": q, - "answer": re.sub(r"\n\s*", "\n", a.strip()) - } - for q, a in matches if q and a - ] + return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a] diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 6f3c1c5d34..0ff1fdb81c 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -55,9 +55,7 @@ class BaseDocumentTransformer(ABC): """ @abstractmethod - def transform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: + def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: """Transform a list of documents. Args: @@ -68,9 +66,7 @@ class BaseDocumentTransformer(ABC): """ @abstractmethod - async def atransform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: + async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: """Asynchronously transform a list of documents. Args: diff --git a/api/core/rag/rerank/constants/rerank_mode.py b/api/core/rag/rerank/constants/rerank_mode.py index afbb9fd89d..d4894e3cc6 100644 --- a/api/core/rag/rerank/constants/rerank_mode.py +++ b/api/core/rag/rerank/constants/rerank_mode.py @@ -2,7 +2,5 @@ from enum import Enum class RerankMode(Enum): - - RERANKING_MODEL = 'reranking_model' - WEIGHTED_SCORE = 'weighted_score' - + RERANKING_MODEL = "reranking_model" + WEIGHTED_SCORE = "weighted_score" diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index d9067da288..6356ff87ab 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -8,8 +8,14 @@ class RerankModelRunner: def __init__(self, rerank_model_instance: ModelInstance) -> None: self.rerank_model_instance = rerank_model_instance - def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None, - top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]: + def run( + self, + query: str, + documents: list[Document], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> list[Document]: """ Run rerank model :param query: search query @@ -23,19 +29,15 @@ class RerankModelRunner: doc_id = [] unique_documents = [] for document in documents: - if document.metadata['doc_id'] not in doc_id: - doc_id.append(document.metadata['doc_id']) + if document.metadata["doc_id"] not in doc_id: + doc_id.append(document.metadata["doc_id"]) docs.append(document.page_content) unique_documents.append(document) documents = unique_documents rerank_result = self.rerank_model_instance.invoke_rerank( - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - user=user + query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user ) rerank_documents = [] @@ -45,12 +47,12 @@ class RerankModelRunner: rerank_document = Document( page_content=result.text, metadata={ - "doc_id": documents[result.index].metadata['doc_id'], - "doc_hash": documents[result.index].metadata['doc_hash'], - "document_id": documents[result.index].metadata['document_id'], - "dataset_id": documents[result.index].metadata['dataset_id'], - 'score': result.score - } + "doc_id": documents[result.index].metadata["doc_id"], + "doc_hash": documents[result.index].metadata["doc_hash"], + "document_id": documents[result.index].metadata["document_id"], + "dataset_id": documents[result.index].metadata["dataset_id"], + "score": result.score, + }, ) rerank_documents.append(rerank_document) diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index d8a7873982..4375079ee5 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -13,13 +13,18 @@ from core.rag.rerank.entity.weight import VectorSetting, Weights class WeightRerankRunner: - def __init__(self, tenant_id: str, weights: Weights) -> None: self.tenant_id = tenant_id self.weights = weights - def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None, - top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]: + def run( + self, + query: str, + documents: list[Document], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> list[Document]: """ Run rerank model :param query: search query @@ -34,8 +39,8 @@ class WeightRerankRunner: doc_id = [] unique_documents = [] for document in documents: - if document.metadata['doc_id'] not in doc_id: - doc_id.append(document.metadata['doc_id']) + if document.metadata["doc_id"] not in doc_id: + doc_id.append(document.metadata["doc_id"]) docs.append(document.page_content) unique_documents.append(document) @@ -47,13 +52,15 @@ class WeightRerankRunner: query_vector_scores = self._calculate_cosine(self.tenant_id, query, documents, self.weights.vector_setting) for document, query_score, query_vector_score in zip(documents, query_scores, query_vector_scores): # format document - score = self.weights.vector_setting.vector_weight * query_vector_score + \ - self.weights.keyword_setting.keyword_weight * query_score + score = ( + self.weights.vector_setting.vector_weight * query_vector_score + + self.weights.keyword_setting.keyword_weight * query_score + ) if score_threshold and score < score_threshold: continue - document.metadata['score'] = score + document.metadata["score"] = score rerank_documents.append(document) - rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata['score'], reverse=True) + rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata["score"], reverse=True) return rerank_documents[:top_n] if top_n else rerank_documents def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]: @@ -70,7 +77,7 @@ class WeightRerankRunner: for document in documents: # get the document keywords document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) - document.metadata['keywords'] = document_keywords + document.metadata["keywords"] = document_keywords documents_keywords.append(document_keywords) # Counter query keywords(TF) @@ -132,8 +139,9 @@ class WeightRerankRunner: return similarities - def _calculate_cosine(self, tenant_id: str, query: str, documents: list[Document], - vector_setting: VectorSetting) -> list[float]: + def _calculate_cosine( + self, tenant_id: str, query: str, documents: list[Document], vector_setting: VectorSetting + ) -> list[float]: """ Calculate Cosine scores :param query: search query @@ -149,15 +157,14 @@ class WeightRerankRunner: tenant_id=tenant_id, provider=vector_setting.embedding_provider_name, model_type=ModelType.TEXT_EMBEDDING, - model=vector_setting.embedding_model_name - + model=vector_setting.embedding_model_name, ) cache_embedding = CacheEmbedding(embedding_model) query_vector = cache_embedding.embed_query(query) for document in documents: # calculate cosine similarity - if 'score' in document.metadata: - query_vector_scores.append(document.metadata['score']) + if "score" in document.metadata: + query_vector_scores.append(document.metadata["score"]) else: # transform to NumPy vec1 = np.array(query_vector) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index c970e3dafa..4948ec6ba8 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -21,7 +21,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.retrieval_service import RetrievalService from core.rag.models.document import Document -from core.rag.retrieval.retrival_methods import RetrievalMethod +from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool @@ -32,14 +32,11 @@ from models.dataset import Dataset, DatasetQuery, DocumentSegment from models.dataset import Document as DatasetDocument default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } @@ -48,15 +45,18 @@ class DatasetRetrieval: self.application_generate_entity = application_generate_entity def retrieve( - self, app_id: str, user_id: str, tenant_id: str, - model_config: ModelConfigWithCredentialsEntity, - config: DatasetEntity, - query: str, - invoke_from: InvokeFrom, - show_retrieve_source: bool, - hit_callback: DatasetIndexToolCallbackHandler, - message_id: str, - memory: Optional[TokenBufferMemory] = None, + self, + app_id: str, + user_id: str, + tenant_id: str, + model_config: ModelConfigWithCredentialsEntity, + config: DatasetEntity, + query: str, + invoke_from: InvokeFrom, + show_retrieve_source: bool, + hit_callback: DatasetIndexToolCallbackHandler, + message_id: str, + memory: Optional[TokenBufferMemory] = None, ) -> Optional[str]: """ Retrieve dataset. @@ -84,16 +84,12 @@ class DatasetRetrieval: model_manager = ModelManager() model_instance = model_manager.get_model_instance( - tenant_id=tenant_id, - model_type=ModelType.LLM, - provider=model_config.provider, - model=model_config.model + tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model ) # get model schema model_schema = model_type_instance.get_model_schema( - model=model_config.model, - credentials=model_config.credentials + model=model_config.model, credentials=model_config.credentials ) if not model_schema: @@ -102,39 +98,46 @@ class DatasetRetrieval: planning_strategy = PlanningStrategy.REACT_ROUTER features = model_schema.features if features: - if ModelFeature.TOOL_CALL in features \ - or ModelFeature.MULTI_TOOL_CALL in features: + if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features: planning_strategy = PlanningStrategy.ROUTER available_datasets = [] for dataset_id in dataset_ids: # get dataset from dataset id - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() # pass if dataset is not available if not dataset: continue # pass if dataset is not available - if (dataset and dataset.available_document_count == 0 - and dataset.available_document_count == 0): + if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0: continue available_datasets.append(dataset) all_documents = [] - user_from = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user' + user_from = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user" if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: all_documents = self.single_retrieve( - app_id, tenant_id, user_id, user_from, available_datasets, query, + app_id, + tenant_id, + user_id, + user_from, + available_datasets, + query, model_instance, - model_config, planning_strategy, message_id + model_config, + planning_strategy, + message_id, ) elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: all_documents = self.multiple_retrieve( - app_id, tenant_id, user_id, user_from, - available_datasets, query, retrieve_config.top_k, + app_id, + tenant_id, + user_id, + user_from, + available_datasets, + query, + retrieve_config.top_k, retrieve_config.score_threshold, retrieve_config.rerank_mode, retrieve_config.reranking_model, @@ -145,89 +148,89 @@ class DatasetRetrieval: document_score_list = {} for item in all_documents: - if item.metadata.get('score'): - document_score_list[item.metadata['doc_id']] = item.metadata['score'] + if item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_context_list = [] - index_node_ids = [document.metadata['doc_id'] for document in all_documents] + index_node_ids = [document.metadata["doc_id"] for document in all_documents] segments = DocumentSegment.query.filter( DocumentSegment.dataset_id.in_(dataset_ids), DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == 'completed', + DocumentSegment.status == "completed", DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids) + DocumentSegment.index_node_id.in_(index_node_ids), ).all() if segments: index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted(segments, - key=lambda segment: index_node_id_to_position.get(segment.index_node_id, - float('inf'))) + sorted_segments = sorted( + segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) + ) for segment in sorted_segments: if segment.answer: - document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}') + document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}") else: document_context_list.append(segment.get_sign_content()) if show_retrieve_source: context_list = [] resource_number = 1 for segment in sorted_segments: - dataset = Dataset.query.filter_by( - id=segment.dataset_id + dataset = Dataset.query.filter_by(id=segment.dataset_id).first() + document = DatasetDocument.query.filter( + DatasetDocument.id == segment.document_id, + DatasetDocument.enabled == True, + DatasetDocument.archived == False, ).first() - document = DatasetDocument.query.filter(DatasetDocument.id == segment.document_id, - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).first() if dataset and document: source = { - 'position': resource_number, - 'dataset_id': dataset.id, - 'dataset_name': dataset.name, - 'document_id': document.id, - 'document_name': document.name, - 'data_source_type': document.data_source_type, - 'segment_id': segment.id, - 'retriever_from': invoke_from.to_source(), - 'score': document_score_list.get(segment.index_node_id, None) + "position": resource_number, + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, + "document_name": document.name, + "data_source_type": document.data_source_type, + "segment_id": segment.id, + "retriever_from": invoke_from.to_source(), + "score": document_score_list.get(segment.index_node_id, None), } - if invoke_from.to_source() == 'dev': - source['hit_count'] = segment.hit_count - source['word_count'] = segment.word_count - source['segment_position'] = segment.position - source['index_node_hash'] = segment.index_node_hash + if invoke_from.to_source() == "dev": + source["hit_count"] = segment.hit_count + source["word_count"] = segment.word_count + source["segment_position"] = segment.position + source["index_node_hash"] = segment.index_node_hash if segment.answer: - source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" else: - source['content'] = segment.content + source["content"] = segment.content context_list.append(source) resource_number += 1 if hit_callback: hit_callback.return_retriever_resource_info(context_list) return str("\n".join(document_context_list)) - return '' + return "" def single_retrieve( - self, app_id: str, - tenant_id: str, - user_id: str, - user_from: str, - available_datasets: list, - query: str, - model_instance: ModelInstance, - model_config: ModelConfigWithCredentialsEntity, - planning_strategy: PlanningStrategy, - message_id: Optional[str] = None, + self, + app_id: str, + tenant_id: str, + user_id: str, + user_from: str, + available_datasets: list, + query: str, + model_instance: ModelInstance, + model_config: ModelConfigWithCredentialsEntity, + planning_strategy: PlanningStrategy, + message_id: Optional[str] = None, ): tools = [] for dataset in available_datasets: description = dataset.description if not description: - description = 'useful for when you want to answer queries about the ' + dataset.name + description = "useful for when you want to answer queries about the " + dataset.name - description = description.replace('\n', '').replace('\r', '') + description = description.replace("\n", "").replace("\r", "") message_tool = PromptMessageTool( name=dataset.id, description=description, @@ -235,14 +238,15 @@ class DatasetRetrieval: "type": "object", "properties": {}, "required": [], - } + }, ) tools.append(message_tool) dataset_id = None if planning_strategy == PlanningStrategy.REACT_ROUTER: react_multi_dataset_router = ReactMultiDatasetRouter() - dataset_id = react_multi_dataset_router.invoke(query, tools, model_config, model_instance, - user_id, tenant_id) + dataset_id = react_multi_dataset_router.invoke( + query, tools, model_config, model_instance, user_id, tenant_id + ) elif planning_strategy == PlanningStrategy.ROUTER: function_call_router = FunctionCallMultiDatasetRouter() @@ -250,61 +254,61 @@ class DatasetRetrieval: if dataset_id: # get retrieval model config - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if dataset: - retrieval_model_config = dataset.retrieval_model \ - if dataset.retrieval_model else default_retrieval_model + retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model # get top k - top_k = retrieval_model_config['top_k'] + top_k = retrieval_model_config["top_k"] # get retrieval method if dataset.indexing_technique == "economy": - retrival_method = 'keyword_search' + retrieval_method = "keyword_search" else: - retrival_method = retrieval_model_config['search_method'] + retrieval_method = retrieval_model_config["search_method"] # get reranking model - reranking_model = retrieval_model_config['reranking_model'] \ - if retrieval_model_config['reranking_enable'] else None + reranking_model = ( + retrieval_model_config["reranking_model"] if retrieval_model_config["reranking_enable"] else None + ) # get score threshold - score_threshold = .0 + score_threshold = 0.0 score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") if score_threshold_enabled: score_threshold = retrieval_model_config.get("score_threshold") with measure_time() as timer: results = RetrievalService.retrieve( - retrival_method=retrival_method, dataset_id=dataset.id, + retrieval_method=retrieval_method, + dataset_id=dataset.id, query=query, - top_k=top_k, score_threshold=score_threshold, + top_k=top_k, + score_threshold=score_threshold, reranking_model=reranking_model, - reranking_mode=retrieval_model_config.get('reranking_mode', 'reranking_model'), - weights=retrieval_model_config.get('weights', None), + reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"), + weights=retrieval_model_config.get("weights", None), ) self._on_query(query, [dataset_id], app_id, user_from, user_id) if results: - self._on_retrival_end(results, message_id, timer) + self._on_retrieval_end(results, message_id, timer) return results return [] def multiple_retrieve( - self, - app_id: str, - tenant_id: str, - user_id: str, - user_from: str, - available_datasets: list, - query: str, - top_k: int, - score_threshold: float, - reranking_mode: str, - reranking_model: Optional[dict] = None, - weights: Optional[dict] = None, - reranking_enable: bool = True, - message_id: Optional[str] = None, + self, + app_id: str, + tenant_id: str, + user_id: str, + user_from: str, + available_datasets: list, + query: str, + top_k: int, + score_threshold: float, + reranking_mode: str, + reranking_model: Optional[dict] = None, + weights: Optional[dict] = None, + reranking_enable: bool = True, + message_id: Optional[str] = None, ): threads = [] all_documents = [] @@ -312,13 +316,16 @@ class DatasetRetrieval: index_type = None for dataset in available_datasets: index_type = dataset.indexing_technique - retrieval_thread = threading.Thread(target=self._retriever, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset.id, - 'query': query, - 'top_k': top_k, - 'all_documents': all_documents, - }) + retrieval_thread = threading.Thread( + target=self._retriever, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset.id, + "query": query, + "top_k": top_k, + "all_documents": all_documents, + }, + ) threads.append(retrieval_thread) retrieval_thread.start() for thread in threads: @@ -327,16 +334,10 @@ class DatasetRetrieval: with measure_time() as timer: if reranking_enable: # do rerank for searched documents - data_post_processor = DataPostProcessor( - tenant_id, reranking_mode, - reranking_model, weights, False - ) + data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False) all_documents = data_post_processor.invoke( - query=query, - documents=all_documents, - score_threshold=score_threshold, - top_n=top_k + query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k ) else: if index_type == "economy": @@ -347,40 +348,36 @@ class DatasetRetrieval: self._on_query(query, dataset_ids, app_id, user_from, user_id) if all_documents: - self._on_retrival_end(all_documents, message_id, timer) + self._on_retrieval_end(all_documents, message_id, timer) return all_documents - def _on_retrival_end( + def _on_retrieval_end( self, documents: list[Document], message_id: Optional[str] = None, timer: Optional[dict] = None ) -> None: - """Handle retrival end.""" + """Handle retrieval end.""" for document in documents: query = db.session.query(DocumentSegment).filter( - DocumentSegment.index_node_id == document.metadata['doc_id'] + DocumentSegment.index_node_id == document.metadata["doc_id"] ) # if 'dataset_id' in document.metadata: - if 'dataset_id' in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id']) + if "dataset_id" in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) # add hit count to document segment - query.update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, - synchronize_session=False - ) + query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) db.session.commit() # get tracing instance - trace_manager: TraceQueueManager = self.application_generate_entity.trace_manager if self.application_generate_entity else None + trace_manager: TraceQueueManager = ( + self.application_generate_entity.trace_manager if self.application_generate_entity else None + ) if trace_manager: trace_manager.add_trace_task( TraceTask( - TraceTaskName.DATASET_RETRIEVAL_TRACE, - message_id=message_id, - documents=documents, - timer=timer + TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer ) ) @@ -395,10 +392,10 @@ class DatasetRetrieval: dataset_query = DatasetQuery( dataset_id=dataset_id, content=query, - source='app', + source="app", source_app_id=app_id, created_by_role=user_from, - created_by=user_id + created_by=user_id, ) dataset_queries.append(dataset_query) if dataset_queries: @@ -407,9 +404,7 @@ class DatasetRetrieval: def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): with flask_app.app_context(): - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: return [] @@ -419,38 +414,42 @@ class DatasetRetrieval: if dataset.indexing_technique == "economy": # use keyword table query - documents = RetrievalService.retrieve(retrival_method='keyword_search', - dataset_id=dataset.id, - query=query, - top_k=top_k - ) + documents = RetrievalService.retrieve( + retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=top_k + ) if documents: all_documents.extend(documents) else: if top_k > 0: # retrieval source - documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], - dataset_id=dataset.id, - query=query, - top_k=top_k, - score_threshold=retrieval_model.get('score_threshold', .0) - if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model.get('reranking_model', None) - if retrieval_model['reranking_enable'] else None, - reranking_mode=retrieval_model.get('reranking_mode') - if retrieval_model.get('reranking_mode') else 'reranking_model', - weights=retrieval_model.get('weights', None), - ) + documents = RetrievalService.retrieve( + retrieval_method=retrieval_model["search_method"], + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else None, + reranking_model=retrieval_model.get("reranking_model", None) + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") + if retrieval_model.get("reranking_mode") + else "reranking_model", + weights=retrieval_model.get("weights", None), + ) all_documents.extend(documents) - def to_dataset_retriever_tool(self, tenant_id: str, - dataset_ids: list[str], - retrieve_config: DatasetRetrieveConfigEntity, - return_resource: bool, - invoke_from: InvokeFrom, - hit_callback: DatasetIndexToolCallbackHandler) \ - -> Optional[list[DatasetRetrieverBaseTool]]: + def to_dataset_retriever_tool( + self, + tenant_id: str, + dataset_ids: list[str], + retrieve_config: DatasetRetrieveConfigEntity, + return_resource: bool, + invoke_from: InvokeFrom, + hit_callback: DatasetIndexToolCallbackHandler, + ) -> Optional[list[DatasetRetrieverBaseTool]]: """ A dataset tool is a tool that can be used to retrieve information from a dataset :param tenant_id: tenant id @@ -464,18 +463,14 @@ class DatasetRetrieval: available_datasets = [] for dataset_id in dataset_ids: # get dataset from dataset id - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() # pass if dataset is not available if not dataset: continue # pass if dataset is not available - if (dataset and dataset.available_document_count == 0 - and dataset.available_document_count == 0): + if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0: continue available_datasets.append(dataset) @@ -483,22 +478,18 @@ class DatasetRetrieval: if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: # get retrieval model config default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } for dataset in available_datasets: - retrieval_model_config = dataset.retrieval_model \ - if dataset.retrieval_model else default_retrieval_model + retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model # get top k - top_k = retrieval_model_config['top_k'] + top_k = retrieval_model_config["top_k"] # get score threshold score_threshold = None @@ -512,7 +503,7 @@ class DatasetRetrieval: score_threshold=score_threshold, hit_callbacks=[hit_callback], return_resource=return_resource, - retriever_from=invoke_from.to_source() + retriever_from=invoke_from.to_source(), ) tools.append(tool) @@ -525,8 +516,8 @@ class DatasetRetrieval: hit_callbacks=[hit_callback], return_resource=return_resource, retriever_from=invoke_from.to_source(), - reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'), - reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name') + reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"), + reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"), ) tools.append(tool) @@ -547,7 +538,7 @@ class DatasetRetrieval: for document in documents: # get the document keywords document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) - document.metadata['keywords'] = document_keywords + document.metadata["keywords"] = document_keywords documents_keywords.append(document_keywords) # Counter query keywords(TF) @@ -606,21 +597,19 @@ class DatasetRetrieval: for document, score in zip(documents, similarities): # format document - document.metadata['score'] = score - documents = sorted(documents, key=lambda x: x.metadata['score'], reverse=True) + document.metadata["score"] = score + documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) return documents[:top_k] if top_k else documents - def calculate_vector_score(self, all_documents: list[Document], - top_k: int, score_threshold: float) -> list[Document]: + def calculate_vector_score( + self, all_documents: list[Document], top_k: int, score_threshold: float + ) -> list[Document]: filter_documents = [] for document in all_documents: - if score_threshold is None or document.metadata['score'] >= score_threshold: + if score_threshold is None or document.metadata["score"] >= score_threshold: filter_documents.append(document) if not filter_documents: return [] - filter_documents = sorted(filter_documents, key=lambda x: x.metadata['score'], reverse=True) + filter_documents = sorted(filter_documents, key=lambda x: x.metadata["score"], reverse=True) return filter_documents[:top_k] if top_k else filter_documents - - - diff --git a/api/core/rag/retrieval/output_parser/structured_chat.py b/api/core/rag/retrieval/output_parser/structured_chat.py index 60770bd4c6..7fc78bce83 100644 --- a/api/core/rag/retrieval/output_parser/structured_chat.py +++ b/api/core/rag/retrieval/output_parser/structured_chat.py @@ -16,9 +16,7 @@ class StructuredChatOutputParser: if response["action"] == "Final Answer": return ReactFinish({"output": response["action_input"]}, text) else: - return ReactAction( - response["action"], response.get("action_input", {}), text - ) + return ReactAction(response["action"], response.get("action_input", {}), text) else: return ReactFinish({"output": text}, text) except Exception as e: diff --git a/api/core/rag/retrieval/retrival_methods.py b/api/core/rag/retrieval/retrieval_methods.py similarity index 79% rename from api/core/rag/retrieval/retrival_methods.py rename to api/core/rag/retrieval/retrieval_methods.py index 12aa28a51c..eaa00bca88 100644 --- a/api/core/rag/retrieval/retrival_methods.py +++ b/api/core/rag/retrieval/retrieval_methods.py @@ -2,9 +2,9 @@ from enum import Enum class RetrievalMethod(Enum): - SEMANTIC_SEARCH = 'semantic_search' - FULL_TEXT_SEARCH = 'full_text_search' - HYBRID_SEARCH = 'hybrid_search' + SEMANTIC_SEARCH = "semantic_search" + FULL_TEXT_SEARCH = "full_text_search" + HYBRID_SEARCH = "hybrid_search" @staticmethod def is_support_semantic_search(retrieval_method: str) -> bool: diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index 84e53952ac..06147fe7b5 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -6,14 +6,12 @@ from core.model_runtime.entities.message_entities import PromptMessageTool, Syst class FunctionCallMultiDatasetRouter: - def invoke( - self, - query: str, - dataset_tools: list[PromptMessageTool], - model_config: ModelConfigWithCredentialsEntity, - model_instance: ModelInstance, - + self, + query: str, + dataset_tools: list[PromptMessageTool], + model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, ) -> Union[str, None]: """Given input, decided what to do. Returns: @@ -26,22 +24,18 @@ class FunctionCallMultiDatasetRouter: try: prompt_messages = [ - SystemPromptMessage(content='You are a helpful AI assistant.'), - UserPromptMessage(content=query) + SystemPromptMessage(content="You are a helpful AI assistant."), + UserPromptMessage(content=query), ] result = model_instance.invoke_llm( prompt_messages=prompt_messages, tools=dataset_tools, stream=False, - model_parameters={ - 'temperature': 0.2, - 'top_p': 0.3, - 'max_tokens': 1500 - } + model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, ) if result.message.tool_calls: # get retrieval model config return result.message.tool_calls[0].function.name return None except Exception as e: - return None \ No newline at end of file + return None diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index 92f24277c1..33841cac06 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -50,16 +50,14 @@ Action: class ReactMultiDatasetRouter: - def invoke( - self, - query: str, - dataset_tools: list[PromptMessageTool], - model_config: ModelConfigWithCredentialsEntity, - model_instance: ModelInstance, - user_id: str, - tenant_id: str - + self, + query: str, + dataset_tools: list[PromptMessageTool], + model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, + user_id: str, + tenant_id: str, ) -> Union[str, None]: """Given input, decided what to do. Returns: @@ -71,23 +69,28 @@ class ReactMultiDatasetRouter: return dataset_tools[0].name try: - return self._react_invoke(query=query, model_config=model_config, - model_instance=model_instance, - tools=dataset_tools, user_id=user_id, tenant_id=tenant_id) + return self._react_invoke( + query=query, + model_config=model_config, + model_instance=model_instance, + tools=dataset_tools, + user_id=user_id, + tenant_id=tenant_id, + ) except Exception as e: return None def _react_invoke( - self, - query: str, - model_config: ModelConfigWithCredentialsEntity, - model_instance: ModelInstance, - tools: Sequence[PromptMessageTool], - user_id: str, - tenant_id: str, - prefix: str = PREFIX, - suffix: str = SUFFIX, - format_instructions: str = FORMAT_INSTRUCTIONS, + self, + query: str, + model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, + tools: Sequence[PromptMessageTool], + user_id: str, + tenant_id: str, + prefix: str = PREFIX, + suffix: str = SUFFIX, + format_instructions: str = FORMAT_INSTRUCTIONS, ) -> Union[str, None]: if model_config.mode == "chat": prompt = self.create_chat_prompt( @@ -103,18 +106,18 @@ class ReactMultiDatasetRouter: prefix=prefix, format_instructions=format_instructions, ) - stop = ['Observation:'] + stop = ["Observation:"] # handle invoke result prompt_transform = AdvancedPromptTransform() prompt_messages = prompt_transform.get_prompt( prompt_template=prompt, inputs={}, - query='', + query="", files=[], - context='', + context="", memory_config=None, memory=None, - model_config=model_config + model_config=model_config, ) result_text, usage = self._invoke_llm( completion_param=model_config.parameters, @@ -122,7 +125,7 @@ class ReactMultiDatasetRouter: prompt_messages=prompt_messages, stop=stop, user_id=user_id, - tenant_id=tenant_id + tenant_id=tenant_id, ) output_parser = StructuredChatOutputParser() react_decision = output_parser.parse(result_text) @@ -130,17 +133,21 @@ class ReactMultiDatasetRouter: return react_decision.tool return None - def _invoke_llm(self, completion_param: dict, - model_instance: ModelInstance, - prompt_messages: list[PromptMessage], - stop: list[str], user_id: str, tenant_id: str - ) -> tuple[str, LLMUsage]: + def _invoke_llm( + self, + completion_param: dict, + model_instance: ModelInstance, + prompt_messages: list[PromptMessage], + stop: list[str], + user_id: str, + tenant_id: str, + ) -> tuple[str, LLMUsage]: """ - Invoke large language model - :param model_instance: model instance - :param prompt_messages: prompt messages - :param stop: stop - :return: + Invoke large language model + :param model_instance: model instance + :param prompt_messages: prompt messages + :param stop: stop + :return: """ invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, @@ -151,9 +158,7 @@ class ReactMultiDatasetRouter: ) # handle invoke result - text, usage = self._handle_invoke_result( - invoke_result=invoke_result - ) + text, usage = self._handle_invoke_result(invoke_result=invoke_result) # deduct quota LLMNode.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) @@ -168,7 +173,7 @@ class ReactMultiDatasetRouter: """ model = None prompt_messages = [] - full_text = '' + full_text = "" usage = None for result in invoke_result: text = result.delta.message.content @@ -189,40 +194,35 @@ class ReactMultiDatasetRouter: return full_text, usage def create_chat_prompt( - self, - query: str, - tools: Sequence[PromptMessageTool], - prefix: str = PREFIX, - suffix: str = SUFFIX, - format_instructions: str = FORMAT_INSTRUCTIONS, + self, + query: str, + tools: Sequence[PromptMessageTool], + prefix: str = PREFIX, + suffix: str = SUFFIX, + format_instructions: str = FORMAT_INSTRUCTIONS, ) -> list[ChatModelMessage]: tool_strings = [] for tool in tools: tool_strings.append( - f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}") + f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}" + ) formatted_tools = "\n".join(tool_strings) unique_tool_names = {tool.name for tool in tools} tool_names = ", ".join('"' + name + '"' for name in unique_tool_names) format_instructions = format_instructions.format(tool_names=tool_names) template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix]) prompt_messages = [] - system_prompt_messages = ChatModelMessage( - role=PromptMessageRole.SYSTEM, - text=template - ) + system_prompt_messages = ChatModelMessage(role=PromptMessageRole.SYSTEM, text=template) prompt_messages.append(system_prompt_messages) - user_prompt_message = ChatModelMessage( - role=PromptMessageRole.USER, - text=query - ) + user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=query) prompt_messages.append(user_prompt_message) return prompt_messages def create_completion_prompt( - self, - tools: Sequence[PromptMessageTool], - prefix: str = PREFIX, - format_instructions: str = FORMAT_INSTRUCTIONS, + self, + tools: Sequence[PromptMessageTool], + prefix: str = PREFIX, + format_instructions: str = FORMAT_INSTRUCTIONS, ) -> CompletionModelPromptTemplate: """Create prompt in the style of the zero shot agent. diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 6a0804f890..53032b34d5 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -1,4 +1,5 @@ """Functionality for splitting text.""" + from __future__ import annotations from typing import Any, Optional @@ -18,31 +19,29 @@ from core.rag.splitter.text_splitter import ( class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): """ - This class is used to implement from_gpt2_encoder, to prevent using of tiktoken + This class is used to implement from_gpt2_encoder, to prevent using of tiktoken """ @classmethod def from_encoder( - cls: type[TS], - embedding_model_instance: Optional[ModelInstance], - allowed_special: Union[Literal[all], Set[str]] = set(), - disallowed_special: Union[Literal[all], Collection[str]] = "all", - **kwargs: Any, + cls: type[TS], + embedding_model_instance: Optional[ModelInstance], + allowed_special: Union[Literal[all], Set[str]] = set(), + disallowed_special: Union[Literal[all], Collection[str]] = "all", + **kwargs: Any, ): def _token_encoder(text: str) -> int: if not text: return 0 if embedding_model_instance: - return embedding_model_instance.get_text_embedding_num_tokens( - texts=[text] - ) + return embedding_model_instance.get_text_embedding_num_tokens(texts=[text]) else: return GPT2Tokenizer.get_num_tokens(text) if issubclass(cls, TokenTextSplitter): extra_kwargs = { - "model_name": embedding_model_instance.model if embedding_model_instance else 'gpt2', + "model_name": embedding_model_instance.model if embedding_model_instance else "gpt2", "allowed_special": allowed_special, "disallowed_special": disallowed_special, } @@ -93,17 +92,21 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) splits = list(text) # Now go merging things, recursively splitting longer texts. _good_splits = [] + _good_splits_lengths = [] # cache the lengths of the splits for s in splits: - if self._length_function(s) < self._chunk_size: + s_len = self._length_function(s) + if s_len < self._chunk_size: _good_splits.append(s) + _good_splits_lengths.append(s_len) else: if _good_splits: - merged_text = self._merge_splits(_good_splits, separator) + merged_text = self._merge_splits(_good_splits, separator, _good_splits_lengths) final_chunks.extend(merged_text) _good_splits = [] + _good_splits_lengths = [] other_info = self.recursive_split_text(s) final_chunks.extend(other_info) if _good_splits: - merged_text = self._merge_splits(_good_splits, separator) + merged_text = self._merge_splits(_good_splits, separator, _good_splits_lengths) final_chunks.extend(merged_text) return final_chunks diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index b3adcedc76..97d0721304 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -22,35 +22,32 @@ logger = logging.getLogger(__name__) TS = TypeVar("TS", bound="TextSplitter") -def _split_text_with_regex( - text: str, separator: str, keep_separator: bool -) -> list[str]: +def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> list[str]: # Now that we have the separator, split the text if separator: if keep_separator: # The parentheses in the pattern keep the delimiters in the result. _splits = re.split(f"({re.escape(separator)})", text) - splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)] - if len(_splits) % 2 == 0: + splits = [_splits[i - 1] + _splits[i] for i in range(1, len(_splits), 2)] + if len(_splits) % 2 != 0: splits += _splits[-1:] - splits = [_splits[0]] + splits else: splits = re.split(separator, text) else: splits = list(text) - return [s for s in splits if s != ""] + return [s for s in splits if (s != "" and s != "\n")] class TextSplitter(BaseDocumentTransformer, ABC): """Interface for splitting text into chunks.""" def __init__( - self, - chunk_size: int = 4000, - chunk_overlap: int = 200, - length_function: Callable[[str], int] = len, - keep_separator: bool = False, - add_start_index: bool = False, + self, + chunk_size: int = 4000, + chunk_overlap: int = 200, + length_function: Callable[[str], int] = len, + keep_separator: bool = False, + add_start_index: bool = False, ) -> None: """Create a new TextSplitter. @@ -63,8 +60,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): """ if chunk_overlap > chunk_size: raise ValueError( - f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " - f"({chunk_size}), should be smaller." + f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " f"({chunk_size}), should be smaller." ) self._chunk_size = chunk_size self._chunk_overlap = chunk_overlap @@ -76,9 +72,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): def split_text(self, text: str) -> list[str]: """Split text into multiple components.""" - def create_documents( - self, texts: list[str], metadatas: Optional[list[dict]] = None - ) -> list[Document]: + def create_documents(self, texts: list[str], metadatas: Optional[list[dict]] = None) -> list[Document]: """Create documents from a list of texts.""" _metadatas = metadatas or [{}] * len(texts) documents = [] @@ -109,7 +103,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): else: return text - def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]: + def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int]) -> list[str]: # We now want to combine these smaller pieces into medium size # chunks to send to the LLM. separator_len = self._length_function(separator) @@ -117,16 +111,13 @@ class TextSplitter(BaseDocumentTransformer, ABC): docs = [] current_doc: list[str] = [] total = 0 + index = 0 for d in splits: - _len = self._length_function(d) - if ( - total + _len + (separator_len if len(current_doc) > 0 else 0) - > self._chunk_size - ): + _len = lengths[index] + if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size: if total > self._chunk_size: logger.warning( - f"Created a chunk of size {total}, " - f"which is longer than the specified {self._chunk_size}" + f"Created a chunk of size {total}, " f"which is longer than the specified {self._chunk_size}" ) if len(current_doc) > 0: doc = self._join_docs(current_doc, separator) @@ -136,16 +127,13 @@ class TextSplitter(BaseDocumentTransformer, ABC): # - we have a larger chunk than in the chunk overlap # - or if we still have any chunks and the length is long while total > self._chunk_overlap or ( - total + _len + (separator_len if len(current_doc) > 0 else 0) - > self._chunk_size - and total > 0 + total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0 ): - total -= self._length_function(current_doc[0]) + ( - separator_len if len(current_doc) > 1 else 0 - ) + total -= self._length_function(current_doc[0]) + (separator_len if len(current_doc) > 1 else 0) current_doc = current_doc[1:] current_doc.append(d) total += _len + (separator_len if len(current_doc) > 1 else 0) + index += 1 doc = self._join_docs(current_doc, separator) if doc is not None: docs.append(doc) @@ -158,28 +146,25 @@ class TextSplitter(BaseDocumentTransformer, ABC): from transformers import PreTrainedTokenizerBase if not isinstance(tokenizer, PreTrainedTokenizerBase): - raise ValueError( - "Tokenizer received was not an instance of PreTrainedTokenizerBase" - ) + raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase") def _huggingface_tokenizer_length(text: str) -> int: return len(tokenizer.encode(text)) except ImportError: raise ValueError( - "Could not import transformers python package. " - "Please install it with `pip install transformers`." + "Could not import transformers python package. " "Please install it with `pip install transformers`." ) return cls(length_function=_huggingface_tokenizer_length, **kwargs) @classmethod def from_tiktoken_encoder( - cls: type[TS], - encoding_name: str = "gpt2", - model_name: Optional[str] = None, - allowed_special: Union[Literal["all"], Set[str]] = set(), - disallowed_special: Union[Literal["all"], Collection[str]] = "all", - **kwargs: Any, + cls: type[TS], + encoding_name: str = "gpt2", + model_name: Optional[str] = None, + allowed_special: Union[Literal["all"], Set[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + **kwargs: Any, ) -> TS: """Text splitter that uses tiktoken encoder to count length.""" try: @@ -216,15 +201,11 @@ class TextSplitter(BaseDocumentTransformer, ABC): return cls(length_function=_tiktoken_encoder, **kwargs) - def transform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: + def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: """Transform sequence of documents by splitting them.""" return self.split_documents(list(documents)) - async def atransform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: + async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: """Asynchronously transform a sequence of documents by splitting them.""" raise NotImplementedError @@ -242,7 +223,10 @@ class CharacterTextSplitter(TextSplitter): # First we naively split the large input into a bunch of smaller ones. splits = _split_text_with_regex(text, self._separator, self._keep_separator) _separator = "" if self._keep_separator else self._separator - return self._merge_splits(splits, _separator) + _good_splits_lengths = [] # cache the lengths of the splits + for split in splits: + _good_splits_lengths.append(self._length_function(split)) + return self._merge_splits(splits, _separator, _good_splits_lengths) class LineType(TypedDict): @@ -263,9 +247,7 @@ class HeaderType(TypedDict): class MarkdownHeaderTextSplitter: """Splitting markdown files based on specified headers.""" - def __init__( - self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False - ): + def __init__(self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False): """Create a new MarkdownHeaderTextSplitter. Args: @@ -276,9 +258,7 @@ class MarkdownHeaderTextSplitter: self.return_each_line = return_each_line # Given the headers we want to split on, # (e.g., "#, ##, etc") order by length - self.headers_to_split_on = sorted( - headers_to_split_on, key=lambda split: len(split[0]), reverse=True - ) + self.headers_to_split_on = sorted(headers_to_split_on, key=lambda split: len(split[0]), reverse=True) def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]: """Combine lines with common metadata into chunks @@ -288,10 +268,7 @@ class MarkdownHeaderTextSplitter: aggregated_chunks: list[LineType] = [] for line in lines: - if ( - aggregated_chunks - and aggregated_chunks[-1]["metadata"] == line["metadata"] - ): + if aggregated_chunks and aggregated_chunks[-1]["metadata"] == line["metadata"]: # If the last line in the aggregated list # has the same metadata as the current line, # append the current content to the last lines's content @@ -300,10 +277,7 @@ class MarkdownHeaderTextSplitter: # Otherwise, append the current line to the aggregated list aggregated_chunks.append(line) - return [ - Document(page_content=chunk["content"], metadata=chunk["metadata"]) - for chunk in aggregated_chunks - ] + return [Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in aggregated_chunks] def split_text(self, text: str) -> list[Document]: """Split markdown file @@ -328,10 +302,9 @@ class MarkdownHeaderTextSplitter: for sep, name in self.headers_to_split_on: # Check if line starts with a header that we intend to split on if stripped_line.startswith(sep) and ( - # Header with no text OR header is followed by space - # Both are valid conditions that sep is being used a header - len(stripped_line) == len(sep) - or stripped_line[len(sep)] == " " + # Header with no text OR header is followed by space + # Both are valid conditions that sep is being used a header + len(stripped_line) == len(sep) or stripped_line[len(sep)] == " " ): # Ensure we are tracking the header as metadata if name is not None: @@ -339,10 +312,7 @@ class MarkdownHeaderTextSplitter: current_header_level = sep.count("#") # Pop out headers of lower or same level from the stack - while ( - header_stack - and header_stack[-1]["level"] >= current_header_level - ): + while header_stack and header_stack[-1]["level"] >= current_header_level: # We have encountered a new header # at the same or higher level popped_header = header_stack.pop() @@ -355,7 +325,7 @@ class MarkdownHeaderTextSplitter: header: HeaderType = { "level": current_header_level, "name": name, - "data": stripped_line[len(sep):].strip(), + "data": stripped_line[len(sep) :].strip(), } header_stack.append(header) # Update initial_metadata with the current header @@ -388,9 +358,7 @@ class MarkdownHeaderTextSplitter: current_metadata = initial_metadata.copy() if current_content: - lines_with_metadata.append( - {"content": "\n".join(current_content), "metadata": current_metadata} - ) + lines_with_metadata.append({"content": "\n".join(current_content), "metadata": current_metadata}) # lines_with_metadata has each line with associated header metadata # aggregate these into chunks based on common metadata @@ -398,8 +366,7 @@ class MarkdownHeaderTextSplitter: return self.aggregate_lines_to_chunks(lines_with_metadata) else: return [ - Document(page_content=chunk["content"], metadata=chunk["metadata"]) - for chunk in lines_with_metadata + Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in lines_with_metadata ] @@ -432,12 +399,12 @@ class TokenTextSplitter(TextSplitter): """Splitting text to tokens using model tokenizer.""" def __init__( - self, - encoding_name: str = "gpt2", - model_name: Optional[str] = None, - allowed_special: Union[Literal["all"], Set[str]] = set(), - disallowed_special: Union[Literal["all"], Collection[str]] = "all", - **kwargs: Any, + self, + encoding_name: str = "gpt2", + model_name: Optional[str] = None, + allowed_special: Union[Literal["all"], Set[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + **kwargs: Any, ) -> None: """Create a new TextSplitter.""" super().__init__(**kwargs) @@ -484,50 +451,55 @@ class RecursiveCharacterTextSplitter(TextSplitter): """ def __init__( - self, - separators: Optional[list[str]] = None, - keep_separator: bool = True, - **kwargs: Any, + self, + separators: Optional[list[str]] = None, + keep_separator: bool = True, + **kwargs: Any, ) -> None: """Create a new TextSplitter.""" super().__init__(keep_separator=keep_separator, **kwargs) self._separators = separators or ["\n\n", "\n", " ", ""] def _split_text(self, text: str, separators: list[str]) -> list[str]: - """Split incoming text and return chunks.""" final_chunks = [] - # Get appropriate separator to use separator = separators[-1] new_separators = [] + for i, _s in enumerate(separators): if _s == "": separator = _s break if re.search(_s, text): separator = _s - new_separators = separators[i + 1:] + new_separators = separators[i + 1 :] break splits = _split_text_with_regex(text, separator, self._keep_separator) - # Now go merging things, recursively splitting longer texts. _good_splits = [] + _good_splits_lengths = [] # cache the lengths of the splits _separator = "" if self._keep_separator else separator + for s in splits: - if self._length_function(s) < self._chunk_size: + s_len = self._length_function(s) + if s_len < self._chunk_size: _good_splits.append(s) + _good_splits_lengths.append(s_len) else: if _good_splits: - merged_text = self._merge_splits(_good_splits, _separator) + merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths) final_chunks.extend(merged_text) _good_splits = [] + _good_splits_lengths = [] if not new_separators: final_chunks.append(s) else: other_info = self._split_text(s, new_separators) final_chunks.extend(other_info) + if _good_splits: - merged_text = self._merge_splits(_good_splits, _separator) + merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths) final_chunks.extend(merged_text) + return final_chunks def split_text(self, text: str) -> list[str]: diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 2b01b8fd8e..b988a588e9 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -10,23 +10,23 @@ from core.tools.tool.tool import ToolParameter class UserTool(BaseModel): author: str - name: str # identifier - label: I18nObject # label + name: str # identifier + label: I18nObject # label description: I18nObject parameters: Optional[list[ToolParameter]] = None labels: list[str] = None -UserToolProviderTypeLiteral = Optional[Literal[ - 'builtin', 'api', 'workflow' -]] + +UserToolProviderTypeLiteral = Optional[Literal["builtin", "api", "workflow"]] + class UserToolProvider(BaseModel): id: str author: str - name: str # identifier + name: str # identifier description: I18nObject icon: str - label: I18nObject # label + label: I18nObject # label type: ToolProviderType masked_credentials: Optional[dict] = None original_credentials: Optional[dict] = None @@ -40,26 +40,27 @@ class UserToolProvider(BaseModel): # overwrite tool parameter types for temp fix tools = jsonable_encoder(self.tools) for tool in tools: - if tool.get('parameters'): - for parameter in tool.get('parameters'): - if parameter.get('type') == ToolParameter.ToolParameterType.FILE.value: - parameter['type'] = 'files' + if tool.get("parameters"): + for parameter in tool.get("parameters"): + if parameter.get("type") == ToolParameter.ToolParameterType.FILE.value: + parameter["type"] = "files" # ------------- return { - 'id': self.id, - 'author': self.author, - 'name': self.name, - 'description': self.description.to_dict(), - 'icon': self.icon, - 'label': self.label.to_dict(), - 'type': self.type.value, - 'team_credentials': self.masked_credentials, - 'is_team_authorization': self.is_team_authorization, - 'allow_delete': self.allow_delete, - 'tools': tools, - 'labels': self.labels, + "id": self.id, + "author": self.author, + "name": self.name, + "description": self.description.to_dict(), + "icon": self.icon, + "label": self.label.to_dict(), + "type": self.type.value, + "team_credentials": self.masked_credentials, + "is_team_authorization": self.is_team_authorization, + "allow_delete": self.allow_delete, + "tools": tools, + "labels": self.labels, } + class UserToolProviderCredentials(BaseModel): - credentials: dict[str, ToolProviderCredentials] \ No newline at end of file + credentials: dict[str, ToolProviderCredentials] diff --git a/api/core/tools/entities/common_entities.py b/api/core/tools/entities/common_entities.py index 55e31e8c35..37a926697b 100644 --- a/api/core/tools/entities/common_entities.py +++ b/api/core/tools/entities/common_entities.py @@ -7,6 +7,7 @@ class I18nObject(BaseModel): """ Model class for i18n object. """ + zh_Hans: Optional[str] = None pt_BR: Optional[str] = None en_US: str @@ -19,8 +20,4 @@ class I18nObject(BaseModel): self.pt_BR = self.en_US def to_dict(self) -> dict: - return { - 'zh_Hans': self.zh_Hans, - 'en_US': self.en_US, - 'pt_BR': self.pt_BR - } + return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR} diff --git a/api/core/tools/entities/tool_bundle.py b/api/core/tools/entities/tool_bundle.py index d18d27fb02..da6201c5aa 100644 --- a/api/core/tools/entities/tool_bundle.py +++ b/api/core/tools/entities/tool_bundle.py @@ -9,6 +9,7 @@ class ApiToolBundle(BaseModel): """ This class is used to store the schema information of an api based tool. such as the url, the method, the parameters, etc. """ + # server_url server_url: str # method diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index e31dec55d2..02b8b35be7 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -7,27 +7,29 @@ from core.tools.entities.common_entities import I18nObject class ToolLabelEnum(Enum): - SEARCH = 'search' - IMAGE = 'image' - VIDEOS = 'videos' - WEATHER = 'weather' - FINANCE = 'finance' - DESIGN = 'design' - TRAVEL = 'travel' - SOCIAL = 'social' - NEWS = 'news' - MEDICAL = 'medical' - PRODUCTIVITY = 'productivity' - EDUCATION = 'education' - BUSINESS = 'business' - ENTERTAINMENT = 'entertainment' - UTILITIES = 'utilities' - OTHER = 'other' + SEARCH = "search" + IMAGE = "image" + VIDEOS = "videos" + WEATHER = "weather" + FINANCE = "finance" + DESIGN = "design" + TRAVEL = "travel" + SOCIAL = "social" + NEWS = "news" + MEDICAL = "medical" + PRODUCTIVITY = "productivity" + EDUCATION = "education" + BUSINESS = "business" + ENTERTAINMENT = "entertainment" + UTILITIES = "utilities" + OTHER = "other" + class ToolProviderType(Enum): """ - Enum class for tool provider + Enum class for tool provider """ + BUILT_IN = "builtin" WORKFLOW = "workflow" API = "api" @@ -35,7 +37,7 @@ class ToolProviderType(Enum): DATASET_RETRIEVAL = "dataset-retrieval" @classmethod - def value_of(cls, value: str) -> 'ToolProviderType': + def value_of(cls, value: str) -> "ToolProviderType": """ Get value of given mode. @@ -45,19 +47,21 @@ class ToolProviderType(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") + class ApiProviderSchemaType(Enum): """ Enum class for api provider schema type. """ + OPENAPI = "openapi" SWAGGER = "swagger" OPENAI_PLUGIN = "openai_plugin" OPENAI_ACTIONS = "openai_actions" @classmethod - def value_of(cls, value: str) -> 'ApiProviderSchemaType': + def value_of(cls, value: str) -> "ApiProviderSchemaType": """ Get value of given mode. @@ -67,17 +71,19 @@ class ApiProviderSchemaType(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") + class ApiProviderAuthType(Enum): """ Enum class for api provider auth type. """ + NONE = "none" API_KEY = "api_key" @classmethod - def value_of(cls, value: str) -> 'ApiProviderAuthType': + def value_of(cls, value: str) -> "ApiProviderAuthType": """ Get value of given mode. @@ -87,7 +93,8 @@ class ApiProviderAuthType(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") + class ToolInvokeMessage(BaseModel): class MessageType(Enum): @@ -105,19 +112,21 @@ class ToolInvokeMessage(BaseModel): """ message: str | bytes | dict | None = None meta: dict[str, Any] | None = None - save_as: str = '' + save_as: str = "" + class ToolInvokeMessageBinary(BaseModel): mimetype: str = Field(..., description="The mimetype of the binary") url: str = Field(..., description="The url of the binary") - save_as: str = '' + save_as: str = "" file_var: Optional[dict[str, Any]] = None + class ToolParameterOption(BaseModel): value: str = Field(..., description="The value of the option") label: I18nObject = Field(..., description="The label of the option") - @field_validator('value', mode='before') + @field_validator("value", mode="before") @classmethod def transform_id_to_str(cls, value) -> str: if not isinstance(value, str): @@ -136,9 +145,9 @@ class ToolParameter(BaseModel): FILE = "file" class ToolParameterForm(Enum): - SCHEMA = "schema" # should be set while adding tool - FORM = "form" # should be set before invoking tool - LLM = "llm" # will be set by LLM + SCHEMA = "schema" # should be set while adding tool + FORM = "form" # should be set before invoking tool + LLM = "llm" # will be set by LLM name: str = Field(..., description="The name of the parameter") label: I18nObject = Field(..., description="The label presented to the user") @@ -154,25 +163,32 @@ class ToolParameter(BaseModel): options: Optional[list[ToolParameterOption]] = None @classmethod - def get_simple_instance(cls, - name: str, llm_description: str, type: ToolParameterType, - required: bool, options: Optional[list[str]] = None) -> 'ToolParameter': + def get_simple_instance( + cls, + name: str, + llm_description: str, + type: ToolParameterType, + required: bool, + options: Optional[list[str]] = None, + ) -> "ToolParameter": """ - get a simple tool parameter + get a simple tool parameter - :param name: the name of the parameter - :param llm_description: the description presented to the LLM - :param type: the type of the parameter - :param required: if the parameter is required - :param options: the options of the parameter + :param name: the name of the parameter + :param llm_description: the description presented to the LLM + :param type: the type of the parameter + :param required: if the parameter is required + :param options: the options of the parameter """ # convert options to ToolParameterOption if options: - options = [ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options] + options = [ + ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options + ] return cls( name=name, - label=I18nObject(en_US='', zh_Hans=''), - human_description=I18nObject(en_US='', zh_Hans=''), + label=I18nObject(en_US="", zh_Hans=""), + human_description=I18nObject(en_US="", zh_Hans=""), type=type, form=cls.ToolParameterForm.LLM, llm_description=llm_description, @@ -180,18 +196,24 @@ class ToolParameter(BaseModel): options=options, ) + class ToolProviderIdentity(BaseModel): author: str = Field(..., description="The author of the tool") name: str = Field(..., description="The name of the tool") description: I18nObject = Field(..., description="The description of the tool") icon: str = Field(..., description="The icon of the tool") label: I18nObject = Field(..., description="The label of the tool") - tags: Optional[list[ToolLabelEnum]] = Field(default=[], description="The tags of the tool", ) + tags: Optional[list[ToolLabelEnum]] = Field( + default=[], + description="The tags of the tool", + ) + class ToolDescription(BaseModel): human: I18nObject = Field(..., description="The description presented to the user") llm: str = Field(..., description="The description presented to the LLM") + class ToolIdentity(BaseModel): author: str = Field(..., description="The author of the tool") name: str = Field(..., description="The name of the tool") @@ -199,10 +221,12 @@ class ToolIdentity(BaseModel): provider: str = Field(..., description="The provider of the tool") icon: Optional[str] = None + class ToolCredentialsOption(BaseModel): value: str = Field(..., description="The value of the option") label: I18nObject = Field(..., description="The label of the option") + class ToolProviderCredentials(BaseModel): class CredentialsType(Enum): SECRET_INPUT = "secret-input" @@ -221,7 +245,7 @@ class ToolProviderCredentials(BaseModel): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") @staticmethod def default(value: str) -> str: @@ -239,33 +263,38 @@ class ToolProviderCredentials(BaseModel): def to_dict(self) -> dict: return { - 'name': self.name, - 'type': self.type.value, - 'required': self.required, - 'default': self.default, - 'options': self.options, - 'help': self.help.to_dict() if self.help else None, - 'label': self.label.to_dict(), - 'url': self.url, - 'placeholder': self.placeholder.to_dict() if self.placeholder else None, + "name": self.name, + "type": self.type.value, + "required": self.required, + "default": self.default, + "options": self.options, + "help": self.help.to_dict() if self.help else None, + "label": self.label.to_dict(), + "url": self.url, + "placeholder": self.placeholder.to_dict() if self.placeholder else None, } + class ToolRuntimeVariableType(Enum): TEXT = "text" IMAGE = "image" + class ToolRuntimeVariable(BaseModel): type: ToolRuntimeVariableType = Field(..., description="The type of the variable") name: str = Field(..., description="The name of the variable") position: int = Field(..., description="The position of the variable") tool_name: str = Field(..., description="The name of the tool") + class ToolRuntimeTextVariable(ToolRuntimeVariable): value: str = Field(..., description="The value of the variable") + class ToolRuntimeImageVariable(ToolRuntimeVariable): value: str = Field(..., description="The path of the image") + class ToolRuntimeVariablePool(BaseModel): conversation_id: str = Field(..., description="The conversation id") user_id: str = Field(..., description="The user id") @@ -274,26 +303,26 @@ class ToolRuntimeVariablePool(BaseModel): pool: list[ToolRuntimeVariable] = Field(..., description="The pool of variables") def __init__(self, **data: Any): - pool = data.get('pool', []) + pool = data.get("pool", []) # convert pool into correct type for index, variable in enumerate(pool): - if variable['type'] == ToolRuntimeVariableType.TEXT.value: + if variable["type"] == ToolRuntimeVariableType.TEXT.value: pool[index] = ToolRuntimeTextVariable(**variable) - elif variable['type'] == ToolRuntimeVariableType.IMAGE.value: + elif variable["type"] == ToolRuntimeVariableType.IMAGE.value: pool[index] = ToolRuntimeImageVariable(**variable) super().__init__(**data) def dict(self) -> dict: return { - 'conversation_id': self.conversation_id, - 'user_id': self.user_id, - 'tenant_id': self.tenant_id, - 'pool': [variable.model_dump() for variable in self.pool], + "conversation_id": self.conversation_id, + "user_id": self.user_id, + "tenant_id": self.tenant_id, + "pool": [variable.model_dump() for variable in self.pool], } def set_text(self, tool_name: str, name: str, value: str) -> None: """ - set a text variable + set a text variable """ for variable in self.pool: if variable.name == name: @@ -314,10 +343,10 @@ class ToolRuntimeVariablePool(BaseModel): def set_file(self, tool_name: str, value: str, name: str = None) -> None: """ - set an image variable + set an image variable - :param tool_name: the name of the tool - :param value: the id of the file + :param tool_name: the name of the tool + :param value: the id of the file """ # check how many image variables are there image_variable_count = 0 @@ -345,22 +374,27 @@ class ToolRuntimeVariablePool(BaseModel): self.pool.append(variable) + class ModelToolPropertyKey(Enum): IMAGE_PARAMETER_NAME = "image_parameter_name" + class ModelToolConfiguration(BaseModel): """ Model tool configuration """ + type: str = Field(..., description="The type of the model tool") model: str = Field(..., description="The model") label: I18nObject = Field(..., description="The label of the model tool") properties: dict[ModelToolPropertyKey, Any] = Field(..., description="The properties of the model tool") + class ModelToolProviderConfiguration(BaseModel): """ Model tool provider configuration """ + provider: str = Field(..., description="The provider of the model tool") models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool") label: I18nObject = Field(..., description="The label of the model tool") @@ -370,27 +404,30 @@ class WorkflowToolParameterConfiguration(BaseModel): """ Workflow tool configuration """ + name: str = Field(..., description="The name of the parameter") description: str = Field(..., description="The description of the parameter") form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter") + class ToolInvokeMeta(BaseModel): """ Tool invoke meta """ + time_cost: float = Field(..., description="The time cost of the tool invoke") error: Optional[str] = None tool_config: Optional[dict] = None @classmethod - def empty(cls) -> 'ToolInvokeMeta': + def empty(cls) -> "ToolInvokeMeta": """ Get an empty instance of ToolInvokeMeta """ return cls(time_cost=0.0, error=None, tool_config={}) @classmethod - def error_instance(cls, error: str) -> 'ToolInvokeMeta': + def error_instance(cls, error: str) -> "ToolInvokeMeta": """ Get an instance of ToolInvokeMeta with error """ @@ -398,22 +435,26 @@ class ToolInvokeMeta(BaseModel): def to_dict(self) -> dict: return { - 'time_cost': self.time_cost, - 'error': self.error, - 'tool_config': self.tool_config, + "time_cost": self.time_cost, + "error": self.error, + "tool_config": self.tool_config, } + class ToolLabel(BaseModel): """ Tool label """ + name: str = Field(..., description="The name of the tool") label: I18nObject = Field(..., description="The label of the tool") icon: str = Field(..., description="The icon of the tool") + class ToolInvokeFrom(Enum): """ Enum class for tool invoke """ + WORKFLOW = "workflow" AGENT = "agent" diff --git a/api/core/tools/entities/values.py b/api/core/tools/entities/values.py index d0be5e9355..f9db190f91 100644 --- a/api/core/tools/entities/values.py +++ b/api/core/tools/entities/values.py @@ -2,73 +2,109 @@ from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolLabel, ToolLabelEnum ICONS = { - ToolLabelEnum.SEARCH: ''' + ToolLabelEnum.SEARCH: """ -''', - ToolLabelEnum.IMAGE: ''' +""", + ToolLabelEnum.IMAGE: """ -''', - ToolLabelEnum.VIDEOS: ''' +""", + ToolLabelEnum.VIDEOS: """ -''', - ToolLabelEnum.WEATHER: ''' +""", + ToolLabelEnum.WEATHER: """ -''', - ToolLabelEnum.FINANCE: ''' +""", + ToolLabelEnum.FINANCE: """ -''', - ToolLabelEnum.DESIGN: ''' +""", + ToolLabelEnum.DESIGN: """ -''', - ToolLabelEnum.TRAVEL: ''' +""", + ToolLabelEnum.TRAVEL: """ -''', - ToolLabelEnum.SOCIAL: ''' +""", + ToolLabelEnum.SOCIAL: """ -''', - ToolLabelEnum.NEWS: ''' +""", + ToolLabelEnum.NEWS: """ -''', - ToolLabelEnum.MEDICAL: ''' +""", + ToolLabelEnum.MEDICAL: """ -''', - ToolLabelEnum.PRODUCTIVITY: ''' +""", + ToolLabelEnum.PRODUCTIVITY: """ -''', - ToolLabelEnum.EDUCATION: ''' +""", + ToolLabelEnum.EDUCATION: """ -''', - ToolLabelEnum.BUSINESS: ''' +""", + ToolLabelEnum.BUSINESS: """ -''', - ToolLabelEnum.ENTERTAINMENT: ''' +""", + ToolLabelEnum.ENTERTAINMENT: """ -''', - ToolLabelEnum.UTILITIES: ''' +""", + ToolLabelEnum.UTILITIES: """ -''', - ToolLabelEnum.OTHER: ''' +""", + ToolLabelEnum.OTHER: """ -''' +""", } default_tool_label_dict = { - ToolLabelEnum.SEARCH: ToolLabel(name='search', label=I18nObject(en_US='Search', zh_Hans='搜索'), icon=ICONS[ToolLabelEnum.SEARCH]), - ToolLabelEnum.IMAGE: ToolLabel(name='image', label=I18nObject(en_US='Image', zh_Hans='图片'), icon=ICONS[ToolLabelEnum.IMAGE]), - ToolLabelEnum.VIDEOS: ToolLabel(name='videos', label=I18nObject(en_US='Videos', zh_Hans='视频'), icon=ICONS[ToolLabelEnum.VIDEOS]), - ToolLabelEnum.WEATHER: ToolLabel(name='weather', label=I18nObject(en_US='Weather', zh_Hans='天气'), icon=ICONS[ToolLabelEnum.WEATHER]), - ToolLabelEnum.FINANCE: ToolLabel(name='finance', label=I18nObject(en_US='Finance', zh_Hans='金融'), icon=ICONS[ToolLabelEnum.FINANCE]), - ToolLabelEnum.DESIGN: ToolLabel(name='design', label=I18nObject(en_US='Design', zh_Hans='设计'), icon=ICONS[ToolLabelEnum.DESIGN]), - ToolLabelEnum.TRAVEL: ToolLabel(name='travel', label=I18nObject(en_US='Travel', zh_Hans='旅行'), icon=ICONS[ToolLabelEnum.TRAVEL]), - ToolLabelEnum.SOCIAL: ToolLabel(name='social', label=I18nObject(en_US='Social', zh_Hans='社交'), icon=ICONS[ToolLabelEnum.SOCIAL]), - ToolLabelEnum.NEWS: ToolLabel(name='news', label=I18nObject(en_US='News', zh_Hans='新闻'), icon=ICONS[ToolLabelEnum.NEWS]), - ToolLabelEnum.MEDICAL: ToolLabel(name='medical', label=I18nObject(en_US='Medical', zh_Hans='医疗'), icon=ICONS[ToolLabelEnum.MEDICAL]), - ToolLabelEnum.PRODUCTIVITY: ToolLabel(name='productivity', label=I18nObject(en_US='Productivity', zh_Hans='生产力'), icon=ICONS[ToolLabelEnum.PRODUCTIVITY]), - ToolLabelEnum.EDUCATION: ToolLabel(name='education', label=I18nObject(en_US='Education', zh_Hans='教育'), icon=ICONS[ToolLabelEnum.EDUCATION]), - ToolLabelEnum.BUSINESS: ToolLabel(name='business', label=I18nObject(en_US='Business', zh_Hans='商业'), icon=ICONS[ToolLabelEnum.BUSINESS]), - ToolLabelEnum.ENTERTAINMENT: ToolLabel(name='entertainment', label=I18nObject(en_US='Entertainment', zh_Hans='娱乐'), icon=ICONS[ToolLabelEnum.ENTERTAINMENT]), - ToolLabelEnum.UTILITIES: ToolLabel(name='utilities', label=I18nObject(en_US='Utilities', zh_Hans='工具'), icon=ICONS[ToolLabelEnum.UTILITIES]), - ToolLabelEnum.OTHER: ToolLabel(name='other', label=I18nObject(en_US='Other', zh_Hans='其他'), icon=ICONS[ToolLabelEnum.OTHER]), + ToolLabelEnum.SEARCH: ToolLabel( + name="search", label=I18nObject(en_US="Search", zh_Hans="搜索"), icon=ICONS[ToolLabelEnum.SEARCH] + ), + ToolLabelEnum.IMAGE: ToolLabel( + name="image", label=I18nObject(en_US="Image", zh_Hans="图片"), icon=ICONS[ToolLabelEnum.IMAGE] + ), + ToolLabelEnum.VIDEOS: ToolLabel( + name="videos", label=I18nObject(en_US="Videos", zh_Hans="视频"), icon=ICONS[ToolLabelEnum.VIDEOS] + ), + ToolLabelEnum.WEATHER: ToolLabel( + name="weather", label=I18nObject(en_US="Weather", zh_Hans="天气"), icon=ICONS[ToolLabelEnum.WEATHER] + ), + ToolLabelEnum.FINANCE: ToolLabel( + name="finance", label=I18nObject(en_US="Finance", zh_Hans="金融"), icon=ICONS[ToolLabelEnum.FINANCE] + ), + ToolLabelEnum.DESIGN: ToolLabel( + name="design", label=I18nObject(en_US="Design", zh_Hans="设计"), icon=ICONS[ToolLabelEnum.DESIGN] + ), + ToolLabelEnum.TRAVEL: ToolLabel( + name="travel", label=I18nObject(en_US="Travel", zh_Hans="旅行"), icon=ICONS[ToolLabelEnum.TRAVEL] + ), + ToolLabelEnum.SOCIAL: ToolLabel( + name="social", label=I18nObject(en_US="Social", zh_Hans="社交"), icon=ICONS[ToolLabelEnum.SOCIAL] + ), + ToolLabelEnum.NEWS: ToolLabel( + name="news", label=I18nObject(en_US="News", zh_Hans="新闻"), icon=ICONS[ToolLabelEnum.NEWS] + ), + ToolLabelEnum.MEDICAL: ToolLabel( + name="medical", label=I18nObject(en_US="Medical", zh_Hans="医疗"), icon=ICONS[ToolLabelEnum.MEDICAL] + ), + ToolLabelEnum.PRODUCTIVITY: ToolLabel( + name="productivity", + label=I18nObject(en_US="Productivity", zh_Hans="生产力"), + icon=ICONS[ToolLabelEnum.PRODUCTIVITY], + ), + ToolLabelEnum.EDUCATION: ToolLabel( + name="education", label=I18nObject(en_US="Education", zh_Hans="教育"), icon=ICONS[ToolLabelEnum.EDUCATION] + ), + ToolLabelEnum.BUSINESS: ToolLabel( + name="business", label=I18nObject(en_US="Business", zh_Hans="商业"), icon=ICONS[ToolLabelEnum.BUSINESS] + ), + ToolLabelEnum.ENTERTAINMENT: ToolLabel( + name="entertainment", + label=I18nObject(en_US="Entertainment", zh_Hans="娱乐"), + icon=ICONS[ToolLabelEnum.ENTERTAINMENT], + ), + ToolLabelEnum.UTILITIES: ToolLabel( + name="utilities", label=I18nObject(en_US="Utilities", zh_Hans="工具"), icon=ICONS[ToolLabelEnum.UTILITIES] + ), + ToolLabelEnum.OTHER: ToolLabel( + name="other", label=I18nObject(en_US="Other", zh_Hans="其他"), icon=ICONS[ToolLabelEnum.OTHER] + ), } default_tool_labels = [v for k, v in default_tool_label_dict.items()] diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py index 9fd8322db1..6febf137b0 100644 --- a/api/core/tools/errors.py +++ b/api/core/tools/errors.py @@ -4,23 +4,30 @@ from core.tools.entities.tool_entities import ToolInvokeMeta class ToolProviderNotFoundError(ValueError): pass + class ToolNotFoundError(ValueError): pass + class ToolParameterValidationError(ValueError): pass + class ToolProviderCredentialValidationError(ValueError): pass + class ToolNotSupportedError(ValueError): pass + class ToolInvokeError(ValueError): pass + class ToolApiSchemaError(ValueError): pass + class ToolEngineInvokeError(Exception): - meta: ToolInvokeMeta \ No newline at end of file + meta: ToolInvokeMeta diff --git a/api/core/tools/provider/_position.yaml b/api/core/tools/provider/_position.yaml index b804089570..40c3356116 100644 --- a/api/core/tools/provider/_position.yaml +++ b/api/core/tools/provider/_position.yaml @@ -1,5 +1,6 @@ - google - bing +- perplexity - duckduckgo - searchapi - serper @@ -10,6 +11,7 @@ - wikipedia - nominatim - yahoo +- alphavantage - arxiv - pubmed - stablediffusion diff --git a/api/core/tools/provider/api_tool_provider.py b/api/core/tools/provider/api_tool_provider.py index ae80ad2114..2e6018cffc 100644 --- a/api/core/tools/provider/api_tool_provider.py +++ b/api/core/tools/provider/api_tool_provider.py @@ -1,4 +1,3 @@ - from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( @@ -18,85 +17,69 @@ class ApiToolProviderController(ToolProviderController): provider_id: str @staticmethod - def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController': + def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController": credentials_schema = { - 'auth_type': ToolProviderCredentials( - name='auth_type', + "auth_type": ToolProviderCredentials( + name="auth_type", required=True, type=ToolProviderCredentials.CredentialsType.SELECT, options=[ - ToolCredentialsOption(value='none', label=I18nObject(en_US='None', zh_Hans='无')), - ToolCredentialsOption(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key')) + ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")), + ToolCredentialsOption(value="api_key", label=I18nObject(en_US="api_key", zh_Hans="api_key")), ], - default='none', - help=I18nObject( - en_US='The auth type of the api provider', - zh_Hans='api provider 的认证类型' - ) + default="none", + help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"), ) } if auth_type == ApiProviderAuthType.API_KEY: credentials_schema = { **credentials_schema, - 'api_key_header': ToolProviderCredentials( - name='api_key_header', + "api_key_header": ToolProviderCredentials( + name="api_key_header", required=False, - default='api_key', + default="api_key", type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, - help=I18nObject( - en_US='The header name of the api key', - zh_Hans='携带 api key 的 header 名称' - ) + help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"), ), - 'api_key_value': ToolProviderCredentials( - name='api_key_value', + "api_key_value": ToolProviderCredentials( + name="api_key_value", required=True, type=ToolProviderCredentials.CredentialsType.SECRET_INPUT, - help=I18nObject( - en_US='The api key', - zh_Hans='api key的值' - ) + help=I18nObject(en_US="The api key", zh_Hans="api key的值"), ), - 'api_key_header_prefix': ToolProviderCredentials( - name='api_key_header_prefix', + "api_key_header_prefix": ToolProviderCredentials( + name="api_key_header_prefix", required=False, - default='basic', + default="basic", type=ToolProviderCredentials.CredentialsType.SELECT, - help=I18nObject( - en_US='The prefix of the api key header', - zh_Hans='api key header 的前缀' - ), + help=I18nObject(en_US="The prefix of the api key header", zh_Hans="api key header 的前缀"), options=[ - ToolCredentialsOption(value='basic', label=I18nObject(en_US='Basic', zh_Hans='Basic')), - ToolCredentialsOption(value='bearer', label=I18nObject(en_US='Bearer', zh_Hans='Bearer')), - ToolCredentialsOption(value='custom', label=I18nObject(en_US='Custom', zh_Hans='Custom')) - ] - ) + ToolCredentialsOption(value="basic", label=I18nObject(en_US="Basic", zh_Hans="Basic")), + ToolCredentialsOption(value="bearer", label=I18nObject(en_US="Bearer", zh_Hans="Bearer")), + ToolCredentialsOption(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")), + ], + ), } elif auth_type == ApiProviderAuthType.NONE: pass else: - raise ValueError(f'invalid auth type {auth_type}') + raise ValueError(f"invalid auth type {auth_type}") - user_name = db_provider.user.name if db_provider.user_id else '' + user_name = db_provider.user.name if db_provider.user_id else "" - return ApiToolProviderController(**{ - 'identity': { - 'author': user_name, - 'name': db_provider.name, - 'label': { - 'en_US': db_provider.name, - 'zh_Hans': db_provider.name + return ApiToolProviderController( + **{ + "identity": { + "author": user_name, + "name": db_provider.name, + "label": {"en_US": db_provider.name, "zh_Hans": db_provider.name}, + "description": {"en_US": db_provider.description, "zh_Hans": db_provider.description}, + "icon": db_provider.icon, }, - 'description': { - 'en_US': db_provider.description, - 'zh_Hans': db_provider.description - }, - 'icon': db_provider.icon, - }, - 'credentials_schema': credentials_schema, - 'provider_id': db_provider.id or '', - }) + "credentials_schema": credentials_schema, + "provider_id": db_provider.id or "", + } + ) @property def provider_type(self) -> ToolProviderType: @@ -104,39 +87,35 @@ class ApiToolProviderController(ToolProviderController): def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool: """ - parse tool bundle to tool + parse tool bundle to tool - :param tool_bundle: the tool bundle - :return: the tool + :param tool_bundle: the tool bundle + :return: the tool """ - return ApiTool(**{ - 'api_bundle': tool_bundle, - 'identity' : { - 'author': tool_bundle.author, - 'name': tool_bundle.operation_id, - 'label': { - 'en_US': tool_bundle.operation_id, - 'zh_Hans': tool_bundle.operation_id + return ApiTool( + **{ + "api_bundle": tool_bundle, + "identity": { + "author": tool_bundle.author, + "name": tool_bundle.operation_id, + "label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id}, + "icon": self.identity.icon, + "provider": self.provider_id, }, - 'icon': self.identity.icon, - 'provider': self.provider_id, - }, - 'description': { - 'human': { - 'en_US': tool_bundle.summary or '', - 'zh_Hans': tool_bundle.summary or '' + "description": { + "human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""}, + "llm": tool_bundle.summary or "", }, - 'llm': tool_bundle.summary or '' - }, - 'parameters' : tool_bundle.parameters if tool_bundle.parameters else [], - }) + "parameters": tool_bundle.parameters if tool_bundle.parameters else [], + } + ) def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]: """ - load bundled tools + load bundled tools - :param tools: the bundled tools - :return: the tools + :param tools: the bundled tools + :return: the tools """ self.tools = [self._parse_tool_bundle(tool) for tool in tools] @@ -144,22 +123,23 @@ class ApiToolProviderController(ToolProviderController): def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]: """ - fetch tools from database + fetch tools from database - :param user_id: the user id - :param tenant_id: the tenant id - :return: the tools + :param user_id: the user id + :param tenant_id: the tenant id + :return: the tools """ if self.tools is not None: return self.tools - + tools: list[Tool] = [] # get tenant api providers - db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == self.identity.name - ).all() + db_providers: list[ApiToolProvider] = ( + db.session.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.identity.name) + .all() + ) if db_providers and len(db_providers) != 0: for db_provider in db_providers: @@ -167,16 +147,16 @@ class ApiToolProviderController(ToolProviderController): assistant_tool = self._parse_tool_bundle(tool) assistant_tool.is_team_authorization = True tools.append(assistant_tool) - + self.tools = tools return tools - + def get_tool(self, tool_name: str) -> ApiTool: """ - get tool by name + get tool by name - :param tool_name: the name of the tool - :return: the tool + :param tool_name: the name of the tool + :return: the tool """ if self.tools is None: self.get_tools() @@ -185,4 +165,4 @@ class ApiToolProviderController(ToolProviderController): if tool.identity.name == tool_name: return tool - raise ValueError(f'tool {tool_name} not found') \ No newline at end of file + raise ValueError(f"tool {tool_name} not found") diff --git a/api/core/tools/provider/app_tool_provider.py b/api/core/tools/provider/app_tool_provider.py index 2d472e0a93..01544d7e56 100644 --- a/api/core/tools/provider/app_tool_provider.py +++ b/api/core/tools/provider/app_tool_provider.py @@ -11,11 +11,12 @@ from models.tools import PublishedAppTool logger = logging.getLogger(__name__) + class AppToolProviderEntity(ToolProviderController): @property def provider_type(self) -> ToolProviderType: return ToolProviderType.APP - + def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None: pass @@ -23,9 +24,13 @@ class AppToolProviderEntity(ToolProviderController): pass def get_tools(self, user_id: str) -> list[Tool]: - db_tools: list[PublishedAppTool] = db.session.query(PublishedAppTool).filter( - PublishedAppTool.user_id == user_id, - ).all() + db_tools: list[PublishedAppTool] = ( + db.session.query(PublishedAppTool) + .filter( + PublishedAppTool.user_id == user_id, + ) + .all() + ) if not db_tools or len(db_tools) == 0: return [] @@ -34,23 +39,17 @@ class AppToolProviderEntity(ToolProviderController): for db_tool in db_tools: tool = { - 'identity': { - 'author': db_tool.author, - 'name': db_tool.tool_name, - 'label': { - 'en_US': db_tool.tool_name, - 'zh_Hans': db_tool.tool_name - }, - 'icon': '' + "identity": { + "author": db_tool.author, + "name": db_tool.tool_name, + "label": {"en_US": db_tool.tool_name, "zh_Hans": db_tool.tool_name}, + "icon": "", }, - 'description': { - 'human': { - 'en_US': db_tool.description_i18n.en_US, - 'zh_Hans': db_tool.description_i18n.zh_Hans - }, - 'llm': db_tool.llm_description + "description": { + "human": {"en_US": db_tool.description_i18n.en_US, "zh_Hans": db_tool.description_i18n.zh_Hans}, + "llm": db_tool.llm_description, }, - 'parameters': [] + "parameters": [], } # get app from db app: App = db_tool.app @@ -64,52 +63,41 @@ class AppToolProviderEntity(ToolProviderController): for input_form in user_input_form_list: # get type form_type = input_form.keys()[0] - default = input_form[form_type]['default'] - required = input_form[form_type]['required'] - label = input_form[form_type]['label'] - variable_name = input_form[form_type]['variable_name'] - options = input_form[form_type].get('options', []) - if form_type == 'paragraph' or form_type == 'text-input': - tool['parameters'].append(ToolParameter( - name=variable_name, - label=I18nObject( - en_US=label, - zh_Hans=label - ), - human_description=I18nObject( - en_US=label, - zh_Hans=label - ), - llm_description=label, - form=ToolParameter.ToolParameterForm.FORM, - type=ToolParameter.ToolParameterType.STRING, - required=required, - default=default - )) - elif form_type == 'select': - tool['parameters'].append(ToolParameter( - name=variable_name, - label=I18nObject( - en_US=label, - zh_Hans=label - ), - human_description=I18nObject( - en_US=label, - zh_Hans=label - ), - llm_description=label, - form=ToolParameter.ToolParameterForm.FORM, - type=ToolParameter.ToolParameterType.SELECT, - required=required, - default=default, - options=[ToolParameterOption( - value=option, - label=I18nObject( - en_US=option, - zh_Hans=option - ) - ) for option in options] - )) + default = input_form[form_type]["default"] + required = input_form[form_type]["required"] + label = input_form[form_type]["label"] + variable_name = input_form[form_type]["variable_name"] + options = input_form[form_type].get("options", []) + if form_type == "paragraph" or form_type == "text-input": + tool["parameters"].append( + ToolParameter( + name=variable_name, + label=I18nObject(en_US=label, zh_Hans=label), + human_description=I18nObject(en_US=label, zh_Hans=label), + llm_description=label, + form=ToolParameter.ToolParameterForm.FORM, + type=ToolParameter.ToolParameterType.STRING, + required=required, + default=default, + ) + ) + elif form_type == "select": + tool["parameters"].append( + ToolParameter( + name=variable_name, + label=I18nObject(en_US=label, zh_Hans=label), + human_description=I18nObject(en_US=label, zh_Hans=label), + llm_description=label, + form=ToolParameter.ToolParameterForm.FORM, + type=ToolParameter.ToolParameterType.SELECT, + required=required, + default=default, + options=[ + ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) + for option in options + ], + ) + ) tools.append(Tool(**tool)) - return tools \ No newline at end of file + return tools diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py index 062668fc5b..5c10f72fda 100644 --- a/api/core/tools/provider/builtin/_positions.py +++ b/api/core/tools/provider/builtin/_positions.py @@ -10,7 +10,7 @@ class BuiltinToolProviderSort: @classmethod def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: if not cls._position: - cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), '..')) + cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), "..")) def name_func(provider: UserToolProvider) -> str: return provider.name diff --git a/api/core/tools/provider/builtin/aippt/aippt.py b/api/core/tools/provider/builtin/aippt/aippt.py index 25133c51df..e0cbbd2992 100644 --- a/api/core/tools/provider/builtin/aippt/aippt.py +++ b/api/core/tools/provider/builtin/aippt/aippt.py @@ -6,6 +6,6 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class AIPPTProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: - AIPPTGenerateTool._get_api_token(credentials, user_id='__dify_system__') + AIPPTGenerateTool._get_api_token(credentials, user_id="__dify_system__") except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/aippt/tools/aippt.py b/api/core/tools/provider/builtin/aippt/tools/aippt.py index 8d6883a3b1..7cee8f9f79 100644 --- a/api/core/tools/provider/builtin/aippt/tools/aippt.py +++ b/api/core/tools/provider/builtin/aippt/tools/aippt.py @@ -20,16 +20,16 @@ class AIPPTGenerateTool(BuiltinTool): A tool for generating a ppt """ - _api_base_url = URL('https://co.aippt.cn/api') + _api_base_url = URL("https://co.aippt.cn/api") _api_token_cache = {} - _api_token_cache_lock:Optional[Lock] = None + _api_token_cache_lock: Optional[Lock] = None _style_cache = {} - _style_cache_lock:Optional[Lock] = None + _style_cache_lock: Optional[Lock] = None _task = {} _task_type_map = { - 'auto': 1, - 'markdown': 7, + "auto": 1, + "markdown": 7, } def __init__(self, **kwargs: Any): @@ -48,65 +48,55 @@ class AIPPTGenerateTool(BuiltinTool): Returns: ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages. """ - title = tool_parameters.get('title', '') + title = tool_parameters.get("title", "") if not title: - return self.create_text_message('Please provide a title for the ppt') - - model = tool_parameters.get('model', 'aippt') + return self.create_text_message("Please provide a title for the ppt") + + model = tool_parameters.get("model", "aippt") if not model: - return self.create_text_message('Please provide a model for the ppt') - - outline = tool_parameters.get('outline', '') + return self.create_text_message("Please provide a model for the ppt") + + outline = tool_parameters.get("outline", "") # create task task_id = self._create_task( - type=self._task_type_map['auto' if not outline else 'markdown'], + type=self._task_type_map["auto" if not outline else "markdown"], title=title, content=outline, - user_id=user_id + user_id=user_id, ) # get suit - color = tool_parameters.get('color') - style = tool_parameters.get('style') + color = tool_parameters.get("color") + style = tool_parameters.get("style") - if color == '__default__': - color_id = '' + if color == "__default__": + color_id = "" else: - color_id = int(color.split('-')[1]) + color_id = int(color.split("-")[1]) - if style == '__default__': - style_id = '' + if style == "__default__": + style_id = "" else: - style_id = int(style.split('-')[1]) + style_id = int(style.split("-")[1]) suit_id = self._get_suit(style_id=style_id, colour_id=color_id) # generate outline if not outline: - self._generate_outline( - task_id=task_id, - model=model, - user_id=user_id - ) + self._generate_outline(task_id=task_id, model=model, user_id=user_id) # generate content - self._generate_content( - task_id=task_id, - model=model, - user_id=user_id - ) + self._generate_content(task_id=task_id, model=model, user_id=user_id) # generate ppt - _, ppt_url = self._generate_ppt( - task_id=task_id, - suit_id=suit_id, - user_id=user_id - ) + _, ppt_url = self._generate_ppt(task_id=task_id, suit_id=suit_id, user_id=user_id) - return self.create_text_message('''the ppt has been created successfully,''' - f'''the ppt url is {ppt_url}''' - '''please give the ppt url to user and direct user to download it.''') + return self.create_text_message( + """the ppt has been created successfully,""" + f"""the ppt url is {ppt_url}""" + """please give the ppt url to user and direct user to download it.""" + ) def _create_task(self, type: int, title: str, content: str, user_id: str) -> str: """ @@ -119,129 +109,121 @@ class AIPPTGenerateTool(BuiltinTool): :return: the task ID """ headers = { - 'x-channel': '', - 'x-api-key': self.runtime.credentials['aippt_access_key'], - 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-channel": "", + "x-api-key": self.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), } response = post( - str(self._api_base_url / 'ai' / 'chat' / 'v2' / 'task'), + str(self._api_base_url / "ai" / "chat" / "v2" / "task"), headers=headers, - files={ - 'type': ('', str(type)), - 'title': ('', title), - 'content': ('', content) - } + files={"type": ("", str(type)), "title": ("", title), "content": ("", content)}, ) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to create task: {response.get("msg")}') - return response.get('data', {}).get('id') - + return response.get("data", {}).get("id") + def _generate_outline(self, task_id: str, model: str, user_id: str) -> str: - api_url = self._api_base_url / 'ai' / 'chat' / 'outline' if model == 'aippt' else \ - self._api_base_url / 'ai' / 'chat' / 'wx' / 'outline' - api_url %= {'task_id': task_id} + api_url = ( + self._api_base_url / "ai" / "chat" / "outline" + if model == "aippt" + else self._api_base_url / "ai" / "chat" / "wx" / "outline" + ) + api_url %= {"task_id": task_id} headers = { - 'x-channel': '', - 'x-api-key': self.runtime.credentials['aippt_access_key'], - 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-channel": "", + "x-api-key": self.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), } - response = requests_get( - url=api_url, - headers=headers, - stream=True, - timeout=(10, 60) - ) + response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60)) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - - outline = '' - for chunk in response.iter_lines(delimiter=b'\n\n'): + raise Exception(f"Failed to connect to aippt: {response.text}") + + outline = "" + for chunk in response.iter_lines(delimiter=b"\n\n"): if not chunk: continue - - event = '' - lines = chunk.decode('utf-8').split('\n') + + event = "" + lines = chunk.decode("utf-8").split("\n") for line in lines: - if line.startswith('event:'): + if line.startswith("event:"): event = line[6:] - elif line.startswith('data:'): + elif line.startswith("data:"): data = line[5:] - if event == 'message': + if event == "message": try: data = json_loads(data) - outline += data.get('content', '') + outline += data.get("content", "") except Exception as e: pass - elif event == 'close': + elif event == "close": break - elif event == 'error' or event == 'filter': - raise Exception(f'Failed to generate outline: {data}') - + elif event == "error" or event == "filter": + raise Exception(f"Failed to generate outline: {data}") + return outline - + def _generate_content(self, task_id: str, model: str, user_id: str) -> str: - api_url = self._api_base_url / 'ai' / 'chat' / 'content' if model == 'aippt' else \ - self._api_base_url / 'ai' / 'chat' / 'wx' / 'content' - api_url %= {'task_id': task_id} + api_url = ( + self._api_base_url / "ai" / "chat" / "content" + if model == "aippt" + else self._api_base_url / "ai" / "chat" / "wx" / "content" + ) + api_url %= {"task_id": task_id} headers = { - 'x-channel': '', - 'x-api-key': self.runtime.credentials['aippt_access_key'], - 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-channel": "", + "x-api-key": self.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), } - response = requests_get( - url=api_url, - headers=headers, - stream=True, - timeout=(10, 60) - ) + response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60)) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - - if model == 'aippt': - content = '' - for chunk in response.iter_lines(delimiter=b'\n\n'): + raise Exception(f"Failed to connect to aippt: {response.text}") + + if model == "aippt": + content = "" + for chunk in response.iter_lines(delimiter=b"\n\n"): if not chunk: continue - - event = '' - lines = chunk.decode('utf-8').split('\n') + + event = "" + lines = chunk.decode("utf-8").split("\n") for line in lines: - if line.startswith('event:'): + if line.startswith("event:"): event = line[6:] - elif line.startswith('data:'): + elif line.startswith("data:"): data = line[5:] - if event == 'message': + if event == "message": try: data = json_loads(data) - content += data.get('content', '') + content += data.get("content", "") except Exception as e: pass - elif event == 'close': + elif event == "close": break - elif event == 'error' or event == 'filter': - raise Exception(f'Failed to generate content: {data}') - + elif event == "error" or event == "filter": + raise Exception(f"Failed to generate content: {data}") + return content - elif model == 'wenxin': + elif model == "wenxin": response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to generate content: {response.get("msg")}') - - return response.get('data', '') - - return '' + + return response.get("data", "") + + return "" def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]: """ @@ -252,83 +234,73 @@ class AIPPTGenerateTool(BuiltinTool): :return: the cover url of the ppt and the ppt url """ headers = { - 'x-channel': '', - 'x-api-key': self.runtime.credentials['aippt_access_key'], - 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-channel": "", + "x-api-key": self.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), } response = post( - str(self._api_base_url / 'design' / 'v2' / 'save'), + str(self._api_base_url / "design" / "v2" / "save"), headers=headers, - data={ - 'task_id': task_id, - 'template_id': suit_id - } + data={"task_id": task_id, "template_id": suit_id}, ) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to generate ppt: {response.get("msg")}') - - id = response.get('data', {}).get('id') - cover_url = response.get('data', {}).get('cover_url') + + id = response.get("data", {}).get("id") + cover_url = response.get("data", {}).get("cover_url") response = post( - str(self._api_base_url / 'download' / 'export' / 'file'), + str(self._api_base_url / "download" / "export" / "file"), headers=headers, - data={ - 'id': id, - 'format': 'ppt', - 'files_to_zip': False, - 'edit': True - } + data={"id": id, "format": "ppt", "files_to_zip": False, "edit": True}, ) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to generate ppt: {response.get("msg")}') - - export_code = response.get('data') + + export_code = response.get("data") if not export_code: - raise Exception('Failed to generate ppt, the export code is empty') - + raise Exception("Failed to generate ppt, the export code is empty") + current_iteration = 0 while current_iteration < 50: # get ppt url response = post( - str(self._api_base_url / 'download' / 'export' / 'file' / 'result'), + str(self._api_base_url / "download" / "export" / "file" / "result"), headers=headers, - data={ - 'task_key': export_code - } + data={"task_key": export_code}, ) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to generate ppt: {response.get("msg")}') - - if response.get('msg') == '导出中': + + if response.get("msg") == "导出中": current_iteration += 1 sleep(2) continue - - ppt_url = response.get('data', []) + + ppt_url = response.get("data", []) if len(ppt_url) == 0: - raise Exception('Failed to generate ppt, the ppt url is empty') - + raise Exception("Failed to generate ppt, the ppt url is empty") + return cover_url, ppt_url[0] - - raise Exception('Failed to generate ppt, the export is timeout') - + + raise Exception("Failed to generate ppt, the export is timeout") + @classmethod def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str: """ @@ -337,53 +309,43 @@ class AIPPTGenerateTool(BuiltinTool): :param credentials: the credentials :return: the API token """ - access_key = credentials['aippt_access_key'] - secret_key = credentials['aippt_secret_key'] + access_key = credentials["aippt_access_key"] + secret_key = credentials["aippt_secret_key"] - cache_key = f'{access_key}#@#{user_id}' + cache_key = f"{access_key}#@#{user_id}" with cls._api_token_cache_lock: # clear expired tokens now = time() for key in list(cls._api_token_cache.keys()): - if cls._api_token_cache[key]['expire'] < now: + if cls._api_token_cache[key]["expire"] < now: del cls._api_token_cache[key] if cache_key in cls._api_token_cache: - return cls._api_token_cache[cache_key]['token'] - + return cls._api_token_cache[cache_key]["token"] + # get token headers = { - 'x-api-key': access_key, - 'x-timestamp': str(int(now)), - 'x-signature': cls._calculate_sign(access_key, secret_key, int(now)) + "x-api-key": access_key, + "x-timestamp": str(int(now)), + "x-signature": cls._calculate_sign(access_key, secret_key, int(now)), } - param = { - 'uid': user_id, - 'channel': '' - } + param = {"uid": user_id, "channel": ""} - response = get( - str(cls._api_base_url / 'grant' / 'token'), - params=param, - headers=headers - ) + response = get(str(cls._api_base_url / "grant" / "token"), params=param, headers=headers) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') + raise Exception(f"Failed to connect to aippt: {response.text}") response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to connect to aippt: {response.get("msg")}') - - token = response.get('data', {}).get('token') - expire = response.get('data', {}).get('time_expire') + + token = response.get("data", {}).get("token") + expire = response.get("data", {}).get("time_expire") with cls._api_token_cache_lock: - cls._api_token_cache[cache_key] = { - 'token': token, - 'expire': now + expire - } + cls._api_token_cache[cache_key] = {"token": token, "expire": now + expire} return token @@ -391,11 +353,9 @@ class AIPPTGenerateTool(BuiltinTool): def _calculate_sign(cls, access_key: str, secret_key: str, timestamp: int) -> str: return b64encode( hmac_new( - key=secret_key.encode('utf-8'), - msg=f'GET@/api/grant/token/@{timestamp}'.encode(), - digestmod=sha1 + key=secret_key.encode("utf-8"), msg=f"GET@/api/grant/token/@{timestamp}".encode(), digestmod=sha1 ).digest() - ).decode('utf-8') + ).decode("utf-8") @classmethod def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]: @@ -408,47 +368,46 @@ class AIPPTGenerateTool(BuiltinTool): # clear expired styles now = time() for key in list(cls._style_cache.keys()): - if cls._style_cache[key]['expire'] < now: + if cls._style_cache[key]["expire"] < now: del cls._style_cache[key] key = f'{credentials["aippt_access_key"]}#@#{user_id}' if key in cls._style_cache: - return cls._style_cache[key]['colors'], cls._style_cache[key]['styles'] + return cls._style_cache[key]["colors"], cls._style_cache[key]["styles"] headers = { - 'x-channel': '', - 'x-api-key': credentials['aippt_access_key'], - 'x-token': cls._get_api_token(credentials=credentials, user_id=user_id) + "x-channel": "", + "x-api-key": credentials["aippt_access_key"], + "x-token": cls._get_api_token(credentials=credentials, user_id=user_id), } - response = get( - str(cls._api_base_url / 'template_component' / 'suit' / 'select'), - headers=headers - ) + response = get(str(cls._api_base_url / "template_component" / "suit" / "select"), headers=headers) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to connect to aippt: {response.get("msg")}') - - colors = [{ - 'id': f'id-{item.get("id")}', - 'name': item.get('name'), - 'en_name': item.get('en_name', item.get('name')), - } for item in response.get('data', {}).get('colour') or []] - styles = [{ - 'id': f'id-{item.get("id")}', - 'name': item.get('title'), - } for item in response.get('data', {}).get('suit_style') or []] + + colors = [ + { + "id": f'id-{item.get("id")}', + "name": item.get("name"), + "en_name": item.get("en_name", item.get("name")), + } + for item in response.get("data", {}).get("colour") or [] + ] + styles = [ + { + "id": f'id-{item.get("id")}', + "name": item.get("title"), + } + for item in response.get("data", {}).get("suit_style") or [] + ] with cls._style_cache_lock: - cls._style_cache[key] = { - 'colors': colors, - 'styles': styles, - 'expire': now + 60 * 60 - } + cls._style_cache[key] = {"colors": colors, "styles": styles, "expire": now + 60 * 60} return colors, styles @@ -459,44 +418,39 @@ class AIPPTGenerateTool(BuiltinTool): :param credentials: the credentials :return: Tuple[list[dict[id, color]], list[dict[id, style]] """ - if not self.runtime.credentials.get('aippt_access_key') or not self.runtime.credentials.get('aippt_secret_key'): - raise Exception('Please provide aippt credentials') + if not self.runtime.credentials.get("aippt_access_key") or not self.runtime.credentials.get("aippt_secret_key"): + raise Exception("Please provide aippt credentials") return self._get_styles(credentials=self.runtime.credentials, user_id=user_id) - + def _get_suit(self, style_id: int, colour_id: int) -> int: """ Get suit """ headers = { - 'x-channel': '', - 'x-api-key': self.runtime.credentials['aippt_access_key'], - 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id='__dify_system__') + "x-channel": "", + "x-api-key": self.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id="__dify_system__"), } response = get( - str(self._api_base_url / 'template_component' / 'suit' / 'search'), + str(self._api_base_url / "template_component" / "suit" / "search"), headers=headers, - params={ - 'style_id': style_id, - 'colour_id': colour_id, - 'page': 1, - 'page_size': 1 - } + params={"style_id": style_id, "colour_id": colour_id, "page": 1, "page_size": 1}, ) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to connect to aippt: {response.get("msg")}') - - if len(response.get('data', {}).get('list') or []) > 0: - return response.get('data', {}).get('list')[0].get('id') - - raise Exception('Failed to get suit, the suit does not exist, please check the style and color') - + + if len(response.get("data", {}).get("list") or []) > 0: + return response.get("data", {}).get("list")[0].get("id") + + raise Exception("Failed to get suit, the suit does not exist, please check the style and color") + def get_runtime_parameters(self) -> list[ToolParameter]: """ Get runtime parameters @@ -504,43 +458,40 @@ class AIPPTGenerateTool(BuiltinTool): Override this method to add runtime parameters to the tool. """ try: - colors, styles = self.get_styles(user_id='__dify_system__') + colors, styles = self.get_styles(user_id="__dify_system__") except Exception as e: - colors, styles = [ - {'id': '-1', 'name': '__default__', 'en_name': '__default__'} - ], [ - {'id': '-1', 'name': '__default__', 'en_name': '__default__'} - ] + colors, styles = ( + [{"id": "-1", "name": "__default__", "en_name": "__default__"}], + [{"id": "-1", "name": "__default__", "en_name": "__default__"}], + ) return [ ToolParameter( - name='color', - label=I18nObject(zh_Hans='颜色', en_US='Color'), - human_description=I18nObject(zh_Hans='颜色', en_US='Color'), + name="color", + label=I18nObject(zh_Hans="颜色", en_US="Color"), + human_description=I18nObject(zh_Hans="颜色", en_US="Color"), type=ToolParameter.ToolParameterType.SELECT, form=ToolParameter.ToolParameterForm.FORM, required=False, - default=colors[0]['id'], + default=colors[0]["id"], options=[ ToolParameterOption( - value=color['id'], - label=I18nObject(zh_Hans=color['name'], en_US=color['en_name']) - ) for color in colors - ] + value=color["id"], label=I18nObject(zh_Hans=color["name"], en_US=color["en_name"]) + ) + for color in colors + ], ), ToolParameter( - name='style', - label=I18nObject(zh_Hans='风格', en_US='Style'), - human_description=I18nObject(zh_Hans='风格', en_US='Style'), + name="style", + label=I18nObject(zh_Hans="风格", en_US="Style"), + human_description=I18nObject(zh_Hans="风格", en_US="Style"), type=ToolParameter.ToolParameterType.SELECT, form=ToolParameter.ToolParameterForm.FORM, required=False, - default=styles[0]['id'], + default=styles[0]["id"], options=[ - ToolParameterOption( - value=style['id'], - label=I18nObject(zh_Hans=style['name'], en_US=style['name']) - ) for style in styles - ] + ToolParameterOption(value=style["id"], label=I18nObject(zh_Hans=style["name"], en_US=style["name"])) + for style in styles + ], ), - ] \ No newline at end of file + ] diff --git a/api/core/tools/provider/builtin/alphavantage/_assets/icon.svg b/api/core/tools/provider/builtin/alphavantage/_assets/icon.svg new file mode 100644 index 0000000000..785432943b --- /dev/null +++ b/api/core/tools/provider/builtin/alphavantage/_assets/icon.svg @@ -0,0 +1,7 @@ + + + 形状结合 + + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/alphavantage/alphavantage.py b/api/core/tools/provider/builtin/alphavantage/alphavantage.py new file mode 100644 index 0000000000..a84630e5aa --- /dev/null +++ b/api/core/tools/provider/builtin/alphavantage/alphavantage.py @@ -0,0 +1,22 @@ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.alphavantage.tools.query_stock import QueryStockTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class AlphaVantageProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + try: + QueryStockTool().fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ).invoke( + user_id="", + tool_parameters={ + "code": "AAPL", # Apple Inc. + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/alphavantage/alphavantage.yaml b/api/core/tools/provider/builtin/alphavantage/alphavantage.yaml new file mode 100644 index 0000000000..710510cfd8 --- /dev/null +++ b/api/core/tools/provider/builtin/alphavantage/alphavantage.yaml @@ -0,0 +1,31 @@ +identity: + author: zhuhao + name: alphavantage + label: + en_US: AlphaVantage + zh_Hans: AlphaVantage + pt_BR: AlphaVantage + description: + en_US: AlphaVantage is an online platform that provides financial market data and APIs, making it convenient for individual investors and developers to access stock quotes, technical indicators, and stock analysis. + zh_Hans: AlphaVantage是一个在线平台,它提供金融市场数据和API,便于个人投资者和开发者获取股票报价、技术指标和股票分析。 + pt_BR: AlphaVantage is an online platform that provides financial market data and APIs, making it convenient for individual investors and developers to access stock quotes, technical indicators, and stock analysis. + icon: icon.svg + tags: + - finance +credentials_for_provider: + api_key: + type: secret-input + required: true + label: + en_US: AlphaVantage API key + zh_Hans: AlphaVantage API key + pt_BR: AlphaVantage API key + placeholder: + en_US: Please input your AlphaVantage API key + zh_Hans: 请输入你的 AlphaVantage API key + pt_BR: Please input your AlphaVantage API key + help: + en_US: Get your AlphaVantage API key from AlphaVantage + zh_Hans: 从 AlphaVantage 获取您的 AlphaVantage API key + pt_BR: Get your AlphaVantage API key from AlphaVantage + url: https://www.alphavantage.co/support/#api-key diff --git a/api/core/tools/provider/builtin/alphavantage/tools/query_stock.py b/api/core/tools/provider/builtin/alphavantage/tools/query_stock.py new file mode 100644 index 0000000000..d06611acd0 --- /dev/null +++ b/api/core/tools/provider/builtin/alphavantage/tools/query_stock.py @@ -0,0 +1,48 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +ALPHAVANTAGE_API_URL = "https://www.alphavantage.co/query" + + +class QueryStockTool(BuiltinTool): + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + stock_code = tool_parameters.get("code", "") + if not stock_code: + return self.create_text_message("Please tell me your stock code") + + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): + return self.create_text_message("Alpha Vantage API key is required.") + + params = { + "function": "TIME_SERIES_DAILY", + "symbol": stock_code, + "outputsize": "compact", + "datatype": "json", + "apikey": self.runtime.credentials["api_key"], + } + response = requests.get(url=ALPHAVANTAGE_API_URL, params=params) + response.raise_for_status() + result = self._handle_response(response.json()) + return self.create_json_message(result) + + def _handle_response(self, response: dict[str, Any]) -> dict[str, Any]: + result = response.get("Time Series (Daily)", {}) + if not result: + return {} + stock_result = {} + for k, v in result.items(): + stock_result[k] = {} + stock_result[k]["open"] = v.get("1. open") + stock_result[k]["high"] = v.get("2. high") + stock_result[k]["low"] = v.get("3. low") + stock_result[k]["close"] = v.get("4. close") + stock_result[k]["volume"] = v.get("5. volume") + return stock_result diff --git a/api/core/tools/provider/builtin/alphavantage/tools/query_stock.yaml b/api/core/tools/provider/builtin/alphavantage/tools/query_stock.yaml new file mode 100644 index 0000000000..d89f34e373 --- /dev/null +++ b/api/core/tools/provider/builtin/alphavantage/tools/query_stock.yaml @@ -0,0 +1,27 @@ +identity: + name: query_stock + author: zhuhao + label: + en_US: query_stock + zh_Hans: query_stock + pt_BR: query_stock +description: + human: + en_US: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol. + zh_Hans: 获取指定股票代码的每日开盘价、每日最高价、每日最低价、每日收盘价和每日交易量等信息。 + pt_BR: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol + llm: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol +parameters: + - name: code + type: string + required: true + label: + en_US: stock code + zh_Hans: 股票代码 + pt_BR: stock code + human_description: + en_US: stock code + zh_Hans: 股票代码 + pt_BR: stock code + llm_description: stock code for query from alphavantage + form: llm diff --git a/api/core/tools/provider/builtin/arxiv/arxiv.py b/api/core/tools/provider/builtin/arxiv/arxiv.py index 707fc69be3..ebb2d1a8c4 100644 --- a/api/core/tools/provider/builtin/arxiv/arxiv.py +++ b/api/core/tools/provider/builtin/arxiv/arxiv.py @@ -11,11 +11,10 @@ class ArxivProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "John Doe", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py index ce28373880..98d82c233e 100644 --- a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py +++ b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py @@ -8,6 +8,8 @@ from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool logger = logging.getLogger(__name__) + + class ArxivAPIWrapper(BaseModel): """Wrapper around ArxivAPI. @@ -86,11 +88,13 @@ class ArxivAPIWrapper(BaseModel): class ArxivSearchInput(BaseModel): query: str = Field(..., description="Search query.") - + + class ArxivSearchTool(BuiltinTool): """ A tool for searching articles on Arxiv. """ + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: """ Invokes the Arxiv search tool with the given user ID and tool parameters. @@ -102,13 +106,13 @@ class ArxivSearchTool(BuiltinTool): Returns: ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages. """ - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Please input query') - + return self.create_text_message("Please input query") + arxiv = ArxivAPIWrapper() - + response = arxiv.run(query) - + return self.create_text_message(self.summary(user_id=user_id, content=response)) diff --git a/api/core/tools/provider/builtin/aws/aws.py b/api/core/tools/provider/builtin/aws/aws.py index 13ede96015..f81b5dbd27 100644 --- a/api/core/tools/provider/builtin/aws/aws.py +++ b/api/core/tools/provider/builtin/aws/aws.py @@ -11,15 +11,14 @@ class SageMakerProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ - "sagemaker_endpoint" : "", + "sagemaker_endpoint": "", "query": "misaka mikoto", - "candidate_texts" : "hello$$$hello world", - "topk" : 5, - "aws_region" : "" + "candidate_texts": "hello$$$hello world", + "topk": 5, + "aws_region": "", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py index 9c006733bd..d6a65b1708 100644 --- a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py +++ b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py @@ -3,6 +3,7 @@ import logging from typing import Any, Union import boto3 +from botocore.exceptions import BotoCoreError from pydantic import BaseModel, Field from core.tools.entities.tool_entities import ToolInvokeMessage @@ -11,40 +12,43 @@ from core.tools.tool.builtin_tool import BuiltinTool logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + class GuardrailParameters(BaseModel): guardrail_id: str = Field(..., description="The identifier of the guardrail") guardrail_version: str = Field(..., description="The version of the guardrail") source: str = Field(..., description="The source of the content") text: str = Field(..., description="The text to apply the guardrail to") - aws_region: str = Field(default="us-east-1", description="AWS region for the Bedrock client") + aws_region: str = Field(..., description="AWS region for the Bedrock client") + class ApplyGuardrailTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the ApplyGuardrail tool """ try: # Validate and parse input parameters params = GuardrailParameters(**tool_parameters) - + # Initialize AWS client - bedrock_client = boto3.client('bedrock-runtime', region_name=params.aws_region) + bedrock_client = boto3.client("bedrock-runtime", region_name=params.aws_region) # Apply guardrail response = bedrock_client.apply_guardrail( guardrailIdentifier=params.guardrail_id, guardrailVersion=params.guardrail_version, source=params.source, - content=[{"text": {"text": params.text}}] + content=[{"text": {"text": params.text}}], ) + logger.info(f"Raw response from AWS: {json.dumps(response, indent=2)}") + # Check for empty response if not response: return self.create_text_message(text="Received empty response from AWS Bedrock.") - + # Process the result action = response.get("action", "No action specified") outputs = response.get("outputs", []) @@ -55,9 +59,11 @@ class ApplyGuardrailTool(BuiltinTool): formatted_assessments = [] for assessment in assessments: for policy_type, policy_data in assessment.items(): - if isinstance(policy_data, dict) and 'topics' in policy_data: - for topic in policy_data['topics']: - formatted_assessments.append(f"Policy: {policy_type}, Topic: {topic['name']}, Type: {topic['type']}, Action: {topic['action']}") + if isinstance(policy_data, dict) and "topics" in policy_data: + for topic in policy_data["topics"]: + formatted_assessments.append( + f"Policy: {policy_type}, Topic: {topic['name']}, Type: {topic['type']}, Action: {topic['action']}" + ) else: formatted_assessments.append(f"Policy: {policy_type}, Data: {policy_data}") @@ -65,19 +71,19 @@ class ApplyGuardrailTool(BuiltinTool): result += f"Output: {output}\n " if formatted_assessments: result += "Assessments:\n " + "\n ".join(formatted_assessments) + "\n " -# result += f"Full response: {json.dumps(response, indent=2, ensure_ascii=False)}" + # result += f"Full response: {json.dumps(response, indent=2, ensure_ascii=False)}" return self.create_text_message(text=result) - except boto3.exceptions.BotoCoreError as e: - error_message = f'AWS service error: {str(e)}' + except BotoCoreError as e: + error_message = f"AWS service error: {str(e)}" logger.error(error_message, exc_info=True) return self.create_text_message(text=error_message) except json.JSONDecodeError as e: - error_message = f'JSON parsing error: {str(e)}' + error_message = f"JSON parsing error: {str(e)}" logger.error(error_message, exc_info=True) return self.create_text_message(text=error_message) except Exception as e: - error_message = f'An unexpected error occurred: {str(e)}' + error_message = f"An unexpected error occurred: {str(e)}" logger.error(error_message, exc_info=True) return self.create_text_message(text=error_message) diff --git a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.yaml b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.yaml index 2b7c8abb44..66044e4ea8 100644 --- a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.yaml +++ b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.yaml @@ -54,3 +54,14 @@ parameters: zh_Hans: 用于请求护栏审查的内容,可以是用户输入或 LLM 输出。 llm_description: The content used for requesting guardrail review, which can be either user input or LLM output. form: llm + - name: aws_region + type: string + required: true + label: + en_US: AWS Region + zh_Hans: AWS 区域 + human_description: + en_US: Please enter the AWS region for the Bedrock client, for example 'us-east-1'. + zh_Hans: 请输入 Bedrock 客户端的 AWS 区域,例如 'us-east-1'。 + llm_description: Please enter the AWS region for the Bedrock client, for example 'us-east-1'. + form: form diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py index 005ba3deb5..48755753ac 100644 --- a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py +++ b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py @@ -11,78 +11,81 @@ class LambdaTranslateUtilsTool(BuiltinTool): lambda_client: Any = None def _invoke_lambda(self, text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name): - msg = { - "src_content":text_content, - "src_lang": src_lang, - "dest_lang":dest_lang, + msg = { + "src_content": text_content, + "src_lang": src_lang, + "dest_lang": dest_lang, "dictionary_id": dictionary_name, - "request_type" : request_type, - "model_id" : model_id + "request_type": request_type, + "model_id": model_id, } - invoke_response = self.lambda_client.invoke(FunctionName=lambda_name, - InvocationType='RequestResponse', - Payload=json.dumps(msg)) - response_body = invoke_response['Payload'] + invoke_response = self.lambda_client.invoke( + FunctionName=lambda_name, InvocationType="RequestResponse", Payload=json.dumps(msg) + ) + response_body = invoke_response["Payload"] response_str = response_body.read().decode("unicode_escape") return response_str - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ line = 0 try: if not self.lambda_client: - aws_region = tool_parameters.get('aws_region') + aws_region = tool_parameters.get("aws_region") if aws_region: self.lambda_client = boto3.client("lambda", region_name=aws_region) else: self.lambda_client = boto3.client("lambda") line = 1 - text_content = tool_parameters.get('text_content', '') + text_content = tool_parameters.get("text_content", "") if not text_content: - return self.create_text_message('Please input text_content') - + return self.create_text_message("Please input text_content") + line = 2 - src_lang = tool_parameters.get('src_lang', '') + src_lang = tool_parameters.get("src_lang", "") if not src_lang: - return self.create_text_message('Please input src_lang') - + return self.create_text_message("Please input src_lang") + line = 3 - dest_lang = tool_parameters.get('dest_lang', '') + dest_lang = tool_parameters.get("dest_lang", "") if not dest_lang: - return self.create_text_message('Please input dest_lang') - + return self.create_text_message("Please input dest_lang") + line = 4 - lambda_name = tool_parameters.get('lambda_name', '') + lambda_name = tool_parameters.get("lambda_name", "") if not lambda_name: - return self.create_text_message('Please input lambda_name') - + return self.create_text_message("Please input lambda_name") + line = 5 - request_type = tool_parameters.get('request_type', '') + request_type = tool_parameters.get("request_type", "") if not request_type: - return self.create_text_message('Please input request_type') - + return self.create_text_message("Please input request_type") + line = 6 - model_id = tool_parameters.get('model_id', '') + model_id = tool_parameters.get("model_id", "") if not model_id: - return self.create_text_message('Please input model_id') + return self.create_text_message("Please input model_id") line = 7 - dictionary_name = tool_parameters.get('dictionary_name', '') + dictionary_name = tool_parameters.get("dictionary_name", "") if not dictionary_name: - return self.create_text_message('Please input dictionary_name') - - result = self._invoke_lambda(text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name) + return self.create_text_message("Please input dictionary_name") + + result = self._invoke_lambda( + text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name + ) return self.create_text_message(text=result) except Exception as e: - return self.create_text_message(f'Exception {str(e)}, line : {line}') + return self.create_text_message(f"Exception {str(e)}, line : {line}") diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py new file mode 100644 index 0000000000..f43f3b6fe0 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py @@ -0,0 +1,70 @@ +import json +import logging +from typing import Any, Union + +import boto3 + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +console_handler = logging.StreamHandler() +logger.addHandler(console_handler) + + +class LambdaYamlToJsonTool(BuiltinTool): + lambda_client: Any = None + + def _invoke_lambda(self, lambda_name: str, yaml_content: str) -> str: + msg = {"body": yaml_content} + logger.info(json.dumps(msg)) + + invoke_response = self.lambda_client.invoke( + FunctionName=lambda_name, InvocationType="RequestResponse", Payload=json.dumps(msg) + ) + response_body = invoke_response["Payload"] + + response_str = response_body.read().decode("utf-8") + resp_json = json.loads(response_str) + + logger.info(resp_json) + if resp_json["statusCode"] != 200: + raise Exception(f"Invalid status code: {response_str}") + + return resp_json["body"] + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + try: + if not self.lambda_client: + aws_region = tool_parameters.get("aws_region") # todo: move aws_region out, and update client region + if aws_region: + self.lambda_client = boto3.client("lambda", region_name=aws_region) + else: + self.lambda_client = boto3.client("lambda") + + yaml_content = tool_parameters.get("yaml_content", "") + if not yaml_content: + return self.create_text_message("Please input yaml_content") + + lambda_name = tool_parameters.get("lambda_name", "") + if not lambda_name: + return self.create_text_message("Please input lambda_name") + logger.debug(f"{json.dumps(tool_parameters, indent=2, ensure_ascii=False)}") + + result = self._invoke_lambda(lambda_name, yaml_content) + logger.debug(result) + + return self.create_text_message(result) + except Exception as e: + return self.create_text_message(f"Exception: {str(e)}") + + console_handler.flush() diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.yaml b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.yaml new file mode 100644 index 0000000000..919c285348 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.yaml @@ -0,0 +1,53 @@ +identity: + name: lambda_yaml_to_json + author: AWS + label: + en_US: LambdaYamlToJson + zh_Hans: LambdaYamlToJson + pt_BR: LambdaYamlToJson + icon: icon.svg +description: + human: + en_US: A tool to convert yaml to json using AWS Lambda. + zh_Hans: 将 YAML 转为 JSON 的工具(通过AWS Lambda)。 + pt_BR: A tool to convert yaml to json using AWS Lambda. + llm: A tool to convert yaml to json. +parameters: + - name: yaml_content + type: string + required: true + label: + en_US: YAML content to convert for + zh_Hans: YAML 内容 + pt_BR: YAML content to convert for + human_description: + en_US: YAML content to convert for + zh_Hans: YAML 内容 + pt_BR: YAML content to convert for + llm_description: YAML content to convert for + form: llm + - name: aws_region + type: string + required: false + label: + en_US: region of lambda + zh_Hans: Lambda 所在的region + pt_BR: region of lambda + human_description: + en_US: region of lambda + zh_Hans: Lambda 所在的region + pt_BR: region of lambda + llm_description: region of lambda + form: form + - name: lambda_name + type: string + required: false + label: + en_US: name of lambda + zh_Hans: Lambda 名称 + pt_BR: name of lambda + human_description: + en_US: name of lambda + zh_Hans: Lambda 名称 + pt_BR: name of lambda + form: form diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py index d4bc446e5b..3c35b65e66 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py @@ -9,37 +9,33 @@ from core.tools.tool.builtin_tool import BuiltinTool class SageMakerReRankTool(BuiltinTool): sagemaker_client: Any = None - sagemaker_endpoint:str = None - topk:int = None + sagemaker_endpoint: str = None + topk: int = None - def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str): - inputs = [query_input]*len(docs) + def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str): + inputs = [query_input] * len(docs) response_model = self.sagemaker_client.invoke_endpoint( EndpointName=rerank_endpoint, - Body=json.dumps( - { - "inputs": inputs, - "docs": docs - } - ), + Body=json.dumps({"inputs": inputs, "docs": docs}), ContentType="application/json", ) - json_str = response_model['Body'].read().decode('utf8') + json_str = response_model["Body"].read().decode("utf8") json_obj = json.loads(json_str) - scores = json_obj['scores'] + scores = json_obj["scores"] return scores if isinstance(scores, list) else [scores] - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ line = 0 try: if not self.sagemaker_client: - aws_region = tool_parameters.get('aws_region') + aws_region = tool_parameters.get("aws_region") if aws_region: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) else: @@ -47,25 +43,25 @@ class SageMakerReRankTool(BuiltinTool): line = 1 if not self.sagemaker_endpoint: - self.sagemaker_endpoint = tool_parameters.get('sagemaker_endpoint') + self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint") line = 2 if not self.topk: - self.topk = tool_parameters.get('topk', 5) + self.topk = tool_parameters.get("topk", 5) line = 3 - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Please input query') - + return self.create_text_message("Please input query") + line = 4 - candidate_texts = tool_parameters.get('candidate_texts') + candidate_texts = tool_parameters.get("candidate_texts") if not candidate_texts: - return self.create_text_message('Please input candidate_texts') - + return self.create_text_message("Please input candidate_texts") + line = 5 candidate_docs = json.loads(candidate_texts) - docs = [ item.get('content') for item in candidate_docs ] + docs = [item.get("content") for item in candidate_docs] line = 6 scores = self._sagemaker_rerank(query_input=query, docs=docs, rerank_endpoint=self.sagemaker_endpoint) @@ -75,12 +71,10 @@ class SageMakerReRankTool(BuiltinTool): candidate_docs[idx]["score"] = scores[idx] line = 8 - sorted_candidate_docs = sorted(candidate_docs, key=lambda x: x['score'], reverse=True) + sorted_candidate_docs = sorted(candidate_docs, key=lambda x: x["score"], reverse=True) line = 9 - results_str = json.dumps(sorted_candidate_docs[:self.topk], ensure_ascii=False) - return self.create_text_message(text=results_str) - + return [self.create_json_message(res) for res in sorted_candidate_docs[: self.topk]] + except Exception as e: - return self.create_text_message(f'Exception {str(e)}, line : {line}') - \ No newline at end of file + return self.create_text_message(f"Exception {str(e)}, line : {line}") diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py new file mode 100644 index 0000000000..bceeaab745 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py @@ -0,0 +1,101 @@ +import json +from enum import Enum +from typing import Any, Union + +import boto3 + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class TTSModelType(Enum): + PresetVoice = "PresetVoice" + CloneVoice = "CloneVoice" + CloneVoice_CrossLingual = "CloneVoice_CrossLingual" + InstructVoice = "InstructVoice" + + +class SageMakerTTSTool(BuiltinTool): + sagemaker_client: Any = None + sagemaker_endpoint: str = None + s3_client: Any = None + comprehend_client: Any = None + + def _detect_lang_code(self, content: str, map_dict: dict = None): + map_dict = {"zh": "<|zh|>", "en": "<|en|>", "ja": "<|jp|>", "zh-TW": "<|yue|>", "ko": "<|ko|>"} + + response = self.comprehend_client.detect_dominant_language(Text=content) + language_code = response["Languages"][0]["LanguageCode"] + return map_dict.get(language_code, "<|zh|>") + + def _build_tts_payload( + self, + model_type: str, + content_text: str, + model_role: str, + prompt_text: str, + prompt_audio: str, + instruct_text: str, + ): + if model_type == TTSModelType.PresetVoice.value and model_role: + return {"tts_text": content_text, "role": model_role} + if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio: + return {"tts_text": content_text, "prompt_text": prompt_text, "prompt_audio": prompt_audio} + if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio: + lang_tag = self._detect_lang_code(content_text) + return {"tts_text": f"{content_text}", "prompt_audio": prompt_audio, "lang_tag": lang_tag} + if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role: + return {"tts_text": content_text, "role": model_role, "instruct_text": instruct_text} + + raise RuntimeError(f"Invalid params for {model_type}") + + def _invoke_sagemaker(self, payload: dict, endpoint: str): + response_model = self.sagemaker_client.invoke_endpoint( + EndpointName=endpoint, + Body=json.dumps(payload), + ContentType="application/json", + ) + json_str = response_model["Body"].read().decode("utf8") + json_obj = json.loads(json_str) + return json_obj + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + try: + if not self.sagemaker_client: + aws_region = tool_parameters.get("aws_region") + if aws_region: + self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + self.s3_client = boto3.client("s3", region_name=aws_region) + self.comprehend_client = boto3.client("comprehend", region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime") + self.s3_client = boto3.client("s3") + self.comprehend_client = boto3.client("comprehend") + + if not self.sagemaker_endpoint: + self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint") + + tts_text = tool_parameters.get("tts_text") + tts_infer_type = tool_parameters.get("tts_infer_type") + + voice = tool_parameters.get("voice") + mock_voice_audio = tool_parameters.get("mock_voice_audio") + mock_voice_text = tool_parameters.get("mock_voice_text") + voice_instruct_prompt = tool_parameters.get("voice_instruct_prompt") + payload = self._build_tts_payload( + tts_infer_type, tts_text, voice, mock_voice_text, mock_voice_audio, voice_instruct_prompt + ) + + result = self._invoke_sagemaker(payload, self.sagemaker_endpoint) + + return self.create_text_message(text=result["s3_presign_url"]) + + except Exception as e: + return self.create_text_message(f"Exception {str(e)}") diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.yaml b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.yaml new file mode 100644 index 0000000000..a6a61dd4aa --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.yaml @@ -0,0 +1,149 @@ +identity: + name: sagemaker_tts + author: AWS + label: + en_US: SagemakerTTS + zh_Hans: Sagemaker语音合成 + pt_BR: SagemakerTTS + icon: icon.svg +description: + human: + en_US: A tool for Speech synthesis - https://github.com/aws-samples/dify-aws-tool + zh_Hans: Sagemaker语音合成工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署脚本 + pt_BR: A tool for Speech synthesis. + llm: A tool for Speech synthesis. You can find deploy notebook on Github Repo - https://github.com/aws-samples/dify-aws-tool +parameters: + - name: sagemaker_endpoint + type: string + required: true + label: + en_US: sagemaker endpoint for tts + zh_Hans: 语音生成的SageMaker端点 + pt_BR: sagemaker endpoint for tts + human_description: + en_US: sagemaker endpoint for tts + zh_Hans: 语音生成的SageMaker端点 + pt_BR: sagemaker endpoint for tts + llm_description: sagemaker endpoint for tts + form: form + - name: tts_text + type: string + required: true + label: + en_US: tts text + zh_Hans: 语音合成原文 + pt_BR: tts text + human_description: + en_US: tts text + zh_Hans: 语音合成原文 + pt_BR: tts text + llm_description: tts text + form: llm + - name: tts_infer_type + type: select + required: false + label: + en_US: tts infer type + zh_Hans: 合成方式 + pt_BR: tts infer type + human_description: + en_US: tts infer type + zh_Hans: 合成方式 + pt_BR: tts infer type + llm_description: tts infer type + options: + - value: PresetVoice + label: + en_US: preset voice + zh_Hans: 预置音色 + - value: CloneVoice + label: + en_US: clone voice + zh_Hans: 克隆音色 + - value: CloneVoice_CrossLingual + label: + en_US: clone crossLingual voice + zh_Hans: 克隆音色(跨语言) + - value: InstructVoice + label: + en_US: instruct voice + zh_Hans: 指令音色 + form: form + - name: voice + type: select + required: false + label: + en_US: preset voice + zh_Hans: 预置音色 + pt_BR: preset voice + human_description: + en_US: preset voice + zh_Hans: 预置音色 + pt_BR: preset voice + llm_description: preset voice + options: + - value: 中文男 + label: + en_US: zh-cn male + zh_Hans: 中文男 + - value: 中文女 + label: + en_US: zh-cn female + zh_Hans: 中文女 + - value: 粤语女 + label: + en_US: zh-TW female + zh_Hans: 粤语女 + form: form + - name: mock_voice_audio + type: string + required: false + label: + en_US: clone voice link + zh_Hans: 克隆音频链接 + pt_BR: clone voice link + human_description: + en_US: clone voice link + zh_Hans: 克隆音频链接 + pt_BR: clone voice link + llm_description: clone voice link + form: llm + - name: mock_voice_text + type: string + required: false + label: + en_US: text of clone voice + zh_Hans: 克隆音频对应文本 + pt_BR: text of clone voice + human_description: + en_US: text of clone voice + zh_Hans: 克隆音频对应文本 + pt_BR: text of clone voice + llm_description: text of clone voice + form: llm + - name: voice_instruct_prompt + type: string + required: false + label: + en_US: instruct prompt for voice + zh_Hans: 音色指令文本 + pt_BR: instruct prompt for voice + human_description: + en_US: instruct prompt for voice + zh_Hans: 音色指令文本 + pt_BR: instruct prompt for voice + llm_description: instruct prompt for voice + form: llm + - name: aws_region + type: string + required: false + label: + en_US: region of sagemaker endpoint + zh_Hans: SageMaker 端点所在的region + pt_BR: region of sagemaker endpoint + human_description: + en_US: region of sagemaker endpoint + zh_Hans: SageMaker 端点所在的region + pt_BR: region of sagemaker endpoint + llm_description: region of sagemaker endpoint + form: form diff --git a/api/core/tools/provider/builtin/azuredalle/azuredalle.py b/api/core/tools/provider/builtin/azuredalle/azuredalle.py index 2981a54d3c..1fab0d03a2 100644 --- a/api/core/tools/provider/builtin/azuredalle/azuredalle.py +++ b/api/core/tools/provider/builtin/azuredalle/azuredalle.py @@ -13,12 +13,8 @@ class AzureDALLEProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "prompt": "cute girl, blue eyes, white hair, anime style", - "size": "square", - "n": 1 - }, + user_id="", + tool_parameters={"prompt": "cute girl, blue eyes, white hair, anime style", "size": "square", "n": 1}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py index 2ffdd38b72..7462824be1 100644 --- a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py @@ -9,47 +9,48 @@ from core.tools.tool.builtin_tool import BuiltinTool class DallE3Tool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ client = AzureOpenAI( - api_version=self.runtime.credentials['azure_openai_api_version'], - azure_endpoint=self.runtime.credentials['azure_openai_base_url'], - api_key=self.runtime.credentials['azure_openai_api_key'], + api_version=self.runtime.credentials["azure_openai_api_version"], + azure_endpoint=self.runtime.credentials["azure_openai_base_url"], + api_key=self.runtime.credentials["azure_openai_api_key"], ) SIZE_MAPPING = { - 'square': '1024x1024', - 'vertical': '1024x1792', - 'horizontal': '1792x1024', + "square": "1024x1024", + "vertical": "1024x1792", + "horizontal": "1792x1024", } # prompt - prompt = tool_parameters.get('prompt', '') + prompt = tool_parameters.get("prompt", "") if not prompt: - return self.create_text_message('Please input prompt') + return self.create_text_message("Please input prompt") # get size - size = SIZE_MAPPING[tool_parameters.get('size', 'square')] + size = SIZE_MAPPING[tool_parameters.get("size", "square")] # get n - n = tool_parameters.get('n', 1) + n = tool_parameters.get("n", 1) # get quality - quality = tool_parameters.get('quality', 'standard') - if quality not in ['standard', 'hd']: - return self.create_text_message('Invalid quality') + quality = tool_parameters.get("quality", "standard") + if quality not in ["standard", "hd"]: + return self.create_text_message("Invalid quality") # get style - style = tool_parameters.get('style', 'vivid') - if style not in ['natural', 'vivid']: - return self.create_text_message('Invalid style') + style = tool_parameters.get("style", "vivid") + if style not in ["natural", "vivid"]: + return self.create_text_message("Invalid style") # set extra body - seed_id = tool_parameters.get('seed_id', self._generate_random_id(8)) - extra_body = {'seed': seed_id} + seed_id = tool_parameters.get("seed_id", self._generate_random_id(8)) + extra_body = {"seed": seed_id} # call openapi dalle3 - model = self.runtime.credentials['azure_openai_api_model_name'] + model = self.runtime.credentials["azure_openai_api_model_name"] response = client.images.generate( prompt=prompt, model=model, @@ -58,21 +59,25 @@ class DallE3Tool(BuiltinTool): extra_body=extra_body, style=style, quality=quality, - response_format='b64_json' + response_format="b64_json", ) result = [] for image in response.data: - result.append(self.create_blob_message(blob=b64decode(image.b64_json), - meta={'mime_type': 'image/png'}, - save_as=self.VARIABLE_KEY.IMAGE.value)) - result.append(self.create_text_message(f'\nGenerate image source to Seed ID: {seed_id}')) + result.append( + self.create_blob_message( + blob=b64decode(image.b64_json), + meta={"mime_type": "image/png"}, + save_as=self.VariableKey.IMAGE.value, + ) + ) + result.append(self.create_text_message(f"\nGenerate image source to Seed ID: {seed_id}")) return result @staticmethod def _generate_random_id(length=8): - characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' - random_id = ''.join(random.choices(characters, k=length)) + characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + random_id = "".join(random.choices(characters, k=length)) return random_id diff --git a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py index f85a5ed472..8bed2c556c 100644 --- a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py +++ b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py @@ -8,142 +8,135 @@ from core.tools.tool.builtin_tool import BuiltinTool class BingSearchTool(BuiltinTool): - url: str = 'https://api.bing.microsoft.com/v7.0/search' + url: str = "https://api.bing.microsoft.com/v7.0/search" - def _invoke_bing(self, - user_id: str, - server_url: str, - subscription_key: str, query: str, limit: int, - result_type: str, market: str, lang: str, - filters: list[str]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke_bing( + self, + user_id: str, + server_url: str, + subscription_key: str, + query: str, + limit: int, + result_type: str, + market: str, + lang: str, + filters: list[str], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke bing search + invoke bing search """ - market_code = f'{lang}-{market}' - accept_language = f'{lang},{market_code};q=0.9' - headers = { - 'Ocp-Apim-Subscription-Key': subscription_key, - 'Accept-Language': accept_language - } + market_code = f"{lang}-{market}" + accept_language = f"{lang},{market_code};q=0.9" + headers = {"Ocp-Apim-Subscription-Key": subscription_key, "Accept-Language": accept_language} query = quote(query) server_url = f'{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}' response = get(server_url, headers=headers) if response.status_code != 200: - raise Exception(f'Error {response.status_code}: {response.text}') - - response = response.json() - search_results = response['webPages']['value'][:limit] if 'webPages' in response else [] - related_searches = response['relatedSearches']['value'] if 'relatedSearches' in response else [] - entities = response['entities']['value'] if 'entities' in response else [] - news = response['news']['value'] if 'news' in response else [] - computation = response['computation']['value'] if 'computation' in response else None + raise Exception(f"Error {response.status_code}: {response.text}") - if result_type == 'link': + response = response.json() + search_results = response["webPages"]["value"][:limit] if "webPages" in response else [] + related_searches = response["relatedSearches"]["value"] if "relatedSearches" in response else [] + entities = response["entities"]["value"] if "entities" in response else [] + news = response["news"]["value"] if "news" in response else [] + computation = response["computation"]["value"] if "computation" in response else None + + if result_type == "link": results = [] if search_results: for result in search_results: url = f': {result["url"]}' if "url" in result else "" - results.append(self.create_text_message( - text=f'{result["name"]}{url}' - )) - + results.append(self.create_text_message(text=f'{result["name"]}{url}')) if entities: for entity in entities: url = f': {entity["url"]}' if "url" in entity else "" - results.append(self.create_text_message( - text=f'{entity.get("name", "")}{url}' - )) + results.append(self.create_text_message(text=f'{entity.get("name", "")}{url}')) if news: for news_item in news: url = f': {news_item["url"]}' if "url" in news_item else "" - results.append(self.create_text_message( - text=f'{news_item.get("name", "")}{url}' - )) + results.append(self.create_text_message(text=f'{news_item.get("name", "")}{url}')) if related_searches: for related in related_searches: url = f': {related["displayText"]}' if "displayText" in related else "" - results.append(self.create_text_message( - text=f'{related.get("displayText", "")}{url}' - )) - + results.append(self.create_text_message(text=f'{related.get("displayText", "")}{url}')) + return results else: # construct text - text = '' + text = "" if search_results: for i, result in enumerate(search_results): - text += f'{i+1}: {result.get("name", "")} - {result.get("snippet", "")}\n' + text += f'{i + 1}: {result.get("name", "")} - {result.get("snippet", "")}\n' - if computation and 'expression' in computation and 'value' in computation: - text += '\nComputation:\n' + if computation and "expression" in computation and "value" in computation: + text += "\nComputation:\n" text += f'{computation["expression"]} = {computation["value"]}\n' if entities: - text += '\nEntities:\n' + text += "\nEntities:\n" for entity in entities: url = f'- {entity["url"]}' if "url" in entity else "" text += f'{entity.get("name", "")}{url}\n' if news: - text += '\nNews:\n' + text += "\nNews:\n" for news_item in news: url = f'- {news_item["url"]}' if "url" in news_item else "" text += f'{news_item.get("name", "")}{url}\n' if related_searches: - text += '\n\nRelated Searches:\n' + text += "\n\nRelated Searches:\n" for related in related_searches: url = f'- {related["webSearchUrl"]}' if "webSearchUrl" in related else "" text += f'{related.get("displayText", "")}{url}\n' return self.create_text_message(text=self.summary(user_id=user_id, content=text)) - def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dict[str, Any]) -> None: - key = credentials.get('subscription_key') + key = credentials.get("subscription_key") if not key: - raise Exception('subscription_key is required') - - server_url = credentials.get('server_url') + raise Exception("subscription_key is required") + + server_url = credentials.get("server_url") if not server_url: server_url = self.url - query = tool_parameters.get('query') + query = tool_parameters.get("query") if not query: - raise Exception('query is required') - - limit = min(tool_parameters.get('limit', 5), 10) - result_type = tool_parameters.get('result_type', 'text') or 'text' + raise Exception("query is required") - market = tool_parameters.get('market', 'US') - lang = tool_parameters.get('language', 'en') + limit = min(tool_parameters.get("limit", 5), 10) + result_type = tool_parameters.get("result_type", "text") or "text" + + market = tool_parameters.get("market", "US") + lang = tool_parameters.get("language", "en") filter = [] - if credentials.get('allow_entities', False): - filter.append('Entities') + if credentials.get("allow_entities", False): + filter.append("Entities") - if credentials.get('allow_computation', False): - filter.append('Computation') + if credentials.get("allow_computation", False): + filter.append("Computation") - if credentials.get('allow_news', False): - filter.append('News') + if credentials.get("allow_news", False): + filter.append("News") - if credentials.get('allow_related_searches', False): - filter.append('RelatedSearches') + if credentials.get("allow_related_searches", False): + filter.append("RelatedSearches") - if credentials.get('allow_web_pages', False): - filter.append('WebPages') + if credentials.get("allow_web_pages", False): + filter.append("WebPages") if not filter: - raise Exception('At least one filter is required') - + raise Exception("At least one filter is required") + self._invoke_bing( - user_id='test', + user_id="test", server_url=server_url, subscription_key=key, query=query, @@ -151,50 +144,51 @@ class BingSearchTool(BuiltinTool): result_type=result_type, market=market, lang=lang, - filters=filter + filters=filter, ) - - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - key = self.runtime.credentials.get('subscription_key', None) + key = self.runtime.credentials.get("subscription_key", None) if not key: - raise Exception('subscription_key is required') - - server_url = self.runtime.credentials.get('server_url', None) + raise Exception("subscription_key is required") + + server_url = self.runtime.credentials.get("server_url", None) if not server_url: server_url = self.url - - query = tool_parameters.get('query') + + query = tool_parameters.get("query") if not query: - raise Exception('query is required') - - limit = min(tool_parameters.get('limit', 5), 10) - result_type = tool_parameters.get('result_type', 'text') or 'text' - - market = tool_parameters.get('market', 'US') - lang = tool_parameters.get('language', 'en') + raise Exception("query is required") + + limit = min(tool_parameters.get("limit", 5), 10) + result_type = tool_parameters.get("result_type", "text") or "text" + + market = tool_parameters.get("market", "US") + lang = tool_parameters.get("language", "en") filter = [] - if tool_parameters.get('enable_computation', False): - filter.append('Computation') - if tool_parameters.get('enable_entities', False): - filter.append('Entities') - if tool_parameters.get('enable_news', False): - filter.append('News') - if tool_parameters.get('enable_related_search', False): - filter.append('RelatedSearches') - if tool_parameters.get('enable_webpages', False): - filter.append('WebPages') + if tool_parameters.get("enable_computation", False): + filter.append("Computation") + if tool_parameters.get("enable_entities", False): + filter.append("Entities") + if tool_parameters.get("enable_news", False): + filter.append("News") + if tool_parameters.get("enable_related_search", False): + filter.append("RelatedSearches") + if tool_parameters.get("enable_webpages", False): + filter.append("WebPages") if not filter: - raise Exception('At least one filter is required') - + raise Exception("At least one filter is required") + return self._invoke_bing( user_id=user_id, server_url=server_url, @@ -204,5 +198,5 @@ class BingSearchTool(BuiltinTool): result_type=result_type, market=market, lang=lang, - filters=filter - ) \ No newline at end of file + filters=filter, + ) diff --git a/api/core/tools/provider/builtin/brave/brave.py b/api/core/tools/provider/builtin/brave/brave.py index e5eada80ee..c24ee67334 100644 --- a/api/core/tools/provider/builtin/brave/brave.py +++ b/api/core/tools/provider/builtin/brave/brave.py @@ -13,11 +13,10 @@ class BraveProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "Sachin Tendulkar", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/brave/tools/brave_search.py b/api/core/tools/provider/builtin/brave/tools/brave_search.py index 21cbf2c7da..94a4d92844 100644 --- a/api/core/tools/provider/builtin/brave/tools/brave_search.py +++ b/api/core/tools/provider/builtin/brave/tools/brave_search.py @@ -37,7 +37,7 @@ class BraveSearchWrapper(BaseModel): for item in web_search_results ] return json.dumps(final_results) - + def _search_request(self, query: str) -> list[dict]: headers = { "X-Subscription-Token": self.api_key, @@ -55,6 +55,7 @@ class BraveSearchWrapper(BaseModel): return response.json().get("web", {}).get("results", []) + class BraveSearch(BaseModel): """Tool that queries the BraveSearch.""" @@ -67,9 +68,7 @@ class BraveSearch(BaseModel): search_wrapper: BraveSearchWrapper @classmethod - def from_api_key( - cls, api_key: str, search_kwargs: Optional[dict] = None, **kwargs: Any - ) -> "BraveSearch": + def from_api_key(cls, api_key: str, search_kwargs: Optional[dict] = None, **kwargs: Any) -> "BraveSearch": """Create a tool from an api key. Args: @@ -90,6 +89,7 @@ class BraveSearch(BaseModel): """Use the tool.""" return self.search_wrapper.run(query) + class BraveSearchTool(BuiltinTool): """ Tool for performing a search using Brave search engine. @@ -106,12 +106,12 @@ class BraveSearchTool(BuiltinTool): Returns: ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation. """ - query = tool_parameters.get('query', '') - count = tool_parameters.get('count', 3) - api_key = self.runtime.credentials['brave_search_api_key'] + query = tool_parameters.get("query", "") + count = tool_parameters.get("count", 3) + api_key = self.runtime.credentials["brave_search_api_key"] if not query: - return self.create_text_message('Please input query') + return self.create_text_message("Please input query") tool = BraveSearch.from_api_key(api_key=api_key, search_kwargs={"count": count}) @@ -121,4 +121,3 @@ class BraveSearchTool(BuiltinTool): return self.create_text_message(f"No results found for '{query}' in Tavily") else: return self.create_text_message(text=results) - diff --git a/api/core/tools/provider/builtin/chart/chart.py b/api/core/tools/provider/builtin/chart/chart.py index 0865bc700a..8a24d33428 100644 --- a/api/core/tools/provider/builtin/chart/chart.py +++ b/api/core/tools/provider/builtin/chart/chart.py @@ -7,16 +7,34 @@ from core.tools.provider.builtin.chart.tools.line import LinearChartTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController # use a business theme -plt.style.use('seaborn-v0_8-darkgrid') -plt.rcParams['axes.unicode_minus'] = False +plt.style.use("seaborn-v0_8-darkgrid") +plt.rcParams["axes.unicode_minus"] = False + def init_fonts(): fonts = findSystemFonts() popular_unicode_fonts = [ - 'Arial Unicode MS', 'DejaVu Sans', 'DejaVu Sans Mono', 'DejaVu Serif', 'FreeMono', 'FreeSans', 'FreeSerif', - 'Liberation Mono', 'Liberation Sans', 'Liberation Serif', 'Noto Mono', 'Noto Sans', 'Noto Serif', 'Open Sans', - 'Roboto', 'Source Code Pro', 'Source Sans Pro', 'Source Serif Pro', 'Ubuntu', 'Ubuntu Mono' + "Arial Unicode MS", + "DejaVu Sans", + "DejaVu Sans Mono", + "DejaVu Serif", + "FreeMono", + "FreeSans", + "FreeSerif", + "Liberation Mono", + "Liberation Sans", + "Liberation Serif", + "Noto Mono", + "Noto Sans", + "Noto Serif", + "Open Sans", + "Roboto", + "Source Code Pro", + "Source Sans Pro", + "Source Serif Pro", + "Ubuntu", + "Ubuntu Mono", ] supported_fonts = [] @@ -25,21 +43,23 @@ def init_fonts(): try: font = TTFont(font_path) # get family name - family_name = font['name'].getName(1, 3, 1).toUnicode() + family_name = font["name"].getName(1, 3, 1).toUnicode() if family_name in popular_unicode_fonts: supported_fonts.append(family_name) except: pass - plt.rcParams['font.family'] = 'sans-serif' + plt.rcParams["font.family"] = "sans-serif" # sort by order of popular_unicode_fonts for font in popular_unicode_fonts: if font in supported_fonts: - plt.rcParams['font.sans-serif'] = font + plt.rcParams["font.sans-serif"] = font break - + + init_fonts() + class ChartProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: @@ -48,11 +68,10 @@ class ChartProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "data": "1,3,5,7,9,2,4,6,8,10", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/chart/tools/bar.py b/api/core/tools/provider/builtin/chart/tools/bar.py index 749ec761c6..3a47c0cfc0 100644 --- a/api/core/tools/provider/builtin/chart/tools/bar.py +++ b/api/core/tools/provider/builtin/chart/tools/bar.py @@ -8,12 +8,13 @@ from core.tools.tool.builtin_tool import BuiltinTool class BarChartTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - data = tool_parameters.get('data', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + data = tool_parameters.get("data", "") if not data: - return self.create_text_message('Please input data') - data = data.split(';') + return self.create_text_message("Please input data") + data = data.split(";") # if all data is int, convert to int if all(i.isdigit() for i in data): @@ -21,29 +22,27 @@ class BarChartTool(BuiltinTool): else: data = [float(i) for i in data] - axis = tool_parameters.get('x_axis') or None + axis = tool_parameters.get("x_axis") or None if axis: - axis = axis.split(';') + axis = axis.split(";") if len(axis) != len(data): axis = None flg, ax = plt.subplots(figsize=(10, 8)) if axis: - axis = [label[:10] + '...' if len(label) > 10 else label for label in axis] - ax.set_xticklabels(axis, rotation=45, ha='right') + axis = [label[:10] + "..." if len(label) > 10 else label for label in axis] + ax.set_xticklabels(axis, rotation=45, ha="right") ax.bar(axis, data) else: ax.bar(range(len(data)), data) buf = io.BytesIO() - flg.savefig(buf, format='png') + flg.savefig(buf, format="png") buf.seek(0) plt.close(flg) return [ - self.create_text_message('the bar chart is saved as an image.'), - self.create_blob_message(blob=buf.read(), - meta={'mime_type': 'image/png'}) + self.create_text_message("the bar chart is saved as an image."), + self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}), ] - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/chart/tools/line.py b/api/core/tools/provider/builtin/chart/tools/line.py index 608bd6623c..39e8caac7e 100644 --- a/api/core/tools/provider/builtin/chart/tools/line.py +++ b/api/core/tools/provider/builtin/chart/tools/line.py @@ -8,18 +8,19 @@ from core.tools.tool.builtin_tool import BuiltinTool class LinearChartTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - data = tool_parameters.get('data', '') + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + data = tool_parameters.get("data", "") if not data: - return self.create_text_message('Please input data') - data = data.split(';') + return self.create_text_message("Please input data") + data = data.split(";") - axis = tool_parameters.get('x_axis') or None + axis = tool_parameters.get("x_axis") or None if axis: - axis = axis.split(';') + axis = axis.split(";") if len(axis) != len(data): axis = None @@ -32,20 +33,18 @@ class LinearChartTool(BuiltinTool): flg, ax = plt.subplots(figsize=(10, 8)) if axis: - axis = [label[:10] + '...' if len(label) > 10 else label for label in axis] - ax.set_xticklabels(axis, rotation=45, ha='right') + axis = [label[:10] + "..." if len(label) > 10 else label for label in axis] + ax.set_xticklabels(axis, rotation=45, ha="right") ax.plot(axis, data) else: ax.plot(data) buf = io.BytesIO() - flg.savefig(buf, format='png') + flg.savefig(buf, format="png") buf.seek(0) plt.close(flg) return [ - self.create_text_message('the linear chart is saved as an image.'), - self.create_blob_message(blob=buf.read(), - meta={'mime_type': 'image/png'}) + self.create_text_message("the linear chart is saved as an image."), + self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}), ] - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/chart/tools/pie.py b/api/core/tools/provider/builtin/chart/tools/pie.py index 4c551229e9..2c3b8a733e 100644 --- a/api/core/tools/provider/builtin/chart/tools/pie.py +++ b/api/core/tools/provider/builtin/chart/tools/pie.py @@ -8,15 +8,16 @@ from core.tools.tool.builtin_tool import BuiltinTool class PieChartTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - data = tool_parameters.get('data', '') + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + data = tool_parameters.get("data", "") if not data: - return self.create_text_message('Please input data') - data = data.split(';') - categories = tool_parameters.get('categories') or None + return self.create_text_message("Please input data") + data = data.split(";") + categories = tool_parameters.get("categories") or None # if all data is int, convert to int if all(i.isdigit() for i in data): @@ -27,7 +28,7 @@ class PieChartTool(BuiltinTool): flg, ax = plt.subplots() if categories: - categories = categories.split(';') + categories = categories.split(";") if len(categories) != len(data): categories = None @@ -37,12 +38,11 @@ class PieChartTool(BuiltinTool): ax.pie(data) buf = io.BytesIO() - flg.savefig(buf, format='png') + flg.savefig(buf, format="png") buf.seek(0) plt.close(flg) return [ - self.create_text_message('the pie chart is saved as an image.'), - self.create_blob_message(blob=buf.read(), - meta={'mime_type': 'image/png'}) - ] \ No newline at end of file + self.create_text_message("the pie chart is saved as an image."), + self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}), + ] diff --git a/api/core/tools/provider/builtin/code/tools/simple_code.py b/api/core/tools/provider/builtin/code/tools/simple_code.py index 37645bf0d0..017fe548f7 100644 --- a/api/core/tools/provider/builtin/code/tools/simple_code.py +++ b/api/core/tools/provider/builtin/code/tools/simple_code.py @@ -8,15 +8,15 @@ from core.tools.tool.builtin_tool import BuiltinTool class SimpleCode(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: """ - invoke simple code + invoke simple code """ - language = tool_parameters.get('language', CodeLanguage.PYTHON3) - code = tool_parameters.get('code', '') + language = tool_parameters.get("language", CodeLanguage.PYTHON3) + code = tool_parameters.get("code", "") if language not in [CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]: - raise ValueError(f'Only python3 and javascript are supported, not {language}') - - result = CodeExecutor.execute_code(language, '', code) + raise ValueError(f"Only python3 and javascript are supported, not {language}") - return self.create_text_message(result) \ No newline at end of file + result = CodeExecutor.execute_code(language, "", code) + + return self.create_text_message(result) diff --git a/api/core/tools/provider/builtin/cogview/cogview.py b/api/core/tools/provider/builtin/cogview/cogview.py index 801817ec06..6941ce8649 100644 --- a/api/core/tools/provider/builtin/cogview/cogview.py +++ b/api/core/tools/provider/builtin/cogview/cogview.py @@ -1,4 +1,5 @@ -""" Provide the input parameters type for the cogview provider class """ +"""Provide the input parameters type for the cogview provider class""" + from typing import Any from core.tools.errors import ToolProviderCredentialValidationError @@ -7,7 +8,8 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class COGVIEWProvider(BuiltinToolProviderController): - """ cogview provider """ + """cogview provider""" + def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: CogView3Tool().fork_tool_runtime( @@ -15,13 +17,12 @@ class COGVIEWProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "prompt": "一个城市在水晶瓶中欢快生活的场景,水彩画风格,展现出微观与珠宝般的美丽。", "size": "square", - "n": 1 + "n": 1, }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) from e - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/cogview/tools/cogview3.py b/api/core/tools/provider/builtin/cogview/tools/cogview3.py index 89ffcf3347..9776bd7dd1 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogview3.py +++ b/api/core/tools/provider/builtin/cogview/tools/cogview3.py @@ -7,43 +7,42 @@ from core.tools.tool.builtin_tool import BuiltinTool class CogView3Tool(BuiltinTool): - """ CogView3 Tool """ + """CogView3 Tool""" - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke CogView3 tool """ client = ZhipuAI( - base_url=self.runtime.credentials['zhipuai_base_url'], - api_key=self.runtime.credentials['zhipuai_api_key'], + base_url=self.runtime.credentials["zhipuai_base_url"], + api_key=self.runtime.credentials["zhipuai_api_key"], ) size_mapping = { - 'square': '1024x1024', - 'vertical': '1024x1792', - 'horizontal': '1792x1024', + "square": "1024x1024", + "vertical": "1024x1792", + "horizontal": "1792x1024", } # prompt - prompt = tool_parameters.get('prompt', '') + prompt = tool_parameters.get("prompt", "") if not prompt: - return self.create_text_message('Please input prompt') + return self.create_text_message("Please input prompt") # get size - size = size_mapping[tool_parameters.get('size', 'square')] + size = size_mapping[tool_parameters.get("size", "square")] # get n - n = tool_parameters.get('n', 1) + n = tool_parameters.get("n", 1) # get quality - quality = tool_parameters.get('quality', 'standard') - if quality not in ['standard', 'hd']: - return self.create_text_message('Invalid quality') + quality = tool_parameters.get("quality", "standard") + if quality not in ["standard", "hd"]: + return self.create_text_message("Invalid quality") # get style - style = tool_parameters.get('style', 'vivid') - if style not in ['natural', 'vivid']: - return self.create_text_message('Invalid style') + style = tool_parameters.get("style", "vivid") + if style not in ["natural", "vivid"]: + return self.create_text_message("Invalid style") # set extra body - seed_id = tool_parameters.get('seed_id', self._generate_random_id(8)) - extra_body = {'seed': seed_id} + seed_id = tool_parameters.get("seed_id", self._generate_random_id(8)) + extra_body = {"seed": seed_id} response = client.images.generations( prompt=prompt, model="cogview-3", @@ -52,18 +51,22 @@ class CogView3Tool(BuiltinTool): extra_body=extra_body, style=style, quality=quality, - response_format='b64_json' + response_format="b64_json", ) result = [] for image in response.data: result.append(self.create_image_message(image=image.url)) - result.append(self.create_json_message({ - "url": image.url, - })) + result.append( + self.create_json_message( + { + "url": image.url, + } + ) + ) return result @staticmethod def _generate_random_id(length=8): - characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' - random_id = ''.join(random.choices(characters, k=length)) + characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + random_id = "".join(random.choices(characters, k=length)) return random_id diff --git a/api/core/tools/provider/builtin/crossref/crossref.py b/api/core/tools/provider/builtin/crossref/crossref.py index 404e483e0d..8ba3c1b48a 100644 --- a/api/core/tools/provider/builtin/crossref/crossref.py +++ b/api/core/tools/provider/builtin/crossref/crossref.py @@ -11,9 +11,9 @@ class CrossRefProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ - "doi": '10.1007/s00894-022-05373-8', + "doi": "10.1007/s00894-022-05373-8", }, ) except Exception as e: diff --git a/api/core/tools/provider/builtin/crossref/tools/query_doi.py b/api/core/tools/provider/builtin/crossref/tools/query_doi.py index a43c0989e4..746139dd69 100644 --- a/api/core/tools/provider/builtin/crossref/tools/query_doi.py +++ b/api/core/tools/provider/builtin/crossref/tools/query_doi.py @@ -11,15 +11,18 @@ class CrossRefQueryDOITool(BuiltinTool): """ Tool for querying the metadata of a publication using its DOI. """ - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - doi = tool_parameters.get('doi') + + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + doi = tool_parameters.get("doi") if not doi: - raise ToolParameterValidationError('doi is required.') + raise ToolParameterValidationError("doi is required.") # doc: https://github.com/CrossRef/rest-api-doc url = f"https://api.crossref.org/works/{doi}" response = requests.get(url) response.raise_for_status() response = response.json() - message = response.get('message', {}) + message = response.get("message", {}) return self.create_json_message(message) diff --git a/api/core/tools/provider/builtin/crossref/tools/query_title.py b/api/core/tools/provider/builtin/crossref/tools/query_title.py index 946aa6dc94..e245238183 100644 --- a/api/core/tools/provider/builtin/crossref/tools/query_title.py +++ b/api/core/tools/provider/builtin/crossref/tools/query_title.py @@ -12,16 +12,16 @@ def convert_time_str_to_seconds(time_str: str) -> int: Convert a time string to seconds. example: 1s -> 1, 1m30s -> 90, 1h30m -> 5400, 1h30m30s -> 5430 """ - time_str = time_str.lower().strip().replace(' ', '') + time_str = time_str.lower().strip().replace(" ", "") seconds = 0 - if 'h' in time_str: - hours, time_str = time_str.split('h') + if "h" in time_str: + hours, time_str = time_str.split("h") seconds += int(hours) * 3600 - if 'm' in time_str: - minutes, time_str = time_str.split('m') + if "m" in time_str: + minutes, time_str = time_str.split("m") seconds += int(minutes) * 60 - if 's' in time_str: - seconds += int(time_str.replace('s', '')) + if "s" in time_str: + seconds += int(time_str.replace("s", "")) return seconds @@ -30,6 +30,7 @@ class CrossRefQueryTitleAPI: Tool for querying the metadata of a publication using its title. Crossref API doc: https://github.com/CrossRef/rest-api-doc """ + query_url_template: str = "https://api.crossref.org/works?query.bibliographic={query}&rows={rows}&offset={offset}&sort={sort}&order={order}&mailto={mailto}" rate_limit: int = 50 rate_interval: float = 1 @@ -38,7 +39,15 @@ class CrossRefQueryTitleAPI: def __init__(self, mailto: str): self.mailto = mailto - def _query(self, query: str, rows: int = 5, offset: int = 0, sort: str = 'relevance', order: str = 'desc', fuzzy_query: bool = False) -> list[dict]: + def _query( + self, + query: str, + rows: int = 5, + offset: int = 0, + sort: str = "relevance", + order: str = "desc", + fuzzy_query: bool = False, + ) -> list[dict]: """ Query the metadata of a publication using its title. :param query: the title of the publication @@ -47,33 +56,37 @@ class CrossRefQueryTitleAPI: :param order: the sort order :param fuzzy_query: whether to return all items that match the query """ - url = self.query_url_template.format(query=query, rows=rows, offset=offset, sort=sort, order=order, mailto=self.mailto) + url = self.query_url_template.format( + query=query, rows=rows, offset=offset, sort=sort, order=order, mailto=self.mailto + ) response = requests.get(url) response.raise_for_status() - rate_limit = int(response.headers['x-ratelimit-limit']) + rate_limit = int(response.headers["x-ratelimit-limit"]) # convert time string to seconds - rate_interval = convert_time_str_to_seconds(response.headers['x-ratelimit-interval']) + rate_interval = convert_time_str_to_seconds(response.headers["x-ratelimit-interval"]) self.rate_limit = rate_limit self.rate_interval = rate_interval response = response.json() - if response['status'] != 'ok': + if response["status"] != "ok": return [] - message = response['message'] + message = response["message"] if fuzzy_query: # fuzzy query return all items - return message['items'] + return message["items"] else: - for paper in message['items']: - title = paper['title'][0] + for paper in message["items"]: + title = paper["title"][0] if title.lower() != query.lower(): continue return [paper] return [] - def query(self, query: str, rows: int = 5, sort: str = 'relevance', order: str = 'desc', fuzzy_query: bool = False) -> list[dict]: + def query( + self, query: str, rows: int = 5, sort: str = "relevance", order: str = "desc", fuzzy_query: bool = False + ) -> list[dict]: """ Query the metadata of a publication using its title. :param query: the title of the publication @@ -89,7 +102,14 @@ class CrossRefQueryTitleAPI: results = [] for i in range(query_times): - result = self._query(query, rows=self.rate_limit, offset=i * self.rate_limit, sort=sort, order=order, fuzzy_query=fuzzy_query) + result = self._query( + query, + rows=self.rate_limit, + offset=i * self.rate_limit, + sort=sort, + order=order, + fuzzy_query=fuzzy_query, + ) if fuzzy_query: results.extend(result) else: @@ -107,13 +127,16 @@ class CrossRefQueryTitleTool(BuiltinTool): """ Tool for querying the metadata of a publication using its title. """ - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - query = tool_parameters.get('query') - fuzzy_query = tool_parameters.get('fuzzy_query', False) - rows = tool_parameters.get('rows', 3) - sort = tool_parameters.get('sort', 'relevance') - order = tool_parameters.get('order', 'desc') - mailto = self.runtime.credentials['mailto'] + + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + query = tool_parameters.get("query") + fuzzy_query = tool_parameters.get("fuzzy_query", False) + rows = tool_parameters.get("rows", 3) + sort = tool_parameters.get("sort", "relevance") + order = tool_parameters.get("order", "desc") + mailto = self.runtime.credentials["mailto"] result = CrossRefQueryTitleAPI(mailto).query(query, rows, sort, order, fuzzy_query) diff --git a/api/core/tools/provider/builtin/dalle/dalle.py b/api/core/tools/provider/builtin/dalle/dalle.py index 1c8019364d..5bd16e49e8 100644 --- a/api/core/tools/provider/builtin/dalle/dalle.py +++ b/api/core/tools/provider/builtin/dalle/dalle.py @@ -13,13 +13,8 @@ class DALLEProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "prompt": "cute girl, blue eyes, white hair, anime style", - "size": "small", - "n": 1 - }, + user_id="", + tool_parameters={"prompt": "cute girl, blue eyes, white hair, anime style", "size": "small", "n": 1}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/dalle/dalle.yaml b/api/core/tools/provider/builtin/dalle/dalle.yaml index f09a9177f2..37cf93c28a 100644 --- a/api/core/tools/provider/builtin/dalle/dalle.yaml +++ b/api/core/tools/provider/builtin/dalle/dalle.yaml @@ -29,7 +29,7 @@ credentials_for_provider: en_US: Please input your OpenAI API key zh_Hans: 请输入你的 OpenAI API key pt_BR: Please input your OpenAI API key - openai_organizaion_id: + openai_organization_id: type: text-input required: false label: diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle2.py b/api/core/tools/provider/builtin/dalle/tools/dalle2.py index 450e782281..fbd7397292 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle2.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle2.py @@ -9,59 +9,58 @@ from core.tools.tool.builtin_tool import BuiltinTool class DallE2Tool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - openai_organization = self.runtime.credentials.get('openai_organizaion_id', None) + openai_organization = self.runtime.credentials.get("openai_organization_id", None) if not openai_organization: openai_organization = None - openai_base_url = self.runtime.credentials.get('openai_base_url', None) + openai_base_url = self.runtime.credentials.get("openai_base_url", None) if not openai_base_url: openai_base_url = None else: - openai_base_url = str(URL(openai_base_url) / 'v1') + openai_base_url = str(URL(openai_base_url) / "v1") client = OpenAI( - api_key=self.runtime.credentials['openai_api_key'], + api_key=self.runtime.credentials["openai_api_key"], base_url=openai_base_url, - organization=openai_organization + organization=openai_organization, ) SIZE_MAPPING = { - 'small': '256x256', - 'medium': '512x512', - 'large': '1024x1024', + "small": "256x256", + "medium": "512x512", + "large": "1024x1024", } # prompt - prompt = tool_parameters.get('prompt', '') + prompt = tool_parameters.get("prompt", "") if not prompt: - return self.create_text_message('Please input prompt') - + return self.create_text_message("Please input prompt") + # get size - size = SIZE_MAPPING[tool_parameters.get('size', 'large')] + size = SIZE_MAPPING[tool_parameters.get("size", "large")] # get n - n = tool_parameters.get('n', 1) + n = tool_parameters.get("n", 1) # call openapi dalle2 - response = client.images.generate( - prompt=prompt, - model='dall-e-2', - size=size, - n=n, - response_format='b64_json' - ) + response = client.images.generate(prompt=prompt, model="dall-e-2", size=size, n=n, response_format="b64_json") result = [] for image in response.data: - result.append(self.create_blob_message(blob=b64decode(image.b64_json), - meta={ 'mime_type': 'image/png' }, - save_as=self.VARIABLE_KEY.IMAGE.value)) + result.append( + self.create_blob_message( + blob=b64decode(image.b64_json), + meta={"mime_type": "image/png"}, + save_as=self.VariableKey.IMAGE.value, + ) + ) return result diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle3.py b/api/core/tools/provider/builtin/dalle/tools/dalle3.py index f985deade5..bcfa2212b6 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle3.py @@ -10,69 +10,64 @@ from core.tools.tool.builtin_tool import BuiltinTool class DallE3Tool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - openai_organization = self.runtime.credentials.get('openai_organizaion_id', None) + openai_organization = self.runtime.credentials.get("openai_organization_id", None) if not openai_organization: openai_organization = None - openai_base_url = self.runtime.credentials.get('openai_base_url', None) + openai_base_url = self.runtime.credentials.get("openai_base_url", None) if not openai_base_url: openai_base_url = None else: - openai_base_url = str(URL(openai_base_url) / 'v1') + openai_base_url = str(URL(openai_base_url) / "v1") client = OpenAI( - api_key=self.runtime.credentials['openai_api_key'], + api_key=self.runtime.credentials["openai_api_key"], base_url=openai_base_url, - organization=openai_organization + organization=openai_organization, ) SIZE_MAPPING = { - 'square': '1024x1024', - 'vertical': '1024x1792', - 'horizontal': '1792x1024', + "square": "1024x1024", + "vertical": "1024x1792", + "horizontal": "1792x1024", } # prompt - prompt = tool_parameters.get('prompt', '') + prompt = tool_parameters.get("prompt", "") if not prompt: - return self.create_text_message('Please input prompt') + return self.create_text_message("Please input prompt") # get size - size = SIZE_MAPPING[tool_parameters.get('size', 'square')] + size = SIZE_MAPPING[tool_parameters.get("size", "square")] # get n - n = tool_parameters.get('n', 1) + n = tool_parameters.get("n", 1) # get quality - quality = tool_parameters.get('quality', 'standard') - if quality not in ['standard', 'hd']: - return self.create_text_message('Invalid quality') + quality = tool_parameters.get("quality", "standard") + if quality not in ["standard", "hd"]: + return self.create_text_message("Invalid quality") # get style - style = tool_parameters.get('style', 'vivid') - if style not in ['natural', 'vivid']: - return self.create_text_message('Invalid style') + style = tool_parameters.get("style", "vivid") + if style not in ["natural", "vivid"]: + return self.create_text_message("Invalid style") # call openapi dalle3 response = client.images.generate( - prompt=prompt, - model='dall-e-3', - size=size, - n=n, - style=style, - quality=quality, - response_format='b64_json' + prompt=prompt, model="dall-e-3", size=size, n=n, style=style, quality=quality, response_format="b64_json" ) result = [] for image in response.data: mime_type, blob_image = DallE3Tool._decode_image(image.b64_json) - blob_message = self.create_blob_message(blob=blob_image, - meta={'mime_type': mime_type}, - save_as=self.VARIABLE_KEY.IMAGE.value) + blob_message = self.create_blob_message( + blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE.value + ) result.append(blob_message) return result @@ -86,7 +81,7 @@ class DallE3Tool(BuiltinTool): :return: A tuple containing the MIME type and the decoded image bytes """ if DallE3Tool._is_plain_base64(base64_image): - return 'image/png', base64.b64decode(base64_image) + return "image/png", base64.b64decode(base64_image) else: return DallE3Tool._extract_mime_and_data(base64_image) @@ -98,7 +93,7 @@ class DallE3Tool(BuiltinTool): :param encoded_str: Base64 encoded image string :return: True if the string is plain base64, False otherwise """ - return not encoded_str.startswith('data:image') + return not encoded_str.startswith("data:image") @staticmethod def _extract_mime_and_data(encoded_str: str) -> tuple[str, bytes]: @@ -108,13 +103,13 @@ class DallE3Tool(BuiltinTool): :param encoded_str: Base64 encoded image string with MIME type prefix :return: A tuple containing the MIME type and the decoded image bytes """ - mime_type = encoded_str.split(';')[0].split(':')[1] - image_data_base64 = encoded_str.split(',')[1] + mime_type = encoded_str.split(";")[0].split(":")[1] + image_data_base64 = encoded_str.split(",")[1] decoded_data = base64.b64decode(image_data_base64) return mime_type, decoded_data @staticmethod def _generate_random_id(length=8): - characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' - random_id = ''.join(random.choices(characters, k=length)) + characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + random_id = "".join(random.choices(characters, k=length)) return random_id diff --git a/api/core/tools/provider/builtin/devdocs/devdocs.py b/api/core/tools/provider/builtin/devdocs/devdocs.py index 95d7939d0d..446c1e5489 100644 --- a/api/core/tools/provider/builtin/devdocs/devdocs.py +++ b/api/core/tools/provider/builtin/devdocs/devdocs.py @@ -11,7 +11,7 @@ class DevDocsProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "doc": "python~3.12", "topic": "library/code", @@ -19,4 +19,3 @@ class DevDocsProvider(BuiltinToolProviderController): ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/devdocs/tools/searchDevDocs.py b/api/core/tools/provider/builtin/devdocs/tools/searchDevDocs.py index 1a244c5db3..e1effd066c 100644 --- a/api/core/tools/provider/builtin/devdocs/tools/searchDevDocs.py +++ b/api/core/tools/provider/builtin/devdocs/tools/searchDevDocs.py @@ -13,7 +13,9 @@ class SearchDevDocsInput(BaseModel): class SearchDevDocsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invokes the DevDocs search tool with the given user ID and tool parameters. @@ -24,13 +26,13 @@ class SearchDevDocsTool(BuiltinTool): Returns: ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages. """ - doc = tool_parameters.get('doc', '') - topic = tool_parameters.get('topic', '') + doc = tool_parameters.get("doc", "") + topic = tool_parameters.get("topic", "") if not doc: - return self.create_text_message('Please provide the documentation name.') + return self.create_text_message("Please provide the documentation name.") if not topic: - return self.create_text_message('Please provide the topic path.') + return self.create_text_message("Please provide the topic path.") url = f"https://documents.devdocs.io/{doc}/{topic}.html" response = requests.get(url) @@ -39,4 +41,6 @@ class SearchDevDocsTool(BuiltinTool): content = response.text return self.create_text_message(self.summary(user_id=user_id, content=content)) else: - return self.create_text_message(f"Failed to retrieve the documentation. Status code: {response.status_code}") \ No newline at end of file + return self.create_text_message( + f"Failed to retrieve the documentation. Status code: {response.status_code}" + ) diff --git a/api/core/tools/provider/builtin/did/did.py b/api/core/tools/provider/builtin/did/did.py index b4bf172131..5af78794f6 100644 --- a/api/core/tools/provider/builtin/did/did.py +++ b/api/core/tools/provider/builtin/did/did.py @@ -7,15 +7,12 @@ class DIDProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: # Example validation using the D-ID talks tool - TalksTool().fork_tool_runtime( - runtime={"credentials": credentials} - ).invoke( - user_id='', + TalksTool().fork_tool_runtime(runtime={"credentials": credentials}).invoke( + user_id="", tool_parameters={ "source_url": "https://www.d-id.com/wp-content/uploads/2023/11/Hero-image-1.png", "text_input": "Hello, welcome to use D-ID tool in Dify", - } + }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/did/did_appx.py b/api/core/tools/provider/builtin/did/did_appx.py index 964e82b729..c68878630d 100644 --- a/api/core/tools/provider/builtin/did/did_appx.py +++ b/api/core/tools/provider/builtin/did/did_appx.py @@ -12,14 +12,14 @@ logger = logging.getLogger(__name__) class DIDApp: def __init__(self, api_key: str | None = None, base_url: str | None = None): self.api_key = api_key - self.base_url = base_url or 'https://api.d-id.com' + self.base_url = base_url or "https://api.d-id.com" if not self.api_key: - raise ValueError('API key is required') + raise ValueError("API key is required") def _prepare_headers(self, idempotency_key: str | None = None): - headers = {'Content-Type': 'application/json', 'Authorization': f'Basic {self.api_key}'} + headers = {"Content-Type": "application/json", "Authorization": f"Basic {self.api_key}"} if idempotency_key: - headers['Idempotency-Key'] = idempotency_key + headers["Idempotency-Key"] = idempotency_key return headers def _request( @@ -44,44 +44,44 @@ class DIDApp: return None def talks(self, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs): - endpoint = f'{self.base_url}/talks' + endpoint = f"{self.base_url}/talks" headers = self._prepare_headers(idempotency_key) - data = kwargs['params'] - logger.debug(f'Send request to {endpoint=} body={data}') - response = self._request('POST', endpoint, data, headers) + data = kwargs["params"] + logger.debug(f"Send request to {endpoint=} body={data}") + response = self._request("POST", endpoint, data, headers) if response is None: - raise HTTPError('Failed to initiate D-ID talks after multiple retries') - id: str = response['id'] + raise HTTPError("Failed to initiate D-ID talks after multiple retries") + id: str = response["id"] if wait: - return self._monitor_job_status(id=id, target='talks', poll_interval=poll_interval) + return self._monitor_job_status(id=id, target="talks", poll_interval=poll_interval) return id def animations(self, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs): - endpoint = f'{self.base_url}/animations' + endpoint = f"{self.base_url}/animations" headers = self._prepare_headers(idempotency_key) - data = kwargs['params'] - logger.debug(f'Send request to {endpoint=} body={data}') - response = self._request('POST', endpoint, data, headers) + data = kwargs["params"] + logger.debug(f"Send request to {endpoint=} body={data}") + response = self._request("POST", endpoint, data, headers) if response is None: - raise HTTPError('Failed to initiate D-ID talks after multiple retries') - id: str = response['id'] + raise HTTPError("Failed to initiate D-ID talks after multiple retries") + id: str = response["id"] if wait: - return self._monitor_job_status(target='animations', id=id, poll_interval=poll_interval) + return self._monitor_job_status(target="animations", id=id, poll_interval=poll_interval) return id def check_did_status(self, target: str, id: str): - endpoint = f'{self.base_url}/{target}/{id}' + endpoint = f"{self.base_url}/{target}/{id}" headers = self._prepare_headers() - response = self._request('GET', endpoint, headers=headers) + response = self._request("GET", endpoint, headers=headers) if response is None: - raise HTTPError(f'Failed to check status for talks {id} after multiple retries') + raise HTTPError(f"Failed to check status for talks {id} after multiple retries") return response def _monitor_job_status(self, target: str, id: str, poll_interval: int): while True: status = self.check_did_status(target=target, id=id) - if status['status'] == 'done': + if status["status"] == "done": return status - elif status['status'] == 'error' or status['status'] == 'rejected': - raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error",{}).get("description")}') + elif status["status"] == "error" or status["status"] == "rejected": + raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error", {}).get("description")}') time.sleep(poll_interval) diff --git a/api/core/tools/provider/builtin/did/tools/animations.py b/api/core/tools/provider/builtin/did/tools/animations.py index e1d9de603f..bc9d17e40d 100644 --- a/api/core/tools/provider/builtin/did/tools/animations.py +++ b/api/core/tools/provider/builtin/did/tools/animations.py @@ -10,33 +10,33 @@ class AnimationsTool(BuiltinTool): def _invoke( self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - app = DIDApp(api_key=self.runtime.credentials['did_api_key'], base_url=self.runtime.credentials['base_url']) + app = DIDApp(api_key=self.runtime.credentials["did_api_key"], base_url=self.runtime.credentials["base_url"]) - driver_expressions_str = tool_parameters.get('driver_expressions') + driver_expressions_str = tool_parameters.get("driver_expressions") driver_expressions = json.loads(driver_expressions_str) if driver_expressions_str else None config = { - 'stitch': tool_parameters.get('stitch', True), - 'mute': tool_parameters.get('mute'), - 'result_format': tool_parameters.get('result_format') or 'mp4', + "stitch": tool_parameters.get("stitch", True), + "mute": tool_parameters.get("mute"), + "result_format": tool_parameters.get("result_format") or "mp4", } - config = {k: v for k, v in config.items() if v is not None and v != ''} + config = {k: v for k, v in config.items() if v is not None and v != ""} options = { - 'source_url': tool_parameters['source_url'], - 'driver_url': tool_parameters.get('driver_url'), - 'config': config, + "source_url": tool_parameters["source_url"], + "driver_url": tool_parameters.get("driver_url"), + "config": config, } - options = {k: v for k, v in options.items() if v is not None and v != ''} + options = {k: v for k, v in options.items() if v is not None and v != ""} - if not options.get('source_url'): - raise ValueError('Source URL is required') + if not options.get("source_url"): + raise ValueError("Source URL is required") - if config.get('logo_url'): - if not config.get('logo_x'): - raise ValueError('Logo X position is required when logo URL is provided') - if not config.get('logo_y'): - raise ValueError('Logo Y position is required when logo URL is provided') + if config.get("logo_url"): + if not config.get("logo_x"): + raise ValueError("Logo X position is required when logo URL is provided") + if not config.get("logo_y"): + raise ValueError("Logo Y position is required when logo URL is provided") animations_result = app.animations(params=options, wait=True) @@ -44,6 +44,6 @@ class AnimationsTool(BuiltinTool): animations_result = json.dumps(animations_result, ensure_ascii=False, indent=4) if not animations_result: - return self.create_text_message('D-ID animations request failed.') + return self.create_text_message("D-ID animations request failed.") return self.create_text_message(animations_result) diff --git a/api/core/tools/provider/builtin/did/tools/talks.py b/api/core/tools/provider/builtin/did/tools/talks.py index 06b2c4cb2f..d6f0c7ff17 100644 --- a/api/core/tools/provider/builtin/did/tools/talks.py +++ b/api/core/tools/provider/builtin/did/tools/talks.py @@ -10,49 +10,49 @@ class TalksTool(BuiltinTool): def _invoke( self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - app = DIDApp(api_key=self.runtime.credentials['did_api_key'], base_url=self.runtime.credentials['base_url']) + app = DIDApp(api_key=self.runtime.credentials["did_api_key"], base_url=self.runtime.credentials["base_url"]) - driver_expressions_str = tool_parameters.get('driver_expressions') + driver_expressions_str = tool_parameters.get("driver_expressions") driver_expressions = json.loads(driver_expressions_str) if driver_expressions_str else None script = { - 'type': tool_parameters.get('script_type') or 'text', - 'input': tool_parameters.get('text_input'), - 'audio_url': tool_parameters.get('audio_url'), - 'reduce_noise': tool_parameters.get('audio_reduce_noise', False), + "type": tool_parameters.get("script_type") or "text", + "input": tool_parameters.get("text_input"), + "audio_url": tool_parameters.get("audio_url"), + "reduce_noise": tool_parameters.get("audio_reduce_noise", False), } - script = {k: v for k, v in script.items() if v is not None and v != ''} + script = {k: v for k, v in script.items() if v is not None and v != ""} config = { - 'stitch': tool_parameters.get('stitch', True), - 'sharpen': tool_parameters.get('sharpen'), - 'fluent': tool_parameters.get('fluent'), - 'result_format': tool_parameters.get('result_format') or 'mp4', - 'pad_audio': tool_parameters.get('pad_audio'), - 'driver_expressions': driver_expressions, + "stitch": tool_parameters.get("stitch", True), + "sharpen": tool_parameters.get("sharpen"), + "fluent": tool_parameters.get("fluent"), + "result_format": tool_parameters.get("result_format") or "mp4", + "pad_audio": tool_parameters.get("pad_audio"), + "driver_expressions": driver_expressions, } - config = {k: v for k, v in config.items() if v is not None and v != ''} + config = {k: v for k, v in config.items() if v is not None and v != ""} options = { - 'source_url': tool_parameters['source_url'], - 'driver_url': tool_parameters.get('driver_url'), - 'script': script, - 'config': config, + "source_url": tool_parameters["source_url"], + "driver_url": tool_parameters.get("driver_url"), + "script": script, + "config": config, } - options = {k: v for k, v in options.items() if v is not None and v != ''} + options = {k: v for k, v in options.items() if v is not None and v != ""} - if not options.get('source_url'): - raise ValueError('Source URL is required') + if not options.get("source_url"): + raise ValueError("Source URL is required") - if script.get('type') == 'audio': - script.pop('input', None) - if not script.get('audio_url'): - raise ValueError('Audio URL is required for audio script type') + if script.get("type") == "audio": + script.pop("input", None) + if not script.get("audio_url"): + raise ValueError("Audio URL is required for audio script type") - if script.get('type') == 'text': - script.pop('audio_url', None) - script.pop('reduce_noise', None) - if not script.get('input'): - raise ValueError('Text input is required for text script type') + if script.get("type") == "text": + script.pop("audio_url", None) + script.pop("reduce_noise", None) + if not script.get("input"): + raise ValueError("Text input is required for text script type") talks_result = app.talks(params=options, wait=True) @@ -60,6 +60,6 @@ class TalksTool(BuiltinTool): talks_result = json.dumps(talks_result, ensure_ascii=False, indent=4) if not talks_result: - return self.create_text_message('D-ID talks request failed.') + return self.create_text_message("D-ID talks request failed.") return self.create_text_message(talks_result) diff --git a/api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.py b/api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.py index c247c3bd6b..f33ad5be59 100644 --- a/api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.py +++ b/api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.py @@ -13,38 +13,43 @@ from core.tools.tool.builtin_tool import BuiltinTool class DingTalkGroupBotTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools - Dingtalk custom group robot API docs: - https://open.dingtalk.com/document/orgapp/custom-robot-access + invoke tools + Dingtalk custom group robot API docs: + https://open.dingtalk.com/document/orgapp/custom-robot-access """ - content = tool_parameters.get('content') + content = tool_parameters.get("content") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") - access_token = tool_parameters.get('access_token') + access_token = tool_parameters.get("access_token") if not access_token: - return self.create_text_message('Invalid parameter access_token. ' - 'Regarding information about security details,' - 'please refer to the DingTalk docs:' - 'https://open.dingtalk.com/document/robots/customize-robot-security-settings') + return self.create_text_message( + "Invalid parameter access_token. " + "Regarding information about security details," + "please refer to the DingTalk docs:" + "https://open.dingtalk.com/document/robots/customize-robot-security-settings" + ) - sign_secret = tool_parameters.get('sign_secret') + sign_secret = tool_parameters.get("sign_secret") if not sign_secret: - return self.create_text_message('Invalid parameter sign_secret. ' - 'Regarding information about security details,' - 'please refer to the DingTalk docs:' - 'https://open.dingtalk.com/document/robots/customize-robot-security-settings') + return self.create_text_message( + "Invalid parameter sign_secret. " + "Regarding information about security details," + "please refer to the DingTalk docs:" + "https://open.dingtalk.com/document/robots/customize-robot-security-settings" + ) - msgtype = 'text' - api_url = 'https://oapi.dingtalk.com/robot/send' + msgtype = "text" + api_url = "https://oapi.dingtalk.com/robot/send" headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } params = { - 'access_token': access_token, + "access_token": access_token, } self._apply_security_mechanism(params, sign_secret) @@ -53,7 +58,7 @@ class DingTalkGroupBotTool(BuiltinTool): "msgtype": msgtype, "text": { "content": content, - } + }, } try: @@ -62,7 +67,8 @@ class DingTalkGroupBotTool(BuiltinTool): return self.create_text_message("Text message sent successfully") else: return self.create_text_message( - f"Failed to send the text message, status code: {res.status_code}, response: {res.text}") + f"Failed to send the text message, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to send message to group chat bot. {}".format(e)) @@ -70,14 +76,14 @@ class DingTalkGroupBotTool(BuiltinTool): def _apply_security_mechanism(params: dict[str, Any], sign_secret: str): try: timestamp = str(round(time.time() * 1000)) - secret_enc = sign_secret.encode('utf-8') - string_to_sign = f'{timestamp}\n{sign_secret}' - string_to_sign_enc = string_to_sign.encode('utf-8') + secret_enc = sign_secret.encode("utf-8") + string_to_sign = f"{timestamp}\n{sign_secret}" + string_to_sign_enc = string_to_sign.encode("utf-8") hmac_code = hmac.new(secret_enc, string_to_sign_enc, digestmod=hashlib.sha256).digest() sign = urllib.parse.quote_plus(base64.b64encode(hmac_code)) - params['timestamp'] = timestamp - params['sign'] = sign + params["timestamp"] = timestamp + params["sign"] = sign except Exception: msg = "Failed to apply security mechanism to the request." logging.exception(msg) diff --git a/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py b/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py index 2292e89fa6..8269167127 100644 --- a/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py +++ b/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py @@ -11,11 +11,10 @@ class DuckDuckGoProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "John Doe", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.py index 878b0d8645..8bdd638f4a 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.py @@ -13,8 +13,8 @@ class DuckDuckGoAITool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: query_dict = { - "keywords": tool_parameters.get('query'), - "model": tool_parameters.get('model'), + "keywords": tool_parameters.get("query"), + "model": tool_parameters.get("model"), } response = DDGS().chat(**query_dict) return self.create_text_message(text=response) diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py index bca53f6b4b..396570248a 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py @@ -14,18 +14,17 @@ class DuckDuckGoImageSearchTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]: query_dict = { - "keywords": tool_parameters.get('query'), - "timelimit": tool_parameters.get('timelimit'), - "size": tool_parameters.get('size'), - "max_results": tool_parameters.get('max_results'), + "keywords": tool_parameters.get("query"), + "timelimit": tool_parameters.get("timelimit"), + "size": tool_parameters.get("size"), + "max_results": tool_parameters.get("max_results"), } response = DDGS().images(**query_dict) result = [] for res in response: - res['transfer_method'] = FileTransferMethod.REMOTE_URL - msg = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=res.get('image'), - save_as='', - meta=res) + res["transfer_method"] = FileTransferMethod.REMOTE_URL + msg = ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=res.get("image"), save_as="", meta=res + ) result.append(msg) return result diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py index dfaeb734d8..cbd65d2e77 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py @@ -21,10 +21,11 @@ class DuckDuckGoSearchTool(BuiltinTool): """ Tool for performing a search using DuckDuckGo search engine. """ + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: - query = tool_parameters.get('query') - max_results = tool_parameters.get('max_results', 5) - require_summary = tool_parameters.get('require_summary', False) + query = tool_parameters.get("query") + max_results = tool_parameters.get("max_results", 5) + require_summary = tool_parameters.get("require_summary", False) response = DDGS().text(query, max_results=max_results) if require_summary: results = "\n".join([res.get("body") for res in response]) @@ -34,7 +35,11 @@ class DuckDuckGoSearchTool(BuiltinTool): def summary_results(self, user_id: str, content: str, query: str) -> str: prompt = SUMMARY_PROMPT.format(query=query, content=content) - summary = self.invoke_model(user_id=user_id, prompt_messages=[ - SystemPromptMessage(content=prompt), - ], stop=[]) + summary = self.invoke_model( + user_id=user_id, + prompt_messages=[ + SystemPromptMessage(content=prompt), + ], + stop=[], + ) return summary.message.content diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_translate.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_translate.py index 9822b37cf0..396ce21b18 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_translate.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_translate.py @@ -13,8 +13,8 @@ class DuckDuckGoTranslateTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: query_dict = { - "keywords": tool_parameters.get('query'), - "to": tool_parameters.get('translate_to'), + "keywords": tool_parameters.get("query"), + "to": tool_parameters.get("translate_to"), } - response = DDGS().translate(**query_dict)[0].get('translated', 'Unable to translate!') + response = DDGS().translate(**query_dict)[0].get("translated", "Unable to translate!") return self.create_text_message(text=response) diff --git a/api/core/tools/provider/builtin/feishu/tools/feishu_group_bot.py b/api/core/tools/provider/builtin/feishu/tools/feishu_group_bot.py index e8ab02f55e..e82da8ca53 100644 --- a/api/core/tools/provider/builtin/feishu/tools/feishu_group_bot.py +++ b/api/core/tools/provider/builtin/feishu/tools/feishu_group_bot.py @@ -8,35 +8,35 @@ from core.tools.utils.uuid_utils import is_valid_uuid class FeishuGroupBotTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools - API document: https://open.feishu.cn/document/client-docs/bot-v3/add-custom-bot + invoke tools + API document: https://open.feishu.cn/document/client-docs/bot-v3/add-custom-bot """ url = "https://open.feishu.cn/open-apis/bot/v2/hook" - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") - hook_key = tool_parameters.get('hook_key', '') + hook_key = tool_parameters.get("hook_key", "") if not is_valid_uuid(hook_key): - return self.create_text_message( - f'Invalid parameter hook_key ${hook_key}, not a valid UUID') + return self.create_text_message(f"Invalid parameter hook_key ${hook_key}, not a valid UUID") - msg_type = 'text' - api_url = f'{url}/{hook_key}' + msg_type = "text" + api_url = f"{url}/{hook_key}" headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } params = {} payload = { "msg_type": msg_type, "content": { "text": content, - } + }, } try: @@ -45,6 +45,7 @@ class FeishuGroupBotTool(BuiltinTool): return self.create_text_message("Text message sent successfully") else: return self.create_text_message( - f"Failed to send the text message, status code: {res.status_code}, response: {res.text}") + f"Failed to send the text message, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: - return self.create_text_message("Failed to send message to group chat bot. {}".format(e)) \ No newline at end of file + return self.create_text_message("Failed to send message to group chat bot. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/feishu_base.py b/api/core/tools/provider/builtin/feishu_base/feishu_base.py index febb769ff8..04056af53b 100644 --- a/api/core/tools/provider/builtin/feishu_base/feishu_base.py +++ b/api/core/tools/provider/builtin/feishu_base/feishu_base.py @@ -5,4 +5,4 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class FeishuBaseProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: GetTenantAccessTokenTool() - pass \ No newline at end of file + pass diff --git a/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.py b/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.py index be43b43ce4..4a605fbffe 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.py @@ -8,45 +8,49 @@ from core.tools.tool.builtin_tool import BuiltinTool class AddBaseRecordTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - table_id = tool_parameters.get('table_id', '') + table_id = tool_parameters.get("table_id", "") if not table_id: - return self.create_text_message('Invalid parameter table_id') + return self.create_text_message("Invalid parameter table_id") - fields = tool_parameters.get('fields', '') + fields = tool_parameters.get("fields", "") if not fields: - return self.create_text_message('Invalid parameter fields') + return self.create_text_message("Invalid parameter fields") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = {} - payload = { - "fields": json.loads(fields) - } + payload = {"fields": json.loads(fields)} try: - res = httpx.post(url.format(app_token=app_token, table_id=table_id), headers=headers, params=params, - json=payload, timeout=30) + res = httpx.post( + url.format(app_token=app_token, table_id=table_id), + headers=headers, + params=params, + json=payload, + timeout=30, + ) res_json = res.json() if res.is_success: return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to add base record, status code: {res.status_code}, response: {res.text}") + f"Failed to add base record, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to add base record. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/create_base.py b/api/core/tools/provider/builtin/feishu_base/tools/create_base.py index 639644e7f0..6b755e2007 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/create_base.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/create_base.py @@ -8,28 +8,25 @@ from core.tools.tool.builtin_tool import BuiltinTool class CreateBaseTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - name = tool_parameters.get('name', '') - folder_token = tool_parameters.get('folder_token', '') + name = tool_parameters.get("name", "") + folder_token = tool_parameters.get("folder_token", "") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = {} - payload = { - "name": name, - "folder_token": folder_token - } + payload = {"name": name, "folder_token": folder_token} try: res = httpx.post(url, headers=headers, params=params, json=payload, timeout=30) @@ -38,6 +35,7 @@ class CreateBaseTool(BuiltinTool): return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to create base, status code: {res.status_code}, response: {res.text}") + f"Failed to create base, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to create base. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.py b/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.py index e9062e8730..b05d700113 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.py @@ -8,37 +8,32 @@ from core.tools.tool.builtin_tool import BuiltinTool class CreateBaseTableTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - name = tool_parameters.get('name', '') + name = tool_parameters.get("name", "") - fields = tool_parameters.get('fields', '') + fields = tool_parameters.get("fields", "") if not fields: - return self.create_text_message('Invalid parameter fields') + return self.create_text_message("Invalid parameter fields") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = {} - payload = { - "table": { - "name": name, - "fields": json.loads(fields) - } - } + payload = {"table": {"name": name, "fields": json.loads(fields)}} try: res = httpx.post(url.format(app_token=app_token), headers=headers, params=params, json=payload, timeout=30) @@ -47,6 +42,7 @@ class CreateBaseTableTool(BuiltinTool): return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to create base table, status code: {res.status_code}, response: {res.text}") + f"Failed to create base table, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to create base table. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.py b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.py index aa13aad6fa..862eb2171b 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.py @@ -8,45 +8,49 @@ from core.tools.tool.builtin_tool import BuiltinTool class DeleteBaseRecordsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/batch_delete" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - table_id = tool_parameters.get('table_id', '') + table_id = tool_parameters.get("table_id", "") if not table_id: - return self.create_text_message('Invalid parameter table_id') + return self.create_text_message("Invalid parameter table_id") - record_ids = tool_parameters.get('record_ids', '') + record_ids = tool_parameters.get("record_ids", "") if not record_ids: - return self.create_text_message('Invalid parameter record_ids') + return self.create_text_message("Invalid parameter record_ids") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = {} - payload = { - "records": json.loads(record_ids) - } + payload = {"records": json.loads(record_ids)} try: - res = httpx.post(url.format(app_token=app_token, table_id=table_id), headers=headers, params=params, - json=payload, timeout=30) + res = httpx.post( + url.format(app_token=app_token, table_id=table_id), + headers=headers, + params=params, + json=payload, + timeout=30, + ) res_json = res.json() if res.is_success: return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to delete base records, status code: {res.status_code}, response: {res.text}") + f"Failed to delete base records, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to delete base records. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.py b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.py index c4280ebc21..f512186303 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.py @@ -8,32 +8,30 @@ from core.tools.tool.builtin_tool import BuiltinTool class DeleteBaseTablesTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/batch_delete" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - table_ids = tool_parameters.get('table_ids', '') + table_ids = tool_parameters.get("table_ids", "") if not table_ids: - return self.create_text_message('Invalid parameter table_ids') + return self.create_text_message("Invalid parameter table_ids") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = {} - payload = { - "table_ids": json.loads(table_ids) - } + payload = {"table_ids": json.loads(table_ids)} try: res = httpx.post(url.format(app_token=app_token), headers=headers, params=params, json=payload, timeout=30) @@ -42,6 +40,7 @@ class DeleteBaseTablesTool(BuiltinTool): return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to delete base tables, status code: {res.status_code}, response: {res.text}") + f"Failed to delete base tables, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to delete base tables. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.py b/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.py index de70f2ed93..f664bbeed0 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.py @@ -8,22 +8,22 @@ from core.tools.tool.builtin_tool import BuiltinTool class GetBaseInfoTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } try: @@ -33,6 +33,7 @@ class GetBaseInfoTool(BuiltinTool): return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to get base info, status code: {res.status_code}, response: {res.text}") + f"Failed to get base info, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to get base info. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.py b/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.py index 88507bda60..2ea61d0068 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.py @@ -8,27 +8,24 @@ from core.tools.tool.builtin_tool import BuiltinTool class GetTenantAccessTokenTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal" - app_id = tool_parameters.get('app_id', '') + app_id = tool_parameters.get("app_id", "") if not app_id: - return self.create_text_message('Invalid parameter app_id') + return self.create_text_message("Invalid parameter app_id") - app_secret = tool_parameters.get('app_secret', '') + app_secret = tool_parameters.get("app_secret", "") if not app_secret: - return self.create_text_message('Invalid parameter app_secret') + return self.create_text_message("Invalid parameter app_secret") headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } params = {} - payload = { - "app_id": app_id, - "app_secret": app_secret - } + payload = {"app_id": app_id, "app_secret": app_secret} """ { @@ -45,6 +42,7 @@ class GetTenantAccessTokenTool(BuiltinTool): return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to get tenant access token, status code: {res.status_code}, response: {res.text}") + f"Failed to get tenant access token, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to get tenant access token. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.py b/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.py index 2a4229f137..e579d02f69 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.py @@ -8,31 +8,31 @@ from core.tools.tool.builtin_tool import BuiltinTool class ListBaseRecordsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/search" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - table_id = tool_parameters.get('table_id', '') + table_id = tool_parameters.get("table_id", "") if not table_id: - return self.create_text_message('Invalid parameter table_id') + return self.create_text_message("Invalid parameter table_id") - page_token = tool_parameters.get('page_token', '') - page_size = tool_parameters.get('page_size', '') - sort_condition = tool_parameters.get('sort_condition', '') - filter_condition = tool_parameters.get('filter_condition', '') + page_token = tool_parameters.get("page_token", "") + page_size = tool_parameters.get("page_size", "") + sort_condition = tool_parameters.get("sort_condition", "") + filter_condition = tool_parameters.get("filter_condition", "") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = { @@ -40,22 +40,26 @@ class ListBaseRecordsTool(BuiltinTool): "page_size": page_size, } - payload = { - "automatic_fields": True - } + payload = {"automatic_fields": True} if sort_condition: payload["sort"] = json.loads(sort_condition) if filter_condition: payload["filter"] = json.loads(filter_condition) try: - res = httpx.post(url.format(app_token=app_token, table_id=table_id), headers=headers, params=params, - json=payload, timeout=30) + res = httpx.post( + url.format(app_token=app_token, table_id=table_id), + headers=headers, + params=params, + json=payload, + timeout=30, + ) res_json = res.json() if res.is_success: return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to list base records, status code: {res.status_code}, response: {res.text}") + f"Failed to list base records, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to list base records. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.py b/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.py index 6d82490eb3..4ec9a476bc 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.py @@ -8,25 +8,25 @@ from core.tools.tool.builtin_tool import BuiltinTool class ListBaseTablesTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - page_token = tool_parameters.get('page_token', '') - page_size = tool_parameters.get('page_size', '') + page_token = tool_parameters.get("page_token", "") + page_size = tool_parameters.get("page_size", "") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = { @@ -41,6 +41,7 @@ class ListBaseTablesTool(BuiltinTool): return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to list base tables, status code: {res.status_code}, response: {res.text}") + f"Failed to list base tables, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to list base tables. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.py b/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.py index bb4bd6c3a6..fb818f8380 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.py @@ -8,40 +8,42 @@ from core.tools.tool.builtin_tool import BuiltinTool class ReadBaseRecordTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/{record_id}" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - table_id = tool_parameters.get('table_id', '') + table_id = tool_parameters.get("table_id", "") if not table_id: - return self.create_text_message('Invalid parameter table_id') + return self.create_text_message("Invalid parameter table_id") - record_id = tool_parameters.get('record_id', '') + record_id = tool_parameters.get("record_id", "") if not record_id: - return self.create_text_message('Invalid parameter record_id') + return self.create_text_message("Invalid parameter record_id") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } try: - res = httpx.get(url.format(app_token=app_token, table_id=table_id, record_id=record_id), headers=headers, - timeout=30) + res = httpx.get( + url.format(app_token=app_token, table_id=table_id, record_id=record_id), headers=headers, timeout=30 + ) res_json = res.json() if res.is_success: return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to read base record, status code: {res.status_code}, response: {res.text}") + f"Failed to read base record, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to read base record. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.py b/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.py index 6551053ce2..6d7e33f3ff 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.py @@ -8,49 +8,53 @@ from core.tools.tool.builtin_tool import BuiltinTool class UpdateBaseRecordTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/{record_id}" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - table_id = tool_parameters.get('table_id', '') + table_id = tool_parameters.get("table_id", "") if not table_id: - return self.create_text_message('Invalid parameter table_id') + return self.create_text_message("Invalid parameter table_id") - record_id = tool_parameters.get('record_id', '') + record_id = tool_parameters.get("record_id", "") if not record_id: - return self.create_text_message('Invalid parameter record_id') + return self.create_text_message("Invalid parameter record_id") - fields = tool_parameters.get('fields', '') + fields = tool_parameters.get("fields", "") if not fields: - return self.create_text_message('Invalid parameter fields') + return self.create_text_message("Invalid parameter fields") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = {} - payload = { - "fields": json.loads(fields) - } + payload = {"fields": json.loads(fields)} try: - res = httpx.put(url.format(app_token=app_token, table_id=table_id, record_id=record_id), headers=headers, - params=params, json=payload, timeout=30) + res = httpx.put( + url.format(app_token=app_token, table_id=table_id, record_id=record_id), + headers=headers, + params=params, + json=payload, + timeout=30, + ) res_json = res.json() if res.is_success: return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to update base record, status code: {res.status_code}, response: {res.text}") + f"Failed to update base record, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to update base record. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_document/feishu_document.py b/api/core/tools/provider/builtin/feishu_document/feishu_document.py index c4f8f26e2c..b0a1e393eb 100644 --- a/api/core/tools/provider/builtin/feishu_document/feishu_document.py +++ b/api/core/tools/provider/builtin/feishu_document/feishu_document.py @@ -5,11 +5,11 @@ from core.tools.utils.feishu_api_utils import FeishuRequest class FeishuDocumentProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: - app_id = credentials.get('app_id') - app_secret = credentials.get('app_secret') + app_id = credentials.get("app_id") + app_secret = credentials.get("app_secret") if not app_id or not app_secret: raise ToolProviderCredentialValidationError("app_id and app_secret is required") try: assert FeishuRequest(app_id, app_secret).tenant_access_token is not None except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/create_document.py b/api/core/tools/provider/builtin/feishu_document/tools/create_document.py index 0ff82e621b..090a0828e8 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/create_document.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/create_document.py @@ -7,13 +7,13 @@ from core.tools.utils.feishu_api_utils import FeishuRequest class CreateDocumentTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get('app_id') - app_secret = self.runtime.credentials.get('app_secret') + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") client = FeishuRequest(app_id, app_secret) - title = tool_parameters.get('title') - content = tool_parameters.get('content') - folder_token = tool_parameters.get('folder_token') + title = tool_parameters.get("title") + content = tool_parameters.get("content") + folder_token = tool_parameters.get("folder_token") res = client.create_document(title, content, folder_token) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py b/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py index 16ef90908b..83073e0822 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py @@ -7,11 +7,11 @@ from core.tools.utils.feishu_api_utils import FeishuRequest class GetDocumentRawContentTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get('app_id') - app_secret = self.runtime.credentials.get('app_secret') + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") client = FeishuRequest(app_id, app_secret) - document_id = tool_parameters.get('document_id') + document_id = tool_parameters.get("document_id") res = client.get_document_raw_content(document_id) - return self.create_json_message(res) \ No newline at end of file + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py b/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py index 97d17bdb04..8c0c4a3c97 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py @@ -7,13 +7,13 @@ from core.tools.utils.feishu_api_utils import FeishuRequest class ListDocumentBlockTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get('app_id') - app_secret = self.runtime.credentials.get('app_secret') + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") client = FeishuRequest(app_id, app_secret) - document_id = tool_parameters.get('document_id') - page_size = tool_parameters.get('page_size', 500) - page_token = tool_parameters.get('page_token', '') + document_id = tool_parameters.get("document_id") + page_size = tool_parameters.get("page_size", 500) + page_token = tool_parameters.get("page_token", "") res = client.list_document_block(document_id, page_token, page_size) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/write_document.py b/api/core/tools/provider/builtin/feishu_document/tools/write_document.py index 914a44dce6..6061250e48 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/write_document.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/write_document.py @@ -7,13 +7,13 @@ from core.tools.utils.feishu_api_utils import FeishuRequest class CreateDocumentTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get('app_id') - app_secret = self.runtime.credentials.get('app_secret') + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") client = FeishuRequest(app_id, app_secret) - document_id = tool_parameters.get('document_id') - content = tool_parameters.get('content') - position = tool_parameters.get('position') + document_id = tool_parameters.get("document_id") + content = tool_parameters.get("content") + position = tool_parameters.get("position") res = client.write_document(document_id, content, position) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_message/feishu_message.py b/api/core/tools/provider/builtin/feishu_message/feishu_message.py index 6d7fed330c..7b3adb9293 100644 --- a/api/core/tools/provider/builtin/feishu_message/feishu_message.py +++ b/api/core/tools/provider/builtin/feishu_message/feishu_message.py @@ -5,11 +5,11 @@ from core.tools.utils.feishu_api_utils import FeishuRequest class FeishuMessageProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: - app_id = credentials.get('app_id') - app_secret = credentials.get('app_secret') + app_id = credentials.get("app_id") + app_secret = credentials.get("app_secret") if not app_id or not app_secret: raise ToolProviderCredentialValidationError("app_id and app_secret is required") try: assert FeishuRequest(app_id, app_secret).tenant_access_token is not None except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.py b/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.py index 74f6866ba3..1dd315d0e2 100644 --- a/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.py +++ b/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.py @@ -7,14 +7,14 @@ from core.tools.utils.feishu_api_utils import FeishuRequest class SendBotMessageTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get('app_id') - app_secret = self.runtime.credentials.get('app_secret') + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") client = FeishuRequest(app_id, app_secret) - receive_id_type = tool_parameters.get('receive_id_type') - receive_id = tool_parameters.get('receive_id') - msg_type = tool_parameters.get('msg_type') - content = tool_parameters.get('content') + receive_id_type = tool_parameters.get("receive_id_type") + receive_id = tool_parameters.get("receive_id") + msg_type = tool_parameters.get("msg_type") + content = tool_parameters.get("content") res = client.send_bot_message(receive_id_type, receive_id, msg_type, content) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.py b/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.py index 7159f59ffa..44e70e0a15 100644 --- a/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.py +++ b/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.py @@ -6,14 +6,14 @@ from core.tools.utils.feishu_api_utils import FeishuRequest class SendWebhookMessageTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) ->ToolInvokeMessage: - app_id = self.runtime.credentials.get('app_id') - app_secret = self.runtime.credentials.get('app_secret') + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") client = FeishuRequest(app_id, app_secret) - webhook = tool_parameters.get('webhook') - msg_type = tool_parameters.get('msg_type') - content = tool_parameters.get('content') + webhook = tool_parameters.get("webhook") + msg_type = tool_parameters.get("msg_type") + content = tool_parameters.get("content") res = client.send_webhook_message(webhook, msg_type, content) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/firecrawl/firecrawl.py b/api/core/tools/provider/builtin/firecrawl/firecrawl.py index 24dc35759d..01455d7206 100644 --- a/api/core/tools/provider/builtin/firecrawl/firecrawl.py +++ b/api/core/tools/provider/builtin/firecrawl/firecrawl.py @@ -7,15 +7,8 @@ class FirecrawlProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: # Example validation using the ScrapeTool, only scraping title for minimize content - ScrapeTool().fork_tool_runtime( - runtime={"credentials": credentials} - ).invoke( - user_id='', - tool_parameters={ - "url": "https://google.com", - "onlyIncludeTags": 'title' - } + ScrapeTool().fork_tool_runtime(runtime={"credentials": credentials}).invoke( + user_id="", tool_parameters={"url": "https://google.com", "onlyIncludeTags": "title"} ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py b/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py index 3b3f78731b..a0e4cdf933 100644 --- a/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py +++ b/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py @@ -13,27 +13,24 @@ logger = logging.getLogger(__name__) class FirecrawlApp: def __init__(self, api_key: str | None = None, base_url: str | None = None): self.api_key = api_key - self.base_url = base_url or 'https://api.firecrawl.dev' + self.base_url = base_url or "https://api.firecrawl.dev" if not self.api_key: raise ValueError("API key is required") def _prepare_headers(self, idempotency_key: str | None = None): - headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}' - } + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} if idempotency_key: - headers['Idempotency-Key'] = idempotency_key + headers["Idempotency-Key"] = idempotency_key return headers def _request( - self, - method: str, - url: str, - data: Mapping[str, Any] | None = None, - headers: Mapping[str, str] | None = None, - retries: int = 3, - backoff_factor: float = 0.3, + self, + method: str, + url: str, + data: Mapping[str, Any] | None = None, + headers: Mapping[str, str] | None = None, + retries: int = 3, + backoff_factor: float = 0.3, ) -> Mapping[str, Any] | None: if not headers: headers = self._prepare_headers() @@ -44,54 +41,54 @@ class FirecrawlApp: return response.json() except requests.exceptions.RequestException as e: if i < retries - 1: - time.sleep(backoff_factor * (2 ** i)) + time.sleep(backoff_factor * (2**i)) else: raise return None def scrape_url(self, url: str, **kwargs): - endpoint = f'{self.base_url}/v0/scrape' - data = {'url': url, **kwargs} + endpoint = f"{self.base_url}/v0/scrape" + data = {"url": url, **kwargs} logger.debug(f"Sent request to {endpoint=} body={data}") - response = self._request('POST', endpoint, data) + response = self._request("POST", endpoint, data) if response is None: raise HTTPError("Failed to scrape URL after multiple retries") return response def search(self, query: str, **kwargs): - endpoint = f'{self.base_url}/v0/search' - data = {'query': query, **kwargs} + endpoint = f"{self.base_url}/v0/search" + data = {"query": query, **kwargs} logger.debug(f"Sent request to {endpoint=} body={data}") - response = self._request('POST', endpoint, data) + response = self._request("POST", endpoint, data) if response is None: raise HTTPError("Failed to perform search after multiple retries") return response def crawl_url( - self, url: str, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs + self, url: str, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs ): - endpoint = f'{self.base_url}/v0/crawl' + endpoint = f"{self.base_url}/v0/crawl" headers = self._prepare_headers(idempotency_key) - data = {'url': url, **kwargs} + data = {"url": url, **kwargs} logger.debug(f"Sent request to {endpoint=} body={data}") - response = self._request('POST', endpoint, data, headers) + response = self._request("POST", endpoint, data, headers) if response is None: raise HTTPError("Failed to initiate crawl after multiple retries") - job_id: str = response['jobId'] + job_id: str = response["jobId"] if wait: return self._monitor_job_status(job_id=job_id, poll_interval=poll_interval) return response def check_crawl_status(self, job_id: str): - endpoint = f'{self.base_url}/v0/crawl/status/{job_id}' - response = self._request('GET', endpoint) + endpoint = f"{self.base_url}/v0/crawl/status/{job_id}" + response = self._request("GET", endpoint) if response is None: raise HTTPError(f"Failed to check status for job {job_id} after multiple retries") return response def cancel_crawl_job(self, job_id: str): - endpoint = f'{self.base_url}/v0/crawl/cancel/{job_id}' - response = self._request('DELETE', endpoint) + endpoint = f"{self.base_url}/v0/crawl/cancel/{job_id}" + response = self._request("DELETE", endpoint) if response is None: raise HTTPError(f"Failed to cancel job {job_id} after multiple retries") return response @@ -99,9 +96,9 @@ class FirecrawlApp: def _monitor_job_status(self, job_id: str, poll_interval: int): while True: status = self.check_crawl_status(job_id) - if status['status'] == 'completed': + if status["status"] == "completed": return status - elif status['status'] == 'failed': + elif status["status"] == "failed": raise HTTPError(f'Job {job_id} failed: {status["error"]}') time.sleep(poll_interval) @@ -109,7 +106,7 @@ class FirecrawlApp: def get_array_params(tool_parameters: dict[str, Any], key): param = tool_parameters.get(key) if param: - return param.split(',') + return param.split(",") def get_json_params(tool_parameters: dict[str, Any], key): diff --git a/api/core/tools/provider/builtin/firecrawl/tools/crawl.py b/api/core/tools/provider/builtin/firecrawl/tools/crawl.py index 08c40a4064..94717cbbfb 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/crawl.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/crawl.py @@ -11,38 +11,36 @@ class CrawlTool(BuiltinTool): the crawlerOptions and pageOptions comes from doc here: https://docs.firecrawl.dev/api-reference/endpoint/crawl """ - app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], - base_url=self.runtime.credentials['base_url']) + app = FirecrawlApp( + api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] + ) crawlerOptions = {} pageOptions = {} - wait_for_results = tool_parameters.get('wait_for_results', True) + wait_for_results = tool_parameters.get("wait_for_results", True) - crawlerOptions['excludes'] = get_array_params(tool_parameters, 'excludes') - crawlerOptions['includes'] = get_array_params(tool_parameters, 'includes') - crawlerOptions['returnOnlyUrls'] = tool_parameters.get('returnOnlyUrls', False) - crawlerOptions['maxDepth'] = tool_parameters.get('maxDepth') - crawlerOptions['mode'] = tool_parameters.get('mode') - crawlerOptions['ignoreSitemap'] = tool_parameters.get('ignoreSitemap', False) - crawlerOptions['limit'] = tool_parameters.get('limit', 5) - crawlerOptions['allowBackwardCrawling'] = tool_parameters.get('allowBackwardCrawling', False) - crawlerOptions['allowExternalContentLinks'] = tool_parameters.get('allowExternalContentLinks', False) + crawlerOptions["excludes"] = get_array_params(tool_parameters, "excludes") + crawlerOptions["includes"] = get_array_params(tool_parameters, "includes") + crawlerOptions["returnOnlyUrls"] = tool_parameters.get("returnOnlyUrls", False) + crawlerOptions["maxDepth"] = tool_parameters.get("maxDepth") + crawlerOptions["mode"] = tool_parameters.get("mode") + crawlerOptions["ignoreSitemap"] = tool_parameters.get("ignoreSitemap", False) + crawlerOptions["limit"] = tool_parameters.get("limit", 5) + crawlerOptions["allowBackwardCrawling"] = tool_parameters.get("allowBackwardCrawling", False) + crawlerOptions["allowExternalContentLinks"] = tool_parameters.get("allowExternalContentLinks", False) - pageOptions['headers'] = get_json_params(tool_parameters, 'headers') - pageOptions['includeHtml'] = tool_parameters.get('includeHtml', False) - pageOptions['includeRawHtml'] = tool_parameters.get('includeRawHtml', False) - pageOptions['onlyIncludeTags'] = get_array_params(tool_parameters, 'onlyIncludeTags') - pageOptions['removeTags'] = get_array_params(tool_parameters, 'removeTags') - pageOptions['onlyMainContent'] = tool_parameters.get('onlyMainContent', False) - pageOptions['replaceAllPathsWithAbsolutePaths'] = tool_parameters.get('replaceAllPathsWithAbsolutePaths', False) - pageOptions['screenshot'] = tool_parameters.get('screenshot', False) - pageOptions['waitFor'] = tool_parameters.get('waitFor', 0) + pageOptions["headers"] = get_json_params(tool_parameters, "headers") + pageOptions["includeHtml"] = tool_parameters.get("includeHtml", False) + pageOptions["includeRawHtml"] = tool_parameters.get("includeRawHtml", False) + pageOptions["onlyIncludeTags"] = get_array_params(tool_parameters, "onlyIncludeTags") + pageOptions["removeTags"] = get_array_params(tool_parameters, "removeTags") + pageOptions["onlyMainContent"] = tool_parameters.get("onlyMainContent", False) + pageOptions["replaceAllPathsWithAbsolutePaths"] = tool_parameters.get("replaceAllPathsWithAbsolutePaths", False) + pageOptions["screenshot"] = tool_parameters.get("screenshot", False) + pageOptions["waitFor"] = tool_parameters.get("waitFor", 0) crawl_result = app.crawl_url( - url=tool_parameters['url'], - wait=wait_for_results, - crawlerOptions=crawlerOptions, - pageOptions=pageOptions + url=tool_parameters["url"], wait=wait_for_results, crawlerOptions=crawlerOptions, pageOptions=pageOptions ) return self.create_json_message(crawl_result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/crawl_job.py b/api/core/tools/provider/builtin/firecrawl/tools/crawl_job.py index fa6c1f87ee..0d2486c7ca 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/crawl_job.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/crawl_job.py @@ -7,14 +7,15 @@ from core.tools.tool.builtin_tool import BuiltinTool class CrawlJobTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], - base_url=self.runtime.credentials['base_url']) - operation = tool_parameters.get('operation', 'get') - if operation == 'get': - result = app.check_crawl_status(job_id=tool_parameters['job_id']) - elif operation == 'cancel': - result = app.cancel_crawl_job(job_id=tool_parameters['job_id']) + app = FirecrawlApp( + api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] + ) + operation = tool_parameters.get("operation", "get") + if operation == "get": + result = app.check_crawl_status(job_id=tool_parameters["job_id"]) + elif operation == "cancel": + result = app.cancel_crawl_job(job_id=tool_parameters["job_id"]) else: - raise ValueError(f'Invalid operation: {operation}') + raise ValueError(f"Invalid operation: {operation}") return self.create_json_message(result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/scrape.py b/api/core/tools/provider/builtin/firecrawl/tools/scrape.py index 91412da548..962570bf73 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/scrape.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/scrape.py @@ -6,34 +6,34 @@ from core.tools.tool.builtin_tool import BuiltinTool class ScrapeTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: """ the pageOptions and extractorOptions comes from doc here: https://docs.firecrawl.dev/api-reference/endpoint/scrape """ - app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], - base_url=self.runtime.credentials['base_url']) + app = FirecrawlApp( + api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] + ) pageOptions = {} extractorOptions = {} - pageOptions['headers'] = get_json_params(tool_parameters, 'headers') - pageOptions['includeHtml'] = tool_parameters.get('includeHtml', False) - pageOptions['includeRawHtml'] = tool_parameters.get('includeRawHtml', False) - pageOptions['onlyIncludeTags'] = get_array_params(tool_parameters, 'onlyIncludeTags') - pageOptions['removeTags'] = get_array_params(tool_parameters, 'removeTags') - pageOptions['onlyMainContent'] = tool_parameters.get('onlyMainContent', False) - pageOptions['replaceAllPathsWithAbsolutePaths'] = tool_parameters.get('replaceAllPathsWithAbsolutePaths', False) - pageOptions['screenshot'] = tool_parameters.get('screenshot', False) - pageOptions['waitFor'] = tool_parameters.get('waitFor', 0) + pageOptions["headers"] = get_json_params(tool_parameters, "headers") + pageOptions["includeHtml"] = tool_parameters.get("includeHtml", False) + pageOptions["includeRawHtml"] = tool_parameters.get("includeRawHtml", False) + pageOptions["onlyIncludeTags"] = get_array_params(tool_parameters, "onlyIncludeTags") + pageOptions["removeTags"] = get_array_params(tool_parameters, "removeTags") + pageOptions["onlyMainContent"] = tool_parameters.get("onlyMainContent", False) + pageOptions["replaceAllPathsWithAbsolutePaths"] = tool_parameters.get("replaceAllPathsWithAbsolutePaths", False) + pageOptions["screenshot"] = tool_parameters.get("screenshot", False) + pageOptions["waitFor"] = tool_parameters.get("waitFor", 0) - extractorOptions['mode'] = tool_parameters.get('mode', '') - extractorOptions['extractionPrompt'] = tool_parameters.get('extractionPrompt', '') - extractorOptions['extractionSchema'] = get_json_params(tool_parameters, 'extractionSchema') + extractorOptions["mode"] = tool_parameters.get("mode", "") + extractorOptions["extractionPrompt"] = tool_parameters.get("extractionPrompt", "") + extractorOptions["extractionSchema"] = get_json_params(tool_parameters, "extractionSchema") - crawl_result = app.scrape_url(url=tool_parameters['url'], - pageOptions=pageOptions, - extractorOptions=extractorOptions) + crawl_result = app.scrape_url( + url=tool_parameters["url"], pageOptions=pageOptions, extractorOptions=extractorOptions + ) return self.create_json_message(crawl_result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/search.py b/api/core/tools/provider/builtin/firecrawl/tools/search.py index e2b2ac6b4d..f077e7d8ea 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/search.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/search.py @@ -11,18 +11,17 @@ class SearchTool(BuiltinTool): the pageOptions and searchOptions comes from doc here: https://docs.firecrawl.dev/api-reference/endpoint/search """ - app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], - base_url=self.runtime.credentials['base_url']) + app = FirecrawlApp( + api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] + ) pageOptions = {} - pageOptions['onlyMainContent'] = tool_parameters.get('onlyMainContent', False) - pageOptions['fetchPageContent'] = tool_parameters.get('fetchPageContent', True) - pageOptions['includeHtml'] = tool_parameters.get('includeHtml', False) - pageOptions['includeRawHtml'] = tool_parameters.get('includeRawHtml', False) - searchOptions = {'limit': tool_parameters.get('limit')} + pageOptions["onlyMainContent"] = tool_parameters.get("onlyMainContent", False) + pageOptions["fetchPageContent"] = tool_parameters.get("fetchPageContent", True) + pageOptions["includeHtml"] = tool_parameters.get("includeHtml", False) + pageOptions["includeRawHtml"] = tool_parameters.get("includeRawHtml", False) + searchOptions = {"limit": tool_parameters.get("limit")} search_result = app.search( - query=tool_parameters['keyword'], - pageOptions=pageOptions, - searchOptions=searchOptions + query=tool_parameters["keyword"], pageOptions=pageOptions, searchOptions=searchOptions ) return self.create_json_message(search_result) diff --git a/api/core/tools/provider/builtin/gaode/gaode.py b/api/core/tools/provider/builtin/gaode/gaode.py index b55d93e07b..a3e50da001 100644 --- a/api/core/tools/provider/builtin/gaode/gaode.py +++ b/api/core/tools/provider/builtin/gaode/gaode.py @@ -9,17 +9,19 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class GaodeProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: - if 'api_key' not in credentials or not credentials.get('api_key'): + if "api_key" not in credentials or not credentials.get("api_key"): raise ToolProviderCredentialValidationError("Gaode API key is required.") try: - response = requests.get(url="https://restapi.amap.com/v3/geocode/geo?address={address}&key={apikey}" - "".format(address=urllib.parse.quote('广东省广州市天河区广州塔'), - apikey=credentials.get('api_key'))) - if response.status_code == 200 and (response.json()).get('info') == 'OK': + response = requests.get( + url="https://restapi.amap.com/v3/geocode/geo?address={address}&key={apikey}" "".format( + address=urllib.parse.quote("广东省广州市天河区广州塔"), apikey=credentials.get("api_key") + ) + ) + if response.status_code == 200 and (response.json()).get("info") == "OK": pass else: - raise ToolProviderCredentialValidationError((response.json()).get('info')) + raise ToolProviderCredentialValidationError((response.json()).get("info")) except Exception as e: raise ToolProviderCredentialValidationError("Gaode API Key is invalid. {}".format(e)) except Exception as e: diff --git a/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py b/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py index efd11cedce..843504eefd 100644 --- a/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py +++ b/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py @@ -8,50 +8,57 @@ from core.tools.tool.builtin_tool import BuiltinTool class GaodeRepositoriesTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - city = tool_parameters.get('city', '') + city = tool_parameters.get("city", "") if not city: - return self.create_text_message('Please tell me your city') + return self.create_text_message("Please tell me your city") - if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): return self.create_text_message("Gaode API key is required.") try: s = requests.session() - api_domain = 'https://restapi.amap.com/v3' - city_response = s.request(method='GET', headers={"Content-Type": "application/json; charset=utf-8"}, - url="{url}/config/district?keywords={keywords}" - "&subdistrict=0&extensions=base&key={apikey}" - "".format(url=api_domain, keywords=city, - apikey=self.runtime.credentials.get('api_key'))) + api_domain = "https://restapi.amap.com/v3" + city_response = s.request( + method="GET", + headers={"Content-Type": "application/json; charset=utf-8"}, + url="{url}/config/district?keywords={keywords}" "&subdistrict=0&extensions=base&key={apikey}" "".format( + url=api_domain, keywords=city, apikey=self.runtime.credentials.get("api_key") + ), + ) City_data = city_response.json() - if city_response.status_code == 200 and City_data.get('info') == 'OK': - if len(City_data.get('districts')) > 0: - CityCode = City_data['districts'][0]['adcode'] - weatherInfo_response = s.request(method='GET', - url="{url}/weather/weatherInfo?city={citycode}&extensions=all&key={apikey}&output=json" - "".format(url=api_domain, citycode=CityCode, - apikey=self.runtime.credentials.get('api_key'))) + if city_response.status_code == 200 and City_data.get("info") == "OK": + if len(City_data.get("districts")) > 0: + CityCode = City_data["districts"][0]["adcode"] + weatherInfo_response = s.request( + method="GET", + url="{url}/weather/weatherInfo?city={citycode}&extensions=all&key={apikey}&output=json" + "".format(url=api_domain, citycode=CityCode, apikey=self.runtime.credentials.get("api_key")), + ) weatherInfo_data = weatherInfo_response.json() - if weatherInfo_response.status_code == 200 and weatherInfo_data.get('info') == 'OK': + if weatherInfo_response.status_code == 200 and weatherInfo_data.get("info") == "OK": contents = [] - if len(weatherInfo_data.get('forecasts')) > 0: - for item in weatherInfo_data['forecasts'][0]['casts']: + if len(weatherInfo_data.get("forecasts")) > 0: + for item in weatherInfo_data["forecasts"][0]["casts"]: content = {} - content['date'] = item.get('date') - content['week'] = item.get('week') - content['dayweather'] = item.get('dayweather') - content['daytemp_float'] = item.get('daytemp_float') - content['daywind'] = item.get('daywind') - content['nightweather'] = item.get('nightweather') - content['nighttemp_float'] = item.get('nighttemp_float') + content["date"] = item.get("date") + content["week"] = item.get("week") + content["dayweather"] = item.get("dayweather") + content["daytemp_float"] = item.get("daytemp_float") + content["daywind"] = item.get("daywind") + content["nightweather"] = item.get("nightweather") + content["nighttemp_float"] = item.get("nighttemp_float") contents.append(content) s.close() - return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False))) + return self.create_text_message( + self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False)) + ) s.close() - return self.create_text_message(f'No weather information for {city} was found.') + return self.create_text_message(f"No weather information for {city} was found.") except Exception as e: return self.create_text_message("Gaode API Key and Api Version is invalid. {}".format(e)) diff --git a/api/core/tools/provider/builtin/getimgai/getimgai.py b/api/core/tools/provider/builtin/getimgai/getimgai.py index c81d5fa333..bbd07d120f 100644 --- a/api/core/tools/provider/builtin/getimgai/getimgai.py +++ b/api/core/tools/provider/builtin/getimgai/getimgai.py @@ -7,16 +7,13 @@ class GetImgAIProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: # Example validation using the text2image tool - Text2ImageTool().fork_tool_runtime( - runtime={"credentials": credentials} - ).invoke( - user_id='', + Text2ImageTool().fork_tool_runtime(runtime={"credentials": credentials}).invoke( + user_id="", tool_parameters={ "prompt": "A fire egg", "response_format": "url", "style": "photorealism", - } + }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/getimgai/getimgai_appx.py b/api/core/tools/provider/builtin/getimgai/getimgai_appx.py index e28c57649c..0e95a5f654 100644 --- a/api/core/tools/provider/builtin/getimgai/getimgai_appx.py +++ b/api/core/tools/provider/builtin/getimgai/getimgai_appx.py @@ -8,18 +8,16 @@ from requests.exceptions import HTTPError logger = logging.getLogger(__name__) + class GetImgAIApp: def __init__(self, api_key: str | None = None, base_url: str | None = None): self.api_key = api_key - self.base_url = base_url or 'https://api.getimg.ai/v1' + self.base_url = base_url or "https://api.getimg.ai/v1" if not self.api_key: raise ValueError("API key is required") def _prepare_headers(self): - headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}' - } + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} return headers def _request( @@ -38,22 +36,20 @@ class GetImgAIApp: return response.json() except requests.exceptions.RequestException as e: if i < retries - 1 and isinstance(e, HTTPError) and e.response.status_code >= 500: - time.sleep(backoff_factor * (2 ** i)) + time.sleep(backoff_factor * (2**i)) else: raise return None - def text2image( - self, mode: str, **kwargs - ): - data = kwargs['params'] - if not data.get('prompt'): + def text2image(self, mode: str, **kwargs): + data = kwargs["params"] + if not data.get("prompt"): raise ValueError("Prompt is required") - endpoint = f'{self.base_url}/{mode}/text-to-image' + endpoint = f"{self.base_url}/{mode}/text-to-image" headers = self._prepare_headers() logger.debug(f"Send request to {endpoint=} body={data}") - response = self._request('POST', endpoint, data, headers) + response = self._request("POST", endpoint, data, headers) if response is None: raise HTTPError("Failed to initiate getimg.ai after multiple retries") return response diff --git a/api/core/tools/provider/builtin/getimgai/tools/text2image.py b/api/core/tools/provider/builtin/getimgai/tools/text2image.py index dad7314479..c556749552 100644 --- a/api/core/tools/provider/builtin/getimgai/tools/text2image.py +++ b/api/core/tools/provider/builtin/getimgai/tools/text2image.py @@ -7,28 +7,28 @@ from core.tools.tool.builtin_tool import BuiltinTool class Text2ImageTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - app = GetImgAIApp(api_key=self.runtime.credentials['getimg_api_key'], base_url=self.runtime.credentials['base_url']) + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + app = GetImgAIApp( + api_key=self.runtime.credentials["getimg_api_key"], base_url=self.runtime.credentials["base_url"] + ) options = { - 'style': tool_parameters.get('style'), - 'prompt': tool_parameters.get('prompt'), - 'aspect_ratio': tool_parameters.get('aspect_ratio'), - 'output_format': tool_parameters.get('output_format', 'jpeg'), - 'response_format': tool_parameters.get('response_format', 'url'), - 'width': tool_parameters.get('width'), - 'height': tool_parameters.get('height'), - 'steps': tool_parameters.get('steps'), - 'negative_prompt': tool_parameters.get('negative_prompt'), - 'prompt_2': tool_parameters.get('prompt_2'), + "style": tool_parameters.get("style"), + "prompt": tool_parameters.get("prompt"), + "aspect_ratio": tool_parameters.get("aspect_ratio"), + "output_format": tool_parameters.get("output_format", "jpeg"), + "response_format": tool_parameters.get("response_format", "url"), + "width": tool_parameters.get("width"), + "height": tool_parameters.get("height"), + "steps": tool_parameters.get("steps"), + "negative_prompt": tool_parameters.get("negative_prompt"), + "prompt_2": tool_parameters.get("prompt_2"), } options = {k: v for k, v in options.items() if v} - text2image_result = app.text2image( - mode=tool_parameters.get('mode', 'essential-v2'), - params=options, - wait=True - ) + text2image_result = app.text2image(mode=tool_parameters.get("mode", "essential-v2"), params=options, wait=True) if not isinstance(text2image_result, str): text2image_result = json.dumps(text2image_result, ensure_ascii=False, indent=4) diff --git a/api/core/tools/provider/builtin/github/github.py b/api/core/tools/provider/builtin/github/github.py index 9275504208..87a34ac3e8 100644 --- a/api/core/tools/provider/builtin/github/github.py +++ b/api/core/tools/provider/builtin/github/github.py @@ -4,28 +4,28 @@ from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController -class GihubProvider(BuiltinToolProviderController): +class GithubProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: - if 'access_tokens' not in credentials or not credentials.get('access_tokens'): + if "access_tokens" not in credentials or not credentials.get("access_tokens"): raise ToolProviderCredentialValidationError("Github API Access Tokens is required.") - if 'api_version' not in credentials or not credentials.get('api_version'): - api_version = '2022-11-28' + if "api_version" not in credentials or not credentials.get("api_version"): + api_version = "2022-11-28" else: - api_version = credentials.get('api_version') + api_version = credentials.get("api_version") try: headers = { "Content-Type": "application/vnd.github+json", "Authorization": f"Bearer {credentials.get('access_tokens')}", - "X-GitHub-Api-Version": api_version + "X-GitHub-Api-Version": api_version, } response = requests.get( - url="https://api.github.com/search/users?q={account}".format(account='charli117'), - headers=headers) + url="https://api.github.com/search/users?q={account}".format(account="charli117"), headers=headers + ) if response.status_code != 200: - raise ToolProviderCredentialValidationError((response.json()).get('message')) + raise ToolProviderCredentialValidationError((response.json()).get("message")) except Exception as e: raise ToolProviderCredentialValidationError("Github API Key and Api Version is invalid. {}".format(e)) except Exception as e: diff --git a/api/core/tools/provider/builtin/github/tools/github_repositories.py b/api/core/tools/provider/builtin/github/tools/github_repositories.py index a2f1e07fd4..3eab8bf8dc 100644 --- a/api/core/tools/provider/builtin/github/tools/github_repositories.py +++ b/api/core/tools/provider/builtin/github/tools/github_repositories.py @@ -9,54 +9,62 @@ from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool -class GihubRepositoriesTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: +class GithubRepositoriesTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - top_n = tool_parameters.get('top_n', 5) - query = tool_parameters.get('query', '') + top_n = tool_parameters.get("top_n", 5) + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Please input symbol') + return self.create_text_message("Please input symbol") - if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'): + if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"): return self.create_text_message("Github API Access Tokens is required.") - if 'api_version' not in self.runtime.credentials or not self.runtime.credentials.get('api_version'): - api_version = '2022-11-28' + if "api_version" not in self.runtime.credentials or not self.runtime.credentials.get("api_version"): + api_version = "2022-11-28" else: - api_version = self.runtime.credentials.get('api_version') + api_version = self.runtime.credentials.get("api_version") try: headers = { "Content-Type": "application/vnd.github+json", "Authorization": f"Bearer {self.runtime.credentials.get('access_tokens')}", - "X-GitHub-Api-Version": api_version + "X-GitHub-Api-Version": api_version, } s = requests.session() - api_domain = 'https://api.github.com' - response = s.request(method='GET', headers=headers, - url=f"{api_domain}/search/repositories?" - f"q={quote(query)}&sort=stars&per_page={top_n}&order=desc") + api_domain = "https://api.github.com" + response = s.request( + method="GET", + headers=headers, + url=f"{api_domain}/search/repositories?" f"q={quote(query)}&sort=stars&per_page={top_n}&order=desc", + ) response_data = response.json() - if response.status_code == 200 and isinstance(response_data.get('items'), list): + if response.status_code == 200 and isinstance(response_data.get("items"), list): contents = [] - if len(response_data.get('items')) > 0: - for item in response_data.get('items'): + if len(response_data.get("items")) > 0: + for item in response_data.get("items"): content = {} - updated_at_object = datetime.strptime(item['updated_at'], "%Y-%m-%dT%H:%M:%SZ") - content['owner'] = item['owner']['login'] - content['name'] = item['name'] - content['description'] = item['description'][:100] + '...' if len(item['description']) > 100 else item['description'] - content['url'] = item['html_url'] - content['star'] = item['watchers'] - content['forks'] = item['forks'] - content['updated'] = updated_at_object.strftime("%Y-%m-%d") + updated_at_object = datetime.strptime(item["updated_at"], "%Y-%m-%dT%H:%M:%SZ") + content["owner"] = item["owner"]["login"] + content["name"] = item["name"] + content["description"] = ( + item["description"][:100] + "..." if len(item["description"]) > 100 else item["description"] + ) + content["url"] = item["html_url"] + content["star"] = item["watchers"] + content["forks"] = item["forks"] + content["updated"] = updated_at_object.strftime("%Y-%m-%d") contents.append(content) s.close() - return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False))) + return self.create_text_message( + self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False)) + ) else: - return self.create_text_message(f'No items related to {query} were found.') + return self.create_text_message(f"No items related to {query} were found.") else: - return self.create_text_message((response.json()).get('message')) + return self.create_text_message((response.json()).get("message")) except Exception as e: return self.create_text_message("Github API Key and Api Version is invalid. {}".format(e)) diff --git a/api/core/tools/provider/builtin/gitlab/gitlab.py b/api/core/tools/provider/builtin/gitlab/gitlab.py index 0c13ec662a..9bd4a0bd52 100644 --- a/api/core/tools/provider/builtin/gitlab/gitlab.py +++ b/api/core/tools/provider/builtin/gitlab/gitlab.py @@ -9,13 +9,13 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class GitlabProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - if 'access_tokens' not in credentials or not credentials.get('access_tokens'): + if "access_tokens" not in credentials or not credentials.get("access_tokens"): raise ToolProviderCredentialValidationError("Gitlab Access Tokens is required.") - - if 'site_url' not in credentials or not credentials.get('site_url'): - site_url = 'https://gitlab.com' + + if "site_url" not in credentials or not credentials.get("site_url"): + site_url = "https://gitlab.com" else: - site_url = credentials.get('site_url') + site_url = credentials.get("site_url") try: headers = { @@ -23,12 +23,10 @@ class GitlabProvider(BuiltinToolProviderController): "Authorization": f"Bearer {credentials.get('access_tokens')}", } - response = requests.get( - url= f"{site_url}/api/v4/user", - headers=headers) + response = requests.get(url=f"{site_url}/api/v4/user", headers=headers) if response.status_code != 200: - raise ToolProviderCredentialValidationError((response.json()).get('message')) + raise ToolProviderCredentialValidationError((response.json()).get("message")) except Exception as e: raise ToolProviderCredentialValidationError("Gitlab Access Tokens is invalid. {}".format(e)) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py index 880d722bda..45ab15f437 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py @@ -1,4 +1,5 @@ import json +import urllib.parse from datetime import datetime, timedelta from typing import Any, Union @@ -9,103 +10,133 @@ from core.tools.tool.builtin_tool import BuiltinTool class GitlabCommitsTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + project = tool_parameters.get("project", "") + repository = tool_parameters.get("repository", "") + employee = tool_parameters.get("employee", "") + start_time = tool_parameters.get("start_time", "") + end_time = tool_parameters.get("end_time", "") + change_type = tool_parameters.get("change_type", "all") - project = tool_parameters.get('project', '') - employee = tool_parameters.get('employee', '') - start_time = tool_parameters.get('start_time', '') - end_time = tool_parameters.get('end_time', '') - change_type = tool_parameters.get('change_type', 'all') - - if not project: - return self.create_text_message('Project is required') + if not project and not repository: + return self.create_text_message("Either project or repository is required") if not start_time: start_time = (datetime.utcnow() - timedelta(days=1)).isoformat() if not end_time: end_time = datetime.utcnow().isoformat() - access_token = self.runtime.credentials.get('access_tokens') - site_url = self.runtime.credentials.get('site_url') + access_token = self.runtime.credentials.get("access_tokens") + site_url = self.runtime.credentials.get("site_url") - if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'): + if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"): return self.create_text_message("Gitlab API Access Tokens is required.") - if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'): - site_url = 'https://gitlab.com' - + if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"): + site_url = "https://gitlab.com" + # Get commit content - result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time, change_type) + if repository: + result = self.fetch_commits( + site_url, access_token, repository, employee, start_time, end_time, change_type, is_repository=True + ) + else: + result = self.fetch_commits( + site_url, access_token, project, employee, start_time, end_time, change_type, is_repository=False + ) return [self.create_json_message(item) for item in result] - - def fetch(self,user_id: str, site_url: str, access_token: str, project: str, employee: str = None, start_time: str = '', end_time: str = '', change_type: str = '') -> list[dict[str, Any]]: + + def fetch_commits( + self, + site_url: str, + access_token: str, + identifier: str, + employee: str, + start_time: str, + end_time: str, + change_type: str, + is_repository: bool, + ) -> list[dict[str, Any]]: domain = site_url headers = {"PRIVATE-TOKEN": access_token} results = [] try: - # Get all of projects - url = f"{domain}/api/v4/projects" - response = requests.get(url, headers=headers) - response.raise_for_status() - projects = response.json() + if is_repository: + # URL encode the repository path + encoded_identifier = urllib.parse.quote(identifier, safe="") + commits_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/commits" + else: + # Get all projects + url = f"{domain}/api/v4/projects" + response = requests.get(url, headers=headers) + response.raise_for_status() + projects = response.json() - filtered_projects = [p for p in projects if project == "*" or p['name'] == project] + filtered_projects = [p for p in projects if identifier == "*" or p["name"] == identifier] - for project in filtered_projects: - project_id = project['id'] - project_name = project['name'] - print(f"Project: {project_name}") + for project in filtered_projects: + project_id = project["id"] + project_name = project["name"] + print(f"Project: {project_name}") - # Get all of proejct commits - commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits" - params = { - 'since': start_time, - 'until': end_time - } - if employee: - params['author'] = employee + commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits" - commits_response = requests.get(commits_url, headers=headers, params=params) - commits_response.raise_for_status() - commits = commits_response.json() + params = {"since": start_time, "until": end_time} + if employee: + params["author"] = employee - for commit in commits: - commit_sha = commit['id'] - author_name = commit['author_name'] + commits_response = requests.get(commits_url, headers=headers, params=params) + commits_response.raise_for_status() + commits = commits_response.json() + for commit in commits: + commit_sha = commit["id"] + author_name = commit["author_name"] + + if is_repository: + diff_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/commits/{commit_sha}/diff" + else: diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff" - diff_response = requests.get(diff_url, headers=headers) - diff_response.raise_for_status() - diffs = diff_response.json() - - for diff in diffs: - # Caculate code lines of changed - added_lines = diff['diff'].count('\n+') - removed_lines = diff['diff'].count('\n-') - total_changes = added_lines + removed_lines - if change_type == "new": - if added_lines > 1: - final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if line.startswith('+') and not line.startswith('+++')]) - results.append({ - "commit_sha": commit_sha, - "author_name": author_name, - "diff": final_code - }) - else: - if total_changes > 1: - final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if (line.startswith('+') or line.startswith('-')) and not line.startswith('+++') and not line.startswith('---')]) - final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code - results.append({ - "commit_sha": commit_sha, - "author_name": author_name, - "diff": final_code_escaped - }) + diff_response = requests.get(diff_url, headers=headers) + diff_response.raise_for_status() + diffs = diff_response.json() + + for diff in diffs: + # Calculate code lines of changes + added_lines = diff["diff"].count("\n+") + removed_lines = diff["diff"].count("\n-") + total_changes = added_lines + removed_lines + + if change_type == "new": + if added_lines > 1: + final_code = "".join( + [ + line[1:] + for line in diff["diff"].split("\n") + if line.startswith("+") and not line.startswith("+++") + ] + ) + results.append({"commit_sha": commit_sha, "author_name": author_name, "diff": final_code}) + else: + if total_changes > 1: + final_code = "".join( + [ + line[1:] + for line in diff["diff"].split("\n") + if (line.startswith("+") or line.startswith("-")) + and not line.startswith("+++") + and not line.startswith("---") + ] + ) + final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code + results.append( + {"commit_sha": commit_sha, "author_name": author_name, "diff": final_code_escaped} + ) except requests.RequestException as e: print(f"Error fetching data from GitLab: {e}") - - return results \ No newline at end of file + + return results diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml index dd4e31d663..669378ac97 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml @@ -6,7 +6,7 @@ identity: zh_Hans: GitLab 提交内容查询 description: human: - en_US: A tool for query GitLab commits, Input should be a exists username or projec. + en_US: A tool for query GitLab commits, Input should be a exists username or project. zh_Hans: 一个用于查询 GitLab 代码提交内容的工具,输入的内容应该是一个已存在的用户名或者项目名。 llm: A tool for query GitLab commits, Input should be a exists username or project. parameters: @@ -21,9 +21,20 @@ parameters: zh_Hans: 员工用户名 llm_description: User name for GitLab form: llm + - name: repository + type: string + required: false + label: + en_US: repository + zh_Hans: 仓库路径 + human_description: + en_US: repository + zh_Hans: 仓库路径,以namespace/project_name的形式。 + llm_description: Repository path for GitLab, like namespace/project_name. + form: llm - name: project type: string - required: true + required: false label: en_US: project zh_Hans: 项目名 diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py index 7fa1d0d112..7606eee7af 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py @@ -1,3 +1,4 @@ +import urllib.parse from typing import Any, Union import requests @@ -7,47 +8,82 @@ from core.tools.tool.builtin_tool import BuiltinTool class GitlabFilesTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - project = tool_parameters.get('project', '') - branch = tool_parameters.get('branch', '') - path = tool_parameters.get('path', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + project = tool_parameters.get("project", "") + repository = tool_parameters.get("repository", "") + branch = tool_parameters.get("branch", "") + path = tool_parameters.get("path", "") - - if not project: - return self.create_text_message('Project is required') + if not project and not repository: + return self.create_text_message("Either project or repository is required") if not branch: - return self.create_text_message('Branch is required') - + return self.create_text_message("Branch is required") if not path: - return self.create_text_message('Path is required') + return self.create_text_message("Path is required") - access_token = self.runtime.credentials.get('access_tokens') - site_url = self.runtime.credentials.get('site_url') + access_token = self.runtime.credentials.get("access_tokens") + site_url = self.runtime.credentials.get("site_url") - if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'): + if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"): return self.create_text_message("Gitlab API Access Tokens is required.") - if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'): - site_url = 'https://gitlab.com' - - # Get project ID from project name - project_id = self.get_project_id(site_url, access_token, project) - if not project_id: - return self.create_text_message(f"Project '{project}' not found.") + if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"): + site_url = "https://gitlab.com" - # Get commit content - result = self.fetch(user_id, project_id, site_url, access_token, branch, path) + # Get file content + if repository: + result = self.fetch_files(site_url, access_token, repository, branch, path, is_repository=True) + else: + result = self.fetch_files(site_url, access_token, project, branch, path, is_repository=False) return [self.create_json_message(item) for item in result] - - def extract_project_name_and_path(self, path: str) -> tuple[str, str]: - parts = path.split('/', 1) - if len(parts) < 2: - return None, None - return parts[0], parts[1] + + def fetch_files( + self, site_url: str, access_token: str, identifier: str, branch: str, path: str, is_repository: bool + ) -> list[dict[str, Any]]: + domain = site_url + headers = {"PRIVATE-TOKEN": access_token} + results = [] + + try: + if is_repository: + # URL encode the repository path + encoded_identifier = urllib.parse.quote(identifier, safe="") + tree_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/tree?path={path}&ref={branch}" + else: + # Get project ID from project name + project_id = self.get_project_id(site_url, access_token, identifier) + if not project_id: + return self.create_text_message(f"Project '{identifier}' not found.") + tree_url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}" + + response = requests.get(tree_url, headers=headers) + response.raise_for_status() + items = response.json() + + for item in items: + item_path = item["path"] + if item["type"] == "tree": # It's a directory + results.extend( + self.fetch_files(site_url, access_token, identifier, branch, item_path, is_repository) + ) + else: # It's a file + if is_repository: + file_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/files/{item_path}/raw?ref={branch}" + else: + file_url = ( + f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}" + ) + + file_response = requests.get(file_url, headers=headers) + file_response.raise_for_status() + file_content = file_response.text + results.append({"path": item_path, "branch": branch, "content": file_content}) + except requests.RequestException as e: + print(f"Error fetching data from GitLab: {e}") + + return results def get_project_id(self, site_url: str, access_token: str, project_name: str) -> Union[str, None]: headers = {"PRIVATE-TOKEN": access_token} @@ -57,39 +93,8 @@ class GitlabFilesTool(BuiltinTool): response.raise_for_status() projects = response.json() for project in projects: - if project['name'] == project_name: - return project['id'] + if project["name"] == project_name: + return project["id"] except requests.RequestException as e: print(f"Error fetching project ID from GitLab: {e}") return None - - def fetch(self,user_id: str, project_id: str, site_url: str, access_token: str, branch: str, path: str = None) -> list[dict[str, Any]]: - domain = site_url - headers = {"PRIVATE-TOKEN": access_token} - results = [] - - try: - # List files and directories in the given path - url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}" - response = requests.get(url, headers=headers) - response.raise_for_status() - items = response.json() - - for item in items: - item_path = item['path'] - if item['type'] == 'tree': # It's a directory - results.extend(self.fetch(project_id, site_url, access_token, branch, item_path)) - else: # It's a file - file_url = f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}" - file_response = requests.get(file_url, headers=headers) - file_response.raise_for_status() - file_content = file_response.text - results.append({ - "path": item_path, - "branch": branch, - "content": file_content - }) - except requests.RequestException as e: - print(f"Error fetching data from GitLab: {e}") - - return results \ No newline at end of file diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml index d99b6254c1..4c733673f1 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml @@ -10,9 +10,20 @@ description: zh_Hans: 一个用于查询 GitLab 文件的工具,输入的内容应该是分支和一个已存在文件或者文件夹路径。 llm: A tool for query GitLab files, Input should be a exists file or directory path. parameters: + - name: repository + type: string + required: false + label: + en_US: repository + zh_Hans: 仓库路径 + human_description: + en_US: repository + zh_Hans: 仓库路径,以namespace/project_name的形式。 + llm_description: Repository path for GitLab, like namespace/project_name. + form: llm - name: project type: string - required: true + required: false label: en_US: project zh_Hans: 项目 diff --git a/api/core/tools/provider/builtin/google/google.py b/api/core/tools/provider/builtin/google/google.py index 8f4b9a4a4e..6b5395f9d3 100644 --- a/api/core/tools/provider/builtin/google/google.py +++ b/api/core/tools/provider/builtin/google/google.py @@ -13,12 +13,8 @@ class GoogleProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "query": "test", - "result_type": "link" - }, + user_id="", + tool_parameters={"query": "test", "result_type": "link"}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/google/tools/google_search.py b/api/core/tools/provider/builtin/google/tools/google_search.py index 09d0326fb4..a9f65925d8 100644 --- a/api/core/tools/provider/builtin/google/tools/google_search.py +++ b/api/core/tools/provider/builtin/google/tools/google_search.py @@ -9,7 +9,6 @@ SERP_API_URL = "https://serpapi.com/search" class GoogleSearchTool(BuiltinTool): - def _parse_response(self, response: dict) -> dict: result = {} if "knowledge_graph" in response: @@ -17,25 +16,23 @@ class GoogleSearchTool(BuiltinTool): result["description"] = response["knowledge_graph"].get("description", "") if "organic_results" in response: result["organic_results"] = [ - { - "title": item.get("title", ""), - "link": item.get("link", ""), - "snippet": item.get("snippet", "") - } + {"title": item.get("title", ""), "link": item.get("link", ""), "snippet": item.get("snippet", "")} for item in response["organic_results"] ] return result - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: params = { - "api_key": self.runtime.credentials['serpapi_api_key'], - "q": tool_parameters['query'], + "api_key": self.runtime.credentials["serpapi_api_key"], + "q": tool_parameters["query"], "engine": "google", "google_domain": "google.com", "gl": "us", - "hl": "en" + "hl": "en", } response = requests.get(url=SERP_API_URL, params=params) response.raise_for_status() diff --git a/api/core/tools/provider/builtin/google_translate/google_translate.py b/api/core/tools/provider/builtin/google_translate/google_translate.py index f6e1d65834..ea53aa4eeb 100644 --- a/api/core/tools/provider/builtin/google_translate/google_translate.py +++ b/api/core/tools/provider/builtin/google_translate/google_translate.py @@ -8,10 +8,6 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class JsonExtractProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - GoogleTranslate().invoke(user_id='', - tool_parameters={ - "content": "这是一段测试文本", - "dest": "en" - }) + GoogleTranslate().invoke(user_id="", tool_parameters={"content": "这是一段测试文本", "dest": "en"}) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/google_translate/tools/translate.py b/api/core/tools/provider/builtin/google_translate/tools/translate.py index 4314182b06..5d57b5fabf 100644 --- a/api/core/tools/provider/builtin/google_translate/tools/translate.py +++ b/api/core/tools/provider/builtin/google_translate/tools/translate.py @@ -7,46 +7,40 @@ from core.tools.tool.builtin_tool import BuiltinTool class GoogleTranslate(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") - dest = tool_parameters.get('dest', '') + dest = tool_parameters.get("dest", "") if not dest: - return self.create_text_message('Invalid parameter destination language') + return self.create_text_message("Invalid parameter destination language") try: result = self._translate(content, dest) return self.create_text_message(str(result)) except Exception: - return self.create_text_message('Translation service error, please check the network') + return self.create_text_message("Translation service error, please check the network") def _translate(self, content: str, dest: str) -> str: try: url = "https://translate.googleapis.com/translate_a/single" - params = { - "client": "gtx", - "sl": "auto", - "tl": dest, - "dt": "t", - "q": content - } + params = {"client": "gtx", "sl": "auto", "tl": dest, "dt": "t", "q": content} headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" } - response_json = requests.get( - url, params=params, headers=headers).json() + response_json = requests.get(url, params=params, headers=headers).json() result = response_json[0] - translated_text = ''.join([item[0] for item in result if item[0]]) + translated_text = "".join([item[0] for item in result if item[0]]) return str(translated_text) except Exception as e: return str(e) diff --git a/api/core/tools/provider/builtin/hap/hap.py b/api/core/tools/provider/builtin/hap/hap.py index e0a48e05a5..cbdf950465 100644 --- a/api/core/tools/provider/builtin/hap/hap.py +++ b/api/core/tools/provider/builtin/hap/hap.py @@ -5,4 +5,4 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class HapProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: - pass \ No newline at end of file + pass diff --git a/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py b/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py index 0e101dc67d..f2288ed81c 100644 --- a/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py +++ b/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py @@ -8,41 +8,40 @@ from core.tools.tool.builtin_tool import BuiltinTool class AddWorksheetRecordTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - worksheet_id = tool_parameters.get('worksheet_id', '') + return self.create_text_message("Invalid parameter Sign") + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') - record_data = tool_parameters.get('record_data', '') + return self.create_text_message("Invalid parameter Worksheet ID") + record_data = tool_parameters.get("record_data", "") if not record_data: - return self.create_text_message('Invalid parameter Record Row Data') - - host = tool_parameters.get('host', '') + return self.create_text_message("Invalid parameter Record Row Data") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not host.startswith(("http://", "https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: host = f"{host[:-1] if host.endswith('/') else host}/api" url = f"{host}/v2/open/worksheet/addRow" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id} try: - payload['controls'] = json.loads(record_data) + payload["controls"] = json.loads(record_data) res = httpx.post(url, headers=headers, json=payload, timeout=60) res.raise_for_status() res_json = res.json() - if res_json.get('error_code') != 1: + if res_json.get("error_code") != 1: return self.create_text_message(f"Failed to add the new record. {res_json['error_msg']}") return self.create_text_message(f"New record added successfully. The record ID is {res_json['data']}.") except httpx.RequestError as e: diff --git a/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py b/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py index ba25952c9f..1df5f6d5cf 100644 --- a/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py +++ b/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py @@ -7,43 +7,42 @@ from core.tools.tool.builtin_tool import BuiltinTool class DeleteWorksheetRecordTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - worksheet_id = tool_parameters.get('worksheet_id', '') + return self.create_text_message("Invalid parameter Sign") + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') - row_id = tool_parameters.get('row_id', '') + return self.create_text_message("Invalid parameter Worksheet ID") + row_id = tool_parameters.get("row_id", "") if not row_id: - return self.create_text_message('Invalid parameter Record Row ID') - - host = tool_parameters.get('host', '') + return self.create_text_message("Invalid parameter Record Row ID") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not host.startswith(("http://", "https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: host = f"{host[:-1] if host.endswith('/') else host}/api" url = f"{host}/v2/open/worksheet/deleteRow" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id, "rowId": row_id} try: res = httpx.post(url, headers=headers, json=payload, timeout=30) res.raise_for_status() res_json = res.json() - if res_json.get('error_code') != 1: + if res_json.get("error_code") != 1: return self.create_text_message(f"Failed to delete the record. {res_json['error_msg']}") return self.create_text_message("Successfully deleted the record.") except httpx.RequestError as e: return self.create_text_message(f"Failed to delete the record, request error: {e}") except Exception as e: - return self.create_text_message(f"Failed to delete the record, unexpected error: {e}") \ No newline at end of file + return self.create_text_message(f"Failed to delete the record, unexpected error: {e}") diff --git a/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py b/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py index 2c46d9dd4e..69cf8aa740 100644 --- a/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py +++ b/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py @@ -8,43 +8,42 @@ from core.tools.tool.builtin_tool import BuiltinTool class GetWorksheetFieldsTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - worksheet_id = tool_parameters.get('worksheet_id', '') + return self.create_text_message("Invalid parameter Sign") + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') - - host = tool_parameters.get('host', '') + return self.create_text_message("Invalid parameter Worksheet ID") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not host.startswith(("http://", "https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: host = f"{host[:-1] if host.endswith('/') else host}/api" url = f"{host}/v2/open/worksheet/getWorksheetInfo" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id} try: res = httpx.post(url, headers=headers, json=payload, timeout=60) res.raise_for_status() res_json = res.json() - if res_json.get('error_code') != 1: + if res_json.get("error_code") != 1: return self.create_text_message(f"Failed to get the worksheet information. {res_json['error_msg']}") - - fields_json, fields_table = self.get_controls(res_json['data']['controls']) - result_type = tool_parameters.get('result_type', 'table') + + fields_json, fields_table = self.get_controls(res_json["data"]["controls"]) + result_type = tool_parameters.get("result_type", "table") return self.create_text_message( - text=json.dumps(fields_json, ensure_ascii=False) if result_type == 'json' else fields_table + text=json.dumps(fields_json, ensure_ascii=False) if result_type == "json" else fields_table ) except httpx.RequestError as e: return self.create_text_message(f"Failed to get the worksheet information, request error: {e}") @@ -88,61 +87,65 @@ class GetWorksheetFieldsTool(BuiltinTool): 50: "Text", 51: "Query Record", } - return field_type_map.get(field_type_id, '') + return field_type_map.get(field_type_id, "") def get_controls(self, controls: list) -> dict: fields = [] - fields_list = ['|fieldId|fieldName|fieldType|fieldTypeId|description|options|','|'+'---|'*6] + fields_list = ["|fieldId|fieldName|fieldType|fieldTypeId|description|options|", "|" + "---|" * 6] for control in controls: - if control['type'] in self._get_ignore_types(): + if control["type"] in self._get_ignore_types(): continue - field_type_id = control['type'] - field_type = self.get_field_type_by_id(control['type']) + field_type_id = control["type"] + field_type = self.get_field_type_by_id(control["type"]) if field_type_id == 30: - source_type = control['sourceControl']['type'] + source_type = control["sourceControl"]["type"] if source_type in self._get_ignore_types(): continue else: field_type_id = source_type field_type = self.get_field_type_by_id(source_type) field = { - 'id': control['controlId'], - 'name': control['controlName'], - 'type': field_type, - 'typeId': field_type_id, - 'description': control['remark'].replace('\n', ' ').replace('\t', ' '), - 'options': self._extract_options(control), + "id": control["controlId"], + "name": control["controlName"], + "type": field_type, + "typeId": field_type_id, + "description": control["remark"].replace("\n", " ").replace("\t", " "), + "options": self._extract_options(control), } fields.append(field) - fields_list.append(f"|{field['id']}|{field['name']}|{field['type']}|{field['typeId']}|{field['description']}|{field['options'] if field['options'] else ''}|") + fields_list.append( + f"|{field['id']}|{field['name']}|{field['type']}|{field['typeId']}|{field['description']}|{field['options'] if field['options'] else ''}|" + ) - fields.append({ - 'id': 'ctime', - 'name': 'Created Time', - 'type': self.get_field_type_by_id(16), - 'typeId': 16, - 'description': '', - 'options': [] - }) + fields.append( + { + "id": "ctime", + "name": "Created Time", + "type": self.get_field_type_by_id(16), + "typeId": 16, + "description": "", + "options": [], + } + ) fields_list.append("|ctime|Created Time|Date|16|||") - return fields, '\n'.join(fields_list) + return fields, "\n".join(fields_list) def _extract_options(self, control: dict) -> list: options = [] - if control['type'] in [9, 10, 11]: - options.extend([{"key": opt['key'], "value": opt['value']} for opt in control.get('options', [])]) - elif control['type'] in [28, 36]: - itemnames = control['advancedSetting'].get('itemnames') - if itemnames and itemnames.startswith('[{'): + if control["type"] in [9, 10, 11]: + options.extend([{"key": opt["key"], "value": opt["value"]} for opt in control.get("options", [])]) + elif control["type"] in [28, 36]: + itemnames = control["advancedSetting"].get("itemnames") + if itemnames and itemnames.startswith("[{"): try: options = json.loads(itemnames) except json.JSONDecodeError: pass - elif control['type'] == 30: - source_type = control['sourceControl']['type'] + elif control["type"] == 30: + source_type = control["sourceControl"]["type"] if source_type not in self._get_ignore_types(): - options.extend([{"key": opt['key'], "value": opt['value']} for opt in control.get('options', [])]) + options.extend([{"key": opt["key"], "value": opt["value"]} for opt in control.get("options", [])]) return options - + def _get_ignore_types(self): - return {14, 21, 22, 34, 42, 43, 45, 47, 49, 10010} \ No newline at end of file + return {14, 21, 22, 34, 42, 43, 45, 47, 49, 10010} diff --git a/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py b/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py index 6bf1caa65e..6b831f3145 100644 --- a/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py +++ b/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py @@ -8,64 +8,66 @@ from core.tools.tool.builtin_tool import BuiltinTool class GetWorksheetPivotDataTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - worksheet_id = tool_parameters.get('worksheet_id', '') + return self.create_text_message("Invalid parameter Sign") + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') - x_column_fields = tool_parameters.get('x_column_fields', '') - if not x_column_fields or not x_column_fields.startswith('['): - return self.create_text_message('Invalid parameter Column Fields') - y_row_fields = tool_parameters.get('y_row_fields', '') - if y_row_fields and not y_row_fields.strip().startswith('['): - return self.create_text_message('Invalid parameter Row Fields') + return self.create_text_message("Invalid parameter Worksheet ID") + x_column_fields = tool_parameters.get("x_column_fields", "") + if not x_column_fields or not x_column_fields.startswith("["): + return self.create_text_message("Invalid parameter Column Fields") + y_row_fields = tool_parameters.get("y_row_fields", "") + if y_row_fields and not y_row_fields.strip().startswith("["): + return self.create_text_message("Invalid parameter Row Fields") elif not y_row_fields: - y_row_fields = '[]' - value_fields = tool_parameters.get('value_fields', '') - if not value_fields or not value_fields.strip().startswith('['): - return self.create_text_message('Invalid parameter Value Fields') - - host = tool_parameters.get('host', '') + y_row_fields = "[]" + value_fields = tool_parameters.get("value_fields", "") + if not value_fields or not value_fields.strip().startswith("["): + return self.create_text_message("Invalid parameter Value Fields") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not host.startswith(("http://", "https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: host = f"{host[:-1] if host.endswith('/') else host}/api" url = f"{host}/report/getPivotData" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id, "options": {"showTotal": True}} try: x_column_fields = json.loads(x_column_fields) - payload['columns'] = x_column_fields + payload["columns"] = x_column_fields y_row_fields = json.loads(y_row_fields) - if y_row_fields: payload['rows'] = y_row_fields + if y_row_fields: + payload["rows"] = y_row_fields value_fields = json.loads(value_fields) - payload['values'] = value_fields - sort_fields = tool_parameters.get('sort_fields', '') - if not sort_fields: sort_fields = '[]' + payload["values"] = value_fields + sort_fields = tool_parameters.get("sort_fields", "") + if not sort_fields: + sort_fields = "[]" sort_fields = json.loads(sort_fields) - if sort_fields: payload['options']['sort'] = sort_fields + if sort_fields: + payload["options"]["sort"] = sort_fields res = httpx.post(url, headers=headers, json=payload, timeout=60) res.raise_for_status() res_json = res.json() - if res_json.get('status') != 1: + if res_json.get("status") != 1: return self.create_text_message(f"Failed to get the worksheet pivot data. {res_json['msg']}") - - pivot_json = self.generate_pivot_json(res_json['data']) - pivot_table = self.generate_pivot_table(res_json['data']) - result_type = tool_parameters.get('result_type', '') - text = pivot_table if result_type == 'table' else json.dumps(pivot_json, ensure_ascii=False) + + pivot_json = self.generate_pivot_json(res_json["data"]) + pivot_table = self.generate_pivot_table(res_json["data"]) + result_type = tool_parameters.get("result_type", "") + text = pivot_table if result_type == "table" else json.dumps(pivot_json, ensure_ascii=False) return self.create_text_message(text) except httpx.RequestError as e: return self.create_text_message(f"Failed to get the worksheet pivot data, request error: {e}") @@ -75,27 +77,31 @@ class GetWorksheetPivotDataTool(BuiltinTool): return self.create_text_message(f"Failed to get the worksheet pivot data, unexpected error: {e}") def generate_pivot_table(self, data: dict[str, Any]) -> str: - columns = data['metadata']['columns'] - rows = data['metadata']['rows'] - values = data['metadata']['values'] + columns = data["metadata"]["columns"] + rows = data["metadata"]["rows"] + values = data["metadata"]["values"] - rows_data = data['data'] + rows_data = data["data"] - header = ([row['displayName'] for row in rows] if rows else []) + [column['displayName'] for column in columns] + [value['displayName'] for value in values] - line = (['---'] * len(rows) if rows else []) + ['---'] * len(columns) + ['--:'] * len(values) + header = ( + ([row["displayName"] for row in rows] if rows else []) + + [column["displayName"] for column in columns] + + [value["displayName"] for value in values] + ) + line = (["---"] * len(rows) if rows else []) + ["---"] * len(columns) + ["--:"] * len(values) table = [header, line] for row in rows_data: - row_data = [self.replace_pipe(row['rows'][r['controlId']]) for r in rows] if rows else [] - row_data.extend([self.replace_pipe(row['columns'][column['controlId']]) for column in columns]) - row_data.extend([self.replace_pipe(str(row['values'][value['controlId']])) for value in values]) + row_data = [self.replace_pipe(row["rows"][r["controlId"]]) for r in rows] if rows else [] + row_data.extend([self.replace_pipe(row["columns"][column["controlId"]]) for column in columns]) + row_data.extend([self.replace_pipe(str(row["values"][value["controlId"]])) for value in values]) table.append(row_data) - return '\n'.join([('|'+'|'.join(row) +'|') for row in table]) - + return "\n".join([("|" + "|".join(row) + "|") for row in table]) + def replace_pipe(self, text: str) -> str: - return text.replace('|', '▏').replace('\n', ' ') - + return text.replace("|", "▏").replace("\n", " ") + def generate_pivot_json(self, data: dict[str, Any]) -> dict: fields = { "x-axis": [ @@ -103,13 +109,14 @@ class GetWorksheetPivotDataTool(BuiltinTool): for column in data["metadata"]["columns"] ], "y-axis": [ - {"fieldId": row["controlId"], "fieldName": row["displayName"]} - for row in data["metadata"]["rows"] - ] if data["metadata"]["rows"] else [], + {"fieldId": row["controlId"], "fieldName": row["displayName"]} for row in data["metadata"]["rows"] + ] + if data["metadata"]["rows"] + else [], "values": [ {"fieldId": value["controlId"], "fieldName": value["displayName"]} for value in data["metadata"]["values"] - ] + ], } # fields = ([ # {"fieldId": row["controlId"], "fieldName": row["displayName"]} @@ -127,4 +134,4 @@ class GetWorksheetPivotDataTool(BuiltinTool): row_data.update(row["columns"]) row_data.update(row["values"]) rows.append(row_data) - return {"fields": fields, "rows": rows, "summary": data["metadata"]["totalRow"]} \ No newline at end of file + return {"fields": fields, "rows": rows, "summary": data["metadata"]["totalRow"]} diff --git a/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py b/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py index dddc041cc1..592fa230cf 100644 --- a/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py +++ b/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py @@ -9,152 +9,173 @@ from core.tools.tool.builtin_tool import BuiltinTool class ListWorksheetRecordsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') + return self.create_text_message("Invalid parameter App Key") - sign = tool_parameters.get('sign', '') + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') + return self.create_text_message("Invalid parameter Sign") - worksheet_id = tool_parameters.get('worksheet_id', '') + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') + return self.create_text_message("Invalid parameter Worksheet ID") - host = tool_parameters.get('host', '') + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not (host.startswith("http://") or host.startswith("https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: host = f"{host[:-1] if host.endswith('/') else host}/api" - + url_fields = f"{host}/v2/open/worksheet/getWorksheetInfo" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id} - field_ids = tool_parameters.get('field_ids', '') + field_ids = tool_parameters.get("field_ids", "") try: res = httpx.post(url_fields, headers=headers, json=payload, timeout=30) res_json = res.json() if res.is_success: - if res_json['error_code'] != 1: - return self.create_text_message("Failed to get the worksheet information. {}".format(res_json['error_msg'])) + if res_json["error_code"] != 1: + return self.create_text_message( + "Failed to get the worksheet information. {}".format(res_json["error_msg"]) + ) else: - worksheet_name = res_json['data']['name'] - fields, schema, table_header = self.get_schema(res_json['data']['controls'], field_ids) + worksheet_name = res_json["data"]["name"] + fields, schema, table_header = self.get_schema(res_json["data"]["controls"], field_ids) else: return self.create_text_message( - f"Failed to get the worksheet information, status code: {res.status_code}, response: {res.text}") + f"Failed to get the worksheet information, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: - return self.create_text_message("Failed to get the worksheet information, something went wrong: {}".format(e)) + return self.create_text_message( + "Failed to get the worksheet information, something went wrong: {}".format(e) + ) if field_ids: - payload['controls'] = [v.strip() for v in field_ids.split(',')] if field_ids else [] - filters = tool_parameters.get('filters', '') + payload["controls"] = [v.strip() for v in field_ids.split(",")] if field_ids else [] + filters = tool_parameters.get("filters", "") if filters: - payload['filters'] = json.loads(filters) - sort_id = tool_parameters.get('sort_id', '') - sort_is_asc = tool_parameters.get('sort_is_asc', False) + payload["filters"] = json.loads(filters) + sort_id = tool_parameters.get("sort_id", "") + sort_is_asc = tool_parameters.get("sort_is_asc", False) if sort_id: - payload['sortId'] = sort_id - payload['isAsc'] = sort_is_asc - limit = tool_parameters.get('limit', 50) - payload['pageSize'] = limit - page_index = tool_parameters.get('page_index', 1) - payload['pageIndex'] = page_index - payload['useControlId'] = True - payload['listType'] = 1 + payload["sortId"] = sort_id + payload["isAsc"] = sort_is_asc + limit = tool_parameters.get("limit", 50) + payload["pageSize"] = limit + page_index = tool_parameters.get("page_index", 1) + payload["pageIndex"] = page_index + payload["useControlId"] = True + payload["listType"] = 1 url = f"{host}/v2/open/worksheet/getFilterRows" try: res = httpx.post(url, headers=headers, json=payload, timeout=90) res_json = res.json() if res.is_success: - if res_json['error_code'] != 1: - return self.create_text_message("Failed to get the records. {}".format(res_json['error_msg'])) + if res_json["error_code"] != 1: + return self.create_text_message("Failed to get the records. {}".format(res_json["error_msg"])) else: result = { "fields": fields, "rows": [], "total": res_json.get("data", {}).get("total"), - "payload": {key: payload[key] for key in ['worksheetId', 'controls', 'filters', 'sortId', 'isAsc', 'pageSize', 'pageIndex'] if key in payload} + "payload": { + key: payload[key] + for key in [ + "worksheetId", + "controls", + "filters", + "sortId", + "isAsc", + "pageSize", + "pageIndex", + ] + if key in payload + }, } rows = res_json.get("data", {}).get("rows", []) - result_type = tool_parameters.get('result_type', '') - if not result_type: result_type = 'table' - if result_type == 'json': + result_type = tool_parameters.get("result_type", "") + if not result_type: + result_type = "table" + if result_type == "json": for row in rows: - result['rows'].append(self.get_row_field_value(row, schema)) + result["rows"].append(self.get_row_field_value(row, schema)) return self.create_text_message(json.dumps(result, ensure_ascii=False)) else: result_text = f"Found {result['total']} rows in worksheet \"{worksheet_name}\"." - if result['total'] > 0: + if result["total"] > 0: result_text += f" The following are {result['total'] if result['total'] < limit else limit} pieces of data presented in a table format:\n\n{table_header}" for row in rows: result_values = [] for f in fields: - result_values.append(self.handle_value_type(row[f['fieldId']], schema[f['fieldId']])) - result_text += '\n|'+'|'.join(result_values)+'|' + result_values.append( + self.handle_value_type(row[f["fieldId"]], schema[f["fieldId"]]) + ) + result_text += "\n|" + "|".join(result_values) + "|" return self.create_text_message(result_text) else: return self.create_text_message( - f"Failed to get the records, status code: {res.status_code}, response: {res.text}") + f"Failed to get the records, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to get the records, something went wrong: {}".format(e)) - def get_row_field_value(self, row: dict, schema: dict): row_value = {"rowid": row["rowid"]} for field in schema: row_value[field] = self.handle_value_type(row[field], schema[field]) return row_value - - def get_schema(self, controls: list, fieldids: str): - allow_fields = {v.strip() for v in fieldids.split(',')} if fieldids else set() + def get_schema(self, controls: list, fieldids: str): + allow_fields = {v.strip() for v in fieldids.split(",")} if fieldids else set() fields = [] schema = {} field_names = [] for control in controls: control_type_id = self.get_real_type_id(control) - if (control_type_id in self._get_ignore_types()) or (allow_fields and not control['controlId'] in allow_fields): + if (control_type_id in self._get_ignore_types()) or ( + allow_fields and control["controlId"] not in allow_fields + ): continue else: - fields.append({'fieldId': control['controlId'], 'fieldName': control['controlName']}) - schema[control['controlId']] = {'typeId': control_type_id, 'options': self.set_option(control)} - field_names.append(control['controlName']) - if (not allow_fields or ('ctime' in allow_fields)): - fields.append({'fieldId': 'ctime', 'fieldName': 'Created Time'}) - schema['ctime'] = {'typeId': 16, 'options': {}} + fields.append({"fieldId": control["controlId"], "fieldName": control["controlName"]}) + schema[control["controlId"]] = {"typeId": control_type_id, "options": self.set_option(control)} + field_names.append(control["controlName"]) + if not allow_fields or ("ctime" in allow_fields): + fields.append({"fieldId": "ctime", "fieldName": "Created Time"}) + schema["ctime"] = {"typeId": 16, "options": {}} field_names.append("Created Time") - fields.append({'fieldId':'rowid', 'fieldName': 'Record Row ID'}) - schema['rowid'] = {'typeId': 2, 'options': {}} + fields.append({"fieldId": "rowid", "fieldName": "Record Row ID"}) + schema["rowid"] = {"typeId": 2, "options": {}} field_names.append("Record Row ID") - return fields, schema, '|'+'|'.join(field_names)+'|\n|'+'---|'*len(field_names) - + return fields, schema, "|" + "|".join(field_names) + "|\n|" + "---|" * len(field_names) + def get_real_type_id(self, control: dict) -> int: - return control['sourceControlType'] if control['type'] == 30 else control['type'] - + return control["sourceControlType"] if control["type"] == 30 else control["type"] + def set_option(self, control: dict) -> dict: options = {} - if control.get('options'): - options = {option['key']: option['value'] for option in control['options']} - elif control.get('advancedSetting', {}).get('itemnames'): + if control.get("options"): + options = {option["key"]: option["value"] for option in control["options"]} + elif control.get("advancedSetting", {}).get("itemnames"): try: - itemnames = json.loads(control['advancedSetting']['itemnames']) - options = {item['key']: item['value'] for item in itemnames} + itemnames = json.loads(control["advancedSetting"]["itemnames"]) + options = {item["key"]: item["value"] for item in itemnames} except json.JSONDecodeError: pass return options def _get_ignore_types(self): return {14, 21, 22, 34, 42, 43, 45, 47, 49, 10010} - + def handle_value_type(self, value, field): type_id = field.get("typeId") if type_id == 10: @@ -167,33 +188,33 @@ class ListWorksheetRecordsTool(BuiltinTool): value = self.parse_cascade_or_associated(field, value) elif type_id == 40: value = self.parse_location(value) - return self.rich_text_to_plain_text(value) if value else '' + return self.rich_text_to_plain_text(value) if value else "" def process_value(self, value): if isinstance(value, str): - if value.startswith("[{\"accountId\""): + if value.startswith('[{"accountId"'): value = json.loads(value) - value = ', '.join([item['fullname'] for item in value]) - elif value.startswith("[{\"departmentId\""): + value = ", ".join([item["fullname"] for item in value]) + elif value.startswith('[{"departmentId"'): value = json.loads(value) - value = '、'.join([item['departmentName'] for item in value]) - elif value.startswith("[{\"organizeId\""): + value = "、".join([item["departmentName"] for item in value]) + elif value.startswith('[{"organizeId"'): value = json.loads(value) - value = '、'.join([item['organizeName'] for item in value]) - elif value.startswith("[{\"file_id\""): - value = '' - elif value == '[]': - value = '' - elif hasattr(value, 'accountId'): - value = value['fullname'] + value = "、".join([item["organizeName"] for item in value]) + elif value.startswith('[{"file_id"'): + value = "" + elif value == "[]": + value = "" + elif hasattr(value, "accountId"): + value = value["fullname"] return value def parse_cascade_or_associated(self, field, value): - if (field['typeId'] == 35 and value.startswith('[')) or (field['typeId'] == 29 and value.startswith('[{')): + if (field["typeId"] == 35 and value.startswith("[")) or (field["typeId"] == 29 and value.startswith("[{")): value = json.loads(value) - value = value[0]['name'] if len(value) > 0 else '' + value = value[0]["name"] if len(value) > 0 else "" else: - value = '' + value = "" return value def parse_location(self, value): @@ -205,5 +226,5 @@ class ListWorksheetRecordsTool(BuiltinTool): return value def rich_text_to_plain_text(self, rich_text): - text = re.sub(r'<[^>]+>', '', rich_text) if '<' in rich_text else rich_text - return text.replace("|", "▏").replace("\n", " ") \ No newline at end of file + text = re.sub(r"<[^>]+>", "", rich_text) if "<" in rich_text else rich_text + return text.replace("|", "▏").replace("\n", " ") diff --git a/api/core/tools/provider/builtin/hap/tools/list_worksheets.py b/api/core/tools/provider/builtin/hap/tools/list_worksheets.py index 960cbd10ac..4dba2df1f1 100644 --- a/api/core/tools/provider/builtin/hap/tools/list_worksheets.py +++ b/api/core/tools/provider/builtin/hap/tools/list_worksheets.py @@ -8,75 +8,76 @@ from core.tools.tool.builtin_tool import BuiltinTool class ListWorksheetsTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - - host = tool_parameters.get('host', '') + return self.create_text_message("Invalid parameter Sign") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not (host.startswith("http://") or host.startswith("https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: host = f"{host[:-1] if host.endswith('/') else host}/api" url = f"{host}/v1/open/app/get" - result_type = tool_parameters.get('result_type', '') + result_type = tool_parameters.get("result_type", "") if not result_type: - result_type = 'table' + result_type = "table" - headers = { 'Content-Type': 'application/json' } - params = { "appKey": appkey, "sign": sign, } + headers = {"Content-Type": "application/json"} + params = { + "appKey": appkey, + "sign": sign, + } try: res = httpx.get(url, headers=headers, params=params, timeout=30) res_json = res.json() if res.is_success: - if res_json['error_code'] != 1: - return self.create_text_message("Failed to access the application. {}".format(res_json['error_msg'])) + if res_json["error_code"] != 1: + return self.create_text_message( + "Failed to access the application. {}".format(res_json["error_msg"]) + ) else: - if result_type == 'json': + if result_type == "json": worksheets = [] - for section in res_json['data']['sections']: + for section in res_json["data"]["sections"]: worksheets.extend(self._extract_worksheets(section, result_type)) return self.create_text_message(text=json.dumps(worksheets, ensure_ascii=False)) else: - worksheets = '|worksheetId|worksheetName|description|\n|---|---|---|' - for section in res_json['data']['sections']: + worksheets = "|worksheetId|worksheetName|description|\n|---|---|---|" + for section in res_json["data"]["sections"]: worksheets += self._extract_worksheets(section, result_type) return self.create_text_message(worksheets) else: return self.create_text_message( - f"Failed to list worksheets, status code: {res.status_code}, response: {res.text}") + f"Failed to list worksheets, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to list worksheets, something went wrong: {}".format(e)) def _extract_worksheets(self, section, type): items = [] - tables = '' - for item in section.get('items', []): - if item.get('type') == 0 and (not 'notes' in item or item.get('notes') != 'NO'): - if type == 'json': - filtered_item = { - 'id': item['id'], - 'name': item['name'], - 'notes': item.get('notes', '') - } + tables = "" + for item in section.get("items", []): + if item.get("type") == 0 and ("notes" not in item or item.get("notes") != "NO"): + if type == "json": + filtered_item = {"id": item["id"], "name": item["name"], "notes": item.get("notes", "")} items.append(filtered_item) else: tables += f"\n|{item['id']}|{item['name']}|{item.get('notes', '')}|" - for child_section in section.get('childSections', []): - if type == 'json': - items.extend(self._extract_worksheets(child_section, 'json')) + for child_section in section.get("childSections", []): + if type == "json": + items.extend(self._extract_worksheets(child_section, "json")) else: - tables += self._extract_worksheets(child_section, 'table') - - return items if type == 'json' else tables \ No newline at end of file + tables += self._extract_worksheets(child_section, "table") + + return items if type == "json" else tables diff --git a/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py b/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py index 6ca1b98d90..32abb18f9a 100644 --- a/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py +++ b/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py @@ -8,44 +8,43 @@ from core.tools.tool.builtin_tool import BuiltinTool class UpdateWorksheetRecordTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - worksheet_id = tool_parameters.get('worksheet_id', '') + return self.create_text_message("Invalid parameter Sign") + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') - row_id = tool_parameters.get('row_id', '') + return self.create_text_message("Invalid parameter Worksheet ID") + row_id = tool_parameters.get("row_id", "") if not row_id: - return self.create_text_message('Invalid parameter Record Row ID') - record_data = tool_parameters.get('record_data', '') + return self.create_text_message("Invalid parameter Record Row ID") + record_data = tool_parameters.get("record_data", "") if not record_data: - return self.create_text_message('Invalid parameter Record Row Data') - - host = tool_parameters.get('host', '') + return self.create_text_message("Invalid parameter Record Row Data") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not host.startswith(("http://", "https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: host = f"{host[:-1] if host.endswith('/') else host}/api" url = f"{host}/v2/open/worksheet/editRow" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id, "rowId": row_id} try: - payload['controls'] = json.loads(record_data) + payload["controls"] = json.loads(record_data) res = httpx.post(url, headers=headers, json=payload, timeout=60) res.raise_for_status() res_json = res.json() - if res_json.get('error_code') != 1: + if res_json.get("error_code") != 1: return self.create_text_message(f"Failed to update the record. {res_json['error_msg']}") return self.create_text_message("Record updated successfully.") except httpx.RequestError as e: diff --git a/api/core/tools/provider/builtin/jina/jina.py b/api/core/tools/provider/builtin/jina/jina.py index 12e5058cdc..154e15db01 100644 --- a/api/core/tools/provider/builtin/jina/jina.py +++ b/api/core/tools/provider/builtin/jina/jina.py @@ -10,27 +10,29 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class GoogleProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - if credentials['api_key'] is None: - credentials['api_key'] = '' + if credentials["api_key"] is None: + credentials["api_key"] = "" else: - result = JinaReaderTool().fork_tool_runtime( - runtime={ - "credentials": credentials, - } - ).invoke( - user_id='', - tool_parameters={ - "url": "https://example.com", - }, - )[0] + result = ( + JinaReaderTool() + .fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ) + .invoke( + user_id="", + tool_parameters={ + "url": "https://example.com", + }, + )[0] + ) message = json.loads(result.message) - if message['code'] != 200: - raise ToolProviderCredentialValidationError(message['message']) + if message["code"] != 200: + raise ToolProviderCredentialValidationError(message["message"]) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - + def _get_tool_labels(self) -> list[ToolLabelEnum]: - return [ - ToolLabelEnum.SEARCH, ToolLabelEnum.PRODUCTIVITY - ] \ No newline at end of file + return [ToolLabelEnum.SEARCH, ToolLabelEnum.PRODUCTIVITY] diff --git a/api/core/tools/provider/builtin/jina/tools/jina_reader.py b/api/core/tools/provider/builtin/jina/tools/jina_reader.py index cee46cee23..0dd55c6529 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_reader.py +++ b/api/core/tools/provider/builtin/jina/tools/jina_reader.py @@ -9,26 +9,25 @@ from core.tools.tool.builtin_tool import BuiltinTool class JinaReaderTool(BuiltinTool): - _jina_reader_endpoint = 'https://r.jina.ai/' + _jina_reader_endpoint = "https://r.jina.ai/" - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - url = tool_parameters['url'] + url = tool_parameters["url"] - headers = { - 'Accept': 'application/json' - } + headers = {"Accept": "application/json"} - if 'api_key' in self.runtime.credentials and self.runtime.credentials.get('api_key'): - headers['Authorization'] = "Bearer " + self.runtime.credentials.get('api_key') + if "api_key" in self.runtime.credentials and self.runtime.credentials.get("api_key"): + headers["Authorization"] = "Bearer " + self.runtime.credentials.get("api_key") - request_params = tool_parameters.get('request_params') - if request_params is not None and request_params != '': + request_params = tool_parameters.get("request_params") + if request_params is not None and request_params != "": try: request_params = json.loads(request_params) if not isinstance(request_params, dict): @@ -36,40 +35,40 @@ class JinaReaderTool(BuiltinTool): except (json.JSONDecodeError, ValueError) as e: raise ValueError(f"Invalid request_params: {e}") - target_selector = tool_parameters.get('target_selector') - if target_selector is not None and target_selector != '': - headers['X-Target-Selector'] = target_selector + target_selector = tool_parameters.get("target_selector") + if target_selector is not None and target_selector != "": + headers["X-Target-Selector"] = target_selector - wait_for_selector = tool_parameters.get('wait_for_selector') - if wait_for_selector is not None and wait_for_selector != '': - headers['X-Wait-For-Selector'] = wait_for_selector + wait_for_selector = tool_parameters.get("wait_for_selector") + if wait_for_selector is not None and wait_for_selector != "": + headers["X-Wait-For-Selector"] = wait_for_selector - if tool_parameters.get('image_caption', False): - headers['X-With-Generated-Alt'] = 'true' + if tool_parameters.get("image_caption", False): + headers["X-With-Generated-Alt"] = "true" - if tool_parameters.get('gather_all_links_at_the_end', False): - headers['X-With-Links-Summary'] = 'true' + if tool_parameters.get("gather_all_links_at_the_end", False): + headers["X-With-Links-Summary"] = "true" - if tool_parameters.get('gather_all_images_at_the_end', False): - headers['X-With-Images-Summary'] = 'true' + if tool_parameters.get("gather_all_images_at_the_end", False): + headers["X-With-Images-Summary"] = "true" - proxy_server = tool_parameters.get('proxy_server') - if proxy_server is not None and proxy_server != '': - headers['X-Proxy-Url'] = proxy_server + proxy_server = tool_parameters.get("proxy_server") + if proxy_server is not None and proxy_server != "": + headers["X-Proxy-Url"] = proxy_server - if tool_parameters.get('no_cache', False): - headers['X-No-Cache'] = 'true' + if tool_parameters.get("no_cache", False): + headers["X-No-Cache"] = "true" - max_retries = tool_parameters.get('max_retries', 3) + max_retries = tool_parameters.get("max_retries", 3) response = ssrf_proxy.get( str(URL(self._jina_reader_endpoint + url)), headers=headers, params=request_params, timeout=(10, 60), - max_retries=max_retries + max_retries=max_retries, ) - if tool_parameters.get('summary', False): + if tool_parameters.get("summary", False): return self.create_text_message(self.summary(user_id, response.text)) return self.create_text_message(response.text) diff --git a/api/core/tools/provider/builtin/jina/tools/jina_search.py b/api/core/tools/provider/builtin/jina/tools/jina_search.py index d4a81cd096..30af6de783 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_search.py +++ b/api/core/tools/provider/builtin/jina/tools/jina_search.py @@ -8,44 +8,39 @@ from core.tools.tool.builtin_tool import BuiltinTool class JinaSearchTool(BuiltinTool): - _jina_search_endpoint = 'https://s.jina.ai/' + _jina_search_endpoint = "https://s.jina.ai/" def _invoke( self, user_id: str, tool_parameters: dict[str, Any], ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - query = tool_parameters['query'] + query = tool_parameters["query"] - headers = { - 'Accept': 'application/json' - } + headers = {"Accept": "application/json"} - if 'api_key' in self.runtime.credentials and self.runtime.credentials.get('api_key'): - headers['Authorization'] = "Bearer " + self.runtime.credentials.get('api_key') + if "api_key" in self.runtime.credentials and self.runtime.credentials.get("api_key"): + headers["Authorization"] = "Bearer " + self.runtime.credentials.get("api_key") - if tool_parameters.get('image_caption', False): - headers['X-With-Generated-Alt'] = 'true' + if tool_parameters.get("image_caption", False): + headers["X-With-Generated-Alt"] = "true" - if tool_parameters.get('gather_all_links_at_the_end', False): - headers['X-With-Links-Summary'] = 'true' + if tool_parameters.get("gather_all_links_at_the_end", False): + headers["X-With-Links-Summary"] = "true" - if tool_parameters.get('gather_all_images_at_the_end', False): - headers['X-With-Images-Summary'] = 'true' + if tool_parameters.get("gather_all_images_at_the_end", False): + headers["X-With-Images-Summary"] = "true" - proxy_server = tool_parameters.get('proxy_server') - if proxy_server is not None and proxy_server != '': - headers['X-Proxy-Url'] = proxy_server + proxy_server = tool_parameters.get("proxy_server") + if proxy_server is not None and proxy_server != "": + headers["X-Proxy-Url"] = proxy_server - if tool_parameters.get('no_cache', False): - headers['X-No-Cache'] = 'true' + if tool_parameters.get("no_cache", False): + headers["X-No-Cache"] = "true" - max_retries = tool_parameters.get('max_retries', 3) + max_retries = tool_parameters.get("max_retries", 3) response = ssrf_proxy.get( - str(URL(self._jina_search_endpoint + query)), - headers=headers, - timeout=(10, 60), - max_retries=max_retries + str(URL(self._jina_search_endpoint + query)), headers=headers, timeout=(10, 60), max_retries=max_retries ) return self.create_text_message(response.text) diff --git a/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py b/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py index 0d018e3ca2..06dabcc9c2 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py +++ b/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py @@ -6,33 +6,29 @@ from core.tools.tool.builtin_tool import BuiltinTool class JinaTokenizerTool(BuiltinTool): - _jina_tokenizer_endpoint = 'https://tokenize.jina.ai/' + _jina_tokenizer_endpoint = "https://tokenize.jina.ai/" def _invoke( self, user_id: str, tool_parameters: dict[str, Any], ) -> ToolInvokeMessage: - content = tool_parameters['content'] - body = { - "content": content - } + content = tool_parameters["content"] + body = {"content": content} - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - if 'api_key' in self.runtime.credentials and self.runtime.credentials.get('api_key'): - headers['Authorization'] = "Bearer " + self.runtime.credentials.get('api_key') + if "api_key" in self.runtime.credentials and self.runtime.credentials.get("api_key"): + headers["Authorization"] = "Bearer " + self.runtime.credentials.get("api_key") - if tool_parameters.get('return_chunks', False): - body['return_chunks'] = True - - if tool_parameters.get('return_tokens', False): - body['return_tokens'] = True - - if tokenizer := tool_parameters.get('tokenizer'): - body['tokenizer'] = tokenizer + if tool_parameters.get("return_chunks", False): + body["return_chunks"] = True + + if tool_parameters.get("return_tokens", False): + body["return_tokens"] = True + + if tokenizer := tool_parameters.get("tokenizer"): + body["tokenizer"] = tokenizer response = ssrf_proxy.post( self._jina_tokenizer_endpoint, diff --git a/api/core/tools/provider/builtin/json_process/json_process.py b/api/core/tools/provider/builtin/json_process/json_process.py index f6eed3c628..10746210b5 100644 --- a/api/core/tools/provider/builtin/json_process/json_process.py +++ b/api/core/tools/provider/builtin/json_process/json_process.py @@ -8,10 +8,9 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class JsonExtractProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - JSONParseTool().invoke(user_id='', - tool_parameters={ - 'content': '{"name": "John", "age": 30, "city": "New York"}', - 'json_filter': '$.name' - }) + JSONParseTool().invoke( + user_id="", + tool_parameters={"content": '{"name": "John", "age": 30, "city": "New York"}', "json_filter": "$.name"}, + ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/json_process/tools/delete.py b/api/core/tools/provider/builtin/json_process/tools/delete.py index 1b49cfe2f3..fcab3d71a9 100644 --- a/api/core/tools/provider/builtin/json_process/tools/delete.py +++ b/api/core/tools/provider/builtin/json_process/tools/delete.py @@ -8,34 +8,35 @@ from core.tools.tool.builtin_tool import BuiltinTool class JSONDeleteTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the JSON delete tool """ # Get content - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") # Get query - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Invalid parameter query') + return self.create_text_message("Invalid parameter query") - ensure_ascii = tool_parameters.get('ensure_ascii', True) + ensure_ascii = tool_parameters.get("ensure_ascii", True) try: result = self._delete(content, query, ensure_ascii) return self.create_text_message(str(result)) except Exception as e: - return self.create_text_message(f'Failed to delete JSON content: {str(e)}') + return self.create_text_message(f"Failed to delete JSON content: {str(e)}") def _delete(self, origin_json: str, query: str, ensure_ascii: bool) -> str: try: input_data = json.loads(origin_json) - expr = parse('$.' + query.lstrip('$.')) # Ensure query path starts with $ + expr = parse("$." + query.lstrip("$.")) # Ensure query path starts with $ matches = expr.find(input_data) diff --git a/api/core/tools/provider/builtin/json_process/tools/insert.py b/api/core/tools/provider/builtin/json_process/tools/insert.py index 48d1bdcab4..793c74e5f9 100644 --- a/api/core/tools/provider/builtin/json_process/tools/insert.py +++ b/api/core/tools/provider/builtin/json_process/tools/insert.py @@ -8,46 +8,49 @@ from core.tools.tool.builtin_tool import BuiltinTool class JSONParseTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get content - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") # get query - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Invalid parameter query') + return self.create_text_message("Invalid parameter query") # get new value - new_value = tool_parameters.get('new_value', '') + new_value = tool_parameters.get("new_value", "") if not new_value: - return self.create_text_message('Invalid parameter new_value') + return self.create_text_message("Invalid parameter new_value") # get insert position - index = tool_parameters.get('index') + index = tool_parameters.get("index") # get create path - create_path = tool_parameters.get('create_path', False) + create_path = tool_parameters.get("create_path", False) # get value decode. # if true, it will be decoded to an dict - value_decode = tool_parameters.get('value_decode', False) + value_decode = tool_parameters.get("value_decode", False) - ensure_ascii = tool_parameters.get('ensure_ascii', True) + ensure_ascii = tool_parameters.get("ensure_ascii", True) try: result = self._insert(content, query, new_value, ensure_ascii, value_decode, index, create_path) return self.create_text_message(str(result)) except Exception: - return self.create_text_message('Failed to insert JSON content') + return self.create_text_message("Failed to insert JSON content") - def _insert(self, origin_json, query, new_value, ensure_ascii: bool, value_decode: bool, index=None, create_path=False): + def _insert( + self, origin_json, query, new_value, ensure_ascii: bool, value_decode: bool, index=None, create_path=False + ): try: input_data = json.loads(origin_json) expr = parse(query) @@ -61,13 +64,13 @@ class JSONParseTool(BuiltinTool): if not matches and create_path: # create new path - path_parts = query.strip('$').strip('.').split('.') + path_parts = query.strip("$").strip(".").split(".") current = input_data for i, part in enumerate(path_parts): - if '[' in part and ']' in part: + if "[" in part and "]" in part: # process array index - array_name, index = part.split('[') - index = int(index.rstrip(']')) + array_name, index = part.split("[") + index = int(index.rstrip("]")) if array_name not in current: current[array_name] = [] while len(current[array_name]) <= index: diff --git a/api/core/tools/provider/builtin/json_process/tools/parse.py b/api/core/tools/provider/builtin/json_process/tools/parse.py index ecd39113ae..37cae40153 100644 --- a/api/core/tools/provider/builtin/json_process/tools/parse.py +++ b/api/core/tools/provider/builtin/json_process/tools/parse.py @@ -8,29 +8,30 @@ from core.tools.tool.builtin_tool import BuiltinTool class JSONParseTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get content - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") # get json filter - json_filter = tool_parameters.get('json_filter', '') + json_filter = tool_parameters.get("json_filter", "") if not json_filter: - return self.create_text_message('Invalid parameter json_filter') + return self.create_text_message("Invalid parameter json_filter") - ensure_ascii = tool_parameters.get('ensure_ascii', True) + ensure_ascii = tool_parameters.get("ensure_ascii", True) try: result = self._extract(content, json_filter, ensure_ascii) return self.create_text_message(str(result)) except Exception: - return self.create_text_message('Failed to extract JSON content') + return self.create_text_message("Failed to extract JSON content") # Extract data from JSON content def _extract(self, content: str, json_filter: str, ensure_ascii: bool) -> str: diff --git a/api/core/tools/provider/builtin/json_process/tools/replace.py b/api/core/tools/provider/builtin/json_process/tools/replace.py index b19198aa93..383825c2d0 100644 --- a/api/core/tools/provider/builtin/json_process/tools/replace.py +++ b/api/core/tools/provider/builtin/json_process/tools/replace.py @@ -8,55 +8,60 @@ from core.tools.tool.builtin_tool import BuiltinTool class JSONReplaceTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get content - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") # get query - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Invalid parameter query') + return self.create_text_message("Invalid parameter query") # get replace value - replace_value = tool_parameters.get('replace_value', '') + replace_value = tool_parameters.get("replace_value", "") if not replace_value: - return self.create_text_message('Invalid parameter replace_value') + return self.create_text_message("Invalid parameter replace_value") # get replace model - replace_model = tool_parameters.get('replace_model', '') + replace_model = tool_parameters.get("replace_model", "") if not replace_model: - return self.create_text_message('Invalid parameter replace_model') + return self.create_text_message("Invalid parameter replace_model") # get value decode. # if true, it will be decoded to an dict - value_decode = tool_parameters.get('value_decode', False) + value_decode = tool_parameters.get("value_decode", False) - ensure_ascii = tool_parameters.get('ensure_ascii', True) + ensure_ascii = tool_parameters.get("ensure_ascii", True) try: - if replace_model == 'pattern': + if replace_model == "pattern": # get replace pattern - replace_pattern = tool_parameters.get('replace_pattern', '') + replace_pattern = tool_parameters.get("replace_pattern", "") if not replace_pattern: - return self.create_text_message('Invalid parameter replace_pattern') - result = self._replace_pattern(content, query, replace_pattern, replace_value, ensure_ascii, value_decode) - elif replace_model == 'key': + return self.create_text_message("Invalid parameter replace_pattern") + result = self._replace_pattern( + content, query, replace_pattern, replace_value, ensure_ascii, value_decode + ) + elif replace_model == "key": result = self._replace_key(content, query, replace_value, ensure_ascii) - elif replace_model == 'value': + elif replace_model == "value": result = self._replace_value(content, query, replace_value, ensure_ascii, value_decode) return self.create_text_message(str(result)) except Exception: - return self.create_text_message('Failed to replace JSON content') + return self.create_text_message("Failed to replace JSON content") # Replace pattern - def _replace_pattern(self, content: str, query: str, replace_pattern: str, replace_value: str, ensure_ascii: bool, value_decode: bool) -> str: + def _replace_pattern( + self, content: str, query: str, replace_pattern: str, replace_value: str, ensure_ascii: bool, value_decode: bool + ) -> str: try: input_data = json.loads(content) expr = parse(query) @@ -102,7 +107,9 @@ class JSONReplaceTool(BuiltinTool): return str(e) # Replace value - def _replace_value(self, content: str, query: str, replace_value: str, ensure_ascii: bool, value_decode: bool) -> str: + def _replace_value( + self, content: str, query: str, replace_value: str, ensure_ascii: bool, value_decode: bool + ) -> str: try: input_data = json.loads(content) expr = parse(query) diff --git a/api/core/tools/provider/builtin/judge0ce/judge0ce.py b/api/core/tools/provider/builtin/judge0ce/judge0ce.py index bac6576797..50db74dd9e 100644 --- a/api/core/tools/provider/builtin/judge0ce/judge0ce.py +++ b/api/core/tools/provider/builtin/judge0ce/judge0ce.py @@ -13,7 +13,7 @@ class Judge0CEProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "source_code": "print('hello world')", "language_id": 71, @@ -21,4 +21,3 @@ class Judge0CEProvider(BuiltinToolProviderController): ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/judge0ce/tools/executeCode.py b/api/core/tools/provider/builtin/judge0ce/tools/executeCode.py index 6031687c03..b8d654ff63 100644 --- a/api/core/tools/provider/builtin/judge0ce/tools/executeCode.py +++ b/api/core/tools/provider/builtin/judge0ce/tools/executeCode.py @@ -9,11 +9,13 @@ from core.tools.tool.builtin_tool import BuiltinTool class ExecuteCodeTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ invoke tools """ - api_key = self.runtime.credentials['X-RapidAPI-Key'] + api_key = self.runtime.credentials["X-RapidAPI-Key"] url = "https://judge0-ce.p.rapidapi.com/submissions" @@ -22,15 +24,15 @@ class ExecuteCodeTool(BuiltinTool): headers = { "Content-Type": "application/json", "X-RapidAPI-Key": api_key, - "X-RapidAPI-Host": "judge0-ce.p.rapidapi.com" + "X-RapidAPI-Host": "judge0-ce.p.rapidapi.com", } payload = { - "language_id": tool_parameters['language_id'], - "source_code": tool_parameters['source_code'], - "stdin": tool_parameters.get('stdin', ''), - "expected_output": tool_parameters.get('expected_output', ''), - "additional_files": tool_parameters.get('additional_files', ''), + "language_id": tool_parameters["language_id"], + "source_code": tool_parameters["source_code"], + "stdin": tool_parameters.get("stdin", ""), + "expected_output": tool_parameters.get("expected_output", ""), + "additional_files": tool_parameters.get("additional_files", ""), } response = post(url, data=json.dumps(payload), headers=headers, params=querystring) @@ -38,22 +40,22 @@ class ExecuteCodeTool(BuiltinTool): if response.status_code != 201: raise Exception(response.text) - token = response.json()['token'] + token = response.json()["token"] url = f"https://judge0-ce.p.rapidapi.com/submissions/{token}" - headers = { - "X-RapidAPI-Key": api_key - } - + headers = {"X-RapidAPI-Key": api_key} + response = requests.get(url, headers=headers) if response.status_code == 200: result = response.json() - return self.create_text_message(text=f"stdout: {result.get('stdout', '')}\n" - f"stderr: {result.get('stderr', '')}\n" - f"compile_output: {result.get('compile_output', '')}\n" - f"message: {result.get('message', '')}\n" - f"status: {result['status']['description']}\n" - f"time: {result.get('time', '')} seconds\n" - f"memory: {result.get('memory', '')} bytes") + return self.create_text_message( + text=f"stdout: {result.get('stdout', '')}\n" + f"stderr: {result.get('stderr', '')}\n" + f"compile_output: {result.get('compile_output', '')}\n" + f"message: {result.get('message', '')}\n" + f"status: {result['status']['description']}\n" + f"time: {result.get('time', '')} seconds\n" + f"memory: {result.get('memory', '')} bytes" + ) else: - return self.create_text_message(text=f"Error retrieving submission details: {response.text}") \ No newline at end of file + return self.create_text_message(text=f"Error retrieving submission details: {response.text}") diff --git a/api/core/tools/provider/builtin/maths/maths.py b/api/core/tools/provider/builtin/maths/maths.py index 7226a5c168..d4b449ec87 100644 --- a/api/core/tools/provider/builtin/maths/maths.py +++ b/api/core/tools/provider/builtin/maths/maths.py @@ -9,9 +9,9 @@ class MathsProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: EvaluateExpressionTool().invoke( - user_id='', + user_id="", tool_parameters={ - 'expression': '1+(2+3)*4', + "expression": "1+(2+3)*4", }, ) except Exception as e: diff --git a/api/core/tools/provider/builtin/maths/tools/eval_expression.py b/api/core/tools/provider/builtin/maths/tools/eval_expression.py index bf73ed6918..0c5b5e41cb 100644 --- a/api/core/tools/provider/builtin/maths/tools/eval_expression.py +++ b/api/core/tools/provider/builtin/maths/tools/eval_expression.py @@ -8,22 +8,23 @@ from core.tools.tool.builtin_tool import BuiltinTool class EvaluateExpressionTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get expression - expression = tool_parameters.get('expression', '').strip() + expression = tool_parameters.get("expression", "").strip() if not expression: - return self.create_text_message('Invalid expression') + return self.create_text_message("Invalid expression") try: result = ne.evaluate(expression) result_str = str(result) except Exception as e: - logging.exception(f'Error evaluating expression: {expression}') - return self.create_text_message(f'Invalid expression: {expression}, error: {str(e)}') - return self.create_text_message(f'The result of the expression "{expression}" is {result_str}') \ No newline at end of file + logging.exception(f"Error evaluating expression: {expression}") + return self.create_text_message(f"Invalid expression: {expression}, error: {str(e)}") + return self.create_text_message(f'The result of the expression "{expression}" is {result_str}') diff --git a/api/core/tools/provider/builtin/nominatim/nominatim.py b/api/core/tools/provider/builtin/nominatim/nominatim.py index b6f29b5feb..5a24bed750 100644 --- a/api/core/tools/provider/builtin/nominatim/nominatim.py +++ b/api/core/tools/provider/builtin/nominatim/nominatim.py @@ -8,16 +8,20 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class NominatimProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - result = NominatimSearchTool().fork_tool_runtime( - runtime={ - "credentials": credentials, - } - ).invoke( - user_id='', - tool_parameters={ - 'query': 'London', - 'limit': 1, - }, + result = ( + NominatimSearchTool() + .fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ) + .invoke( + user_id="", + tool_parameters={ + "query": "London", + "limit": 1, + }, + ) ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/nominatim/tools/nominatim_lookup.py b/api/core/tools/provider/builtin/nominatim/tools/nominatim_lookup.py index e21ce14f54..ffa8ad0fcc 100644 --- a/api/core/tools/provider/builtin/nominatim/tools/nominatim_lookup.py +++ b/api/core/tools/provider/builtin/nominatim/tools/nominatim_lookup.py @@ -8,40 +8,33 @@ from core.tools.tool.builtin_tool import BuiltinTool class NominatimLookupTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - osm_ids = tool_parameters.get('osm_ids', '') - - if not osm_ids: - return self.create_text_message('Please provide OSM IDs') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + osm_ids = tool_parameters.get("osm_ids", "") - params = { - 'osm_ids': osm_ids, - 'format': 'json', - 'addressdetails': 1 - } - - return self._make_request(user_id, 'lookup', params) + if not osm_ids: + return self.create_text_message("Please provide OSM IDs") + + params = {"osm_ids": osm_ids, "format": "json", "addressdetails": 1} + + return self._make_request(user_id, "lookup", params) def _make_request(self, user_id: str, endpoint: str, params: dict) -> ToolInvokeMessage: - base_url = self.runtime.credentials.get('base_url', 'https://nominatim.openstreetmap.org') - + base_url = self.runtime.credentials.get("base_url", "https://nominatim.openstreetmap.org") + try: - headers = { - "User-Agent": "DifyNominatimTool/1.0" - } + headers = {"User-Agent": "DifyNominatimTool/1.0"} s = requests.session() - response = s.request( - method='GET', - headers=headers, - url=f"{base_url}/{endpoint}", - params=params - ) + response = s.request(method="GET", headers=headers, url=f"{base_url}/{endpoint}", params=params) response_data = response.json() - + if response.status_code == 200: s.close() - return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False))) + return self.create_text_message( + self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False)) + ) else: return self.create_text_message(f"Error: {response.status_code} - {response.text}") except Exception as e: - return self.create_text_message(f"An error occurred: {str(e)}") \ No newline at end of file + return self.create_text_message(f"An error occurred: {str(e)}") diff --git a/api/core/tools/provider/builtin/nominatim/tools/nominatim_reverse.py b/api/core/tools/provider/builtin/nominatim/tools/nominatim_reverse.py index 438d5219e9..f46691e1a3 100644 --- a/api/core/tools/provider/builtin/nominatim/tools/nominatim_reverse.py +++ b/api/core/tools/provider/builtin/nominatim/tools/nominatim_reverse.py @@ -8,42 +8,34 @@ from core.tools.tool.builtin_tool import BuiltinTool class NominatimReverseTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - lat = tool_parameters.get('lat') - lon = tool_parameters.get('lon') - - if lat is None or lon is None: - return self.create_text_message('Please provide both latitude and longitude') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + lat = tool_parameters.get("lat") + lon = tool_parameters.get("lon") - params = { - 'lat': lat, - 'lon': lon, - 'format': 'json', - 'addressdetails': 1 - } - - return self._make_request(user_id, 'reverse', params) + if lat is None or lon is None: + return self.create_text_message("Please provide both latitude and longitude") + + params = {"lat": lat, "lon": lon, "format": "json", "addressdetails": 1} + + return self._make_request(user_id, "reverse", params) def _make_request(self, user_id: str, endpoint: str, params: dict) -> ToolInvokeMessage: - base_url = self.runtime.credentials.get('base_url', 'https://nominatim.openstreetmap.org') - + base_url = self.runtime.credentials.get("base_url", "https://nominatim.openstreetmap.org") + try: - headers = { - "User-Agent": "DifyNominatimTool/1.0" - } + headers = {"User-Agent": "DifyNominatimTool/1.0"} s = requests.session() - response = s.request( - method='GET', - headers=headers, - url=f"{base_url}/{endpoint}", - params=params - ) + response = s.request(method="GET", headers=headers, url=f"{base_url}/{endpoint}", params=params) response_data = response.json() - + if response.status_code == 200: s.close() - return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False))) + return self.create_text_message( + self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False)) + ) else: return self.create_text_message(f"Error: {response.status_code} - {response.text}") except Exception as e: - return self.create_text_message(f"An error occurred: {str(e)}") \ No newline at end of file + return self.create_text_message(f"An error occurred: {str(e)}") diff --git a/api/core/tools/provider/builtin/nominatim/tools/nominatim_search.py b/api/core/tools/provider/builtin/nominatim/tools/nominatim_search.py index 983cbc0e34..34851d86dc 100644 --- a/api/core/tools/provider/builtin/nominatim/tools/nominatim_search.py +++ b/api/core/tools/provider/builtin/nominatim/tools/nominatim_search.py @@ -8,42 +8,34 @@ from core.tools.tool.builtin_tool import BuiltinTool class NominatimSearchTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - query = tool_parameters.get('query', '') - limit = tool_parameters.get('limit', 10) - - if not query: - return self.create_text_message('Please input a search query') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + query = tool_parameters.get("query", "") + limit = tool_parameters.get("limit", 10) - params = { - 'q': query, - 'format': 'json', - 'limit': limit, - 'addressdetails': 1 - } - - return self._make_request(user_id, 'search', params) + if not query: + return self.create_text_message("Please input a search query") + + params = {"q": query, "format": "json", "limit": limit, "addressdetails": 1} + + return self._make_request(user_id, "search", params) def _make_request(self, user_id: str, endpoint: str, params: dict) -> ToolInvokeMessage: - base_url = self.runtime.credentials.get('base_url', 'https://nominatim.openstreetmap.org') - + base_url = self.runtime.credentials.get("base_url", "https://nominatim.openstreetmap.org") + try: - headers = { - "User-Agent": "DifyNominatimTool/1.0" - } + headers = {"User-Agent": "DifyNominatimTool/1.0"} s = requests.session() - response = s.request( - method='GET', - headers=headers, - url=f"{base_url}/{endpoint}", - params=params - ) + response = s.request(method="GET", headers=headers, url=f"{base_url}/{endpoint}", params=params) response_data = response.json() - + if response.status_code == 200: s.close() - return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False))) + return self.create_text_message( + self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False)) + ) else: return self.create_text_message(f"Error: {response.status_code} - {response.text}") except Exception as e: - return self.create_text_message(f"An error occurred: {str(e)}") \ No newline at end of file + return self.create_text_message(f"An error occurred: {str(e)}") diff --git a/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py b/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py index b753be4791..762e158459 100644 --- a/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py +++ b/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py @@ -12,10 +12,10 @@ class NovitaAiToolBase: if not loras_str: return [] - loras_ori_list = lora_str.strip().split(';') + loras_ori_list = lora_str.strip().split(";") result_list = [] for lora_str in loras_ori_list: - lora_info = lora_str.strip().split(',') + lora_info = lora_str.strip().split(",") lora = Txt2ImgV3LoRA( model_name=lora_info[0].strip(), strength=float(lora_info[1]), @@ -28,43 +28,39 @@ class NovitaAiToolBase: if not embeddings_str: return [] - embeddings_ori_list = embeddings_str.strip().split(';') + embeddings_ori_list = embeddings_str.strip().split(";") result_list = [] for embedding_str in embeddings_ori_list: - embedding = Txt2ImgV3Embedding( - model_name=embedding_str.strip() - ) + embedding = Txt2ImgV3Embedding(model_name=embedding_str.strip()) result_list.append(embedding) return result_list def _extract_hires_fix(self, hires_fix_str: str): - hires_fix_info = hires_fix_str.strip().split(',') - if 'upscaler' in hires_fix_info: + hires_fix_info = hires_fix_str.strip().split(",") + if "upscaler" in hires_fix_info: hires_fix = Txt2ImgV3HiresFix( target_width=int(hires_fix_info[0]), target_height=int(hires_fix_info[1]), strength=float(hires_fix_info[2]), - upscaler=hires_fix_info[3].strip() + upscaler=hires_fix_info[3].strip(), ) else: hires_fix = Txt2ImgV3HiresFix( target_width=int(hires_fix_info[0]), target_height=int(hires_fix_info[1]), - strength=float(hires_fix_info[2]) + strength=float(hires_fix_info[2]), ) return hires_fix def _extract_refiner(self, switch_at: str): - refiner = Txt2ImgV3Refiner( - switch_at=float(switch_at) - ) + refiner = Txt2ImgV3Refiner(switch_at=float(switch_at)) return refiner def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool: """ - is hit nsfw + is hit nsfw """ if image.nsfw_detection_result is None: return False diff --git a/api/core/tools/provider/builtin/novitaai/novitaai.py b/api/core/tools/provider/builtin/novitaai/novitaai.py index 1e7d9757c3..d5e32eff29 100644 --- a/api/core/tools/provider/builtin/novitaai/novitaai.py +++ b/api/core/tools/provider/builtin/novitaai/novitaai.py @@ -8,23 +8,27 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class NovitaAIProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - result = NovitaAiTxt2ImgTool().fork_tool_runtime( - runtime={ - "credentials": credentials, - } - ).invoke( - user_id='', - tool_parameters={ - 'model_name': 'cinenautXLATRUE_cinenautV10_392434.safetensors', - 'prompt': 'a futuristic city with flying cars', - 'negative_prompt': '', - 'width': 128, - 'height': 128, - 'image_num': 1, - 'guidance_scale': 7.5, - 'seed': -1, - 'steps': 1, - }, + result = ( + NovitaAiTxt2ImgTool() + .fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ) + .invoke( + user_id="", + tool_parameters={ + "model_name": "cinenautXLATRUE_cinenautV10_392434.safetensors", + "prompt": "a futuristic city with flying cars", + "negative_prompt": "", + "width": 128, + "height": 128, + "image_num": 1, + "guidance_scale": 7.5, + "seed": -1, + "steps": 1, + }, + ) ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py index e63c891957..0b4f2edff3 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py @@ -12,17 +12,18 @@ from core.tools.tool.builtin_tool import BuiltinTool class NovitaAiCreateTileTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): raise ToolProviderCredentialValidationError("Novita AI API Key is required.") - api_key = self.runtime.credentials.get('api_key') + api_key = self.runtime.credentials.get("api_key") client = NovitaClient(api_key=api_key) param = self._process_parameters(tool_parameters) @@ -30,21 +31,23 @@ class NovitaAiCreateTileTool(BuiltinTool): results = [] results.append( - self.create_blob_message(blob=b64decode(client_result.image_file), - meta={'mime_type': f'image/{client_result.image_type}'}, - save_as=self.VARIABLE_KEY.IMAGE.value) + self.create_blob_message( + blob=b64decode(client_result.image_file), + meta={"mime_type": f"image/{client_result.image_type}"}, + save_as=self.VariableKey.IMAGE.value, + ) ) return results def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: """ - process parameters + process parameters """ res_parameters = deepcopy(parameters) # delete none and empty - keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == ''] + keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == ""] for k in keys_to_delete: del res_parameters[k] diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py index ec2927675e..fe105f70a7 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py @@ -12,127 +12,137 @@ from core.tools.tool.builtin_tool import BuiltinTool class NovitaAiModelQueryTool(BuiltinTool): - _model_query_endpoint = 'https://api.novita.ai/v3/model' + _model_query_endpoint = "https://api.novita.ai/v3/model" - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): raise ToolProviderCredentialValidationError("Novita AI API Key is required.") - api_key = self.runtime.credentials.get('api_key') - headers = { - 'Content-Type': 'application/json', - 'Authorization': "Bearer " + api_key - } + api_key = self.runtime.credentials.get("api_key") + headers = {"Content-Type": "application/json", "Authorization": "Bearer " + api_key} params = self._process_parameters(tool_parameters) - result_type = params.get('result_type') - del params['result_type'] + result_type = params.get("result_type") + del params["result_type"] models_data = self._query_models( models_data=[], headers=headers, params=params, - recursive=False if result_type == 'first sd_name' or result_type == 'first name sd_name pair' else True + recursive=False if result_type == "first sd_name" or result_type == "first name sd_name pair" else True, ) - result_str = '' - if result_type == 'first sd_name': - result_str = models_data[0]['sd_name_in_api'] if len(models_data) > 0 else '' - elif result_type == 'first name sd_name pair': - result_str = json.dumps({'name': models_data[0]['name'], 'sd_name': models_data[0]['sd_name_in_api']}) if len(models_data) > 0 else '' - elif result_type == 'sd_name array': - sd_name_array = [model['sd_name_in_api'] for model in models_data] if len(models_data) > 0 else [] + result_str = "" + if result_type == "first sd_name": + result_str = models_data[0]["sd_name_in_api"] if len(models_data) > 0 else "" + elif result_type == "first name sd_name pair": + result_str = ( + json.dumps({"name": models_data[0]["name"], "sd_name": models_data[0]["sd_name_in_api"]}) + if len(models_data) > 0 + else "" + ) + elif result_type == "sd_name array": + sd_name_array = [model["sd_name_in_api"] for model in models_data] if len(models_data) > 0 else [] result_str = json.dumps(sd_name_array) - elif result_type == 'name array': - name_array = [model['name'] for model in models_data] if len(models_data) > 0 else [] + elif result_type == "name array": + name_array = [model["name"] for model in models_data] if len(models_data) > 0 else [] result_str = json.dumps(name_array) - elif result_type == 'name sd_name pair array': - name_sd_name_pair_array = [{'name': model['name'], 'sd_name': model['sd_name_in_api']} - for model in models_data] if len(models_data) > 0 else [] + elif result_type == "name sd_name pair array": + name_sd_name_pair_array = ( + [{"name": model["name"], "sd_name": model["sd_name_in_api"]} for model in models_data] + if len(models_data) > 0 + else [] + ) result_str = json.dumps(name_sd_name_pair_array) - elif result_type == 'whole info array': + elif result_type == "whole info array": result_str = json.dumps(models_data) else: raise NotImplementedError return self.create_text_message(result_str) - def _query_models(self, models_data: list, headers: dict[str, Any], - params: dict[str, Any], pagination_cursor: str = '', recursive: bool = True) -> list: + def _query_models( + self, + models_data: list, + headers: dict[str, Any], + params: dict[str, Any], + pagination_cursor: str = "", + recursive: bool = True, + ) -> list: """ - query models + query models """ inside_params = deepcopy(params) - if pagination_cursor != '': - inside_params['pagination.cursor'] = pagination_cursor + if pagination_cursor != "": + inside_params["pagination.cursor"] = pagination_cursor response = ssrf_proxy.get( - url=str(URL(self._model_query_endpoint)), - headers=headers, - params=params, - timeout=(10, 60) + url=str(URL(self._model_query_endpoint)), headers=headers, params=params, timeout=(10, 60) ) res_data = response.json() - models_data.extend(res_data['models']) + models_data.extend(res_data["models"]) - res_data_len = len(res_data['models']) - if res_data_len == 0 or res_data_len < int(params['pagination.limit']) or recursive is False: + res_data_len = len(res_data["models"]) + if res_data_len == 0 or res_data_len < int(params["pagination.limit"]) or recursive is False: # deduplicate df = DataFrame.from_dict(models_data) - df_unique = df.drop_duplicates(subset=['id']) - models_data = df_unique.to_dict('records') + df_unique = df.drop_duplicates(subset=["id"]) + models_data = df_unique.to_dict("records") return models_data return self._query_models( models_data=models_data, headers=headers, params=inside_params, - pagination_cursor=res_data['pagination']['next_cursor'] + pagination_cursor=res_data["pagination"]["next_cursor"], ) def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: """ - process parameters + process parameters """ process_parameters = deepcopy(parameters) res_parameters = {} # delete none or empty - keys_to_delete = [k for k, v in process_parameters.items() if v is None or v == ''] + keys_to_delete = [k for k, v in process_parameters.items() if v is None or v == ""] for k in keys_to_delete: del process_parameters[k] - if 'query' in process_parameters and process_parameters.get('query') != 'unspecified': - res_parameters['filter.query'] = process_parameters['query'] + if "query" in process_parameters and process_parameters.get("query") != "unspecified": + res_parameters["filter.query"] = process_parameters["query"] - if 'visibility' in process_parameters and process_parameters.get('visibility') != 'unspecified': - res_parameters['filter.visibility'] = process_parameters['visibility'] + if "visibility" in process_parameters and process_parameters.get("visibility") != "unspecified": + res_parameters["filter.visibility"] = process_parameters["visibility"] - if 'source' in process_parameters and process_parameters.get('source') != 'unspecified': - res_parameters['filter.source'] = process_parameters['source'] + if "source" in process_parameters and process_parameters.get("source") != "unspecified": + res_parameters["filter.source"] = process_parameters["source"] - if 'type' in process_parameters and process_parameters.get('type') != 'unspecified': - res_parameters['filter.types'] = process_parameters['type'] + if "type" in process_parameters and process_parameters.get("type") != "unspecified": + res_parameters["filter.types"] = process_parameters["type"] - if 'is_sdxl' in process_parameters: - if process_parameters['is_sdxl'] == 'true': - res_parameters['filter.is_sdxl'] = True - elif process_parameters['is_sdxl'] == 'false': - res_parameters['filter.is_sdxl'] = False + if "is_sdxl" in process_parameters: + if process_parameters["is_sdxl"] == "true": + res_parameters["filter.is_sdxl"] = True + elif process_parameters["is_sdxl"] == "false": + res_parameters["filter.is_sdxl"] = False - res_parameters['result_type'] = process_parameters.get('result_type', 'first sd_name') + res_parameters["result_type"] = process_parameters.get("result_type", "first sd_name") - res_parameters['pagination.limit'] = 1 \ - if res_parameters.get('result_type') == 'first sd_name' \ - or res_parameters.get('result_type') == 'first name sd_name pair'\ + res_parameters["pagination.limit"] = ( + 1 + if res_parameters.get("result_type") == "first sd_name" + or res_parameters.get("result_type") == "first name sd_name pair" else 100 + ) return res_parameters diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py index 5fef3d2da7..9c61eab9f9 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py @@ -13,17 +13,18 @@ from core.tools.tool.builtin_tool import BuiltinTool class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): raise ToolProviderCredentialValidationError("Novita AI API Key is required.") - api_key = self.runtime.credentials.get('api_key') + api_key = self.runtime.credentials.get("api_key") client = NovitaClient(api_key=api_key) param = self._process_parameters(tool_parameters) @@ -32,56 +33,58 @@ class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase): results = [] for image_encoded, image in zip(client_result.images_encoded, client_result.images): if self._is_hit_nsfw_detection(image, 0.8): - results = self.create_text_message(text='NSFW detected!') + results = self.create_text_message(text="NSFW detected!") break results.append( - self.create_blob_message(blob=b64decode(image_encoded), - meta={'mime_type': f'image/{image.image_type}'}, - save_as=self.VARIABLE_KEY.IMAGE.value) + self.create_blob_message( + blob=b64decode(image_encoded), + meta={"mime_type": f"image/{image.image_type}"}, + save_as=self.VariableKey.IMAGE.value, + ) ) return results def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: """ - process parameters + process parameters """ res_parameters = deepcopy(parameters) # delete none and empty - keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == ''] + keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == ""] for k in keys_to_delete: del res_parameters[k] - if 'clip_skip' in res_parameters and res_parameters.get('clip_skip') == 0: - del res_parameters['clip_skip'] + if "clip_skip" in res_parameters and res_parameters.get("clip_skip") == 0: + del res_parameters["clip_skip"] - if 'refiner_switch_at' in res_parameters and res_parameters.get('refiner_switch_at') == 0: - del res_parameters['refiner_switch_at'] + if "refiner_switch_at" in res_parameters and res_parameters.get("refiner_switch_at") == 0: + del res_parameters["refiner_switch_at"] - if 'enabled_enterprise_plan' in res_parameters: - res_parameters['enterprise_plan'] = {'enabled': res_parameters['enabled_enterprise_plan']} - del res_parameters['enabled_enterprise_plan'] + if "enabled_enterprise_plan" in res_parameters: + res_parameters["enterprise_plan"] = {"enabled": res_parameters["enabled_enterprise_plan"]} + del res_parameters["enabled_enterprise_plan"] - if 'nsfw_detection_level' in res_parameters: - res_parameters['nsfw_detection_level'] = int(res_parameters['nsfw_detection_level']) + if "nsfw_detection_level" in res_parameters: + res_parameters["nsfw_detection_level"] = int(res_parameters["nsfw_detection_level"]) # process loras - if 'loras' in res_parameters: - res_parameters['loras'] = self._extract_loras(res_parameters.get('loras')) + if "loras" in res_parameters: + res_parameters["loras"] = self._extract_loras(res_parameters.get("loras")) # process embeddings - if 'embeddings' in res_parameters: - res_parameters['embeddings'] = self._extract_embeddings(res_parameters.get('embeddings')) + if "embeddings" in res_parameters: + res_parameters["embeddings"] = self._extract_embeddings(res_parameters.get("embeddings")) # process hires_fix - if 'hires_fix' in res_parameters: - res_parameters['hires_fix'] = self._extract_hires_fix(res_parameters.get('hires_fix')) + if "hires_fix" in res_parameters: + res_parameters["hires_fix"] = self._extract_hires_fix(res_parameters.get("hires_fix")) # process refiner - if 'refiner_switch_at' in res_parameters: - res_parameters['refiner'] = self._extract_refiner(res_parameters.get('refiner_switch_at')) - del res_parameters['refiner_switch_at'] + if "refiner_switch_at" in res_parameters: + res_parameters["refiner"] = self._extract_refiner(res_parameters.get("refiner_switch_at")) + del res_parameters["refiner_switch_at"] return res_parameters diff --git a/api/core/tools/provider/builtin/onebot/onebot.py b/api/core/tools/provider/builtin/onebot/onebot.py index 42f321e919..b8e5ed24d6 100644 --- a/api/core/tools/provider/builtin/onebot/onebot.py +++ b/api/core/tools/provider/builtin/onebot/onebot.py @@ -5,8 +5,6 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class OneBotProvider(BuiltinToolProviderController): - def _validate_credentials(self, credentials: dict[str, Any]) -> None: - if not credentials.get("ob11_http_url"): - raise ToolProviderCredentialValidationError('OneBot HTTP URL is required.') + raise ToolProviderCredentialValidationError("OneBot HTTP URL is required.") diff --git a/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py b/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py index 2a1a9f86de..9c95bbc2ae 100644 --- a/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py +++ b/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py @@ -11,54 +11,29 @@ class SendGroupMsg(BuiltinTool): """OneBot v11 Tool: Send Group Message""" def _invoke( - self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: # Get parameters - send_group_id = tool_parameters.get('group_id', '') - - message = tool_parameters.get('message', '') + send_group_id = tool_parameters.get("group_id", "") + + message = tool_parameters.get("message", "") if not message: - return self.create_json_message( - { - 'error': 'Message is empty.' - } - ) - - auto_escape = tool_parameters.get('auto_escape', False) + return self.create_json_message({"error": "Message is empty."}) + + auto_escape = tool_parameters.get("auto_escape", False) try: - url = URL(self.runtime.credentials['ob11_http_url']) / 'send_group_msg' + url = URL(self.runtime.credentials["ob11_http_url"]) / "send_group_msg" resp = requests.post( url, - json={ - 'group_id': send_group_id, - 'message': message, - 'auto_escape': auto_escape - }, - headers={ - 'Authorization': 'Bearer ' + self.runtime.credentials['access_token'] - } + json={"group_id": send_group_id, "message": message, "auto_escape": auto_escape}, + headers={"Authorization": "Bearer " + self.runtime.credentials["access_token"]}, ) if resp.status_code != 200: - return self.create_json_message( - { - 'error': f'Failed to send group message: {resp.text}' - } - ) + return self.create_json_message({"error": f"Failed to send group message: {resp.text}"}) - return self.create_json_message( - { - 'response': resp.json() - } - ) + return self.create_json_message({"response": resp.json()}) except Exception as e: - return self.create_json_message( - { - 'error': f'Failed to send group message: {e}' - } - ) + return self.create_json_message({"error": f"Failed to send group message: {e}"}) diff --git a/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py b/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py index 8ef4d72ab6..1174c7f07d 100644 --- a/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py +++ b/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py @@ -11,54 +11,29 @@ class SendPrivateMsg(BuiltinTool): """OneBot v11 Tool: Send Private Message""" def _invoke( - self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: # Get parameters - send_user_id = tool_parameters.get('user_id', '') - - message = tool_parameters.get('message', '') + send_user_id = tool_parameters.get("user_id", "") + + message = tool_parameters.get("message", "") if not message: - return self.create_json_message( - { - 'error': 'Message is empty.' - } - ) - - auto_escape = tool_parameters.get('auto_escape', False) + return self.create_json_message({"error": "Message is empty."}) + + auto_escape = tool_parameters.get("auto_escape", False) try: - url = URL(self.runtime.credentials['ob11_http_url']) / 'send_private_msg' + url = URL(self.runtime.credentials["ob11_http_url"]) / "send_private_msg" resp = requests.post( url, - json={ - 'user_id': send_user_id, - 'message': message, - 'auto_escape': auto_escape - }, - headers={ - 'Authorization': 'Bearer ' + self.runtime.credentials['access_token'] - } + json={"user_id": send_user_id, "message": message, "auto_escape": auto_escape}, + headers={"Authorization": "Bearer " + self.runtime.credentials["access_token"]}, ) if resp.status_code != 200: - return self.create_json_message( - { - 'error': f'Failed to send private message: {resp.text}' - } - ) - - return self.create_json_message( - { - 'response': resp.json() - } - ) + return self.create_json_message({"error": f"Failed to send private message: {resp.text}"}) + + return self.create_json_message({"response": resp.json()}) except Exception as e: - return self.create_json_message( - { - 'error': f'Failed to send private message: {e}' - } - ) \ No newline at end of file + return self.create_json_message({"error": f"Failed to send private message: {e}"}) diff --git a/api/core/tools/provider/builtin/openweather/openweather.py b/api/core/tools/provider/builtin/openweather/openweather.py index a2827177a3..9e40249aba 100644 --- a/api/core/tools/provider/builtin/openweather/openweather.py +++ b/api/core/tools/provider/builtin/openweather/openweather.py @@ -5,7 +5,6 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl def query_weather(city="Beijing", units="metric", language="zh_cn", api_key=None): - url = "https://api.openweathermap.org/data/2.5/weather" params = {"q": city, "appid": api_key, "units": units, "lang": language} @@ -16,21 +15,15 @@ class OpenweatherProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: if "api_key" not in credentials or not credentials.get("api_key"): - raise ToolProviderCredentialValidationError( - "Open weather API key is required." - ) + raise ToolProviderCredentialValidationError("Open weather API key is required.") apikey = credentials.get("api_key") try: response = query_weather(api_key=apikey) if response.status_code == 200: pass else: - raise ToolProviderCredentialValidationError( - (response.json()).get("info") - ) + raise ToolProviderCredentialValidationError((response.json()).get("info")) except Exception as e: - raise ToolProviderCredentialValidationError( - "Open weather API Key is invalid. {}".format(e) - ) + raise ToolProviderCredentialValidationError("Open weather API Key is invalid. {}".format(e)) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/openweather/tools/weather.py b/api/core/tools/provider/builtin/openweather/tools/weather.py index 536a3511f4..ed4ec487fa 100644 --- a/api/core/tools/provider/builtin/openweather/tools/weather.py +++ b/api/core/tools/provider/builtin/openweather/tools/weather.py @@ -17,10 +17,7 @@ class OpenweatherTool(BuiltinTool): city = tool_parameters.get("city", "") if not city: return self.create_text_message("Please tell me your city") - if ( - "api_key" not in self.runtime.credentials - or not self.runtime.credentials.get("api_key") - ): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): return self.create_text_message("OpenWeather API key is required.") units = tool_parameters.get("units", "metric") @@ -29,7 +26,7 @@ class OpenweatherTool(BuiltinTool): # request URL url = "https://api.openweathermap.org/data/2.5/weather" - # request parmas + # request params params = { "q": city, "appid": self.runtime.credentials.get("api_key"), @@ -39,12 +36,9 @@ class OpenweatherTool(BuiltinTool): response = requests.get(url, params=params) if response.status_code == 200: - data = response.json() return self.create_text_message( - self.summary( - user_id=user_id, content=json.dumps(data, ensure_ascii=False) - ) + self.summary(user_id=user_id, content=json.dumps(data, ensure_ascii=False)) ) else: error_message = { @@ -55,6 +49,4 @@ class OpenweatherTool(BuiltinTool): return json.dumps(error_message) except Exception as e: - return self.create_text_message( - "Openweather API Key is invalid. {}".format(e) - ) + return self.create_text_message("Openweather API Key is invalid. {}".format(e)) diff --git a/api/core/tools/provider/builtin/perplexity/_assets/icon.svg b/api/core/tools/provider/builtin/perplexity/_assets/icon.svg new file mode 100644 index 0000000000..c2974c142f --- /dev/null +++ b/api/core/tools/provider/builtin/perplexity/_assets/icon.svg @@ -0,0 +1,3 @@ + + + diff --git a/api/core/tools/provider/builtin/perplexity/perplexity.py b/api/core/tools/provider/builtin/perplexity/perplexity.py new file mode 100644 index 0000000000..80518853fb --- /dev/null +++ b/api/core/tools/provider/builtin/perplexity/perplexity.py @@ -0,0 +1,38 @@ +from typing import Any + +import requests + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.perplexity.tools.perplexity_search import PERPLEXITY_API_URL +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class PerplexityProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + headers = { + "Authorization": f"Bearer {credentials.get('perplexity_api_key')}", + "Content-Type": "application/json", + } + + payload = { + "model": "llama-3.1-sonar-small-128k-online", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}, + ], + "max_tokens": 5, + "temperature": 0.1, + "top_p": 0.9, + "stream": False, + } + + try: + response = requests.post(PERPLEXITY_API_URL, json=payload, headers=headers) + response.raise_for_status() + except requests.RequestException as e: + raise ToolProviderCredentialValidationError(f"Failed to validate Perplexity API key: {str(e)}") + + if response.status_code != 200: + raise ToolProviderCredentialValidationError( + f"Perplexity API key is invalid. Status code: {response.status_code}" + ) diff --git a/api/core/tools/provider/builtin/perplexity/perplexity.yaml b/api/core/tools/provider/builtin/perplexity/perplexity.yaml new file mode 100644 index 0000000000..c0b504f300 --- /dev/null +++ b/api/core/tools/provider/builtin/perplexity/perplexity.yaml @@ -0,0 +1,26 @@ +identity: + author: Dify + name: perplexity + label: + en_US: Perplexity + zh_Hans: Perplexity + description: + en_US: Perplexity.AI + zh_Hans: Perplexity.AI + icon: icon.svg + tags: + - search +credentials_for_provider: + perplexity_api_key: + type: secret-input + required: true + label: + en_US: Perplexity API key + zh_Hans: Perplexity API key + placeholder: + en_US: Please input your Perplexity API key + zh_Hans: 请输入你的 Perplexity API key + help: + en_US: Get your Perplexity API key from Perplexity + zh_Hans: 从 Perplexity 获取您的 Perplexity API key + url: https://www.perplexity.ai/settings/api diff --git a/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.py b/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.py new file mode 100644 index 0000000000..5ed4b9ca99 --- /dev/null +++ b/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.py @@ -0,0 +1,67 @@ +import json +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +PERPLEXITY_API_URL = "https://api.perplexity.ai/chat/completions" + + +class PerplexityAITool(BuiltinTool): + def _parse_response(self, response: dict) -> dict: + """Parse the response from Perplexity AI API""" + if "choices" in response and len(response["choices"]) > 0: + message = response["choices"][0]["message"] + return { + "content": message.get("content", ""), + "role": message.get("role", ""), + "citations": response.get("citations", []), + } + else: + return {"content": "Unable to get a valid response", "role": "assistant", "citations": []} + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + headers = { + "Authorization": f"Bearer {self.runtime.credentials['perplexity_api_key']}", + "Content-Type": "application/json", + } + + payload = { + "model": tool_parameters.get("model", "llama-3.1-sonar-small-128k-online"), + "messages": [ + {"role": "system", "content": "Be precise and concise."}, + {"role": "user", "content": tool_parameters["query"]}, + ], + "max_tokens": tool_parameters.get("max_tokens", 4096), + "temperature": tool_parameters.get("temperature", 0.7), + "top_p": tool_parameters.get("top_p", 1), + "top_k": tool_parameters.get("top_k", 5), + "presence_penalty": tool_parameters.get("presence_penalty", 0), + "frequency_penalty": tool_parameters.get("frequency_penalty", 1), + "stream": False, + } + + if "search_recency_filter" in tool_parameters: + payload["search_recency_filter"] = tool_parameters["search_recency_filter"] + if "return_citations" in tool_parameters: + payload["return_citations"] = tool_parameters["return_citations"] + if "search_domain_filter" in tool_parameters: + if isinstance(tool_parameters["search_domain_filter"], str): + payload["search_domain_filter"] = [tool_parameters["search_domain_filter"]] + elif isinstance(tool_parameters["search_domain_filter"], list): + payload["search_domain_filter"] = tool_parameters["search_domain_filter"] + + response = requests.post(url=PERPLEXITY_API_URL, json=payload, headers=headers) + response.raise_for_status() + valuable_res = self._parse_response(response.json()) + + return [ + self.create_json_message(valuable_res), + self.create_text_message(json.dumps(valuable_res, ensure_ascii=False, indent=2)), + ] diff --git a/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.yaml b/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.yaml new file mode 100644 index 0000000000..02a645df33 --- /dev/null +++ b/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.yaml @@ -0,0 +1,178 @@ +identity: + name: perplexity + author: Dify + label: + en_US: Perplexity Search +description: + human: + en_US: Search information using Perplexity AI's language models. + llm: This tool is used to search information using Perplexity AI's language models. +parameters: + - name: query + type: string + required: true + label: + en_US: Query + zh_Hans: 查询 + human_description: + en_US: The text query to be processed by the AI model. + zh_Hans: 要由 AI 模型处理的文本查询。 + form: llm + - name: model + type: select + required: false + label: + en_US: Model Name + zh_Hans: 模型名称 + human_description: + en_US: The Perplexity AI model to use for generating the response. + zh_Hans: 用于生成响应的 Perplexity AI 模型。 + form: form + default: "llama-3.1-sonar-small-128k-online" + options: + - value: llama-3.1-sonar-small-128k-online + label: + en_US: llama-3.1-sonar-small-128k-online + zh_Hans: llama-3.1-sonar-small-128k-online + - value: llama-3.1-sonar-large-128k-online + label: + en_US: llama-3.1-sonar-large-128k-online + zh_Hans: llama-3.1-sonar-large-128k-online + - value: llama-3.1-sonar-huge-128k-online + label: + en_US: llama-3.1-sonar-huge-128k-online + zh_Hans: llama-3.1-sonar-huge-128k-online + - name: max_tokens + type: number + required: false + label: + en_US: Max Tokens + zh_Hans: 最大令牌数 + pt_BR: Máximo de Tokens + human_description: + en_US: The maximum number of tokens to generate in the response. + zh_Hans: 在响应中生成的最大令牌数。 + pt_BR: O número máximo de tokens a serem gerados na resposta. + form: form + default: 4096 + min: 1 + max: 4096 + - name: temperature + type: number + required: false + label: + en_US: Temperature + zh_Hans: 温度 + pt_BR: Temperatura + human_description: + en_US: Controls randomness in the output. Lower values make the output more focused and deterministic. + zh_Hans: 控制输出的随机性。较低的值使输出更加集中和确定。 + form: form + default: 0.7 + min: 0 + max: 1 + - name: top_k + type: number + required: false + label: + en_US: Top K + zh_Hans: 取样数量 + human_description: + en_US: The number of top results to consider for response generation. + zh_Hans: 用于生成响应的顶部结果数量。 + form: form + default: 5 + min: 1 + max: 100 + - name: top_p + type: number + required: false + label: + en_US: Top P + zh_Hans: Top P + human_description: + en_US: Controls diversity via nucleus sampling. + zh_Hans: 通过核心采样控制多样性。 + form: form + default: 1 + min: 0.1 + max: 1 + step: 0.1 + - name: presence_penalty + type: number + required: false + label: + en_US: Presence Penalty + zh_Hans: 存在惩罚 + human_description: + en_US: Positive values penalize new tokens based on whether they appear in the text so far. + zh_Hans: 正值会根据新词元是否已经出现在文本中来对其进行惩罚。 + form: form + default: 0 + min: -1.0 + max: 1.0 + step: 0.1 + - name: frequency_penalty + type: number + required: false + label: + en_US: Frequency Penalty + zh_Hans: 频率惩罚 + human_description: + en_US: Positive values penalize new tokens based on their existing frequency in the text so far. + zh_Hans: 正值会根据新词元在文本中已经出现的频率来对其进行惩罚。 + form: form + default: 1 + min: 0.1 + max: 1.0 + step: 0.1 + - name: return_citations + type: boolean + required: false + label: + en_US: Return Citations + zh_Hans: 返回引用 + human_description: + en_US: Whether to return citations in the response. + zh_Hans: 是否在响应中返回引用。 + form: form + default: true + - name: search_domain_filter + type: string + required: false + label: + en_US: Search Domain Filter + zh_Hans: 搜索域过滤器 + human_description: + en_US: Domain to filter the search results. + zh_Hans: 用于过滤搜索结果的域名。 + form: form + default: "" + - name: search_recency_filter + type: select + required: false + label: + en_US: Search Recency Filter + zh_Hans: 搜索时间过滤器 + human_description: + en_US: Filter for search results based on recency. + zh_Hans: 基于时间筛选搜索结果。 + form: form + default: "month" + options: + - value: day + label: + en_US: Day + zh_Hans: 天 + - value: week + label: + en_US: Week + zh_Hans: 周 + - value: month + label: + en_US: Month + zh_Hans: 月 + - value: year + label: + en_US: Year + zh_Hans: 年 diff --git a/api/core/tools/provider/builtin/pubmed/pubmed.py b/api/core/tools/provider/builtin/pubmed/pubmed.py index 05cd171b87..ea3a477c30 100644 --- a/api/core/tools/provider/builtin/pubmed/pubmed.py +++ b/api/core/tools/provider/builtin/pubmed/pubmed.py @@ -11,11 +11,10 @@ class PubMedProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "John Doe", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py b/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py index 58811d65e6..fedfdbd859 100644 --- a/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py +++ b/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py @@ -51,17 +51,12 @@ class PubMedAPIWrapper(BaseModel): try: # Retrieve the top-k results for the query docs = [ - f"Published: {result['pub_date']}\nTitle: {result['title']}\n" - f"Summary: {result['summary']}" + f"Published: {result['pub_date']}\nTitle: {result['title']}\n" f"Summary: {result['summary']}" for result in self.load(query[: self.ARXIV_MAX_QUERY_LENGTH]) ] # Join the results and limit the character count - return ( - "\n\n".join(docs)[:self.doc_content_chars_max] - if docs - else "No good PubMed Result was found" - ) + return "\n\n".join(docs)[: self.doc_content_chars_max] if docs else "No good PubMed Result was found" except Exception as ex: return f"PubMed exception: {ex}" @@ -91,13 +86,7 @@ class PubMedAPIWrapper(BaseModel): return articles def retrieve_article(self, uid: str, webenv: str) -> dict: - url = ( - self.base_url_efetch - + "db=pubmed&retmode=xml&id=" - + uid - + "&webenv=" - + webenv - ) + url = self.base_url_efetch + "db=pubmed&retmode=xml&id=" + uid + "&webenv=" + webenv retry = 0 while True: @@ -108,10 +97,7 @@ class PubMedAPIWrapper(BaseModel): if e.code == 429 and retry < self.max_retry: # Too Many Requests error # wait for an exponentially increasing amount of time - print( - f"Too Many Requests, " - f"waiting for {self.sleep_time:.2f} seconds..." - ) + print(f"Too Many Requests, " f"waiting for {self.sleep_time:.2f} seconds...") time.sleep(self.sleep_time) self.sleep_time *= 2 retry += 1 @@ -125,27 +111,21 @@ class PubMedAPIWrapper(BaseModel): if "" in xml_text and "" in xml_text: start_tag = "" end_tag = "" - title = xml_text[ - xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag) - ] + title = xml_text[xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)] # Get abstract abstract = "" if "" in xml_text and "" in xml_text: start_tag = "" end_tag = "" - abstract = xml_text[ - xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag) - ] + abstract = xml_text[xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)] # Get publication date pub_date = "" if "" in xml_text and "" in xml_text: start_tag = "" end_tag = "" - pub_date = xml_text[ - xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag) - ] + pub_date = xml_text[xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)] # Return article as dictionary article = { @@ -182,6 +162,7 @@ class PubmedQueryRun(BaseModel): class PubMedInput(BaseModel): query: str = Field(..., description="Search query.") + class PubMedSearchTool(BuiltinTool): """ Tool for performing a search using PubMed search engine. @@ -198,14 +179,13 @@ class PubMedSearchTool(BuiltinTool): Returns: ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation. """ - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Please input query') + return self.create_text_message("Please input query") tool = PubmedQueryRun(args_schema=PubMedInput) result = tool._run(query) return self.create_text_message(self.summary(user_id=user_id, content=result)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/qrcode/qrcode.py b/api/core/tools/provider/builtin/qrcode/qrcode.py index 9fa7d01265..8466b9a26b 100644 --- a/api/core/tools/provider/builtin/qrcode/qrcode.py +++ b/api/core/tools/provider/builtin/qrcode/qrcode.py @@ -8,9 +8,6 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class QRCodeProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - QRCodeGeneratorTool().invoke(user_id='', - tool_parameters={ - 'content': 'Dify 123 😊' - }) + QRCodeGeneratorTool().invoke(user_id="", tool_parameters={"content": "Dify 123 😊"}) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py index 5eede98f5e..cac59f76d8 100644 --- a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py +++ b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py @@ -13,43 +13,44 @@ from core.tools.tool.builtin_tool import BuiltinTool class QRCodeGeneratorTool(BuiltinTool): error_correction_levels: dict[str, int] = { - 'L': ERROR_CORRECT_L, # <=7% - 'M': ERROR_CORRECT_M, # <=15% - 'Q': ERROR_CORRECT_Q, # <=25% - 'H': ERROR_CORRECT_H, # <=30% + "L": ERROR_CORRECT_L, # <=7% + "M": ERROR_CORRECT_M, # <=15% + "Q": ERROR_CORRECT_Q, # <=25% + "H": ERROR_CORRECT_H, # <=30% } - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get text content - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") # get border size - border = tool_parameters.get('border', 0) + border = tool_parameters.get("border", 0) if border < 0 or border > 100: - return self.create_text_message('Invalid parameter border') + return self.create_text_message("Invalid parameter border") # get error_correction - error_correction = tool_parameters.get('error_correction', '') + error_correction = tool_parameters.get("error_correction", "") if error_correction not in self.error_correction_levels.keys(): - return self.create_text_message('Invalid parameter error_correction') + return self.create_text_message("Invalid parameter error_correction") try: image = self._generate_qrcode(content, border, error_correction) image_bytes = self._image_to_byte_array(image) - return self.create_blob_message(blob=image_bytes, - meta={'mime_type': 'image/png'}, - save_as=self.VARIABLE_KEY.IMAGE.value) + return self.create_blob_message( + blob=image_bytes, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value + ) except Exception: - logging.exception(f'Failed to generate QR code for content: {content}') - return self.create_text_message('Failed to generate QR code') + logging.exception(f"Failed to generate QR code for content: {content}") + return self.create_text_message("Failed to generate QR code") def _generate_qrcode(self, content: str, border: int, error_correction: str) -> BaseImage: qr = QRCode( diff --git a/api/core/tools/provider/builtin/regex/regex.py b/api/core/tools/provider/builtin/regex/regex.py index d38ae1b292..c498105979 100644 --- a/api/core/tools/provider/builtin/regex/regex.py +++ b/api/core/tools/provider/builtin/regex/regex.py @@ -9,10 +9,10 @@ class RegexProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: RegexExpressionTool().invoke( - user_id='', + user_id="", tool_parameters={ - 'content': '1+(2+3)*4', - 'expression': r'(\d+)', + "content": "1+(2+3)*4", + "expression": r"(\d+)", }, ) except Exception as e: diff --git a/api/core/tools/provider/builtin/regex/tools/regex_extract.py b/api/core/tools/provider/builtin/regex/tools/regex_extract.py index 5d8f013d0d..786b469404 100644 --- a/api/core/tools/provider/builtin/regex/tools/regex_extract.py +++ b/api/core/tools/provider/builtin/regex/tools/regex_extract.py @@ -6,22 +6,23 @@ from core.tools.tool.builtin_tool import BuiltinTool class RegexExpressionTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get expression - content = tool_parameters.get('content', '').strip() + content = tool_parameters.get("content", "").strip() if not content: - return self.create_text_message('Invalid content') - expression = tool_parameters.get('expression', '').strip() + return self.create_text_message("Invalid content") + expression = tool_parameters.get("expression", "").strip() if not expression: - return self.create_text_message('Invalid expression') + return self.create_text_message("Invalid expression") try: result = re.findall(expression, content) return self.create_text_message(str(result)) except Exception as e: - return self.create_text_message(f'Failed to extract result, error: {str(e)}') \ No newline at end of file + return self.create_text_message(f"Failed to extract result, error: {str(e)}") diff --git a/api/core/tools/provider/builtin/searchapi/searchapi.py b/api/core/tools/provider/builtin/searchapi/searchapi.py index 6fa4f05acd..109bba8b2d 100644 --- a/api/core/tools/provider/builtin/searchapi/searchapi.py +++ b/api/core/tools/provider/builtin/searchapi/searchapi.py @@ -13,11 +13,8 @@ class SearchAPIProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "query": "SearchApi dify", - "result_type": "link" - }, + user_id="", + tool_parameters={"query": "SearchApi dify", "result_type": "link"}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/searchapi/tools/google.py b/api/core/tools/provider/builtin/searchapi/tools/google.py index dd780aeadc..d632304a46 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google.py @@ -7,6 +7,7 @@ from core.tools.tool.builtin_tool import BuiltinTool SEARCH_API_URL = "https://www.searchapi.io/api/v1/search" + class SearchAPI: """ SearchAPI tool provider. @@ -80,25 +81,29 @@ class SearchAPI: toret = "No good search result found" return toret + class GoogleTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the SearchApi tool. """ - query = tool_parameters['query'] - result_type = tool_parameters['result_type'] + query = tool_parameters["query"] + result_type = tool_parameters["result_type"] num = tool_parameters.get("num", 10) google_domain = tool_parameters.get("google_domain", "google.com") gl = tool_parameters.get("gl", "us") hl = tool_parameters.get("hl", "en") location = tool_parameters.get("location") - api_key = self.runtime.credentials['searchapi_api_key'] - result = SearchAPI(api_key).run(query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location) + api_key = self.runtime.credentials["searchapi_api_key"] + result = SearchAPI(api_key).run( + query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location + ) - if result_type == 'text': + if result_type == "text": return self.create_text_message(text=result) return self.create_link_message(link=result) diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py index 81c67c51a9..1544061c08 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py @@ -7,6 +7,7 @@ from core.tools.tool.builtin_tool import BuiltinTool SEARCH_API_URL = "https://www.searchapi.io/api/v1/search" + class SearchAPI: """ SearchAPI tool provider. @@ -50,7 +51,16 @@ class SearchAPI: if type == "text": if "jobs" in res.keys() and "title" in res["jobs"][0].keys(): for item in res["jobs"]: - toret += "title: " + item["title"] + "\n" + "company_name: " + item["company_name"] + "content: " + item["description"] + "\n" + toret += ( + "title: " + + item["title"] + + "\n" + + "company_name: " + + item["company_name"] + + "content: " + + item["description"] + + "\n" + ) if toret == "": toret = "No good search result found" @@ -62,16 +72,18 @@ class SearchAPI: toret = "No good search result found" return toret + class GoogleJobsTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the SearchApi tool. """ - query = tool_parameters['query'] - result_type = tool_parameters['result_type'] + query = tool_parameters["query"] + result_type = tool_parameters["result_type"] is_remote = tool_parameters.get("is_remote") google_domain = tool_parameters.get("google_domain", "google.com") gl = tool_parameters.get("gl", "us") @@ -80,9 +92,11 @@ class GoogleJobsTool(BuiltinTool): ltype = 1 if is_remote else None - api_key = self.runtime.credentials['searchapi_api_key'] - result = SearchAPI(api_key).run(query, result_type=result_type, google_domain=google_domain, gl=gl, hl=hl, location=location, ltype=ltype) + api_key = self.runtime.credentials["searchapi_api_key"] + result = SearchAPI(api_key).run( + query, result_type=result_type, google_domain=google_domain, gl=gl, hl=hl, location=location, ltype=ltype + ) - if result_type == 'text': + if result_type == "text": return self.create_text_message(text=result) return self.create_link_message(link=result) diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_news.py b/api/core/tools/provider/builtin/searchapi/tools/google_news.py index 5d2657dddd..95a7aad736 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_news.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google_news.py @@ -7,6 +7,7 @@ from core.tools.tool.builtin_tool import BuiltinTool SEARCH_API_URL = "https://www.searchapi.io/api/v1/search" + class SearchAPI: """ SearchAPI tool provider. @@ -68,25 +69,29 @@ class SearchAPI: toret = "No good search result found" return toret + class GoogleNewsTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the SearchApi tool. """ - query = tool_parameters['query'] - result_type = tool_parameters['result_type'] + query = tool_parameters["query"] + result_type = tool_parameters["result_type"] num = tool_parameters.get("num", 10) google_domain = tool_parameters.get("google_domain", "google.com") gl = tool_parameters.get("gl", "us") hl = tool_parameters.get("hl", "en") location = tool_parameters.get("location") - api_key = self.runtime.credentials['searchapi_api_key'] - result = SearchAPI(api_key).run(query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location) + api_key = self.runtime.credentials["searchapi_api_key"] + result = SearchAPI(api_key).run( + query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location + ) - if result_type == 'text': + if result_type == "text": return self.create_text_message(text=result) return self.create_link_message(link=result) diff --git a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py index 6345b33801..88def504fc 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py +++ b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py @@ -7,6 +7,7 @@ from core.tools.tool.builtin_tool import BuiltinTool SEARCH_API_URL = "https://www.searchapi.io/api/v1/search" + class SearchAPI: """ SearchAPI tool provider. @@ -55,18 +56,20 @@ class SearchAPI: return toret + class YoutubeTranscriptsTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the SearchApi tool. """ - video_id = tool_parameters['video_id'] - language = tool_parameters.get('language', "en") + video_id = tool_parameters["video_id"] + language = tool_parameters.get("language", "en") - api_key = self.runtime.credentials['searchapi_api_key'] + api_key = self.runtime.credentials["searchapi_api_key"] result = SearchAPI(api_key).run(video_id, language=language) return self.create_text_message(text=result) diff --git a/api/core/tools/provider/builtin/searxng/searxng.py b/api/core/tools/provider/builtin/searxng/searxng.py index ab354003e6..b7bbcc60b1 100644 --- a/api/core/tools/provider/builtin/searxng/searxng.py +++ b/api/core/tools/provider/builtin/searxng/searxng.py @@ -13,12 +13,8 @@ class SearXNGProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "query": "SearXNG", - "limit": 1, - "search_type": "general" - }, + user_id="", + tool_parameters={"query": "SearXNG", "limit": 1, "search_type": "general"}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/searxng/tools/searxng_search.py b/api/core/tools/provider/builtin/searxng/tools/searxng_search.py index dc835a8e8c..c5e339a108 100644 --- a/api/core/tools/provider/builtin/searxng/tools/searxng_search.py +++ b/api/core/tools/provider/builtin/searxng/tools/searxng_search.py @@ -23,18 +23,21 @@ class SearXNGSearchTool(BuiltinTool): ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation. """ - host = self.runtime.credentials.get('searxng_base_url') + host = self.runtime.credentials.get("searxng_base_url") if not host: - raise Exception('SearXNG api is required') + raise Exception("SearXNG api is required") - response = requests.get(host, params={ - "q": tool_parameters.get('query'), - "format": "json", - "categories": tool_parameters.get('search_type', 'general') - }) + response = requests.get( + host, + params={ + "q": tool_parameters.get("query"), + "format": "json", + "categories": tool_parameters.get("search_type", "general"), + }, + ) if response.status_code != 200: - raise Exception(f'Error {response.status_code}: {response.text}') + raise Exception(f"Error {response.status_code}: {response.text}") res = response.json().get("results", []) if not res: diff --git a/api/core/tools/provider/builtin/serper/serper.py b/api/core/tools/provider/builtin/serper/serper.py index 2a42109373..cb1d090a9d 100644 --- a/api/core/tools/provider/builtin/serper/serper.py +++ b/api/core/tools/provider/builtin/serper/serper.py @@ -13,11 +13,8 @@ class SerperProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "query": "test", - "result_type": "link" - }, + user_id="", + tool_parameters={"query": "test", "result_type": "link"}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/serper/tools/serper_search.py b/api/core/tools/provider/builtin/serper/tools/serper_search.py index 24facaf4ec..7baebbf958 100644 --- a/api/core/tools/provider/builtin/serper/tools/serper_search.py +++ b/api/core/tools/provider/builtin/serper/tools/serper_search.py @@ -9,7 +9,6 @@ SERPER_API_URL = "https://google.serper.dev/search" class SerperSearchTool(BuiltinTool): - def _parse_response(self, response: dict) -> dict: result = {} if "knowledgeGraph" in response: @@ -17,28 +16,19 @@ class SerperSearchTool(BuiltinTool): result["description"] = response["knowledgeGraph"].get("description", "") if "organic" in response: result["organic"] = [ - { - "title": item.get("title", ""), - "link": item.get("link", ""), - "snippet": item.get("snippet", "") - } + {"title": item.get("title", ""), "link": item.get("link", ""), "snippet": item.get("snippet", "")} for item in response["organic"] ] return result - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - params = { - "q": tool_parameters['query'], - "gl": "us", - "hl": "en" - } - headers = { - 'X-API-KEY': self.runtime.credentials['serperapi_api_key'], - 'Content-Type': 'application/json' - } - response = requests.get(url=SERPER_API_URL, params=params,headers=headers) + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + params = {"q": tool_parameters["query"], "gl": "us", "hl": "en"} + headers = {"X-API-KEY": self.runtime.credentials["serperapi_api_key"], "Content-Type": "application/json"} + response = requests.get(url=SERPER_API_URL, params=params, headers=headers) response.raise_for_status() valuable_res = self._parse_response(response.json()) return self.create_json_message(valuable_res) diff --git a/api/core/tools/provider/builtin/siliconflow/siliconflow.py b/api/core/tools/provider/builtin/siliconflow/siliconflow.py index 0df78280df..37a0b0755b 100644 --- a/api/core/tools/provider/builtin/siliconflow/siliconflow.py +++ b/api/core/tools/provider/builtin/siliconflow/siliconflow.py @@ -14,6 +14,4 @@ class SiliconflowProvider(BuiltinToolProviderController): response = requests.get(url, headers=headers) if response.status_code != 200: - raise ToolProviderCredentialValidationError( - "SiliconFlow API key is invalid" - ) + raise ToolProviderCredentialValidationError("SiliconFlow API key is invalid") diff --git a/api/core/tools/provider/builtin/siliconflow/tools/flux.py b/api/core/tools/provider/builtin/siliconflow/tools/flux.py index ed9f4be574..1b846624bd 100644 --- a/api/core/tools/provider/builtin/siliconflow/tools/flux.py +++ b/api/core/tools/provider/builtin/siliconflow/tools/flux.py @@ -5,17 +5,13 @@ import requests from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool -FLUX_URL = ( - "https://api.siliconflow.cn/v1/black-forest-labs/FLUX.1-schnell/text-to-image" -) +FLUX_URL = "https://api.siliconflow.cn/v1/black-forest-labs/FLUX.1-schnell/text-to-image" class FluxTool(BuiltinTool): - def _invoke( self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - headers = { "accept": "application/json", "content-type": "application/json", @@ -36,9 +32,5 @@ class FluxTool(BuiltinTool): res = response.json() result = [self.create_json_message(res)] for image in res.get("images", []): - result.append( - self.create_image_message( - image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value - ) - ) + result.append(self.create_image_message(image=image.get("url"), save_as=self.VariableKey.IMAGE.value)) return result diff --git a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py index e8134a6565..d6a0b03d1b 100644 --- a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py @@ -12,11 +12,9 @@ SDURL = { class StableDiffusionTool(BuiltinTool): - def _invoke( self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - headers = { "accept": "application/json", "content-type": "application/json", @@ -43,9 +41,5 @@ class StableDiffusionTool(BuiltinTool): res = response.json() result = [self.create_json_message(res)] for image in res.get("images", []): - result.append( - self.create_image_message( - image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value - ) - ) + result.append(self.create_image_message(image=image.get("url"), save_as=self.VariableKey.IMAGE.value)) return result diff --git a/api/core/tools/provider/builtin/slack/tools/slack_webhook.py b/api/core/tools/provider/builtin/slack/tools/slack_webhook.py index f47557f2ef..85e0de7675 100644 --- a/api/core/tools/provider/builtin/slack/tools/slack_webhook.py +++ b/api/core/tools/provider/builtin/slack/tools/slack_webhook.py @@ -7,25 +7,27 @@ from core.tools.tool.builtin_tool import BuiltinTool class SlackWebhookTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - Incoming Webhooks - API Document: https://api.slack.com/messaging/webhooks + Incoming Webhooks + API Document: https://api.slack.com/messaging/webhooks """ - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") - webhook_url = tool_parameters.get('webhook_url', '') + webhook_url = tool_parameters.get("webhook_url", "") - if not webhook_url.startswith('https://hooks.slack.com/'): + if not webhook_url.startswith("https://hooks.slack.com/"): return self.create_text_message( - f'Invalid parameter webhook_url ${webhook_url}, not a valid Slack webhook URL') + f"Invalid parameter webhook_url ${webhook_url}, not a valid Slack webhook URL" + ) headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } params = {} payload = { @@ -38,6 +40,7 @@ class SlackWebhookTool(BuiltinTool): return self.create_text_message("Text message was sent successfully") else: return self.create_text_message( - f"Failed to send the text message, status code: {res.status_code}, response: {res.text}") + f"Failed to send the text message, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: - return self.create_text_message("Failed to send message through webhook. {}".format(e)) \ No newline at end of file + return self.create_text_message("Failed to send message through webhook. {}".format(e)) diff --git a/api/core/tools/provider/builtin/spark/spark.py b/api/core/tools/provider/builtin/spark/spark.py index cb8e69a59f..e0b1a58a3f 100644 --- a/api/core/tools/provider/builtin/spark/spark.py +++ b/api/core/tools/provider/builtin/spark/spark.py @@ -29,12 +29,8 @@ class SparkProvider(BuiltinToolProviderController): # 0 success, pass else: - raise ToolProviderCredentialValidationError( - "image generate error, code:{}".format(code) - ) + raise ToolProviderCredentialValidationError("image generate error, code:{}".format(code)) except Exception as e: - raise ToolProviderCredentialValidationError( - "APPID APISecret APIKey is invalid. {}".format(e) - ) + raise ToolProviderCredentialValidationError("APPID APISecret APIKey is invalid. {}".format(e)) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py index a977af2b76..81d9e8d941 100644 --- a/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py +++ b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py @@ -15,16 +15,16 @@ from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool -class AssembleHeaderException(Exception): +class AssembleHeaderError(Exception): def __init__(self, msg): self.message = msg class Url: - def __init__(this, host, path, schema): - this.host = host - this.path = path - this.schema = schema + def __init__(self, host, path, schema): + self.host = host + self.path = path + self.schema = schema # calculate sha256 and encode to base64 @@ -35,49 +35,46 @@ def sha256base64(data): return digest -def parse_url(requset_url): - stidx = requset_url.index("://") - host = requset_url[stidx + 3 :] - schema = requset_url[: stidx + 3] +def parse_url(request_url): + stidx = request_url.index("://") + host = request_url[stidx + 3 :] + schema = request_url[: stidx + 3] edidx = host.index("/") if edidx <= 0: - raise AssembleHeaderException("invalid request url:" + requset_url) + raise AssembleHeaderError("invalid request url:" + request_url) path = host[edidx:] host = host[:edidx] u = Url(host, path, schema) return u -def assemble_ws_auth_url(requset_url, method="GET", api_key="", api_secret=""): - u = parse_url(requset_url) + +def assemble_ws_auth_url(request_url, method="GET", api_key="", api_secret=""): + u = parse_url(request_url) host = u.host path = u.path now = datetime.now() date = format_date_time(mktime(now.timetuple())) - signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format( - host, date, method, path - ) + signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format(host, date, method, path) signature_sha = hmac.new( api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256, ).digest() signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8") - authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha}"' - - authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode( - encoding="utf-8" + authorization_origin = ( + f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha}"' ) + + authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8") values = {"host": host, "date": date, "authorization": authorization} - return requset_url + "?" + urlencode(values) + return request_url + "?" + urlencode(values) def get_body(appid, text): body = { "header": {"app_id": appid, "uid": "123456789"}, - "parameter": { - "chat": {"domain": "general", "temperature": 0.5, "max_tokens": 4096} - }, + "parameter": {"chat": {"domain": "general", "temperature": 0.5, "max_tokens": 4096}}, "payload": {"message": {"text": [{"role": "user", "content": text}]}}, } return body @@ -85,13 +82,9 @@ def get_body(appid, text): def spark_response(text, appid, apikey, apisecret): host = "http://spark-api.cn-huabei-1.xf-yun.com/v2.1/tti" - url = assemble_ws_auth_url( - host, method="POST", api_key=apikey, api_secret=apisecret - ) + url = assemble_ws_auth_url(host, method="POST", api_key=apikey, api_secret=apisecret) content = get_body(appid, text) - response = requests.post( - url, json=content, headers={"content-type": "application/json"} - ).text + response = requests.post(url, json=content, headers={"content-type": "application/json"}).text return response @@ -105,19 +98,11 @@ class SparkImgGeneratorTool(BuiltinTool): invoke tools """ - if "APPID" not in self.runtime.credentials or not self.runtime.credentials.get( - "APPID" - ): + if "APPID" not in self.runtime.credentials or not self.runtime.credentials.get("APPID"): return self.create_text_message("APPID is required.") - if ( - "APISecret" not in self.runtime.credentials - or not self.runtime.credentials.get("APISecret") - ): + if "APISecret" not in self.runtime.credentials or not self.runtime.credentials.get("APISecret"): return self.create_text_message("APISecret is required.") - if ( - "APIKey" not in self.runtime.credentials - or not self.runtime.credentials.get("APIKey") - ): + if "APIKey" not in self.runtime.credentials or not self.runtime.credentials.get("APIKey"): return self.create_text_message("APIKey is required.") prompt = tool_parameters.get("prompt", "") @@ -130,7 +115,7 @@ class SparkImgGeneratorTool(BuiltinTool): self.create_blob_message( blob=b64decode(image["base64_image"]), meta={"mime_type": "image/png"}, - save_as=self.VARIABLE_KEY.IMAGE.value, + save_as=self.VariableKey.IMAGE.value, ) ) return result diff --git a/api/core/tools/provider/builtin/spider/spider.py b/api/core/tools/provider/builtin/spider/spider.py index 5bcc56a724..5959555318 100644 --- a/api/core/tools/provider/builtin/spider/spider.py +++ b/api/core/tools/provider/builtin/spider/spider.py @@ -8,13 +8,13 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class SpiderProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - app = Spider(api_key=credentials['spider_api_key']) - app.scrape_url(url='https://spider.cloud') + app = Spider(api_key=credentials["spider_api_key"]) + app.scrape_url(url="https://spider.cloud") except AttributeError as e: # Handle cases where NoneType is not iterable, which might indicate API issues - if 'NoneType' in str(e) and 'not iterable' in str(e): - raise ToolProviderCredentialValidationError('API is currently down, try again in 15 minutes', str(e)) + if "NoneType" in str(e) and "not iterable" in str(e): + raise ToolProviderCredentialValidationError("API is currently down, try again in 15 minutes", str(e)) else: - raise ToolProviderCredentialValidationError('An unexpected error occurred.', str(e)) + raise ToolProviderCredentialValidationError("An unexpected error occurred.", str(e)) except Exception as e: - raise ToolProviderCredentialValidationError('An unexpected error occurred.', str(e)) + raise ToolProviderCredentialValidationError("An unexpected error occurred.", str(e)) diff --git a/api/core/tools/provider/builtin/spider/spiderApp.py b/api/core/tools/provider/builtin/spider/spiderApp.py index f0ed64867a..3972e560c4 100644 --- a/api/core/tools/provider/builtin/spider/spiderApp.py +++ b/api/core/tools/provider/builtin/spider/spiderApp.py @@ -65,9 +65,7 @@ class Spider: :return: The JSON response or the raw response stream if stream is True. """ headers = self._prepare_headers(content_type) - response = self._post_request( - f"https://api.spider.cloud/v1/{endpoint}", data, headers, stream - ) + response = self._post_request(f"https://api.spider.cloud/v1/{endpoint}", data, headers, stream) if stream: return response @@ -76,9 +74,7 @@ class Spider: else: self._handle_error(response, f"post to {endpoint}") - def api_get( - self, endpoint: str, stream: bool, content_type: str = "application/json" - ): + def api_get(self, endpoint: str, stream: bool, content_type: str = "application/json"): """ Send a GET request to the specified endpoint. @@ -86,9 +82,7 @@ class Spider: :return: The JSON decoded response. """ headers = self._prepare_headers(content_type) - response = self._get_request( - f"https://api.spider.cloud/v1/{endpoint}", headers, stream - ) + response = self._get_request(f"https://api.spider.cloud/v1/{endpoint}", headers, stream) if response.status_code == 200: return response.json() else: @@ -120,14 +114,12 @@ class Spider: # Add { "return_format": "markdown" } to the params if not already present if "return_format" not in params: - params["return_format"] = "markdown" + params["return_format"] = "markdown" # Set limit to 1 params["limit"] = 1 - return self.api_post( - "crawl", {"url": url, **(params or {})}, stream, content_type - ) + return self.api_post("crawl", {"url": url, **(params or {})}, stream, content_type) def crawl_url( self, @@ -150,9 +142,7 @@ class Spider: if "return_format" not in params: params["return_format"] = "markdown" - return self.api_post( - "crawl", {"url": url, **(params or {})}, stream, content_type - ) + return self.api_post("crawl", {"url": url, **(params or {})}, stream, content_type) def links( self, @@ -168,9 +158,7 @@ class Spider: :param params: Optional parameters for the link retrieval request. :return: JSON response containing the links. """ - return self.api_post( - "links", {"url": url, **(params or {})}, stream, content_type - ) + return self.api_post("links", {"url": url, **(params or {})}, stream, content_type) def extract_contacts( self, @@ -207,9 +195,7 @@ class Spider: :param params: Optional parameters to guide the labeling process. :return: JSON response with labeled data. """ - return self.api_post( - "pipeline/label", {"url": url, **(params or {})}, stream, content_type - ) + return self.api_post("pipeline/label", {"url": url, **(params or {})}, stream, content_type) def _prepare_headers(self, content_type: str = "application/json"): return { @@ -230,10 +216,6 @@ class Spider: def _handle_error(self, response, action): if response.status_code in [402, 409, 500]: error_message = response.json().get("error", "Unknown error occurred") - raise Exception( - f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}" - ) + raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") else: - raise Exception( - f"Unexpected error occurred while trying to {action}. Status code: {response.status_code}" - ) + raise Exception(f"Unexpected error occurred while trying to {action}. Status code: {response.status_code}") diff --git a/api/core/tools/provider/builtin/spider/tools/scraper_crawler.py b/api/core/tools/provider/builtin/spider/tools/scraper_crawler.py index 64bbcc10cc..20d2daef55 100644 --- a/api/core/tools/provider/builtin/spider/tools/scraper_crawler.py +++ b/api/core/tools/provider/builtin/spider/tools/scraper_crawler.py @@ -6,42 +6,44 @@ from core.tools.tool.builtin_tool import BuiltinTool class ScrapeTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: # initialize the app object with the api key - app = Spider(api_key=self.runtime.credentials['spider_api_key']) + app = Spider(api_key=self.runtime.credentials["spider_api_key"]) + + url = tool_parameters["url"] + mode = tool_parameters["mode"] - url = tool_parameters['url'] - mode = tool_parameters['mode'] - options = { - 'limit': tool_parameters.get('limit', 0), - 'depth': tool_parameters.get('depth', 0), - 'blacklist': tool_parameters.get('blacklist', '').split(',') if tool_parameters.get('blacklist') else [], - 'whitelist': tool_parameters.get('whitelist', '').split(',') if tool_parameters.get('whitelist') else [], - 'readability': tool_parameters.get('readability', False), + "limit": tool_parameters.get("limit", 0), + "depth": tool_parameters.get("depth", 0), + "blacklist": tool_parameters.get("blacklist", "").split(",") if tool_parameters.get("blacklist") else [], + "whitelist": tool_parameters.get("whitelist", "").split(",") if tool_parameters.get("whitelist") else [], + "readability": tool_parameters.get("readability", False), } result = "" try: - if mode == 'scrape': + if mode == "scrape": scrape_result = app.scrape_url( - url=url, + url=url, params=options, ) for i in scrape_result: - result += "URL: " + i.get('url', '') + "\n" - result += "CONTENT: " + i.get('content', '') + "\n\n" - elif mode == 'crawl': + result += "URL: " + i.get("url", "") + "\n" + result += "CONTENT: " + i.get("content", "") + "\n\n" + elif mode == "crawl": crawl_result = app.crawl_url( - url=tool_parameters['url'], + url=tool_parameters["url"], params=options, ) for i in crawl_result: - result += "URL: " + i.get('url', '') + "\n" - result += "CONTENT: " + i.get('content', '') + "\n\n" + result += "URL: " + i.get("url", "") + "\n" + result += "CONTENT: " + i.get("content", "") + "\n\n" except Exception as e: - return self.create_text_message("An error occured", str(e)) + return self.create_text_message("An error occurred", str(e)) return self.create_text_message(result) diff --git a/api/core/tools/provider/builtin/stability/stability.py b/api/core/tools/provider/builtin/stability/stability.py index b31d786178..f09d81ac27 100644 --- a/api/core/tools/provider/builtin/stability/stability.py +++ b/api/core/tools/provider/builtin/stability/stability.py @@ -8,6 +8,7 @@ class StabilityToolProvider(BuiltinToolProviderController, BaseStabilityAuthoriz """ This class is responsible for providing the stability tool. """ + def _validate_credentials(self, credentials: dict[str, Any]) -> None: """ This method is responsible for validating the credentials. diff --git a/api/core/tools/provider/builtin/stability/tools/base.py b/api/core/tools/provider/builtin/stability/tools/base.py index a4788fd869..c3b7edbefa 100644 --- a/api/core/tools/provider/builtin/stability/tools/base.py +++ b/api/core/tools/provider/builtin/stability/tools/base.py @@ -9,26 +9,23 @@ class BaseStabilityAuthorization: """ This method is responsible for validating the credentials. """ - api_key = credentials.get('api_key', '') + api_key = credentials.get("api_key", "") if not api_key: - raise ToolProviderCredentialValidationError('API key is required.') - + raise ToolProviderCredentialValidationError("API key is required.") + response = requests.get( - URL('https://api.stability.ai') / 'v1' / 'user' / 'account', + URL("https://api.stability.ai") / "v1" / "user" / "account", headers=self.generate_authorization_headers(credentials), - timeout=(5, 30) + timeout=(5, 30), ) if not response.ok: - raise ToolProviderCredentialValidationError('Invalid API key.') + raise ToolProviderCredentialValidationError("Invalid API key.") return True - + def generate_authorization_headers(self, credentials: dict) -> dict[str, str]: """ This method is responsible for generating the authorization headers. """ - return { - 'Authorization': f'Bearer {credentials.get("api_key", "")}' - } - \ No newline at end of file + return {"Authorization": f'Bearer {credentials.get("api_key", "")}'} diff --git a/api/core/tools/provider/builtin/stability/tools/text2image.py b/api/core/tools/provider/builtin/stability/tools/text2image.py index 41236f7b43..12b6cc3352 100644 --- a/api/core/tools/provider/builtin/stability/tools/text2image.py +++ b/api/core/tools/provider/builtin/stability/tools/text2image.py @@ -11,10 +11,11 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization): """ This class is responsible for providing the stable diffusion tool. """ + model_endpoint_map: dict[str, str] = { - 'sd3': 'https://api.stability.ai/v2beta/stable-image/generate/sd3', - 'sd3-turbo': 'https://api.stability.ai/v2beta/stable-image/generate/sd3', - 'core': 'https://api.stability.ai/v2beta/stable-image/generate/core', + "sd3": "https://api.stability.ai/v2beta/stable-image/generate/sd3", + "sd3-turbo": "https://api.stability.ai/v2beta/stable-image/generate/sd3", + "core": "https://api.stability.ai/v2beta/stable-image/generate/core", } def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: @@ -22,39 +23,34 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization): Invoke the tool. """ payload = { - 'prompt': tool_parameters.get('prompt', ''), - 'aspect_ratio': tool_parameters.get('aspect_ratio', '16:9') or tool_parameters.get('aspect_radio', '16:9'), - 'mode': 'text-to-image', - 'seed': tool_parameters.get('seed', 0), - 'output_format': 'png', + "prompt": tool_parameters.get("prompt", ""), + "aspect_ratio": tool_parameters.get("aspect_ratio", "16:9") or tool_parameters.get("aspect_radio", "16:9"), + "mode": "text-to-image", + "seed": tool_parameters.get("seed", 0), + "output_format": "png", } - model = tool_parameters.get('model', 'core') + model = tool_parameters.get("model", "core") - if model in ['sd3', 'sd3-turbo']: - payload['model'] = tool_parameters.get('model') + if model in ["sd3", "sd3-turbo"]: + payload["model"] = tool_parameters.get("model") - if not model == 'sd3-turbo': - payload['negative_prompt'] = tool_parameters.get('negative_prompt', '') + if not model == "sd3-turbo": + payload["negative_prompt"] = tool_parameters.get("negative_prompt", "") response = post( - self.model_endpoint_map[tool_parameters.get('model', 'core')], + self.model_endpoint_map[tool_parameters.get("model", "core")], headers={ - 'accept': 'image/*', + "accept": "image/*", **self.generate_authorization_headers(self.runtime.credentials), }, - files={ - key: (None, str(value)) for key, value in payload.items() - }, - timeout=(5, 30) + files={key: (None, str(value)) for key, value in payload.items()}, + timeout=(5, 30), ) if not response.status_code == 200: raise Exception(response.text) - + return self.create_blob_message( - blob=response.content, meta={ - 'mime_type': 'image/png' - }, - save_as=self.VARIABLE_KEY.IMAGE.value + blob=response.content, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value ) diff --git a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py index 317d705f7c..abaa297cf3 100644 --- a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py @@ -15,4 +15,3 @@ class StableDiffusionProvider(BuiltinToolProviderController): ).validate_models() except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py index 4be9207d66..46137886bd 100644 --- a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py @@ -18,19 +18,17 @@ DRAW_TEXT_OPTIONS = { # Prompts "prompt": "", "negative_prompt": "", - # "styles": [], - # Seeds + # "styles": [], + # Seeds "seed": -1, "subseed": -1, "subseed_strength": 0, "seed_resize_from_h": -1, "seed_resize_from_w": -1, - # Samplers "sampler_name": "DPM++ 2M", # "scheduler": "", # "sampler_index": "Automatic", - # Latent Space Options "batch_size": 1, "n_iter": 1, @@ -42,9 +40,9 @@ DRAW_TEXT_OPTIONS = { # "tiling": True, "do_not_save_samples": False, "do_not_save_grid": False, - # "eta": 0, - # "denoising_strength": 0.75, - # "s_min_uncond": 0, + # "eta": 0, + # "denoising_strength": 0.75, + # "s_min_uncond": 0, # "s_churn": 0, # "s_tmax": 0, # "s_tmin": 0, @@ -73,7 +71,6 @@ DRAW_TEXT_OPTIONS = { "hr_negative_prompt": "", # Task Options # "force_task_id": "", - # Script Options # "script_name": "", "script_args": [], @@ -82,131 +79,130 @@ DRAW_TEXT_OPTIONS = { "save_images": False, "alwayson_scripts": {}, # "infotext": "", - } class StableDiffusionTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # base url - base_url = self.runtime.credentials.get('base_url', None) + base_url = self.runtime.credentials.get("base_url", None) if not base_url: - return self.create_text_message('Please input base_url') + return self.create_text_message("Please input base_url") - if tool_parameters.get('model'): - self.runtime.credentials['model'] = tool_parameters['model'] + if tool_parameters.get("model"): + self.runtime.credentials["model"] = tool_parameters["model"] - model = self.runtime.credentials.get('model', None) + model = self.runtime.credentials.get("model", None) if not model: - return self.create_text_message('Please input model') - + return self.create_text_message("Please input model") + # set model try: - url = str(URL(base_url) / 'sdapi' / 'v1' / 'options') - response = post(url, data=json.dumps({ - 'sd_model_checkpoint': model - })) + url = str(URL(base_url) / "sdapi" / "v1" / "options") + response = post(url, data=json.dumps({"sd_model_checkpoint": model})) if response.status_code != 200: - raise ToolProviderCredentialValidationError('Failed to set model, please tell user to set model') + raise ToolProviderCredentialValidationError("Failed to set model, please tell user to set model") except Exception as e: - raise ToolProviderCredentialValidationError('Failed to set model, please tell user to set model') + raise ToolProviderCredentialValidationError("Failed to set model, please tell user to set model") # get image id and image variable - image_id = tool_parameters.get('image_id', '') + image_id = tool_parameters.get("image_id", "") image_variable = self.get_default_image_variable() # Return text2img if there's no image ID or no image variable if not image_id or not image_variable: - return self.text2img(base_url=base_url,tool_parameters=tool_parameters) + return self.text2img(base_url=base_url, tool_parameters=tool_parameters) # Proceed with image-to-image generation - return self.img2img(base_url=base_url,tool_parameters=tool_parameters) + return self.img2img(base_url=base_url, tool_parameters=tool_parameters) def validate_models(self) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - validate models + validate models """ try: - base_url = self.runtime.credentials.get('base_url', None) + base_url = self.runtime.credentials.get("base_url", None) if not base_url: - raise ToolProviderCredentialValidationError('Please input base_url') - model = self.runtime.credentials.get('model', None) + raise ToolProviderCredentialValidationError("Please input base_url") + model = self.runtime.credentials.get("model", None) if not model: - raise ToolProviderCredentialValidationError('Please input model') + raise ToolProviderCredentialValidationError("Please input model") - api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models') + api_url = str(URL(base_url) / "sdapi" / "v1" / "sd-models") response = get(url=api_url, timeout=10) if response.status_code == 404: # try draw a picture self._invoke( - user_id='test', + user_id="test", tool_parameters={ - 'prompt': 'a cat', - 'width': 1024, - 'height': 1024, - 'steps': 1, - 'lora': '', - } + "prompt": "a cat", + "width": 1024, + "height": 1024, + "steps": 1, + "lora": "", + }, ) elif response.status_code != 200: - raise ToolProviderCredentialValidationError('Failed to get models') + raise ToolProviderCredentialValidationError("Failed to get models") else: - models = [d['model_name'] for d in response.json()] + models = [d["model_name"] for d in response.json()] if len([d for d in models if d == model]) > 0: return self.create_text_message(json.dumps(models)) else: - raise ToolProviderCredentialValidationError(f'model {model} does not exist') + raise ToolProviderCredentialValidationError(f"model {model} does not exist") except Exception as e: - raise ToolProviderCredentialValidationError(f'Failed to get models, {e}') + raise ToolProviderCredentialValidationError(f"Failed to get models, {e}") def get_sd_models(self) -> list[str]: """ - get sd models + get sd models """ try: - base_url = self.runtime.credentials.get('base_url', None) + base_url = self.runtime.credentials.get("base_url", None) if not base_url: return [] - api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models') + api_url = str(URL(base_url) / "sdapi" / "v1" / "sd-models") response = get(url=api_url, timeout=(2, 10)) if response.status_code != 200: return [] else: - return [d['model_name'] for d in response.json()] - except Exception as e: - return [] - - def get_sample_methods(self) -> list[str]: - """ - get sample method - """ - try: - base_url = self.runtime.credentials.get('base_url', None) - if not base_url: - return [] - api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'samplers') - response = get(url=api_url, timeout=(2, 10)) - if response.status_code != 200: - return [] - else: - return [d['name'] for d in response.json()] + return [d["model_name"] for d in response.json()] except Exception as e: return [] - def img2img(self, base_url: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def get_sample_methods(self) -> list[str]: """ - generate image + get sample method + """ + try: + base_url = self.runtime.credentials.get("base_url", None) + if not base_url: + return [] + api_url = str(URL(base_url) / "sdapi" / "v1" / "samplers") + response = get(url=api_url, timeout=(2, 10)) + if response.status_code != 200: + return [] + else: + return [d["name"] for d in response.json()] + except Exception as e: + return [] + + def img2img( + self, base_url: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + generate image """ # Fetch the binary data of the image image_variable = self.get_default_image_variable() image_binary = self.get_variable_file(image_variable.name) if not image_binary: - return self.create_text_message('Image not found, please request user to generate image firstly.') + return self.create_text_message("Image not found, please request user to generate image firstly.") # Convert image to RGB and save as PNG try: @@ -220,14 +216,14 @@ class StableDiffusionTool(BuiltinTool): # copy draw options draw_options = deepcopy(DRAW_TEXT_OPTIONS) # set image options - model = tool_parameters.get('model', '') + model = tool_parameters.get("model", "") draw_options_image = { - "init_images": [b64encode(image_binary).decode('utf-8')], + "init_images": [b64encode(image_binary).decode("utf-8")], "denoising_strength": 0.9, "restore_faces": False, "script_args": [], "override_settings": {"sd_model_checkpoint": model}, - "resize_mode":0, + "resize_mode": 0, "image_cfg_scale": 0, # "mask": None, "mask_blur_x": 4, @@ -247,136 +243,142 @@ class StableDiffusionTool(BuiltinTool): draw_options.update(tool_parameters) # get prompt lora model - prompt = tool_parameters.get('prompt', '') - lora = tool_parameters.get('lora', '') - model = tool_parameters.get('model', '') + prompt = tool_parameters.get("prompt", "") + lora = tool_parameters.get("lora", "") + model = tool_parameters.get("model", "") if lora: - draw_options['prompt'] = f'{lora},{prompt}' + draw_options["prompt"] = f"{lora},{prompt}" else: - draw_options['prompt'] = prompt + draw_options["prompt"] = prompt try: - url = str(URL(base_url) / 'sdapi' / 'v1' / 'img2img') + url = str(URL(base_url) / "sdapi" / "v1" / "img2img") response = post(url, data=json.dumps(draw_options), timeout=120) if response.status_code != 200: - return self.create_text_message('Failed to generate image') - - image = response.json()['images'][0] + return self.create_text_message("Failed to generate image") + + image = response.json()["images"][0] + + return self.create_blob_message( + blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value + ) - return self.create_blob_message(blob=b64decode(image), - meta={ 'mime_type': 'image/png' }, - save_as=self.VARIABLE_KEY.IMAGE.value) - except Exception as e: - return self.create_text_message('Failed to generate image') + return self.create_text_message("Failed to generate image") - def text2img(self, base_url: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def text2img( + self, base_url: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - generate image + generate image """ # copy draw options draw_options = deepcopy(DRAW_TEXT_OPTIONS) draw_options.update(tool_parameters) # get prompt lora model - prompt = tool_parameters.get('prompt', '') - lora = tool_parameters.get('lora', '') - model = tool_parameters.get('model', '') + prompt = tool_parameters.get("prompt", "") + lora = tool_parameters.get("lora", "") + model = tool_parameters.get("model", "") if lora: - draw_options['prompt'] = f'{lora},{prompt}' + draw_options["prompt"] = f"{lora},{prompt}" else: - draw_options['prompt'] = prompt - draw_options['override_settings']['sd_model_checkpoint'] = model + draw_options["prompt"] = prompt + draw_options["override_settings"]["sd_model_checkpoint"] = model - try: - url = str(URL(base_url) / 'sdapi' / 'v1' / 'txt2img') + url = str(URL(base_url) / "sdapi" / "v1" / "txt2img") response = post(url, data=json.dumps(draw_options), timeout=120) if response.status_code != 200: - return self.create_text_message('Failed to generate image') - - image = response.json()['images'][0] + return self.create_text_message("Failed to generate image") + + image = response.json()["images"][0] + + return self.create_blob_message( + blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value + ) - return self.create_blob_message(blob=b64decode(image), - meta={ 'mime_type': 'image/png' }, - save_as=self.VARIABLE_KEY.IMAGE.value) - except Exception as e: - return self.create_text_message('Failed to generate image') + return self.create_text_message("Failed to generate image") def get_runtime_parameters(self) -> list[ToolParameter]: parameters = [ - ToolParameter(name='prompt', - label=I18nObject(en_US='Prompt', zh_Hans='Prompt'), - human_description=I18nObject( - en_US='Image prompt, you can check the official documentation of Stable Diffusion', - zh_Hans='图像提示词,您可以查看 Stable Diffusion 的官方文档', - ), - type=ToolParameter.ToolParameterType.STRING, - form=ToolParameter.ToolParameterForm.LLM, - llm_description='Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English.', - required=True), + ToolParameter( + name="prompt", + label=I18nObject(en_US="Prompt", zh_Hans="Prompt"), + human_description=I18nObject( + en_US="Image prompt, you can check the official documentation of Stable Diffusion", + zh_Hans="图像提示词,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + llm_description="Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English.", + required=True, + ), ] if len(self.list_default_image_variables()) != 0: parameters.append( - ToolParameter(name='image_id', - label=I18nObject(en_US='image_id', zh_Hans='image_id'), - human_description=I18nObject( - en_US='Image id of the image you want to generate based on, if you want to generate image based on the default image, you can leave this field empty.', - zh_Hans='您想要生成的图像的图像 ID,如果您想要基于默认图像生成图像,则可以将此字段留空。', - ), - type=ToolParameter.ToolParameterType.STRING, - form=ToolParameter.ToolParameterForm.LLM, - llm_description='Image id of the original image, you can leave this field empty if you want to generate a new image.', - required=True, - options=[ToolParameterOption( - value=i.name, - label=I18nObject(en_US=i.name, zh_Hans=i.name) - ) for i in self.list_default_image_variables()]) + ToolParameter( + name="image_id", + label=I18nObject(en_US="image_id", zh_Hans="image_id"), + human_description=I18nObject( + en_US="Image id of the image you want to generate based on, if you want to generate image based on the default image, you can leave this field empty.", + zh_Hans="您想要生成的图像的图像 ID,如果您想要基于默认图像生成图像,则可以将此字段留空。", + ), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + llm_description="Image id of the original image, you can leave this field empty if you want to generate a new image.", + required=True, + options=[ + ToolParameterOption(value=i.name, label=I18nObject(en_US=i.name, zh_Hans=i.name)) + for i in self.list_default_image_variables() + ], + ) ) - + if self.runtime.credentials: try: models = self.get_sd_models() if len(models) != 0: parameters.append( - ToolParameter(name='model', - label=I18nObject(en_US='Model', zh_Hans='Model'), - human_description=I18nObject( - en_US='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion', - zh_Hans='Stable Diffusion 的模型,您可以查看 Stable Diffusion 的官方文档', - ), - type=ToolParameter.ToolParameterType.SELECT, - form=ToolParameter.ToolParameterForm.FORM, - llm_description='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion', - required=True, - default=models[0], - options=[ToolParameterOption( - value=i, - label=I18nObject(en_US=i, zh_Hans=i) - ) for i in models]) + ToolParameter( + name="model", + label=I18nObject(en_US="Model", zh_Hans="Model"), + human_description=I18nObject( + en_US="Model of Stable Diffusion, you can check the official documentation of Stable Diffusion", + zh_Hans="Stable Diffusion 的模型,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description="Model of Stable Diffusion, you can check the official documentation of Stable Diffusion", + required=True, + default=models[0], + options=[ + ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in models + ], + ) ) - + except: pass - + sample_methods = self.get_sample_methods() if len(sample_methods) != 0: parameters.append( - ToolParameter(name='sampler_name', - label=I18nObject(en_US='Sampling method', zh_Hans='Sampling method'), - human_description=I18nObject( - en_US='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion', - zh_Hans='Stable Diffusion 的Sampling method,您可以查看 Stable Diffusion 的官方文档', - ), - type=ToolParameter.ToolParameterType.SELECT, - form=ToolParameter.ToolParameterForm.FORM, - llm_description='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion', - required=True, - default=sample_methods[0], - options=[ToolParameterOption( - value=i, - label=I18nObject(en_US=i, zh_Hans=i) - ) for i in sample_methods]) + ToolParameter( + name="sampler_name", + label=I18nObject(en_US="Sampling method", zh_Hans="Sampling method"), + human_description=I18nObject( + en_US="Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion", + zh_Hans="Stable Diffusion 的Sampling method,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description="Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion", + required=True, + default=sample_methods[0], + options=[ + ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in sample_methods + ], ) + ) return parameters diff --git a/api/core/tools/provider/builtin/stackexchange/stackexchange.py b/api/core/tools/provider/builtin/stackexchange/stackexchange.py index de64c84997..9680c633cc 100644 --- a/api/core/tools/provider/builtin/stackexchange/stackexchange.py +++ b/api/core/tools/provider/builtin/stackexchange/stackexchange.py @@ -11,16 +11,15 @@ class StackExchangeProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "intitle": "Test", - "sort": "relevance", + "sort": "relevance", "order": "desc", "site": "stackoverflow", "accepted": True, - "pagesize": 1 + "pagesize": 1, }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stackexchange/tools/fetchAnsByStackExQuesID.py b/api/core/tools/provider/builtin/stackexchange/tools/fetchAnsByStackExQuesID.py index f8e1710844..5345320095 100644 --- a/api/core/tools/provider/builtin/stackexchange/tools/fetchAnsByStackExQuesID.py +++ b/api/core/tools/provider/builtin/stackexchange/tools/fetchAnsByStackExQuesID.py @@ -17,7 +17,9 @@ class FetchAnsByStackExQuesIDInput(BaseModel): class FetchAnsByStackExQuesIDTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: input = FetchAnsByStackExQuesIDInput(**tool_parameters) params = { @@ -26,7 +28,7 @@ class FetchAnsByStackExQuesIDTool(BuiltinTool): "order": input.order, "sort": input.sort, "pagesize": input.pagesize, - "page": input.page + "page": input.page, } response = requests.get(f"https://api.stackexchange.com/2.3/questions/{input.id}/answers", params=params) @@ -34,4 +36,4 @@ class FetchAnsByStackExQuesIDTool(BuiltinTool): if response.status_code == 200: return self.create_text_message(self.summary(user_id=user_id, content=response.text)) else: - return self.create_text_message(f"API request failed with status code {response.status_code}") \ No newline at end of file + return self.create_text_message(f"API request failed with status code {response.status_code}") diff --git a/api/core/tools/provider/builtin/stackexchange/tools/searchStackExQuestions.py b/api/core/tools/provider/builtin/stackexchange/tools/searchStackExQuestions.py index 8436433c32..4a25a808ad 100644 --- a/api/core/tools/provider/builtin/stackexchange/tools/searchStackExQuestions.py +++ b/api/core/tools/provider/builtin/stackexchange/tools/searchStackExQuestions.py @@ -9,26 +9,28 @@ from core.tools.tool.builtin_tool import BuiltinTool class SearchStackExQuestionsInput(BaseModel): intitle: str = Field(..., description="The search query.") - sort: str = Field(..., description="The sort order - relevance, activity, votes, creation.") + sort: str = Field(..., description="The sort order - relevance, activity, votes, creation.") order: str = Field(..., description="asc or desc") site: str = Field(..., description="The Stack Exchange site.") tagged: str = Field(None, description="Semicolon-separated tags to include.") nottagged: str = Field(None, description="Semicolon-separated tags to exclude.") - accepted: bool = Field(..., description="true for only accepted answers, false otherwise") + accepted: bool = Field(..., description="true for only accepted answers, false otherwise") pagesize: int = Field(..., description="Number of results per page") class SearchStackExQuestionsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: input = SearchStackExQuestionsInput(**tool_parameters) params = { "intitle": input.intitle, "sort": input.sort, - "order": input.order, + "order": input.order, "site": input.site, "accepted": input.accepted, - "pagesize": input.pagesize + "pagesize": input.pagesize, } if input.tagged: params["tagged"] = input.tagged @@ -40,4 +42,4 @@ class SearchStackExQuestionsTool(BuiltinTool): if response.status_code == 200: return self.create_text_message(self.summary(user_id=user_id, content=response.text)) else: - return self.create_text_message(f"API request failed with status code {response.status_code}") \ No newline at end of file + return self.create_text_message(f"API request failed with status code {response.status_code}") diff --git a/api/core/tools/provider/builtin/stepfun/stepfun.py b/api/core/tools/provider/builtin/stepfun/stepfun.py index e809b04546..b24f730c95 100644 --- a/api/core/tools/provider/builtin/stepfun/stepfun.py +++ b/api/core/tools/provider/builtin/stepfun/stepfun.py @@ -13,13 +13,12 @@ class StepfunProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "prompt": "cute girl, blue eyes, white hair, anime style", "size": "1024x1024", - "n": 1 + "n": 1, }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stepfun/tools/image.py b/api/core/tools/provider/builtin/stepfun/tools/image.py index 5e544aada6..0b92b122bf 100644 --- a/api/core/tools/provider/builtin/stepfun/tools/image.py +++ b/api/core/tools/provider/builtin/stepfun/tools/image.py @@ -9,64 +9,67 @@ from core.tools.tool.builtin_tool import BuiltinTool class StepfunTool(BuiltinTool): - """ Stepfun Image Generation Tool """ - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """Stepfun Image Generation Tool""" + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - base_url = self.runtime.credentials.get('stepfun_base_url', None) - if not base_url: - base_url = None - else: - base_url = str(URL(base_url) / 'v1') + base_url = self.runtime.credentials.get("stepfun_base_url", "https://api.stepfun.com") + base_url = str(URL(base_url) / "v1") client = OpenAI( - api_key=self.runtime.credentials['stepfun_api_key'], + api_key=self.runtime.credentials["stepfun_api_key"], base_url=base_url, ) extra_body = {} - model = tool_parameters.get('model', 'step-1x-medium') + model = tool_parameters.get("model", "step-1x-medium") if not model: - return self.create_text_message('Please input model name') + return self.create_text_message("Please input model name") # prompt - prompt = tool_parameters.get('prompt', '') + prompt = tool_parameters.get("prompt", "") if not prompt: - return self.create_text_message('Please input prompt') + return self.create_text_message("Please input prompt") - seed = tool_parameters.get('seed', 0) + seed = tool_parameters.get("seed", 0) if seed > 0: - extra_body['seed'] = seed - steps = tool_parameters.get('steps', 0) + extra_body["seed"] = seed + steps = tool_parameters.get("steps", 0) if steps > 0: - extra_body['steps'] = steps - negative_prompt = tool_parameters.get('negative_prompt', '') + extra_body["steps"] = steps + negative_prompt = tool_parameters.get("negative_prompt", "") if negative_prompt: - extra_body['negative_prompt'] = negative_prompt + extra_body["negative_prompt"] = negative_prompt # call openapi stepfun model response = client.images.generate( prompt=prompt, model=model, - size=tool_parameters.get('size', '1024x1024'), - n=tool_parameters.get('n', 1), - extra_body= extra_body + size=tool_parameters.get("size", "1024x1024"), + n=tool_parameters.get("n", 1), + extra_body=extra_body, ) print(response) result = [] for image in response.data: result.append(self.create_image_message(image=image.url)) - result.append(self.create_json_message({ - "url": image.url, - })) + result.append( + self.create_json_message( + { + "url": image.url, + } + ) + ) return result @staticmethod def _generate_random_id(length=8): - characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' - random_id = ''.join(random.choices(characters, k=length)) + characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + random_id = "".join(random.choices(characters, k=length)) return random_id diff --git a/api/core/tools/provider/builtin/tavily/tavily.py b/api/core/tools/provider/builtin/tavily/tavily.py index e376d99d6b..a702b0a74e 100644 --- a/api/core/tools/provider/builtin/tavily/tavily.py +++ b/api/core/tools/provider/builtin/tavily/tavily.py @@ -13,7 +13,7 @@ class TavilyProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "Sachin Tendulkar", "search_depth": "basic", @@ -22,9 +22,8 @@ class TavilyProvider(BuiltinToolProviderController): "include_raw_content": False, "max_results": 5, "include_domains": "", - "exclude_domains": "" + "exclude_domains": "", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/tavily/tools/tavily_search.py b/api/core/tools/provider/builtin/tavily/tools/tavily_search.py index 0200df3c8a..ca6d8633e4 100644 --- a/api/core/tools/provider/builtin/tavily/tools/tavily_search.py +++ b/api/core/tools/provider/builtin/tavily/tools/tavily_search.py @@ -36,15 +36,23 @@ class TavilySearch: """ params["api_key"] = self.api_key - if 'exclude_domains' in params and isinstance(params['exclude_domains'], str) and params['exclude_domains'] != 'None': - params['exclude_domains'] = params['exclude_domains'].split() + if ( + "exclude_domains" in params + and isinstance(params["exclude_domains"], str) + and params["exclude_domains"] != "None" + ): + params["exclude_domains"] = params["exclude_domains"].split() else: - params['exclude_domains'] = [] - if 'include_domains' in params and isinstance(params['include_domains'], str) and params['include_domains'] != 'None': - params['include_domains'] = params['include_domains'].split() + params["exclude_domains"] = [] + if ( + "include_domains" in params + and isinstance(params["include_domains"], str) + and params["include_domains"] != "None" + ): + params["include_domains"] = params["include_domains"].split() else: - params['include_domains'] = [] - + params["include_domains"] = [] + response = requests.post(f"{TAVILY_API_URL}/search", json=params) response.raise_for_status() return response.json() @@ -91,9 +99,7 @@ class TavilySearchTool(BuiltinTool): A tool for searching Tavily using a given query. """ - def _invoke( - self, user_id: str, tool_parameters: dict[str, Any] - ) -> ToolInvokeMessage | list[ToolInvokeMessage]: + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: """ Invokes the Tavily search tool with the given user ID and tool parameters. @@ -115,4 +121,4 @@ class TavilySearchTool(BuiltinTool): if not results: return self.create_text_message(f"No results found for '{query}' in Tavily") else: - return self.create_text_message(text=results) \ No newline at end of file + return self.create_text_message(text=results) diff --git a/api/core/tools/provider/builtin/tianditu/tianditu.py b/api/core/tools/provider/builtin/tianditu/tianditu.py index 1f96be06b0..cb7d7bd8bb 100644 --- a/api/core/tools/provider/builtin/tianditu/tianditu.py +++ b/api/core/tools/provider/builtin/tianditu/tianditu.py @@ -12,10 +12,12 @@ class TiandituProvider(BuiltinToolProviderController): runtime={ "credentials": credentials, } - ).invoke(user_id='', - tool_parameters={ - 'content': '北京', - 'specify': '156110000', - }) + ).invoke( + user_id="", + tool_parameters={ + "content": "北京", + "specify": "156110000", + }, + ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/tianditu/tools/geocoder.py b/api/core/tools/provider/builtin/tianditu/tools/geocoder.py index 484a3768c8..690a0aed6f 100644 --- a/api/core/tools/provider/builtin/tianditu/tools/geocoder.py +++ b/api/core/tools/provider/builtin/tianditu/tools/geocoder.py @@ -8,26 +8,26 @@ from core.tools.tool.builtin_tool import BuiltinTool class GeocoderTool(BuiltinTool): - - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - base_url = 'http://api.tianditu.gov.cn/geocoder' - - keyword = tool_parameters.get('keyword', '') + base_url = "http://api.tianditu.gov.cn/geocoder" + + keyword = tool_parameters.get("keyword", "") if not keyword: - return self.create_text_message('Invalid parameter keyword') - - tk = self.runtime.credentials['tianditu_api_key'] - + return self.create_text_message("Invalid parameter keyword") + + tk = self.runtime.credentials["tianditu_api_key"] + params = { - 'keyWord': keyword, + "keyWord": keyword, } - - result = requests.get(base_url + '?ds=' + json.dumps(params, ensure_ascii=False) + '&tk=' + tk).json() + + result = requests.get(base_url + "?ds=" + json.dumps(params, ensure_ascii=False) + "&tk=" + tk).json() return self.create_json_message(result) diff --git a/api/core/tools/provider/builtin/tianditu/tools/poisearch.py b/api/core/tools/provider/builtin/tianditu/tools/poisearch.py index 08a5b8ef42..798dd94d33 100644 --- a/api/core/tools/provider/builtin/tianditu/tools/poisearch.py +++ b/api/core/tools/provider/builtin/tianditu/tools/poisearch.py @@ -8,38 +8,51 @@ from core.tools.tool.builtin_tool import BuiltinTool class PoiSearchTool(BuiltinTool): - - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - geocoder_base_url = 'http://api.tianditu.gov.cn/geocoder' - base_url = 'http://api.tianditu.gov.cn/v2/search' - - keyword = tool_parameters.get('keyword', '') + geocoder_base_url = "http://api.tianditu.gov.cn/geocoder" + base_url = "http://api.tianditu.gov.cn/v2/search" + + keyword = tool_parameters.get("keyword", "") if not keyword: - return self.create_text_message('Invalid parameter keyword') - - baseAddress = tool_parameters.get('baseAddress', '') + return self.create_text_message("Invalid parameter keyword") + + baseAddress = tool_parameters.get("baseAddress", "") if not baseAddress: - return self.create_text_message('Invalid parameter baseAddress') - - tk = self.runtime.credentials['tianditu_api_key'] - - base_coords = requests.get(geocoder_base_url + '?ds=' + json.dumps({'keyWord': baseAddress,}, ensure_ascii=False) + '&tk=' + tk).json() - + return self.create_text_message("Invalid parameter baseAddress") + + tk = self.runtime.credentials["tianditu_api_key"] + + base_coords = requests.get( + geocoder_base_url + + "?ds=" + + json.dumps( + { + "keyWord": baseAddress, + }, + ensure_ascii=False, + ) + + "&tk=" + + tk + ).json() + params = { - 'keyWord': keyword, - 'queryRadius': 5000, - 'queryType': 3, - 'pointLonlat': base_coords['location']['lon'] + ',' + base_coords['location']['lat'], - 'start': 0, - 'count': 100, + "keyWord": keyword, + "queryRadius": 5000, + "queryType": 3, + "pointLonlat": base_coords["location"]["lon"] + "," + base_coords["location"]["lat"], + "start": 0, + "count": 100, } - - result = requests.get(base_url + '?postStr=' + json.dumps(params, ensure_ascii=False) + '&type=query&tk=' + tk).json() + + result = requests.get( + base_url + "?postStr=" + json.dumps(params, ensure_ascii=False) + "&type=query&tk=" + tk + ).json() return self.create_json_message(result) diff --git a/api/core/tools/provider/builtin/tianditu/tools/staticmap.py b/api/core/tools/provider/builtin/tianditu/tools/staticmap.py index ecac4404ca..aeaef08805 100644 --- a/api/core/tools/provider/builtin/tianditu/tools/staticmap.py +++ b/api/core/tools/provider/builtin/tianditu/tools/staticmap.py @@ -8,29 +8,42 @@ from core.tools.tool.builtin_tool import BuiltinTool class PoiSearchTool(BuiltinTool): - - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - - geocoder_base_url = 'http://api.tianditu.gov.cn/geocoder' - base_url = 'http://api.tianditu.gov.cn/staticimage' - - keyword = tool_parameters.get('keyword', '') - if not keyword: - return self.create_text_message('Invalid parameter keyword') - - tk = self.runtime.credentials['tianditu_api_key'] - - keyword_coords = requests.get(geocoder_base_url + '?ds=' + json.dumps({'keyWord': keyword,}, ensure_ascii=False) + '&tk=' + tk).json() - coords = keyword_coords['location']['lon'] + ',' + keyword_coords['location']['lat'] - - result = requests.get(base_url + '?center=' + coords + '&markers=' + coords + '&width=400&height=300&zoom=14&tk=' + tk).content - return self.create_blob_message(blob=result, - meta={'mime_type': 'image/png'}, - save_as=self.VARIABLE_KEY.IMAGE.value) + geocoder_base_url = "http://api.tianditu.gov.cn/geocoder" + base_url = "http://api.tianditu.gov.cn/staticimage" + + keyword = tool_parameters.get("keyword", "") + if not keyword: + return self.create_text_message("Invalid parameter keyword") + + tk = self.runtime.credentials["tianditu_api_key"] + + keyword_coords = requests.get( + geocoder_base_url + + "?ds=" + + json.dumps( + { + "keyWord": keyword, + }, + ensure_ascii=False, + ) + + "&tk=" + + tk + ).json() + coords = keyword_coords["location"]["lon"] + "," + keyword_coords["location"]["lat"] + + result = requests.get( + base_url + "?center=" + coords + "&markers=" + coords + "&width=400&height=300&zoom=14&tk=" + tk + ).content + + return self.create_blob_message( + blob=result, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value + ) diff --git a/api/core/tools/provider/builtin/time/time.py b/api/core/tools/provider/builtin/time/time.py index 833ae194ef..e4df8d616c 100644 --- a/api/core/tools/provider/builtin/time/time.py +++ b/api/core/tools/provider/builtin/time/time.py @@ -9,9 +9,8 @@ class WikiPediaProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: CurrentTimeTool().invoke( - user_id='', + user_id="", tool_parameters={}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/time/tools/current_time.py b/api/core/tools/provider/builtin/time/tools/current_time.py index 90c01665e6..cc38739c16 100644 --- a/api/core/tools/provider/builtin/time/tools/current_time.py +++ b/api/core/tools/provider/builtin/time/tools/current_time.py @@ -8,21 +8,22 @@ from core.tools.tool.builtin_tool import BuiltinTool class CurrentTimeTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get timezone - tz = tool_parameters.get('timezone', 'UTC') - fm = tool_parameters.get('format') or '%Y-%m-%d %H:%M:%S %Z' - if tz == 'UTC': - return self.create_text_message(f'{datetime.now(timezone.utc).strftime(fm)}') - + tz = tool_parameters.get("timezone", "UTC") + fm = tool_parameters.get("format") or "%Y-%m-%d %H:%M:%S %Z" + if tz == "UTC": + return self.create_text_message(f"{datetime.now(timezone.utc).strftime(fm)}") + try: tz = pytz_timezone(tz) except: - return self.create_text_message(f'Invalid timezone: {tz}') - return self.create_text_message(f'{datetime.now(tz).strftime(fm)}') \ No newline at end of file + return self.create_text_message(f"Invalid timezone: {tz}") + return self.create_text_message(f"{datetime.now(tz).strftime(fm)}") diff --git a/api/core/tools/provider/builtin/time/tools/weekday.py b/api/core/tools/provider/builtin/time/tools/weekday.py index 4461cb5a32..b327e54e17 100644 --- a/api/core/tools/provider/builtin/time/tools/weekday.py +++ b/api/core/tools/provider/builtin/time/tools/weekday.py @@ -7,25 +7,26 @@ from core.tools.tool.builtin_tool import BuiltinTool class WeekdayTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - Calculate the day of the week for a given date + Calculate the day of the week for a given date """ - year = tool_parameters.get('year') - month = tool_parameters.get('month') - day = tool_parameters.get('day') + year = tool_parameters.get("year") + month = tool_parameters.get("month") + day = tool_parameters.get("day") date_obj = self.convert_datetime(year, month, day) if not date_obj: - return self.create_text_message(f'Invalid date: Year {year}, Month {month}, Day {day}.') + return self.create_text_message(f"Invalid date: Year {year}, Month {month}, Day {day}.") weekday_name = calendar.day_name[date_obj.weekday()] month_name = calendar.month_name[month] readable_date = f"{month_name} {date_obj.day}, {date_obj.year}" - return self.create_text_message(f'{readable_date} is {weekday_name}.') + return self.create_text_message(f"{readable_date} is {weekday_name}.") @staticmethod def convert_datetime(year, month, day) -> datetime | None: diff --git a/api/core/tools/provider/builtin/trello/tools/create_board.py b/api/core/tools/provider/builtin/trello/tools/create_board.py index 2655602afa..5a61d22157 100644 --- a/api/core/tools/provider/builtin/trello/tools/create_board.py +++ b/api/core/tools/provider/builtin/trello/tools/create_board.py @@ -22,19 +22,15 @@ class CreateBoardTool(BuiltinTool): Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_name = tool_parameters.get('name') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_name = tool_parameters.get("name") if not (api_key and token and board_name): return self.create_text_message("Missing required parameters: API key, token, or board name.") url = "https://api.trello.com/1/boards/" - query_params = { - 'name': board_name, - 'key': api_key, - 'token': token - } + query_params = {"name": board_name, "key": api_key, "token": token} try: response = requests.post(url, params=query_params) @@ -43,5 +39,6 @@ class CreateBoardTool(BuiltinTool): return self.create_text_message("Failed to create board") board = response.json() - return self.create_text_message(text=f"Board created successfully! Board name: {board['name']}, ID: {board['id']}") - + return self.create_text_message( + text=f"Board created successfully! Board name: {board['name']}, ID: {board['id']}" + ) diff --git a/api/core/tools/provider/builtin/trello/tools/create_list_on_board.py b/api/core/tools/provider/builtin/trello/tools/create_list_on_board.py index f5b156cb44..26f12864c3 100644 --- a/api/core/tools/provider/builtin/trello/tools/create_list_on_board.py +++ b/api/core/tools/provider/builtin/trello/tools/create_list_on_board.py @@ -22,20 +22,16 @@ class CreateListOnBoardTool(BuiltinTool): Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('id') - list_name = tool_parameters.get('name') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("id") + list_name = tool_parameters.get("name") if not (api_key and token and board_id and list_name): return self.create_text_message("Missing required parameters: API key, token, board ID, or list name.") url = f"https://api.trello.com/1/boards/{board_id}/lists" - params = { - 'name': list_name, - 'key': api_key, - 'token': token - } + params = {"name": list_name, "key": api_key, "token": token} try: response = requests.post(url, params=params) @@ -44,5 +40,6 @@ class CreateListOnBoardTool(BuiltinTool): return self.create_text_message("Failed to create list") new_list = response.json() - return self.create_text_message(text=f"List '{new_list['name']}' created successfully with Id {new_list['id']} on board {board_id}.") - + return self.create_text_message( + text=f"List '{new_list['name']}' created successfully with Id {new_list['id']} on board {board_id}." + ) diff --git a/api/core/tools/provider/builtin/trello/tools/create_new_card_on_board.py b/api/core/tools/provider/builtin/trello/tools/create_new_card_on_board.py index 74b73b40e5..dfc013a6b8 100644 --- a/api/core/tools/provider/builtin/trello/tools/create_new_card_on_board.py +++ b/api/core/tools/provider/builtin/trello/tools/create_new_card_on_board.py @@ -22,15 +22,15 @@ class CreateNewCardOnBoardTool(BuiltinTool): Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") # Ensure required parameters are present - if 'name' not in tool_parameters or 'idList' not in tool_parameters: + if "name" not in tool_parameters or "idList" not in tool_parameters: return self.create_text_message("Missing required parameters: name or idList.") url = "https://api.trello.com/1/cards" - params = {**tool_parameters, 'key': api_key, 'token': token} + params = {**tool_parameters, "key": api_key, "token": token} try: response = requests.post(url, params=params) @@ -39,5 +39,6 @@ class CreateNewCardOnBoardTool(BuiltinTool): except requests.exceptions.RequestException as e: return self.create_text_message("Failed to create card") - return self.create_text_message(text=f"New card '{new_card['name']}' created successfully with ID {new_card['id']}.") - + return self.create_text_message( + text=f"New card '{new_card['name']}' created successfully with ID {new_card['id']}." + ) diff --git a/api/core/tools/provider/builtin/trello/tools/delete_board.py b/api/core/tools/provider/builtin/trello/tools/delete_board.py index 29df3fda2d..9dbd8f78d5 100644 --- a/api/core/tools/provider/builtin/trello/tools/delete_board.py +++ b/api/core/tools/provider/builtin/trello/tools/delete_board.py @@ -22,9 +22,9 @@ class DeleteBoardTool(BuiltinTool): Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -38,4 +38,3 @@ class DeleteBoardTool(BuiltinTool): return self.create_text_message("Failed to delete board") return self.create_text_message(text=f"Board with ID {board_id} deleted successfully.") - diff --git a/api/core/tools/provider/builtin/trello/tools/delete_card.py b/api/core/tools/provider/builtin/trello/tools/delete_card.py index 2ced5f6c14..960c3055fe 100644 --- a/api/core/tools/provider/builtin/trello/tools/delete_card.py +++ b/api/core/tools/provider/builtin/trello/tools/delete_card.py @@ -22,9 +22,9 @@ class DeleteCardByIdTool(BuiltinTool): Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - card_id = tool_parameters.get('id') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + card_id = tool_parameters.get("id") if not (api_key and token and card_id): return self.create_text_message("Missing required parameters: API key, token, or card ID.") @@ -38,4 +38,3 @@ class DeleteCardByIdTool(BuiltinTool): return self.create_text_message("Failed to delete card") return self.create_text_message(text=f"Card with ID {card_id} has been successfully deleted.") - diff --git a/api/core/tools/provider/builtin/trello/tools/fetch_all_boards.py b/api/core/tools/provider/builtin/trello/tools/fetch_all_boards.py index f9d554c6fb..0c5ed9ea85 100644 --- a/api/core/tools/provider/builtin/trello/tools/fetch_all_boards.py +++ b/api/core/tools/provider/builtin/trello/tools/fetch_all_boards.py @@ -28,9 +28,7 @@ class FetchAllBoardsTool(BuiltinTool): token = self.runtime.credentials.get("trello_api_token") if not (api_key and token): - return self.create_text_message( - "Missing Trello API key or token in credentials." - ) + return self.create_text_message("Missing Trello API key or token in credentials.") # Including board filter in the request if provided board_filter = tool_parameters.get("boards", "open") @@ -48,7 +46,5 @@ class FetchAllBoardsTool(BuiltinTool): return self.create_text_message("No boards found in Trello.") # Creating a string with both board names and IDs - boards_info = ", ".join( - [f"{board['name']} (ID: {board['id']})" for board in boards] - ) + boards_info = ", ".join([f"{board['name']} (ID: {board['id']})" for board in boards]) return self.create_text_message(text=f"Boards: {boards_info}") diff --git a/api/core/tools/provider/builtin/trello/tools/get_board_actions.py b/api/core/tools/provider/builtin/trello/tools/get_board_actions.py index 5678d8f8d7..03510f1964 100644 --- a/api/core/tools/provider/builtin/trello/tools/get_board_actions.py +++ b/api/core/tools/provider/builtin/trello/tools/get_board_actions.py @@ -22,9 +22,9 @@ class GetBoardActionsTool(BuiltinTool): Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -38,6 +38,7 @@ class GetBoardActionsTool(BuiltinTool): except requests.exceptions.RequestException as e: return self.create_text_message("Failed to retrieve board actions") - actions_summary = "\n".join([f"{action['type']}: {action.get('data', {}).get('text', 'No details available')}" for action in actions]) + actions_summary = "\n".join( + [f"{action['type']}: {action.get('data', {}).get('text', 'No details available')}" for action in actions] + ) return self.create_text_message(text=f"Actions for Board ID {board_id}:\n{actions_summary}") - diff --git a/api/core/tools/provider/builtin/trello/tools/get_board_by_id.py b/api/core/tools/provider/builtin/trello/tools/get_board_by_id.py index ee6cb065e5..5b41b128d0 100644 --- a/api/core/tools/provider/builtin/trello/tools/get_board_by_id.py +++ b/api/core/tools/provider/builtin/trello/tools/get_board_by_id.py @@ -22,9 +22,9 @@ class GetBoardByIdTool(BuiltinTool): Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -63,4 +63,3 @@ class GetBoardByIdTool(BuiltinTool): f"Background Color: {board['prefs']['backgroundColor']}" ) return details - diff --git a/api/core/tools/provider/builtin/trello/tools/get_board_cards.py b/api/core/tools/provider/builtin/trello/tools/get_board_cards.py index 1abb688750..e3bed2e6e6 100644 --- a/api/core/tools/provider/builtin/trello/tools/get_board_cards.py +++ b/api/core/tools/provider/builtin/trello/tools/get_board_cards.py @@ -22,9 +22,9 @@ class GetBoardCardsTool(BuiltinTool): Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -40,4 +40,3 @@ class GetBoardCardsTool(BuiltinTool): cards_summary = "\n".join([f"{card['name']} (ID: {card['id']})" for card in cards]) return self.create_text_message(text=f"Cards for Board ID {board_id}:\n{cards_summary}") - diff --git a/api/core/tools/provider/builtin/trello/tools/get_filterd_board_cards.py b/api/core/tools/provider/builtin/trello/tools/get_filterd_board_cards.py index 375ead5b1d..4d8854747c 100644 --- a/api/core/tools/provider/builtin/trello/tools/get_filterd_board_cards.py +++ b/api/core/tools/provider/builtin/trello/tools/get_filterd_board_cards.py @@ -22,10 +22,10 @@ class GetFilteredBoardCardsTool(BuiltinTool): Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') - filter = tool_parameters.get('filter') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") + filter = tool_parameters.get("filter") if not (api_key and token and board_id and filter): return self.create_text_message("Missing required parameters: API key, token, board ID, or filter.") @@ -40,5 +40,6 @@ class GetFilteredBoardCardsTool(BuiltinTool): return self.create_text_message("Failed to retrieve filtered cards") card_details = "\n".join([f"{card['name']} (ID: {card['id']})" for card in filtered_cards]) - return self.create_text_message(text=f"Filtered Cards for Board ID {board_id} with Filter '{filter}':\n{card_details}") - + return self.create_text_message( + text=f"Filtered Cards for Board ID {board_id} with Filter '{filter}':\n{card_details}" + ) diff --git a/api/core/tools/provider/builtin/trello/tools/get_lists_on_board.py b/api/core/tools/provider/builtin/trello/tools/get_lists_on_board.py index 7b9b9cf24b..ca8aa9c2d5 100644 --- a/api/core/tools/provider/builtin/trello/tools/get_lists_on_board.py +++ b/api/core/tools/provider/builtin/trello/tools/get_lists_on_board.py @@ -22,9 +22,9 @@ class GetListsFromBoardTool(BuiltinTool): Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -40,4 +40,3 @@ class GetListsFromBoardTool(BuiltinTool): lists_info = "\n".join([f"{list['name']} (ID: {list['id']})" for list in lists]) return self.create_text_message(text=f"Lists on Board ID {board_id}:\n{lists_info}") - diff --git a/api/core/tools/provider/builtin/trello/tools/update_board.py b/api/core/tools/provider/builtin/trello/tools/update_board.py index 7ad6ac2e64..62681eea6b 100644 --- a/api/core/tools/provider/builtin/trello/tools/update_board.py +++ b/api/core/tools/provider/builtin/trello/tools/update_board.py @@ -22,9 +22,9 @@ class UpdateBoardByIdTool(BuiltinTool): Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.pop('boardId', None) + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.pop("boardId", None) if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -33,8 +33,8 @@ class UpdateBoardByIdTool(BuiltinTool): # Removing parameters not intended for update action or with None value params = {k: v for k, v in tool_parameters.items() if v is not None} - params['key'] = api_key - params['token'] = token + params["key"] = api_key + params["token"] = token try: response = requests.put(url, params=params) @@ -44,4 +44,3 @@ class UpdateBoardByIdTool(BuiltinTool): updated_board = response.json() return self.create_text_message(text=f"Board '{updated_board['name']}' updated successfully.") - diff --git a/api/core/tools/provider/builtin/trello/tools/update_card.py b/api/core/tools/provider/builtin/trello/tools/update_card.py index 417344350c..26113f1229 100644 --- a/api/core/tools/provider/builtin/trello/tools/update_card.py +++ b/api/core/tools/provider/builtin/trello/tools/update_card.py @@ -22,17 +22,17 @@ class UpdateCardByIdTool(BuiltinTool): Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - card_id = tool_parameters.get('id') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + card_id = tool_parameters.get("id") if not (api_key and token and card_id): return self.create_text_message("Missing required parameters: API key, token, or card ID.") # Constructing the URL and the payload for the PUT request url = f"https://api.trello.com/1/cards/{card_id}" - params = {k: v for k, v in tool_parameters.items() if v is not None and k != 'id'} - params.update({'key': api_key, 'token': token}) + params = {k: v for k, v in tool_parameters.items() if v is not None and k != "id"} + params.update({"key": api_key, "token": token}) try: response = requests.put(url, params=params) diff --git a/api/core/tools/provider/builtin/trello/trello.py b/api/core/tools/provider/builtin/trello/trello.py index 84ecd20803..e0dca50ec9 100644 --- a/api/core/tools/provider/builtin/trello/trello.py +++ b/api/core/tools/provider/builtin/trello/trello.py @@ -9,17 +9,17 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class TrelloProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: """Validate Trello API credentials by making a test API call. - + Args: credentials (dict[str, Any]): The Trello API credentials to validate. - + Raises: ToolProviderCredentialValidationError: If the credentials are invalid. """ api_key = credentials.get("trello_api_key") token = credentials.get("trello_api_token") url = f"https://api.trello.com/1/members/me?key={api_key}&token={token}" - + try: response = requests.get(url) response.raise_for_status() # Raises an HTTPError for bad responses @@ -32,4 +32,3 @@ class TrelloProvider(BuiltinToolProviderController): except requests.exceptions.RequestException as e: # Handle other exceptions, such as connection errors raise ToolProviderCredentialValidationError("Error validating Trello credentials") - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/twilio/tools/send_message.py b/api/core/tools/provider/builtin/twilio/tools/send_message.py index 1c52589956..822d0c0ebd 100644 --- a/api/core/tools/provider/builtin/twilio/tools/send_message.py +++ b/api/core/tools/provider/builtin/twilio/tools/send_message.py @@ -32,17 +32,14 @@ class TwilioAPIWrapper(BaseModel): must be empty. """ - @field_validator('client', mode='before') + @field_validator("client", mode="before") @classmethod def set_validator(cls, values: dict) -> dict: """Validate that api key and python package exists in environment.""" try: from twilio.rest import Client except ImportError: - raise ImportError( - "Could not import twilio python package. " - "Please install it with `pip install twilio`." - ) + raise ImportError("Could not import twilio python package. " "Please install it with `pip install twilio`.") account_sid = values.get("account_sid") auth_token = values.get("auth_token") values["from_number"] = values.get("from_number") @@ -91,9 +88,7 @@ class SendMessageTool(BuiltinTool): if to_number.startswith("whatsapp:"): from_number = f"whatsapp: {from_number}" - twilio = TwilioAPIWrapper( - account_sid=account_sid, auth_token=auth_token, from_number=from_number - ) + twilio = TwilioAPIWrapper(account_sid=account_sid, auth_token=auth_token, from_number=from_number) # Sending the message through Twilio result = twilio.run(message, to_number) diff --git a/api/core/tools/provider/builtin/twilio/twilio.py b/api/core/tools/provider/builtin/twilio/twilio.py index 06f276053a..b1d100aad9 100644 --- a/api/core/tools/provider/builtin/twilio/twilio.py +++ b/api/core/tools/provider/builtin/twilio/twilio.py @@ -14,7 +14,7 @@ class TwilioProvider(BuiltinToolProviderController): account_sid = credentials["account_sid"] auth_token = credentials["auth_token"] from_number = credentials["from_number"] - + # Initialize twilio client client = Client(account_sid, auth_token) @@ -27,4 +27,3 @@ class TwilioProvider(BuiltinToolProviderController): raise ToolProviderCredentialValidationError(f"Missing required credential: {e}") from e except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/vanna/vanna.py b/api/core/tools/provider/builtin/vanna/vanna.py index ab1fd71df5..84724e921a 100644 --- a/api/core/tools/provider/builtin/vanna/vanna.py +++ b/api/core/tools/provider/builtin/vanna/vanna.py @@ -13,13 +13,13 @@ class VannaProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "model": "chinook", "db_type": "SQLite", "url": "https://vanna.ai/Chinook.sqlite", - "query": "What are the top 10 customers by sales?" + "query": "What are the top 10 customers by sales?", }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/vectorizer/tools/test_data.py b/api/core/tools/provider/builtin/vectorizer/tools/test_data.py index 1506ac0c9d..8e1b097776 100644 --- a/api/core/tools/provider/builtin/vectorizer/tools/test_data.py +++ b/api/core/tools/provider/builtin/vectorizer/tools/test_data.py @@ -1 +1 @@ -VECTORIZER_ICON_PNG = 'iVBORw0KGgoAAAANSUhEUgAAAGAAAABgCAYAAADimHc4AAAACXBIWXMAACxLAAAsSwGlPZapAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAboSURBVHgB7Z09bBxFFMffRoAvcQqbguBUxu4wCUikMCZ0TmQK4NLQJCJOlQIkokgEGhQ7NCFIKEhQuIqNnIaGMxRY2GVwmlggDHS+pIHELmIXMTEULPP3eeXz7e7szO7MvE1ufpKV03nuNn7/mfcxH7tEHo/H42lXgqwG1bGw65+/aTQM6K0gpJdCoi7ypCIMui5s9Qv9R1OVTqrVxoL1jPbpvH4hrIp/rnmj5+YOhTQ++1kwmdZgT9ovRi6EF4Xhv/XGL0Sv6OLXYMu0BokjYOSDcBQfJI8xhKFP/HAlqCW8v5vqubBr8yn6maCexxiIDR376LnWmBBzQZtPEvx+L3mMAleOZKb1/XgM2EOnyWMFZJKt78UEQKpJHisk2TYmgM967JFk2z3kYcULwIwXgBkvADNeAGa8AMw8Qcwc6N55/eAh0cYmGaOzQtR/kOhQX+M6+/c23r+3RlT/i2ipTrSyRqw4F+CwMMbgANHQwG7jRywLw/wqDDNzI79xYPjqa2L262jjtYzaT0QT3xEbsck4MXUakgWOvUx08liy0ZPYEKNhel4Y6AZpgR7/8Tvq1wEQ+sMJN6Nh9kqwy+bWYwAM8elZovNv6xmlU7iLs280RNO9ls51os/h/8eBVQEig8Dt5OXUsNrno2tluZw0cI3qUXKONQHy9sYkVHqnjntLA2LnFTAv1gSA+zBhfIDvkfVO/B4xRgWZn4fbe2WAnGJFAAxn03+I7PtUXdzE90Sjl4ne+6L4d5nCigAyYyHPn7tFdPN30uJwX/qI6jtISkQZFVLdhd9SrtNPTrFSB6QZBAaYntsptpAyfvk+KYOCamVR/XrNtLqepduiFnkh3g4iIw6YLAhlOJmKwB9zaarhApr/MPREjAZVisSU1s/KYsGzhmKXClYEWLm/8xpV7btXhcv5I7lt2vtJFA3q/T07r1HopdG5l5xhxQVdn28YFn8kBJCBOZmiPHio1m5QuJzlu9ntXApgZwSsNYJslvGjtjrfm8Sq4neceFUtz3dZCzwW09Gqo2hreuPN7HZRnNqa1BP1x8lhczVNK+zT0TqkjYAF4e7Okxoo2PZX5K4IrhNpb/P8FTK2S1+TcUq1HpBFmquJYo1qEYU6RVarJE0c2ooL7C5IRwBZ5nJ9joyRtk5hA3YBdHqWzG1gBKgE/bzMaK5LqMIugKrbUDHu59/YWVRBsWhrsYZdANV5HBUXYGNlC9dFBW8LdgH6FQVYUnQvkQgm3NH8YuO7bM4LsWZBfT3qRY9OxRyJgJRz+Ij+FDPEQ1C3GVMiWAVQ7f31u/ncytxi4wdZTbRGgdcHnpYLD/FcwSrAoOKizfKfVAiIF4kBMPK+Opfe1iWsMUB1BJh2BRgBabSNAOiFqkXYbcNFUF9P+u82FGdWTcEmgGrvh0FUppB1kC073muXEaDq/21kIjLxV9tFAC7/n5X6tkUM0PH/dcP+P0v41fvkFBYBVHs/MD0CDmVsOzEdb7JgEYDT/8uq4rpj44NSjwDTc/CyzV1gxbH7Ac4F0PH/S4ZHAOaFZLiY+2nFuQA6/t9kQMTCz1CG66tbWvWS4VwAVf9vugAbel6efqrsYbKBcwFeVNz8ajobyTppw2F84FQAnfl/kwER6wJZcWdBc7e2KZwKoOP/TVakWb0f7md+kVhwOwI0BDCFyq42rt4PSiuAiRGAEXdK4ZQlV+8HTgVwefwHvR7nhbOA0FwBGDgTIM/Z3SLXUj2hOW1wR10eSrs7Ou9eTB3jo/dzuh/gTABdn35c8dhpM3BxOmeTuXs/cDoCdDY4qe7l32pbaZxL1jF+GXo/cLotBcWVTiZU3T7RMn8rHiijW9FgauP4Ef1TLdhHWgacCgAj6tYCqGKjU/DNbqxIkMYZNs7MpxmnLuhmwYJna1dbdzHjY42hDL4/wqkA6HWuDkAngRH0iYVjRkVwnoZO/0gsuLwpkw7OBcAtwlwvfESHxctmfMBSiOG0oStj4HCF7T3+RWARwIU7QK/HbWlqls52mYJtezqMj3v34C5VOveFy8Ll4QoTsJ8Txp0RsW8/Os2im2LCtSC1RIqLw3RldTVplOKkPEYDhMAPqttnune2rzTv5Y+WKdEem2ixkWqZYSeDSUp3qwIYNOrR7cBjcbOORxkvADNeAGa8AMx4AZjxAjATf5Ab0Tp5rJBk2/iD3PAwYo8Vkmyb9CjDGfLYIaCp1rdiAnT8S5PeDVkgoDuVCsWeJxwToHZ163m3Z8hjloDGk54vn5gFbT/5eZw8phifvZz8XPlA9qmRj8JRCumi+OkljzbbrvxM0qPMm9rIqY6FXZubVBUinMbzcP3jbuXA6Mh2kMx07KPJJLfj8Xg8Hg/4H+KfFYb2WM4MAAAAAElFTkSuQmCC' \ No newline at end of file +VECTORIZER_ICON_PNG = "iVBORw0KGgoAAAANSUhEUgAAAGAAAABgCAYAAADimHc4AAAACXBIWXMAACxLAAAsSwGlPZapAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAboSURBVHgB7Z09bBxFFMffRoAvcQqbguBUxu4wCUikMCZ0TmQK4NLQJCJOlQIkokgEGhQ7NCFIKEhQuIqNnIaGMxRY2GVwmlggDHS+pIHELmIXMTEULPP3eeXz7e7szO7MvE1ufpKV03nuNn7/mfcxH7tEHo/H42lXgqwG1bGw65+/aTQM6K0gpJdCoi7ypCIMui5s9Qv9R1OVTqrVxoL1jPbpvH4hrIp/rnmj5+YOhTQ++1kwmdZgT9ovRi6EF4Xhv/XGL0Sv6OLXYMu0BokjYOSDcBQfJI8xhKFP/HAlqCW8v5vqubBr8yn6maCexxiIDR376LnWmBBzQZtPEvx+L3mMAleOZKb1/XgM2EOnyWMFZJKt78UEQKpJHisk2TYmgM967JFk2z3kYcULwIwXgBkvADNeAGa8AMw8Qcwc6N55/eAh0cYmGaOzQtR/kOhQX+M6+/c23r+3RlT/i2ipTrSyRqw4F+CwMMbgANHQwG7jRywLw/wqDDNzI79xYPjqa2L262jjtYzaT0QT3xEbsck4MXUakgWOvUx08liy0ZPYEKNhel4Y6AZpgR7/8Tvq1wEQ+sMJN6Nh9kqwy+bWYwAM8elZovNv6xmlU7iLs280RNO9ls51os/h/8eBVQEig8Dt5OXUsNrno2tluZw0cI3qUXKONQHy9sYkVHqnjntLA2LnFTAv1gSA+zBhfIDvkfVO/B4xRgWZn4fbe2WAnGJFAAxn03+I7PtUXdzE90Sjl4ne+6L4d5nCigAyYyHPn7tFdPN30uJwX/qI6jtISkQZFVLdhd9SrtNPTrFSB6QZBAaYntsptpAyfvk+KYOCamVR/XrNtLqepduiFnkh3g4iIw6YLAhlOJmKwB9zaarhApr/MPREjAZVisSU1s/KYsGzhmKXClYEWLm/8xpV7btXhcv5I7lt2vtJFA3q/T07r1HopdG5l5xhxQVdn28YFn8kBJCBOZmiPHio1m5QuJzlu9ntXApgZwSsNYJslvGjtjrfm8Sq4neceFUtz3dZCzwW09Gqo2hreuPN7HZRnNqa1BP1x8lhczVNK+zT0TqkjYAF4e7Okxoo2PZX5K4IrhNpb/P8FTK2S1+TcUq1HpBFmquJYo1qEYU6RVarJE0c2ooL7C5IRwBZ5nJ9joyRtk5hA3YBdHqWzG1gBKgE/bzMaK5LqMIugKrbUDHu59/YWVRBsWhrsYZdANV5HBUXYGNlC9dFBW8LdgH6FQVYUnQvkQgm3NH8YuO7bM4LsWZBfT3qRY9OxRyJgJRz+Ij+FDPEQ1C3GVMiWAVQ7f31u/ncytxi4wdZTbRGgdcHnpYLD/FcwSrAoOKizfKfVAiIF4kBMPK+Opfe1iWsMUB1BJh2BRgBabSNAOiFqkXYbcNFUF9P+u82FGdWTcEmgGrvh0FUppB1kC073muXEaDq/21kIjLxV9tFAC7/n5X6tkUM0PH/dcP+P0v41fvkFBYBVHs/MD0CDmVsOzEdb7JgEYDT/8uq4rpj44NSjwDTc/CyzV1gxbH7Ac4F0PH/S4ZHAOaFZLiY+2nFuQA6/t9kQMTCz1CG66tbWvWS4VwAVf9vugAbel6efqrsYbKBcwFeVNz8ajobyTppw2F84FQAnfl/kwER6wJZcWdBc7e2KZwKoOP/TVakWb0f7md+kVhwOwI0BDCFyq42rt4PSiuAiRGAEXdK4ZQlV+8HTgVwefwHvR7nhbOA0FwBGDgTIM/Z3SLXUj2hOW1wR10eSrs7Ou9eTB3jo/dzuh/gTABdn35c8dhpM3BxOmeTuXs/cDoCdDY4qe7l32pbaZxL1jF+GXo/cLotBcWVTiZU3T7RMn8rHiijW9FgauP4Ef1TLdhHWgacCgAj6tYCqGKjU/DNbqxIkMYZNs7MpxmnLuhmwYJna1dbdzHjY42hDL4/wqkA6HWuDkAngRH0iYVjRkVwnoZO/0gsuLwpkw7OBcAtwlwvfESHxctmfMBSiOG0oStj4HCF7T3+RWARwIU7QK/HbWlqls52mYJtezqMj3v34C5VOveFy8Ll4QoTsJ8Txp0RsW8/Os2im2LCtSC1RIqLw3RldTVplOKkPEYDhMAPqttnune2rzTv5Y+WKdEem2ixkWqZYSeDSUp3qwIYNOrR7cBjcbOORxkvADNeAGa8AMx4AZjxAjATf5Ab0Tp5rJBk2/iD3PAwYo8Vkmyb9CjDGfLYIaCp1rdiAnT8S5PeDVkgoDuVCsWeJxwToHZ163m3Z8hjloDGk54vn5gFbT/5eZw8phifvZz8XPlA9qmRj8JRCumi+OkljzbbrvxM0qPMm9rIqY6FXZubVBUinMbzcP3jbuXA6Mh2kMx07KPJJLfj8Xg8Hg/4H+KfFYb2WM4MAAAAAElFTkSuQmCC" diff --git a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py index c6ec198034..4bd601c0bd 100644 --- a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py +++ b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py @@ -10,65 +10,60 @@ from core.tools.tool.builtin_tool import BuiltinTool class VectorizerTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - api_key_name = self.runtime.credentials.get('api_key_name', None) - api_key_value = self.runtime.credentials.get('api_key_value', None) - mode = tool_parameters.get('mode', 'test') - if mode == 'production': - mode = 'preview' + api_key_name = self.runtime.credentials.get("api_key_name", None) + api_key_value = self.runtime.credentials.get("api_key_value", None) + mode = tool_parameters.get("mode", "test") + if mode == "production": + mode = "preview" if not api_key_name or not api_key_value: - raise ToolProviderCredentialValidationError('Please input api key name and value') + raise ToolProviderCredentialValidationError("Please input api key name and value") - image_id = tool_parameters.get('image_id', '') + image_id = tool_parameters.get("image_id", "") if not image_id: - return self.create_text_message('Please input image id') - - if image_id.startswith('__test_'): + return self.create_text_message("Please input image id") + + if image_id.startswith("__test_"): image_binary = b64decode(VECTORIZER_ICON_PNG) else: - image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE) + image_binary = self.get_variable_file(self.VariableKey.IMAGE) if not image_binary: - return self.create_text_message('Image not found, please request user to generate image firstly.') + return self.create_text_message("Image not found, please request user to generate image firstly.") response = post( - 'https://vectorizer.ai/api/v1/vectorize', - files={ - 'image': image_binary - }, - data={ - 'mode': mode - } if mode == 'test' else {}, - auth=(api_key_name, api_key_value), - timeout=30 + "https://vectorizer.ai/api/v1/vectorize", + files={"image": image_binary}, + data={"mode": mode} if mode == "test" else {}, + auth=(api_key_name, api_key_value), + timeout=30, ) if response.status_code != 200: raise Exception(response.text) - + return [ - self.create_text_message('the vectorized svg is saved as an image.'), - self.create_blob_message(blob=response.content, - meta={'mime_type': 'image/svg+xml'}) + self.create_text_message("the vectorized svg is saved as an image."), + self.create_blob_message(blob=response.content, meta={"mime_type": "image/svg+xml"}), ] - + def get_runtime_parameters(self) -> list[ToolParameter]: """ override the runtime parameters """ return [ ToolParameter.get_simple_instance( - name='image_id', - llm_description=f'the image id that you want to vectorize, \ + name="image_id", + llm_description=f"the image id that you want to vectorize, \ and the image id should be specified in \ - {[i.name for i in self.list_default_image_variables()]}', + {[i.name for i in self.list_default_image_variables()]}", type=ToolParameter.ToolParameterType.SELECT, required=True, - options=[i.name for i in self.list_default_image_variables()] + options=[i.name for i in self.list_default_image_variables()], ) ] - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/vectorizer/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/vectorizer.py index 3f89a83500..3b868572f9 100644 --- a/api/core/tools/provider/builtin/vectorizer/vectorizer.py +++ b/api/core/tools/provider/builtin/vectorizer/vectorizer.py @@ -13,12 +13,8 @@ class VectorizerProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "mode": "test", - "image_id": "__test_123" - }, + user_id="", + tool_parameters={"mode": "test", "image_id": "__test_123"}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/webscraper/tools/webscraper.py b/api/core/tools/provider/builtin/webscraper/tools/webscraper.py index 3d098e6768..12670b4b8b 100644 --- a/api/core/tools/provider/builtin/webscraper/tools/webscraper.py +++ b/api/core/tools/provider/builtin/webscraper/tools/webscraper.py @@ -6,23 +6,24 @@ from core.tools.tool.builtin_tool import BuiltinTool class WebscraperTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ try: - url = tool_parameters.get('url', '') - user_agent = tool_parameters.get('user_agent', '') + url = tool_parameters.get("url", "") + user_agent = tool_parameters.get("user_agent", "") if not url: - return self.create_text_message('Please input url') + return self.create_text_message("Please input url") # get webpage result = self.get_url(url, user_agent=user_agent) - if tool_parameters.get('generate_summary'): + if tool_parameters.get("generate_summary"): # summarize and return return self.create_text_message(self.summary(user_id=user_id, content=result)) else: diff --git a/api/core/tools/provider/builtin/webscraper/webscraper.py b/api/core/tools/provider/builtin/webscraper/webscraper.py index 1e60fdb293..3c51393ac6 100644 --- a/api/core/tools/provider/builtin/webscraper/webscraper.py +++ b/api/core/tools/provider/builtin/webscraper/webscraper.py @@ -13,12 +13,11 @@ class WebscraperProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ - 'url': 'https://www.google.com', - 'user_agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 ' + "url": "https://www.google.com", + "user_agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 ", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/websearch/tools/job_search.py b/api/core/tools/provider/builtin/websearch/tools/job_search.py index 9128305922..293f4f6329 100644 --- a/api/core/tools/provider/builtin/websearch/tools/job_search.py +++ b/api/core/tools/provider/builtin/websearch/tools/job_search.py @@ -50,14 +50,16 @@ class SerplyApi: for job in jobs[:10]: try: string.append( - "\n".join([ - f"Position: {job['position']}", - f"Employer: {job['employer']}", - f"Location: {job['location']}", - f"Link: {job['link']}", - f"""Highest: {", ".join(list(job["highlights"]))}""", - "---", - ]) + "\n".join( + [ + f"Position: {job['position']}", + f"Employer: {job['employer']}", + f"Location: {job['location']}", + f"Link: {job['link']}", + f"""Highest: {", ".join(list(job["highlights"]))}""", + "---", + ] + ) ) except KeyError: continue diff --git a/api/core/tools/provider/builtin/websearch/tools/news_search.py b/api/core/tools/provider/builtin/websearch/tools/news_search.py index e9c0744f05..9b5482fe18 100644 --- a/api/core/tools/provider/builtin/websearch/tools/news_search.py +++ b/api/core/tools/provider/builtin/websearch/tools/news_search.py @@ -53,13 +53,15 @@ class SerplyApi: r = requests.get(entry["link"]) final_link = r.history[-1].headers["Location"] string.append( - "\n".join([ - f"Title: {entry['title']}", - f"Link: {final_link}", - f"Source: {entry['source']['title']}", - f"Published: {entry['published']}", - "---", - ]) + "\n".join( + [ + f"Title: {entry['title']}", + f"Link: {final_link}", + f"Source: {entry['source']['title']}", + f"Published: {entry['published']}", + "---", + ] + ) ) except KeyError: continue diff --git a/api/core/tools/provider/builtin/websearch/tools/scholar_search.py b/api/core/tools/provider/builtin/websearch/tools/scholar_search.py index 0030a03c06..798d059b51 100644 --- a/api/core/tools/provider/builtin/websearch/tools/scholar_search.py +++ b/api/core/tools/provider/builtin/websearch/tools/scholar_search.py @@ -55,14 +55,16 @@ class SerplyApi: link = article["link"] authors = [author["name"] for author in article["author"]["authors"]] string.append( - "\n".join([ - f"Title: {article['title']}", - f"Link: {link}", - f"Description: {article['description']}", - f"Cite: {article['cite']}", - f"Authors: {', '.join(authors)}", - "---", - ]) + "\n".join( + [ + f"Title: {article['title']}", + f"Link: {link}", + f"Description: {article['description']}", + f"Cite: {article['cite']}", + f"Authors: {', '.join(authors)}", + "---", + ] + ) ) except KeyError: continue diff --git a/api/core/tools/provider/builtin/websearch/tools/web_search.py b/api/core/tools/provider/builtin/websearch/tools/web_search.py index 4f57c27caf..fe363ac7a4 100644 --- a/api/core/tools/provider/builtin/websearch/tools/web_search.py +++ b/api/core/tools/provider/builtin/websearch/tools/web_search.py @@ -49,12 +49,14 @@ class SerplyApi: for result in results: try: string.append( - "\n".join([ - f"Title: {result['title']}", - f"Link: {result['link']}", - f"Description: {result['description'].strip()}", - "---", - ]) + "\n".join( + [ + f"Title: {result['title']}", + f"Link: {result['link']}", + f"Description: {result['description'].strip()}", + "---", + ] + ) ) except KeyError: continue diff --git a/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.py b/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.py index fb44b70f4e..545d9f4f8d 100644 --- a/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.py +++ b/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.py @@ -8,41 +8,41 @@ from core.tools.utils.uuid_utils import is_valid_uuid class WecomGroupBotTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") - hook_key = tool_parameters.get('hook_key', '') + hook_key = tool_parameters.get("hook_key", "") if not is_valid_uuid(hook_key): - return self.create_text_message( - f'Invalid parameter hook_key ${hook_key}, not a valid UUID') + return self.create_text_message(f"Invalid parameter hook_key ${hook_key}, not a valid UUID") - message_type = tool_parameters.get('message_type', 'text') - if message_type == 'markdown': + message_type = tool_parameters.get("message_type", "text") + if message_type == "markdown": payload = { - "msgtype": 'markdown', + "msgtype": "markdown", "markdown": { "content": content, - } + }, } else: payload = { - "msgtype": 'text', + "msgtype": "text", "text": { "content": content, - } + }, } - api_url = 'https://qyapi.weixin.qq.com/cgi-bin/webhook/send' + api_url = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send" headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } params = { - 'key': hook_key, + "key": hook_key, } try: @@ -51,6 +51,7 @@ class WecomGroupBotTool(BuiltinTool): return self.create_text_message("Text message sent successfully") else: return self.create_text_message( - f"Failed to send the text message, status code: {res.status_code}, response: {res.text}") + f"Failed to send the text message, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to send message to group chat bot. {}".format(e)) diff --git a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py index 0796cd2392..67efcf0954 100644 --- a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py +++ b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py @@ -83,7 +83,6 @@ class WikipediaQueryRun: class WikiPediaSearchTool(BuiltinTool): - def _invoke( self, user_id: str, diff --git a/api/core/tools/provider/builtin/wikipedia/wikipedia.py b/api/core/tools/provider/builtin/wikipedia/wikipedia.py index f8038714a5..178bf7b0ce 100644 --- a/api/core/tools/provider/builtin/wikipedia/wikipedia.py +++ b/api/core/tools/provider/builtin/wikipedia/wikipedia.py @@ -11,11 +11,10 @@ class WikiPediaProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "misaka mikoto", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py b/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py index 8cb9c10ddf..9dc5bed824 100644 --- a/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py +++ b/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py @@ -8,29 +8,24 @@ from core.tools.tool.builtin_tool import BuiltinTool class WolframAlphaTool(BuiltinTool): - _base_url = 'https://api.wolframalpha.com/v2/query' + _base_url = "https://api.wolframalpha.com/v2/query" - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Please input query') - appid = self.runtime.credentials.get('appid', '') + return self.create_text_message("Please input query") + appid = self.runtime.credentials.get("appid", "") if not appid: - raise ToolProviderCredentialValidationError('Please input appid') - - params = { - 'appid': appid, - 'input': query, - 'includepodid': 'Result', - 'format': 'plaintext', - 'output': 'json' - } + raise ToolProviderCredentialValidationError("Please input appid") + + params = {"appid": appid, "input": query, "includepodid": "Result", "format": "plaintext", "output": "json"} finished = False result = None @@ -45,34 +40,33 @@ class WolframAlphaTool(BuiltinTool): response_data = response.json() except Exception as e: raise ToolInvokeError(str(e)) - - if 'success' not in response_data['queryresult'] or response_data['queryresult']['success'] != True: - query_result = response_data.get('queryresult', {}) - if query_result.get('error'): - if 'msg' in query_result['error']: - if query_result['error']['msg'] == 'Invalid appid': - raise ToolProviderCredentialValidationError('Invalid appid') - raise ToolInvokeError('Failed to invoke tool') - - if 'didyoumeans' in response_data['queryresult']: - # get the most likely interpretation - query = '' - max_score = 0 - for didyoumean in response_data['queryresult']['didyoumeans']: - if float(didyoumean['score']) > max_score: - query = didyoumean['val'] - max_score = float(didyoumean['score']) - params['input'] = query + if "success" not in response_data["queryresult"] or response_data["queryresult"]["success"] != True: + query_result = response_data.get("queryresult", {}) + if query_result.get("error"): + if "msg" in query_result["error"]: + if query_result["error"]["msg"] == "Invalid appid": + raise ToolProviderCredentialValidationError("Invalid appid") + raise ToolInvokeError("Failed to invoke tool") + + if "didyoumeans" in response_data["queryresult"]: + # get the most likely interpretation + query = "" + max_score = 0 + for didyoumean in response_data["queryresult"]["didyoumeans"]: + if float(didyoumean["score"]) > max_score: + query = didyoumean["val"] + max_score = float(didyoumean["score"]) + + params["input"] = query else: finished = True - if 'souces' in response_data['queryresult']: - return self.create_link_message(response_data['queryresult']['sources']['url']) - elif 'pods' in response_data['queryresult']: - result = response_data['queryresult']['pods'][0]['subpods'][0]['plaintext'] + if "souces" in response_data["queryresult"]: + return self.create_link_message(response_data["queryresult"]["sources"]["url"]) + elif "pods" in response_data["queryresult"]: + result = response_data["queryresult"]["pods"][0]["subpods"][0]["plaintext"] if not finished or not result: - return self.create_text_message('No result found') + return self.create_text_message("No result found") return self.create_text_message(result) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py index ef1aac7ff2..7be288b538 100644 --- a/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py +++ b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py @@ -13,11 +13,10 @@ class GoogleProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "1+2+....+111", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/yahoo/tools/analytics.py b/api/core/tools/provider/builtin/yahoo/tools/analytics.py index cf511ea894..f044fbe540 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/analytics.py +++ b/api/core/tools/provider/builtin/yahoo/tools/analytics.py @@ -10,27 +10,28 @@ from core.tools.tool.builtin_tool import BuiltinTool class YahooFinanceAnalyticsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - symbol = tool_parameters.get('symbol', '') + symbol = tool_parameters.get("symbol", "") if not symbol: - return self.create_text_message('Please input symbol') - + return self.create_text_message("Please input symbol") + time_range = [None, None] - start_date = tool_parameters.get('start_date', '') + start_date = tool_parameters.get("start_date", "") if start_date: time_range[0] = start_date else: - time_range[0] = '1800-01-01' + time_range[0] = "1800-01-01" - end_date = tool_parameters.get('end_date', '') + end_date = tool_parameters.get("end_date", "") if end_date: time_range[1] = end_date else: - time_range[1] = datetime.now().strftime('%Y-%m-%d') + time_range[1] = datetime.now().strftime("%Y-%m-%d") stock_data = download(symbol, start=time_range[0], end=time_range[1]) max_segments = min(15, len(stock_data)) @@ -41,30 +42,29 @@ class YahooFinanceAnalyticsTool(BuiltinTool): end_idx = (i + 1) * rows_per_segment if i < max_segments - 1 else len(stock_data) segment_data = stock_data.iloc[start_idx:end_idx] segment_summary = { - 'Start Date': segment_data.index[0], - 'End Date': segment_data.index[-1], - 'Average Close': segment_data['Close'].mean(), - 'Average Volume': segment_data['Volume'].mean(), - 'Average Open': segment_data['Open'].mean(), - 'Average High': segment_data['High'].mean(), - 'Average Low': segment_data['Low'].mean(), - 'Average Adj Close': segment_data['Adj Close'].mean(), - 'Max Close': segment_data['Close'].max(), - 'Min Close': segment_data['Close'].min(), - 'Max Volume': segment_data['Volume'].max(), - 'Min Volume': segment_data['Volume'].min(), - 'Max Open': segment_data['Open'].max(), - 'Min Open': segment_data['Open'].min(), - 'Max High': segment_data['High'].max(), - 'Min High': segment_data['High'].min(), + "Start Date": segment_data.index[0], + "End Date": segment_data.index[-1], + "Average Close": segment_data["Close"].mean(), + "Average Volume": segment_data["Volume"].mean(), + "Average Open": segment_data["Open"].mean(), + "Average High": segment_data["High"].mean(), + "Average Low": segment_data["Low"].mean(), + "Average Adj Close": segment_data["Adj Close"].mean(), + "Max Close": segment_data["Close"].max(), + "Min Close": segment_data["Close"].min(), + "Max Volume": segment_data["Volume"].max(), + "Min Volume": segment_data["Volume"].min(), + "Max Open": segment_data["Open"].max(), + "Min Open": segment_data["Open"].min(), + "Max High": segment_data["High"].max(), + "Min High": segment_data["High"].min(), } - + summary_data.append(segment_summary) summary_df = pd.DataFrame(summary_data) - + try: return self.create_text_message(str(summary_df.to_dict())) except (HTTPError, ReadTimeout): - return self.create_text_message('There is a internet connection problem. Please try again later.') - \ No newline at end of file + return self.create_text_message("There is a internet connection problem. Please try again later.") diff --git a/api/core/tools/provider/builtin/yahoo/tools/news.py b/api/core/tools/provider/builtin/yahoo/tools/news.py index 4f2922ef3e..ff820430f9 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/news.py +++ b/api/core/tools/provider/builtin/yahoo/tools/news.py @@ -8,40 +8,39 @@ from core.tools.tool.builtin_tool import BuiltinTool class YahooFinanceSearchTickerTool(BuiltinTool): - def _invoke(self,user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - ''' - invoke tools - ''' - - query = tool_parameters.get('symbol', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + + query = tool_parameters.get("symbol", "") if not query: - return self.create_text_message('Please input symbol') - + return self.create_text_message("Please input symbol") + try: return self.run(ticker=query, user_id=user_id) except (HTTPError, ReadTimeout): - return self.create_text_message('There is a internet connection problem. Please try again later.') + return self.create_text_message("There is a internet connection problem. Please try again later.") def run(self, ticker: str, user_id: str) -> ToolInvokeMessage: company = yfinance.Ticker(ticker) try: if company.isin is None: - return self.create_text_message(f'Company ticker {ticker} not found.') + return self.create_text_message(f"Company ticker {ticker} not found.") except (HTTPError, ReadTimeout, ConnectionError): - return self.create_text_message(f'Company ticker {ticker} not found.') + return self.create_text_message(f"Company ticker {ticker} not found.") links = [] try: - links = [n['link'] for n in company.news if n['type'] == 'STORY'] + links = [n["link"] for n in company.news if n["type"] == "STORY"] except (HTTPError, ReadTimeout, ConnectionError): if not links: - return self.create_text_message(f'There is nothing about {ticker} ticker') + return self.create_text_message(f"There is nothing about {ticker} ticker") if not links: - return self.create_text_message(f'No news found for company that searched with {ticker} ticker.') - - result = '\n\n'.join([ - self.get_url(link) for link in links - ]) + return self.create_text_message(f"No news found for company that searched with {ticker} ticker.") + + result = "\n\n".join([self.get_url(link) for link in links]) return self.create_text_message(self.summary(user_id=user_id, content=result)) diff --git a/api/core/tools/provider/builtin/yahoo/tools/ticker.py b/api/core/tools/provider/builtin/yahoo/tools/ticker.py index 262fff3b25..dfc7e46047 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/ticker.py +++ b/api/core/tools/provider/builtin/yahoo/tools/ticker.py @@ -8,19 +8,20 @@ from core.tools.tool.builtin_tool import BuiltinTool class YahooFinanceSearchTickerTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - query = tool_parameters.get('symbol', '') + query = tool_parameters.get("symbol", "") if not query: - return self.create_text_message('Please input symbol') - + return self.create_text_message("Please input symbol") + try: return self.create_text_message(self.run(ticker=query)) except (HTTPError, ReadTimeout): - return self.create_text_message('There is a internet connection problem. Please try again later.') - + return self.create_text_message("There is a internet connection problem. Please try again later.") + def run(self, ticker: str) -> str: - return str(Ticker(ticker).info) \ No newline at end of file + return str(Ticker(ticker).info) diff --git a/api/core/tools/provider/builtin/yahoo/yahoo.py b/api/core/tools/provider/builtin/yahoo/yahoo.py index 96dbc6c3d0..8d82084e76 100644 --- a/api/core/tools/provider/builtin/yahoo/yahoo.py +++ b/api/core/tools/provider/builtin/yahoo/yahoo.py @@ -11,11 +11,10 @@ class YahooFinanceProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "ticker": "MSFT", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/youtube/tools/videos.py b/api/core/tools/provider/builtin/youtube/tools/videos.py index 7a9b9fce4a..95dec2eac9 100644 --- a/api/core/tools/provider/builtin/youtube/tools/videos.py +++ b/api/core/tools/provider/builtin/youtube/tools/videos.py @@ -8,60 +8,67 @@ from core.tools.tool.builtin_tool import BuiltinTool class YoutubeVideosAnalyticsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - channel = tool_parameters.get('channel', '') + channel = tool_parameters.get("channel", "") if not channel: - return self.create_text_message('Please input symbol') - + return self.create_text_message("Please input symbol") + time_range = [None, None] - start_date = tool_parameters.get('start_date', '') + start_date = tool_parameters.get("start_date", "") if start_date: time_range[0] = start_date else: - time_range[0] = '1800-01-01' + time_range[0] = "1800-01-01" - end_date = tool_parameters.get('end_date', '') + end_date = tool_parameters.get("end_date", "") if end_date: time_range[1] = end_date else: - time_range[1] = datetime.now().strftime('%Y-%m-%d') + time_range[1] = datetime.now().strftime("%Y-%m-%d") - if 'google_api_key' not in self.runtime.credentials or not self.runtime.credentials['google_api_key']: - return self.create_text_message('Please input api key') + if "google_api_key" not in self.runtime.credentials or not self.runtime.credentials["google_api_key"]: + return self.create_text_message("Please input api key") - youtube = build('youtube', 'v3', developerKey=self.runtime.credentials['google_api_key']) + youtube = build("youtube", "v3", developerKey=self.runtime.credentials["google_api_key"]) # try to get channel id - search_results = youtube.search().list(q=channel, type='channel', order='relevance', part='id').execute() - channel_id = search_results['items'][0]['id']['channelId'] + search_results = youtube.search().list(q=channel, type="channel", order="relevance", part="id").execute() + channel_id = search_results["items"][0]["id"]["channelId"] start_date, end_date = time_range - start_date = datetime.strptime(start_date, '%Y-%m-%d').strftime('%Y-%m-%dT%H:%M:%SZ') - end_date = datetime.strptime(end_date, '%Y-%m-%d').strftime('%Y-%m-%dT%H:%M:%SZ') + start_date = datetime.strptime(start_date, "%Y-%m-%d").strftime("%Y-%m-%dT%H:%M:%SZ") + end_date = datetime.strptime(end_date, "%Y-%m-%d").strftime("%Y-%m-%dT%H:%M:%SZ") # get videos - time_range_videos = youtube.search().list( - part='snippet', channelId=channel_id, order='date', type='video', - publishedAfter=start_date, - publishedBefore=end_date - ).execute() + time_range_videos = ( + youtube.search() + .list( + part="snippet", + channelId=channel_id, + order="date", + type="video", + publishedAfter=start_date, + publishedBefore=end_date, + ) + .execute() + ) def extract_video_data(video_list): data = [] - for video in video_list['items']: - video_id = video['id']['videoId'] - video_info = youtube.videos().list(part='snippet,statistics', id=video_id).execute() - title = video_info['items'][0]['snippet']['title'] - views = video_info['items'][0]['statistics']['viewCount'] - data.append({'Title': title, 'Views': views}) + for video in video_list["items"]: + video_id = video["id"]["videoId"] + video_info = youtube.videos().list(part="snippet,statistics", id=video_id).execute() + title = video_info["items"][0]["snippet"]["title"] + views = video_info["items"][0]["statistics"]["viewCount"] + data.append({"Title": title, "Views": views}) return data summary = extract_video_data(time_range_videos) - + return self.create_text_message(str(summary)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/youtube/youtube.py b/api/core/tools/provider/builtin/youtube/youtube.py index 83a4fccb32..aad876491c 100644 --- a/api/core/tools/provider/builtin/youtube/youtube.py +++ b/api/core/tools/provider/builtin/youtube/youtube.py @@ -11,7 +11,7 @@ class YahooFinanceProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "channel": "TOKYO GIRLS COLLECTION", "start_date": "2020-01-01", @@ -20,4 +20,3 @@ class YahooFinanceProvider(BuiltinToolProviderController): ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index bcf41c90ed..6b64dd1b4e 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -22,34 +22,36 @@ class BuiltinToolProviderController(ToolProviderController): if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP: super().__init__(**data) return - + # load provider yaml - provider = self.__class__.__module__.split('.')[-1] - yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml') + provider = self.__class__.__module__.split(".")[-1] + yaml_path = path.join(path.dirname(path.realpath(__file__)), "builtin", provider, f"{provider}.yaml") try: provider_yaml = load_yaml_file(yaml_path, ignore_error=False) except Exception as e: - raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}: {e}') + raise ToolProviderNotFoundError(f"can not load provider yaml for {provider}: {e}") - if 'credentials_for_provider' in provider_yaml and provider_yaml['credentials_for_provider'] is not None: + if "credentials_for_provider" in provider_yaml and provider_yaml["credentials_for_provider"] is not None: # set credentials name - for credential_name in provider_yaml['credentials_for_provider']: - provider_yaml['credentials_for_provider'][credential_name]['name'] = credential_name + for credential_name in provider_yaml["credentials_for_provider"]: + provider_yaml["credentials_for_provider"][credential_name]["name"] = credential_name - super().__init__(**{ - 'identity': provider_yaml['identity'], - 'credentials_schema': provider_yaml.get('credentials_for_provider', None), - }) + super().__init__( + **{ + "identity": provider_yaml["identity"], + "credentials_schema": provider_yaml.get("credentials_for_provider", None), + } + ) def _get_builtin_tools(self) -> list[Tool]: """ - returns a list of tools that the provider can provide + returns a list of tools that the provider can provide - :return: list of tools + :return: list of tools """ if self.tools: return self.tools - + provider = self.identity.name tool_path = path.join(path.dirname(path.realpath(__file__)), "builtin", provider, "tools") # get all the yaml files in the tool path @@ -62,155 +64,161 @@ class BuiltinToolProviderController(ToolProviderController): # get tool class, import the module assistant_tool_class = load_single_subclass_from_source( - module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}', - script_path=path.join(path.dirname(path.realpath(__file__)), - 'builtin', provider, 'tools', f'{tool_name}.py'), - parent_type=BuiltinTool) + module_name=f"core.tools.provider.builtin.{provider}.tools.{tool_name}", + script_path=path.join( + path.dirname(path.realpath(__file__)), "builtin", provider, "tools", f"{tool_name}.py" + ), + parent_type=BuiltinTool, + ) tool["identity"]["provider"] = provider tools.append(assistant_tool_class(**tool)) self.tools = tools return tools - + def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]: """ - returns the credentials schema of the provider + returns the credentials schema of the provider - :return: the credentials schema + :return: the credentials schema """ if not self.credentials_schema: return {} - + return self.credentials_schema.copy() def get_tools(self) -> list[Tool]: """ - returns a list of tools that the provider can provide + returns a list of tools that the provider can provide - :return: list of tools + :return: list of tools """ return self._get_builtin_tools() - + def get_tool(self, tool_name: str) -> Tool: """ - returns the tool that the provider can provide + returns the tool that the provider can provide """ return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) def get_parameters(self, tool_name: str) -> list[ToolParameter]: """ - returns the parameters of the tool + returns the parameters of the tool - :param tool_name: the name of the tool, defined in `get_tools` - :return: list of parameters + :param tool_name: the name of the tool, defined in `get_tools` + :return: list of parameters """ tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) if tool is None: - raise ToolNotFoundError(f'tool {tool_name} not found') + raise ToolNotFoundError(f"tool {tool_name} not found") return tool.parameters @property def need_credentials(self) -> bool: """ - returns whether the provider needs credentials + returns whether the provider needs credentials - :return: whether the provider needs credentials + :return: whether the provider needs credentials """ - return self.credentials_schema is not None and \ - len(self.credentials_schema) != 0 + return self.credentials_schema is not None and len(self.credentials_schema) != 0 @property def provider_type(self) -> ToolProviderType: """ - returns the type of the provider + returns the type of the provider - :return: type of the provider + :return: type of the provider """ return ToolProviderType.BUILT_IN @property def tool_labels(self) -> list[str]: """ - returns the labels of the provider + returns the labels of the provider - :return: labels of the provider + :return: labels of the provider """ label_enums = self._get_tool_labels() return [default_tool_label_dict[label].name for label in label_enums] def _get_tool_labels(self) -> list[ToolLabelEnum]: """ - returns the labels of the provider + returns the labels of the provider """ return self.identity.tags or [] def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None: """ - validate the parameters of the tool and set the default value if needed + validate the parameters of the tool and set the default value if needed - :param tool_name: the name of the tool, defined in `get_tools` - :param tool_parameters: the parameters of the tool + :param tool_name: the name of the tool, defined in `get_tools` + :param tool_parameters: the parameters of the tool """ tool_parameters_schema = self.get_parameters(tool_name) - + tool_parameters_need_to_validate: dict[str, ToolParameter] = {} for parameter in tool_parameters_schema: tool_parameters_need_to_validate[parameter.name] = parameter for parameter in tool_parameters: if parameter not in tool_parameters_need_to_validate: - raise ToolParameterValidationError(f'parameter {parameter} not found in tool {tool_name}') - + raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}") + # check type parameter_schema = tool_parameters_need_to_validate[parameter] if parameter_schema.type == ToolParameter.ToolParameterType.STRING: if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f'parameter {parameter} should be string') - + raise ToolParameterValidationError(f"parameter {parameter} should be string") + elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER: if not isinstance(tool_parameters[parameter], int | float): - raise ToolParameterValidationError(f'parameter {parameter} should be number') - + raise ToolParameterValidationError(f"parameter {parameter} should be number") + if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min: - raise ToolParameterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}') - + raise ToolParameterValidationError( + f"parameter {parameter} should be greater than {parameter_schema.min}" + ) + if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max: - raise ToolParameterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}') - + raise ToolParameterValidationError( + f"parameter {parameter} should be less than {parameter_schema.max}" + ) + elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN: if not isinstance(tool_parameters[parameter], bool): - raise ToolParameterValidationError(f'parameter {parameter} should be boolean') - + raise ToolParameterValidationError(f"parameter {parameter} should be boolean") + elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT: if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f'parameter {parameter} should be string') - + raise ToolParameterValidationError(f"parameter {parameter} should be string") + options = parameter_schema.options if not isinstance(options, list): - raise ToolParameterValidationError(f'parameter {parameter} options should be list') - + raise ToolParameterValidationError(f"parameter {parameter} options should be list") + if tool_parameters[parameter] not in [x.value for x in options]: - raise ToolParameterValidationError(f'parameter {parameter} should be one of {options}') - + raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}") + tool_parameters_need_to_validate.pop(parameter) for parameter in tool_parameters_need_to_validate: parameter_schema = tool_parameters_need_to_validate[parameter] if parameter_schema.required: - raise ToolParameterValidationError(f'parameter {parameter} is required') - + raise ToolParameterValidationError(f"parameter {parameter} is required") + # the parameter is not set currently, set the default value if needed if parameter_schema.default is not None: - default_value = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default, - parameter_schema.type) + default_value = ToolParameterConverter.cast_parameter_by_type( + parameter_schema.default, parameter_schema.type + ) tool_parameters[parameter] = default_value - + def validate_credentials(self, credentials: dict[str, Any]) -> None: """ - validate the credentials of the provider + validate the credentials of the provider - :param tool_name: the name of the tool, defined in `get_tools` - :param credentials: the credentials of the tool + :param tool_name: the name of the tool, defined in `get_tools` + :param credentials: the credentials of the tool """ # validate credentials format self.validate_credentials_format(credentials) @@ -221,9 +229,9 @@ class BuiltinToolProviderController(ToolProviderController): @abstractmethod def _validate_credentials(self, credentials: dict[str, Any]) -> None: """ - validate the credentials of the provider + validate the credentials of the provider - :param tool_name: the name of the tool, defined in `get_tools` - :param credentials: the credentials of the tool + :param tool_name: the name of the tool, defined in `get_tools` + :param credentials: the credentials of the tool """ pass diff --git a/api/core/tools/provider/tool_provider.py b/api/core/tools/provider/tool_provider.py index ef1ace9c7c..f4008eedce 100644 --- a/api/core/tools/provider/tool_provider.py +++ b/api/core/tools/provider/tool_provider.py @@ -21,162 +21,174 @@ class ToolProviderController(BaseModel, ABC): def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]: """ - returns the credentials schema of the provider + returns the credentials schema of the provider - :return: the credentials schema + :return: the credentials schema """ return self.credentials_schema.copy() - + @abstractmethod def get_tools(self) -> list[Tool]: """ - returns a list of tools that the provider can provide + returns a list of tools that the provider can provide - :return: list of tools + :return: list of tools """ pass @abstractmethod def get_tool(self, tool_name: str) -> Tool: """ - returns a tool that the provider can provide + returns a tool that the provider can provide - :return: tool + :return: tool """ pass def get_parameters(self, tool_name: str) -> list[ToolParameter]: """ - returns the parameters of the tool + returns the parameters of the tool - :param tool_name: the name of the tool, defined in `get_tools` - :return: list of parameters + :param tool_name: the name of the tool, defined in `get_tools` + :return: list of parameters """ tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) if tool is None: - raise ToolNotFoundError(f'tool {tool_name} not found') + raise ToolNotFoundError(f"tool {tool_name} not found") return tool.parameters @property def provider_type(self) -> ToolProviderType: """ - returns the type of the provider + returns the type of the provider - :return: type of the provider + :return: type of the provider """ return ToolProviderType.BUILT_IN def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None: """ - validate the parameters of the tool and set the default value if needed + validate the parameters of the tool and set the default value if needed - :param tool_name: the name of the tool, defined in `get_tools` - :param tool_parameters: the parameters of the tool + :param tool_name: the name of the tool, defined in `get_tools` + :param tool_parameters: the parameters of the tool """ tool_parameters_schema = self.get_parameters(tool_name) - + tool_parameters_need_to_validate: dict[str, ToolParameter] = {} for parameter in tool_parameters_schema: tool_parameters_need_to_validate[parameter.name] = parameter for parameter in tool_parameters: if parameter not in tool_parameters_need_to_validate: - raise ToolParameterValidationError(f'parameter {parameter} not found in tool {tool_name}') - + raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}") + # check type parameter_schema = tool_parameters_need_to_validate[parameter] if parameter_schema.type == ToolParameter.ToolParameterType.STRING: if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f'parameter {parameter} should be string') - + raise ToolParameterValidationError(f"parameter {parameter} should be string") + elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER: if not isinstance(tool_parameters[parameter], int | float): - raise ToolParameterValidationError(f'parameter {parameter} should be number') - + raise ToolParameterValidationError(f"parameter {parameter} should be number") + if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min: - raise ToolParameterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}') - + raise ToolParameterValidationError( + f"parameter {parameter} should be greater than {parameter_schema.min}" + ) + if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max: - raise ToolParameterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}') - + raise ToolParameterValidationError( + f"parameter {parameter} should be less than {parameter_schema.max}" + ) + elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN: if not isinstance(tool_parameters[parameter], bool): - raise ToolParameterValidationError(f'parameter {parameter} should be boolean') - + raise ToolParameterValidationError(f"parameter {parameter} should be boolean") + elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT: if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f'parameter {parameter} should be string') - + raise ToolParameterValidationError(f"parameter {parameter} should be string") + options = parameter_schema.options if not isinstance(options, list): - raise ToolParameterValidationError(f'parameter {parameter} options should be list') - + raise ToolParameterValidationError(f"parameter {parameter} options should be list") + if tool_parameters[parameter] not in [x.value for x in options]: - raise ToolParameterValidationError(f'parameter {parameter} should be one of {options}') - + raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}") + tool_parameters_need_to_validate.pop(parameter) for parameter in tool_parameters_need_to_validate: parameter_schema = tool_parameters_need_to_validate[parameter] if parameter_schema.required: - raise ToolParameterValidationError(f'parameter {parameter} is required') - + raise ToolParameterValidationError(f"parameter {parameter} is required") + # the parameter is not set currently, set the default value if needed if parameter_schema.default is not None: - tool_parameters[parameter] = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default, - parameter_schema.type) + tool_parameters[parameter] = ToolParameterConverter.cast_parameter_by_type( + parameter_schema.default, parameter_schema.type + ) def validate_credentials_format(self, credentials: dict[str, Any]) -> None: """ - validate the format of the credentials of the provider and set the default value if needed + validate the format of the credentials of the provider and set the default value if needed - :param credentials: the credentials of the tool + :param credentials: the credentials of the tool """ credentials_schema = self.credentials_schema if credentials_schema is None: return - + credentials_need_to_validate: dict[str, ToolProviderCredentials] = {} for credential_name in credentials_schema: credentials_need_to_validate[credential_name] = credentials_schema[credential_name] for credential_name in credentials: if credential_name not in credentials_need_to_validate: - raise ToolProviderCredentialValidationError(f'credential {credential_name} not found in provider {self.identity.name}') - + raise ToolProviderCredentialValidationError( + f"credential {credential_name} not found in provider {self.identity.name}" + ) + # check type credential_schema = credentials_need_to_validate[credential_name] - if credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \ - credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT: + if ( + credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT + or credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT + ): if not isinstance(credentials[credential_name], str): - raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string') - + raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") + elif credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT: if not isinstance(credentials[credential_name], str): - raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string') - + raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") + options = credential_schema.options if not isinstance(options, list): - raise ToolProviderCredentialValidationError(f'credential {credential_name} options should be list') - + raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list") + if credentials[credential_name] not in [x.value for x in options]: - raise ToolProviderCredentialValidationError(f'credential {credential_name} should be one of {options}') - + raise ToolProviderCredentialValidationError( + f"credential {credential_name} should be one of {options}" + ) + credentials_need_to_validate.pop(credential_name) for credential_name in credentials_need_to_validate: credential_schema = credentials_need_to_validate[credential_name] if credential_schema.required: - raise ToolProviderCredentialValidationError(f'credential {credential_name} is required') - + raise ToolProviderCredentialValidationError(f"credential {credential_name} is required") + # the credential is not set currently, set the default value if needed if credential_schema.default is not None: default_value = credential_schema.default # parse default value into the correct type - if credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \ - credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT or \ - credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT: + if ( + credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT + or credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT + or credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT + ): default_value = str(default_value) credentials[credential_name] = default_value - \ No newline at end of file diff --git a/api/core/tools/provider/workflow_tool_provider.py b/api/core/tools/provider/workflow_tool_provider.py index f14abac767..25eaf6a66a 100644 --- a/api/core/tools/provider/workflow_tool_provider.py +++ b/api/core/tools/provider/workflow_tool_provider.py @@ -30,29 +30,25 @@ class WorkflowToolProviderController(ToolProviderController): provider_id: str @classmethod - def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderController': + def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController": app = db_provider.app if not app: - raise ValueError('app not found') + raise ValueError("app not found") - controller = WorkflowToolProviderController(**{ - 'identity': { - 'author': db_provider.user.name if db_provider.user_id and db_provider.user else '', - 'name': db_provider.label, - 'label': { - 'en_US': db_provider.label, - 'zh_Hans': db_provider.label + controller = WorkflowToolProviderController( + **{ + "identity": { + "author": db_provider.user.name if db_provider.user_id and db_provider.user else "", + "name": db_provider.label, + "label": {"en_US": db_provider.label, "zh_Hans": db_provider.label}, + "description": {"en_US": db_provider.description, "zh_Hans": db_provider.description}, + "icon": db_provider.icon, }, - 'description': { - 'en_US': db_provider.description, - 'zh_Hans': db_provider.description - }, - 'icon': db_provider.icon, - }, - 'credentials_schema': {}, - 'provider_id': db_provider.id or '', - }) + "credentials_schema": {}, + "provider_id": db_provider.id or "", + } + ) # init tools @@ -66,25 +62,23 @@ class WorkflowToolProviderController(ToolProviderController): def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool: """ - get db provider tool - :param db_provider: the db provider - :param app: the app - :return: the tool + get db provider tool + :param db_provider: the db provider + :param app: the app + :return: the tool """ - workflow: Workflow = db.session.query(Workflow).filter( - Workflow.app_id == db_provider.app_id, - Workflow.version == db_provider.version - ).first() + workflow: Workflow = ( + db.session.query(Workflow) + .filter(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version) + .first() + ) if not workflow: - raise ValueError('workflow not found') + raise ValueError("workflow not found") # fetch start node graph: dict = workflow.graph_dict features_dict: dict = workflow.features_dict - features = WorkflowAppConfigManager.convert_features( - config_dict=features_dict, - app_mode=AppMode.WORKFLOW - ) + features = WorkflowAppConfigManager.convert_features(config_dict=features_dict, app_mode=AppMode.WORKFLOW) parameters = db_provider.parameter_configurations variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph) @@ -101,51 +95,34 @@ class WorkflowToolProviderController(ToolProviderController): parameter_type = None options = None if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING: - raise ValueError(f'unsupported variable type {variable.type}') + raise ValueError(f"unsupported variable type {variable.type}") parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type] if variable.type == VariableEntityType.SELECT and variable.options: options = [ - ToolParameterOption( - value=option, - label=I18nObject( - en_US=option, - zh_Hans=option - ) - ) for option in variable.options + ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) + for option in variable.options ] workflow_tool_parameters.append( ToolParameter( name=parameter.name, - label=I18nObject( - en_US=variable.label, - zh_Hans=variable.label - ), - human_description=I18nObject( - en_US=parameter.description, - zh_Hans=parameter.description - ), + label=I18nObject(en_US=variable.label, zh_Hans=variable.label), + human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description), type=parameter_type, form=parameter.form, llm_description=parameter.description, required=variable.required, options=options, - default=variable.default + default=variable.default, ) ) elif features.file_upload: workflow_tool_parameters.append( ToolParameter( name=parameter.name, - label=I18nObject( - en_US=parameter.name, - zh_Hans=parameter.name - ), - human_description=I18nObject( - en_US=parameter.description, - zh_Hans=parameter.description - ), + label=I18nObject(en_US=parameter.name, zh_Hans=parameter.name), + human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description), type=ToolParameter.ToolParameterType.FILE, llm_description=parameter.description, required=False, @@ -153,53 +130,51 @@ class WorkflowToolProviderController(ToolProviderController): ) ) else: - raise ValueError('variable not found') + raise ValueError("variable not found") return WorkflowTool( identity=ToolIdentity( - author=user.name if user else '', + author=user.name if user else "", name=db_provider.name, - label=I18nObject( - en_US=db_provider.label, - zh_Hans=db_provider.label - ), + label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label), provider=self.provider_id, icon=db_provider.icon, ), description=ToolDescription( - human=I18nObject( - en_US=db_provider.description, - zh_Hans=db_provider.description - ), + human=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description), llm=db_provider.description, ), parameters=workflow_tool_parameters, is_team_authorization=True, workflow_app_id=app.id, workflow_entities={ - 'app': app, - 'workflow': workflow, + "app": app, + "workflow": workflow, }, version=db_provider.version, workflow_call_depth=0, - label=db_provider.label + label=db_provider.label, ) def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]: """ - fetch tools from database + fetch tools from database - :param user_id: the user id - :param tenant_id: the tenant id - :return: the tools + :param user_id: the user id + :param tenant_id: the tenant id + :return: the tools """ if self.tools is not None: return self.tools - db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.app_id == self.provider_id, - ).first() + db_providers: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.app_id == self.provider_id, + ) + .first() + ) if not db_providers: return [] @@ -210,10 +185,10 @@ class WorkflowToolProviderController(ToolProviderController): def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: """ - get tool by name + get tool by name - :param tool_name: the name of the tool - :return: the tool + :param tool_name: the name of the tool + :return: the tool """ if self.tools is None: return None diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index 38f10032e2..bf336b48f3 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -12,8 +12,8 @@ from core.tools.errors import ToolInvokeError, ToolParameterValidationError, Too from core.tools.tool.tool import Tool API_TOOL_DEFAULT_TIMEOUT = ( - int(getenv('API_TOOL_DEFAULT_CONNECT_TIMEOUT', '10')), - int(getenv('API_TOOL_DEFAULT_READ_TIMEOUT', '60')) + int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")), + int(getenv("API_TOOL_DEFAULT_READ_TIMEOUT", "60")), ) @@ -24,31 +24,32 @@ class ApiTool(Tool): Api tool """ - def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool': + def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": """ - fork a new tool with meta data + fork a new tool with meta data - :param meta: the meta data of a tool call processing, tenant_id is required - :return: the new tool + :param meta: the meta data of a tool call processing, tenant_id is required + :return: the new tool """ return self.__class__( identity=self.identity.model_copy() if self.identity else None, parameters=self.parameters.copy() if self.parameters else None, description=self.description.model_copy() if self.description else None, api_bundle=self.api_bundle.model_copy() if self.api_bundle else None, - runtime=Tool.Runtime(**runtime) + runtime=Tool.Runtime(**runtime), ) - def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], - format_only: bool = False) -> str: + def validate_credentials( + self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False + ) -> str: """ - validate the credentials for Api tool + validate the credentials for Api tool """ - # assemble validate request and request parameters + # assemble validate request and request parameters headers = self.assembling_request(parameters) if format_only: - return '' + return "" response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters) # validate response @@ -61,30 +62,30 @@ class ApiTool(Tool): headers = {} credentials = self.runtime.credentials or {} - if 'auth_type' not in credentials: - raise ToolProviderCredentialValidationError('Missing auth_type') + if "auth_type" not in credentials: + raise ToolProviderCredentialValidationError("Missing auth_type") - if credentials['auth_type'] == 'api_key': - api_key_header = 'api_key' + if credentials["auth_type"] == "api_key": + api_key_header = "api_key" - if 'api_key_header' in credentials: - api_key_header = credentials['api_key_header'] + if "api_key_header" in credentials: + api_key_header = credentials["api_key_header"] - if 'api_key_value' not in credentials: - raise ToolProviderCredentialValidationError('Missing api_key_value') - elif not isinstance(credentials['api_key_value'], str): - raise ToolProviderCredentialValidationError('api_key_value must be a string') + if "api_key_value" not in credentials: + raise ToolProviderCredentialValidationError("Missing api_key_value") + elif not isinstance(credentials["api_key_value"], str): + raise ToolProviderCredentialValidationError("api_key_value must be a string") - if 'api_key_header_prefix' in credentials: - api_key_header_prefix = credentials['api_key_header_prefix'] - if api_key_header_prefix == 'basic' and credentials['api_key_value']: - credentials['api_key_value'] = f'Basic {credentials["api_key_value"]}' - elif api_key_header_prefix == 'bearer' and credentials['api_key_value']: - credentials['api_key_value'] = f'Bearer {credentials["api_key_value"]}' - elif api_key_header_prefix == 'custom': + if "api_key_header_prefix" in credentials: + api_key_header_prefix = credentials["api_key_header_prefix"] + if api_key_header_prefix == "basic" and credentials["api_key_value"]: + credentials["api_key_value"] = f'Basic {credentials["api_key_value"]}' + elif api_key_header_prefix == "bearer" and credentials["api_key_value"]: + credentials["api_key_value"] = f'Bearer {credentials["api_key_value"]}' + elif api_key_header_prefix == "custom": pass - headers[api_key_header] = credentials['api_key_value'] + headers[api_key_header] = credentials["api_key_value"] needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required] for parameter in needed_parameters: @@ -98,13 +99,13 @@ class ApiTool(Tool): def validate_and_parse_response(self, response: httpx.Response) -> str: """ - validate the response + validate the response """ if isinstance(response, httpx.Response): if response.status_code >= 400: raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}") if not response.content: - return 'Empty response from the tool, please check your parameters and try again.' + return "Empty response from the tool, please check your parameters and try again." try: response = response.json() try: @@ -114,21 +115,22 @@ class ApiTool(Tool): except Exception as e: return response.text else: - raise ValueError(f'Invalid response type {type(response)}') + raise ValueError(f"Invalid response type {type(response)}") @staticmethod def get_parameter_value(parameter, parameters): - if parameter['name'] in parameters: - return parameters[parameter['name']] - elif parameter.get('required', False): + if parameter["name"] in parameters: + return parameters[parameter["name"]] + elif parameter.get("required", False): raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}") else: - return (parameter.get('schema', {}) or {}).get('default', '') + return (parameter.get("schema", {}) or {}).get("default", "") - def do_http_request(self, url: str, method: str, headers: dict[str, Any], - parameters: dict[str, Any]) -> httpx.Response: + def do_http_request( + self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any] + ) -> httpx.Response: """ - do http request depending on api bundle + do http request depending on api bundle """ method = method.lower() @@ -138,29 +140,30 @@ class ApiTool(Tool): cookies = {} # check parameters - for parameter in self.api_bundle.openapi.get('parameters', []): + for parameter in self.api_bundle.openapi.get("parameters", []): value = self.get_parameter_value(parameter, parameters) - if parameter['in'] == 'path': - path_params[parameter['name']] = value + if parameter["in"] == "path": + path_params[parameter["name"]] = value - elif parameter['in'] == 'query': - if value !='': params[parameter['name']] = value + elif parameter["in"] == "query": + if value != "": + params[parameter["name"]] = value - elif parameter['in'] == 'cookie': - cookies[parameter['name']] = value + elif parameter["in"] == "cookie": + cookies[parameter["name"]] = value - elif parameter['in'] == 'header': - headers[parameter['name']] = value + elif parameter["in"] == "header": + headers[parameter["name"]] = value # check if there is a request body and handle it - if 'requestBody' in self.api_bundle.openapi and self.api_bundle.openapi['requestBody'] is not None: + if "requestBody" in self.api_bundle.openapi and self.api_bundle.openapi["requestBody"] is not None: # handle json request body - if 'content' in self.api_bundle.openapi['requestBody']: - for content_type in self.api_bundle.openapi['requestBody']['content']: - headers['Content-Type'] = content_type - body_schema = self.api_bundle.openapi['requestBody']['content'][content_type]['schema'] - required = body_schema.get('required', []) - properties = body_schema.get('properties', {}) + if "content" in self.api_bundle.openapi["requestBody"]: + for content_type in self.api_bundle.openapi["requestBody"]["content"]: + headers["Content-Type"] = content_type + body_schema = self.api_bundle.openapi["requestBody"]["content"][content_type]["schema"] + required = body_schema.get("required", []) + properties = body_schema.get("properties", {}) for name, property in properties.items(): if name in parameters: # convert type @@ -169,63 +172,71 @@ class ApiTool(Tool): raise ToolParameterValidationError( f"Missing required parameter {name} in operation {self.api_bundle.operation_id}" ) - elif 'default' in property: - body[name] = property['default'] + elif "default" in property: + body[name] = property["default"] else: body[name] = None break # replace path parameters for name, value in path_params.items(): - url = url.replace(f'{{{name}}}', f'{value}') + url = url.replace(f"{{{name}}}", f"{value}") # parse http body data if needed, for GET/HEAD/OPTIONS/TRACE, the body is ignored - if 'Content-Type' in headers: - if headers['Content-Type'] == 'application/json': + if "Content-Type" in headers: + if headers["Content-Type"] == "application/json": body = json.dumps(body) - elif headers['Content-Type'] == 'application/x-www-form-urlencoded': + elif headers["Content-Type"] == "application/x-www-form-urlencoded": body = urlencode(body) else: body = body - if method in ('get', 'head', 'post', 'put', 'delete', 'patch'): - response = getattr(ssrf_proxy, method)(url, params=params, headers=headers, cookies=cookies, data=body, - timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True) + if method in ("get", "head", "post", "put", "delete", "patch"): + response = getattr(ssrf_proxy, method)( + url, + params=params, + headers=headers, + cookies=cookies, + data=body, + timeout=API_TOOL_DEFAULT_TIMEOUT, + follow_redirects=True, + ) return response else: - raise ValueError(f'Invalid http method {self.method}') + raise ValueError(f"Invalid http method {self.method}") - def _convert_body_property_any_of(self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], - max_recursive=10) -> Any: + def _convert_body_property_any_of( + self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10 + ) -> Any: if max_recursive <= 0: raise Exception("Max recursion depth reached") for option in any_of or []: try: - if 'type' in option: + if "type" in option: # Attempt to convert the value based on the type. - if option['type'] == 'integer' or option['type'] == 'int': + if option["type"] == "integer" or option["type"] == "int": return int(value) - elif option['type'] == 'number': - if '.' in str(value): + elif option["type"] == "number": + if "." in str(value): return float(value) else: return int(value) - elif option['type'] == 'string': + elif option["type"] == "string": return str(value) - elif option['type'] == 'boolean': - if str(value).lower() in ['true', '1']: + elif option["type"] == "boolean": + if str(value).lower() in ["true", "1"]: return True - elif str(value).lower() in ['false', '0']: + elif str(value).lower() in ["false", "0"]: return False else: continue # Not a boolean, try next option - elif option['type'] == 'null' and not value: + elif option["type"] == "null" and not value: return None else: continue # Unsupported type, try next option - elif 'anyOf' in option and isinstance(option['anyOf'], list): + elif "anyOf" in option and isinstance(option["anyOf"], list): # Recursive call to handle nested anyOf - return self._convert_body_property_any_of(property, value, option['anyOf'], max_recursive - 1) + return self._convert_body_property_any_of(property, value, option["anyOf"], max_recursive - 1) except ValueError: continue # Conversion failed, try next option # If no option succeeded, you might want to return the value as is or raise an error @@ -233,23 +244,23 @@ class ApiTool(Tool): def _convert_body_property_type(self, property: dict[str, Any], value: Any) -> Any: try: - if 'type' in property: - if property['type'] == 'integer' or property['type'] == 'int': + if "type" in property: + if property["type"] == "integer" or property["type"] == "int": return int(value) - elif property['type'] == 'number': + elif property["type"] == "number": # check if it is a float - if '.' in str(value): + if "." in str(value): return float(value) else: return int(value) - elif property['type'] == 'string': + elif property["type"] == "string": return str(value) - elif property['type'] == 'boolean': + elif property["type"] == "boolean": return bool(value) - elif property['type'] == 'null': + elif property["type"] == "null": if value is None: return None - elif property['type'] == 'object' or property['type'] == 'array': + elif property["type"] == "object" or property["type"] == "array": if isinstance(value, str): try: # an array str like '[1,2]' also can convert to list [1,2] through json.loads @@ -264,8 +275,8 @@ class ApiTool(Tool): return value else: raise ValueError(f"Invalid type {property['type']} for property {property}") - elif 'anyOf' in property and isinstance(property['anyOf'], list): - return self._convert_body_property_any_of(property, value, property['anyOf']) + elif "anyOf" in property and isinstance(property["anyOf"], list): + return self._convert_body_property_any_of(property, value, property["anyOf"]) except ValueError as e: return value diff --git a/api/core/tools/tool/builtin_tool.py b/api/core/tools/tool/builtin_tool.py index ad7a88838b..8edaf7c0e6 100644 --- a/api/core/tools/tool/builtin_tool.py +++ b/api/core/tools/tool/builtin_tool.py @@ -1,4 +1,3 @@ - from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage from core.tools.entities.tool_entities import ToolProviderType @@ -16,40 +15,38 @@ Please summarize the text you got. class BuiltinTool(Tool): """ - Builtin tool + Builtin tool - :param meta: the meta data of a tool call processing + :param meta: the meta data of a tool call processing """ - def invoke_model( - self, user_id: str, prompt_messages: list[PromptMessage], stop: list[str] - ) -> LLMResult: + def invoke_model(self, user_id: str, prompt_messages: list[PromptMessage], stop: list[str]) -> LLMResult: """ - invoke model + invoke model - :param model_config: the model config - :param prompt_messages: the prompt messages - :param stop: the stop words - :return: the model result + :param model_config: the model config + :param prompt_messages: the prompt messages + :param stop: the stop words + :return: the model result """ # invoke model return ModelInvocationUtils.invoke( user_id=user_id, tenant_id=self.runtime.tenant_id, - tool_type='builtin', + tool_type="builtin", tool_name=self.identity.name, prompt_messages=prompt_messages, ) - + def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.BUILT_IN - + def get_max_tokens(self) -> int: """ - get max tokens + get max tokens - :param model_config: the model config - :return: the max tokens + :param model_config: the model config + :return: the max tokens """ return ModelInvocationUtils.get_max_llm_context_tokens( tenant_id=self.runtime.tenant_id, @@ -57,39 +54,34 @@ class BuiltinTool(Tool): def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int: """ - get prompt tokens + get prompt tokens - :param prompt_messages: the prompt messages - :return: the tokens + :param prompt_messages: the prompt messages + :return: the tokens """ - return ModelInvocationUtils.calculate_tokens( - tenant_id=self.runtime.tenant_id, - prompt_messages=prompt_messages - ) + return ModelInvocationUtils.calculate_tokens(tenant_id=self.runtime.tenant_id, prompt_messages=prompt_messages) def summary(self, user_id: str, content: str) -> str: max_tokens = self.get_max_tokens() - if self.get_prompt_tokens(prompt_messages=[ - UserPromptMessage(content=content) - ]) < max_tokens * 0.6: + if self.get_prompt_tokens(prompt_messages=[UserPromptMessage(content=content)]) < max_tokens * 0.6: return content - + def get_prompt_tokens(content: str) -> int: - return self.get_prompt_tokens(prompt_messages=[ - SystemPromptMessage(content=_SUMMARY_PROMPT), - UserPromptMessage(content=content) - ]) - + return self.get_prompt_tokens( + prompt_messages=[SystemPromptMessage(content=_SUMMARY_PROMPT), UserPromptMessage(content=content)] + ) + def summarize(content: str) -> str: - summary = self.invoke_model(user_id=user_id, prompt_messages=[ - SystemPromptMessage(content=_SUMMARY_PROMPT), - UserPromptMessage(content=content) - ], stop=[]) + summary = self.invoke_model( + user_id=user_id, + prompt_messages=[SystemPromptMessage(content=_SUMMARY_PROMPT), UserPromptMessage(content=content)], + stop=[], + ) return summary.message.content - lines = content.split('\n') + lines = content.split("\n") new_lines = [] # split long line into multiple lines for i in range(len(lines)): @@ -100,8 +92,8 @@ class BuiltinTool(Tool): new_lines.append(line) elif get_prompt_tokens(line) > max_tokens * 0.7: while get_prompt_tokens(line) > max_tokens * 0.7: - new_lines.append(line[:int(max_tokens * 0.5)]) - line = line[int(max_tokens * 0.5):] + new_lines.append(line[: int(max_tokens * 0.5)]) + line = line[int(max_tokens * 0.5) :] new_lines.append(line) else: new_lines.append(line) @@ -125,17 +117,15 @@ class BuiltinTool(Tool): summary = summarize(message) summaries.append(summary) - result = '\n'.join(summaries) + result = "\n".join(summaries) - if self.get_prompt_tokens(prompt_messages=[ - UserPromptMessage(content=result) - ]) > max_tokens * 0.7: + if self.get_prompt_tokens(prompt_messages=[UserPromptMessage(content=result)]) > max_tokens * 0.7: return self.summary(user_id=user_id, content=result) - + return result - + def get_url(self, url: str, user_agent: str = None) -> str: """ - get url + get url """ - return get_url(url, user_agent=user_agent) \ No newline at end of file + return get_url(url, user_agent=user_agent) diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index 7cb7c033bb..e76af6fe70 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -8,20 +8,17 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.retrieval_service import RetrievalService from core.rag.rerank.rerank_model import RerankModelRunner -from core.rag.retrieval.retrival_methods import RetrievalMethod +from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } @@ -31,6 +28,7 @@ class DatasetMultiRetrieverToolInput(BaseModel): class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): """Tool for querying multi dataset.""" + name: str = "dataset_" args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput description: str = "dataset multi retriever and rerank. " @@ -38,27 +36,26 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): reranking_provider_name: str reranking_model_name: str - @classmethod def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs): return cls( - name=f"dataset_{tenant_id.replace('-', '_')}", - tenant_id=tenant_id, - dataset_ids=dataset_ids, - **kwargs + name=f"dataset_{tenant_id.replace('-', '_')}", tenant_id=tenant_id, dataset_ids=dataset_ids, **kwargs ) def _run(self, query: str) -> str: threads = [] all_documents = [] for dataset_id in self.dataset_ids: - retrieval_thread = threading.Thread(target=self._retriever, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'all_documents': all_documents, - 'hit_callbacks': self.hit_callbacks - }) + retrieval_thread = threading.Thread( + target=self._retriever, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset_id, + "query": query, + "all_documents": all_documents, + "hit_callbacks": self.hit_callbacks, + }, + ) threads.append(retrieval_thread) retrieval_thread.start() for thread in threads: @@ -69,7 +66,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): tenant_id=self.tenant_id, provider=self.reranking_provider_name, model_type=ModelType.RERANK, - model=self.reranking_model_name + model=self.reranking_model_name, ) rerank_runner = RerankModelRunner(rerank_model_instance) @@ -80,62 +77,61 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): document_score_list = {} for item in all_documents: - if item.metadata.get('score'): - document_score_list[item.metadata['doc_id']] = item.metadata['score'] + if item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_context_list = [] - index_node_ids = [document.metadata['doc_id'] for document in all_documents] + index_node_ids = [document.metadata["doc_id"] for document in all_documents] segments = DocumentSegment.query.filter( DocumentSegment.dataset_id.in_(self.dataset_ids), DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == 'completed', + DocumentSegment.status == "completed", DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids) + DocumentSegment.index_node_id.in_(index_node_ids), ).all() if segments: index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted(segments, - key=lambda segment: index_node_id_to_position.get(segment.index_node_id, - float('inf'))) + sorted_segments = sorted( + segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) + ) for segment in sorted_segments: if segment.answer: - document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}') + document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}") else: document_context_list.append(segment.get_sign_content()) if self.return_resource: context_list = [] resource_number = 1 for segment in sorted_segments: - dataset = Dataset.query.filter_by( - id=segment.dataset_id + dataset = Dataset.query.filter_by(id=segment.dataset_id).first() + document = Document.query.filter( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, ).first() - document = Document.query.filter(Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ).first() if dataset and document: source = { - 'position': resource_number, - 'dataset_id': dataset.id, - 'dataset_name': dataset.name, - 'document_id': document.id, - 'document_name': document.name, - 'data_source_type': document.data_source_type, - 'segment_id': segment.id, - 'retriever_from': self.retriever_from, - 'score': document_score_list.get(segment.index_node_id, None) + "position": resource_number, + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, + "document_name": document.name, + "data_source_type": document.data_source_type, + "segment_id": segment.id, + "retriever_from": self.retriever_from, + "score": document_score_list.get(segment.index_node_id, None), } - if self.retriever_from == 'dev': - source['hit_count'] = segment.hit_count - source['word_count'] = segment.word_count - source['segment_position'] = segment.position - source['index_node_hash'] = segment.index_node_hash + if self.retriever_from == "dev": + source["hit_count"] = segment.hit_count + source["word_count"] = segment.word_count + source["segment_position"] = segment.position + source["index_node_hash"] = segment.index_node_hash if segment.answer: - source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" else: - source['content'] = segment.content + source["content"] = segment.content context_list.append(source) resource_number += 1 @@ -144,13 +140,18 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): return str("\n".join(document_context_list)) - def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: list, - hit_callbacks: list[DatasetIndexToolCallbackHandler]): + def _retriever( + self, + flask_app: Flask, + dataset_id: str, + query: str, + all_documents: list, + hit_callbacks: list[DatasetIndexToolCallbackHandler], + ): with flask_app.app_context(): - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == self.tenant_id, - Dataset.id == dataset_id - ).first() + dataset = ( + db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first() + ) if not dataset: return [] @@ -163,27 +164,29 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): if dataset.indexing_technique == "economy": # use keyword table query - documents = RetrievalService.retrieve(retrival_method='keyword_search', - dataset_id=dataset.id, - query=query, - top_k=self.top_k - ) + documents = RetrievalService.retrieve( + retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k + ) if documents: all_documents.extend(documents) else: if self.top_k > 0: # retrieval source - documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], - dataset_id=dataset.id, - query=query, - top_k=self.top_k, - score_threshold=retrieval_model.get('score_threshold', .0) - if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model.get('reranking_model', None) - if retrieval_model['reranking_enable'] else None, - reranking_mode=retrieval_model.get('reranking_mode') - if retrieval_model.get('reranking_mode') else 'reranking_model', - weights=retrieval_model.get('weights', None), - ) + documents = RetrievalService.retrieve( + retrieval_method=retrieval_model["search_method"], + dataset_id=dataset.id, + query=query, + top_k=self.top_k, + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else None, + reranking_model=retrieval_model.get("reranking_model", None) + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") + if retrieval_model.get("reranking_mode") + else "reranking_model", + weights=retrieval_model.get("weights", None), + ) - all_documents.extend(documents) \ No newline at end of file + all_documents.extend(documents) diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py index 62e97a0230..dad8c77357 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py @@ -9,6 +9,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa class DatasetRetrieverBaseTool(BaseModel, ABC): """Tool for querying a Dataset.""" + name: str = "dataset" description: str = "use this to retrieve a dataset. " tenant_id: str diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index a7e70af628..f61458278e 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -1,22 +1,18 @@ - from pydantic import BaseModel, Field from core.rag.datasource.retrieval_service import RetrievalService -from core.rag.retrieval.retrival_methods import RetrievalMethod +from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'reranking_mode': 'reranking_model', - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "reranking_mode": "reranking_model", + "top_k": 2, + "score_threshold_enabled": False, } @@ -26,35 +22,34 @@ class DatasetRetrieverToolInput(BaseModel): class DatasetRetrieverTool(DatasetRetrieverBaseTool): """Tool for querying a Dataset.""" + name: str = "dataset" args_schema: type[BaseModel] = DatasetRetrieverToolInput description: str = "use this to retrieve a dataset. " dataset_id: str - @classmethod def from_dataset(cls, dataset: Dataset, **kwargs): description = dataset.description if not description: - description = 'useful for when you want to answer queries about the ' + dataset.name + description = "useful for when you want to answer queries about the " + dataset.name - description = description.replace('\n', '').replace('\r', '') + description = description.replace("\n", "").replace("\r", "") return cls( name=f"dataset_{dataset.id.replace('-', '_')}", tenant_id=dataset.tenant_id, dataset_id=dataset.id, description=description, - **kwargs + **kwargs, ) def _run(self, query: str) -> str: - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == self.tenant_id, - Dataset.id == self.dataset_id - ).first() + dataset = ( + db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first() + ) if not dataset: - return '' + return "" for hit_callback in self.hit_callbacks: hit_callback.on_query(query, dataset.id) @@ -63,27 +58,29 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model if dataset.indexing_technique == "economy": # use keyword table query - documents = RetrievalService.retrieve(retrival_method='keyword_search', - dataset_id=dataset.id, - query=query, - top_k=self.top_k - ) + documents = RetrievalService.retrieve( + retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k + ) return str("\n".join([document.page_content for document in documents])) else: if self.top_k > 0: # retrieval source - documents = RetrievalService.retrieve(retrival_method=retrieval_model.get('search_method', 'semantic_search'), - dataset_id=dataset.id, - query=query, - top_k=self.top_k, - score_threshold=retrieval_model.get('score_threshold', .0) - if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model.get('reranking_model', None) - if retrieval_model['reranking_enable'] else None, - reranking_mode=retrieval_model.get('reranking_mode') - if retrieval_model.get('reranking_mode') else 'reranking_model', - weights=retrieval_model.get('weights', None), - ) + documents = RetrievalService.retrieve( + retrieval_method=retrieval_model.get("search_method", "semantic_search"), + dataset_id=dataset.id, + query=query, + top_k=self.top_k, + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else None, + reranking_model=retrieval_model.get("reranking_model", None) + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") + if retrieval_model.get("reranking_mode") + else "reranking_model", + weights=retrieval_model.get("weights", None), + ) else: documents = [] @@ -92,25 +89,26 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): document_score_list = {} if dataset.indexing_technique != "economy": for item in documents: - if item.metadata.get('score'): - document_score_list[item.metadata['doc_id']] = item.metadata['score'] + if item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_context_list = [] - index_node_ids = [document.metadata['doc_id'] for document in documents] - segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id, - DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == 'completed', - DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids) - ).all() + index_node_ids = [document.metadata["doc_id"] for document in documents] + segments = DocumentSegment.query.filter( + DocumentSegment.dataset_id == self.dataset_id, + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids), + ).all() if segments: index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted(segments, - key=lambda segment: index_node_id_to_position.get(segment.index_node_id, - float('inf'))) + sorted_segments = sorted( + segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) + ) for segment in sorted_segments: if segment.answer: - document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}') + document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}") else: document_context_list.append(segment.get_sign_content()) if self.return_resource: @@ -118,36 +116,36 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): resource_number = 1 for segment in sorted_segments: context = {} - document = Document.query.filter(Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ).first() + document = Document.query.filter( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ).first() if dataset and document: source = { - 'position': resource_number, - 'dataset_id': dataset.id, - 'dataset_name': dataset.name, - 'document_id': document.id, - 'document_name': document.name, - 'data_source_type': document.data_source_type, - 'segment_id': segment.id, - 'retriever_from': self.retriever_from, - 'score': document_score_list.get(segment.index_node_id, None) - + "position": resource_number, + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, + "document_name": document.name, + "data_source_type": document.data_source_type, + "segment_id": segment.id, + "retriever_from": self.retriever_from, + "score": document_score_list.get(segment.index_node_id, None), } - if self.retriever_from == 'dev': - source['hit_count'] = segment.hit_count - source['word_count'] = segment.word_count - source['segment_position'] = segment.position - source['index_node_hash'] = segment.index_node_hash + if self.retriever_from == "dev": + source["hit_count"] = segment.hit_count + source["word_count"] = segment.word_count + source["segment_position"] = segment.position + source["index_node_hash"] = segment.index_node_hash if segment.answer: - source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" else: - source['content'] = segment.content + source["content"] = segment.content context_list.append(source) resource_number += 1 for hit_callback in self.hit_callbacks: hit_callback.return_retriever_resource_info(context_list) - return str("\n".join(document_context_list)) \ No newline at end of file + return str("\n".join(document_context_list)) diff --git a/api/core/tools/tool/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever_tool.py index 1170e1b7a5..3c9295c493 100644 --- a/api/core/tools/tool/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever_tool.py @@ -17,16 +17,17 @@ from core.tools.tool.tool import Tool class DatasetRetrieverTool(Tool): - retrival_tool: DatasetRetrieverBaseTool + retrieval_tool: DatasetRetrieverBaseTool @staticmethod - def get_dataset_tools(tenant_id: str, - dataset_ids: list[str], - retrieve_config: DatasetRetrieveConfigEntity, - return_resource: bool, - invoke_from: InvokeFrom, - hit_callback: DatasetIndexToolCallbackHandler - ) -> list['DatasetRetrieverTool']: + def get_dataset_tools( + tenant_id: str, + dataset_ids: list[str], + retrieve_config: DatasetRetrieveConfigEntity, + return_resource: bool, + invoke_from: InvokeFrom, + hit_callback: DatasetIndexToolCallbackHandler, + ) -> list["DatasetRetrieverTool"]: """ get dataset tool """ @@ -42,29 +43,29 @@ class DatasetRetrieverTool(Tool): # Agent only support SINGLE mode original_retriever_mode = retrieve_config.retrieve_strategy retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE - retrival_tools = feature.to_dataset_retriever_tool( + retrieval_tools = feature.to_dataset_retriever_tool( tenant_id=tenant_id, dataset_ids=dataset_ids, retrieve_config=retrieve_config, return_resource=return_resource, invoke_from=invoke_from, - hit_callback=hit_callback + hit_callback=hit_callback, ) # restore retrieve strategy retrieve_config.retrieve_strategy = original_retriever_mode - # convert retrival tools to Tools + # convert retrieval tools to Tools tools = [] - for retrival_tool in retrival_tools: + for retrieval_tool in retrieval_tools: tool = DatasetRetrieverTool( - retrival_tool=retrival_tool, - identity=ToolIdentity(provider='', author='', name=retrival_tool.name, label=I18nObject(en_US='', zh_Hans='')), + retrieval_tool=retrieval_tool, + identity=ToolIdentity( + provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="") + ), parameters=[], is_team_authorization=True, - description=ToolDescription( - human=I18nObject(en_US='', zh_Hans=''), - llm=retrival_tool.description), - runtime=DatasetRetrieverTool.Runtime() + description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description), + runtime=DatasetRetrieverTool.Runtime(), ) tools.append(tool) @@ -73,16 +74,18 @@ class DatasetRetrieverTool(Tool): def get_runtime_parameters(self) -> list[ToolParameter]: return [ - ToolParameter(name='query', - label=I18nObject(en_US='', zh_Hans=''), - human_description=I18nObject(en_US='', zh_Hans=''), - type=ToolParameter.ToolParameterType.STRING, - form=ToolParameter.ToolParameterForm.LLM, - llm_description='Query for the dataset to be used to retrieve the dataset.', - required=True, - default=''), + ToolParameter( + name="query", + label=I18nObject(en_US="", zh_Hans=""), + human_description=I18nObject(en_US="", zh_Hans=""), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + llm_description="Query for the dataset to be used to retrieve the dataset.", + required=True, + default="", + ), ] - + def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.DATASET_RETRIEVAL @@ -90,12 +93,12 @@ class DatasetRetrieverTool(Tool): """ invoke dataset retriever tool """ - query = tool_parameters.get('query') + query = tool_parameters.get("query") if not query: - return self.create_text_message(text='please input query') + return self.create_text_message(text="please input query") # invoke dataset retriever tool - result = self.retrival_tool._run(query=query) + result = self.retrieval_tool._run(query=query) return self.create_text_message(text=result) diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index d990131b5f..d9e9a0faad 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -35,15 +35,16 @@ class Tool(BaseModel, ABC): # pydantic configs model_config = ConfigDict(protected_namespaces=()) - @field_validator('parameters', mode='before') + @field_validator("parameters", mode="before") @classmethod def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]: return v or [] class Runtime(BaseModel): """ - Meta data of a tool call processing + Meta data of a tool call processing """ + def __init__(self, **data: Any): super().__init__(**data) if not self.runtime_parameters: @@ -62,15 +63,15 @@ class Tool(BaseModel, ABC): def __init__(self, **data: Any): super().__init__(**data) - class VARIABLE_KEY(Enum): - IMAGE = 'image' + class VariableKey(Enum): + IMAGE = "image" - def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool': + def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": """ - fork a new tool with meta data + fork a new tool with meta data - :param meta: the meta data of a tool call processing, tenant_id is required - :return: the new tool + :param meta: the meta data of a tool call processing, tenant_id is required + :return: the new tool """ return self.__class__( identity=self.identity.model_copy() if self.identity else None, @@ -82,22 +83,22 @@ class Tool(BaseModel, ABC): @abstractmethod def tool_provider_type(self) -> ToolProviderType: """ - get the tool provider type + get the tool provider type - :return: the tool provider type + :return: the tool provider type """ def load_variables(self, variables: ToolRuntimeVariablePool): """ - load variables from database + load variables from database - :param conversation_id: the conversation id + :param conversation_id: the conversation id """ self.variables = variables def set_image_variable(self, variable_name: str, image_key: str) -> None: """ - set an image variable + set an image variable """ if not self.variables: return @@ -106,7 +107,7 @@ class Tool(BaseModel, ABC): def set_text_variable(self, variable_name: str, text: str) -> None: """ - set a text variable + set a text variable """ if not self.variables: return @@ -115,10 +116,10 @@ class Tool(BaseModel, ABC): def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]: """ - get a variable + get a variable - :param name: the name of the variable - :return: the variable + :param name: the name of the variable + :return: the variable """ if not self.variables: return None @@ -134,21 +135,21 @@ class Tool(BaseModel, ABC): def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]: """ - get the default image variable + get the default image variable - :return: the image variable + :return: the image variable """ if not self.variables: return None - return self.get_variable(self.VARIABLE_KEY.IMAGE) + return self.get_variable(self.VariableKey.IMAGE) def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]: """ - get a variable file + get a variable file - :param name: the name of the variable - :return: the variable file + :param name: the name of the variable + :return: the variable file """ variable = self.get_variable(name) if not variable: @@ -167,9 +168,9 @@ class Tool(BaseModel, ABC): def list_variables(self) -> list[ToolRuntimeVariable]: """ - list all variables + list all variables - :return: the variables + :return: the variables """ if not self.variables: return [] @@ -178,9 +179,9 @@ class Tool(BaseModel, ABC): def list_default_image_variables(self) -> list[ToolRuntimeVariable]: """ - list all image variables + list all image variables - :return: the image variables + :return: the image variables """ if not self.variables: return [] @@ -188,7 +189,7 @@ class Tool(BaseModel, ABC): result = [] for variable in self.variables.pool: - if variable.name.startswith(self.VARIABLE_KEY.IMAGE.value): + if variable.name.startswith(self.VariableKey.IMAGE.value): result.append(variable) return result @@ -220,38 +221,42 @@ class Tool(BaseModel, ABC): result = deepcopy(tool_parameters) for parameter in self.parameters or []: if parameter.name in tool_parameters: - result[parameter.name] = ToolParameterConverter.cast_parameter_by_type(tool_parameters[parameter.name], parameter.type) + result[parameter.name] = ToolParameterConverter.cast_parameter_by_type( + tool_parameters[parameter.name], parameter.type + ) return result @abstractmethod - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: pass def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: """ - validate the credentials + validate the credentials - :param credentials: the credentials - :param parameters: the parameters + :param credentials: the credentials + :param parameters: the parameters """ pass def get_runtime_parameters(self) -> list[ToolParameter]: """ - get the runtime parameters + get the runtime parameters - interface for developer to dynamic change the parameters of a tool depends on the variables pool + interface for developer to dynamic change the parameters of a tool depends on the variables pool - :return: the runtime parameters + :return: the runtime parameters """ return self.parameters or [] def get_all_runtime_parameters(self) -> list[ToolParameter]: """ - get all runtime parameters + get all runtime parameters - :return: all runtime parameters + :return: all runtime parameters """ parameters = self.parameters or [] parameters = parameters.copy() @@ -281,67 +286,49 @@ class Tool(BaseModel, ABC): return parameters - def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage: + def create_image_message(self, image: str, save_as: str = "") -> ToolInvokeMessage: """ - create an image message + create an image message - :param image: the url of the image - :return: the image message + :param image: the url of the image + :return: the image message """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, - message=image, - save_as=save_as) + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, message=image, save_as=save_as) def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage: - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR, - message='', - meta={ - 'file_var': file_var - }, - save_as='') - - def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage: - """ - create a link message - - :param link: the url of the link - :return: the link message - """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, - message=link, - save_as=save_as) - - def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage: - """ - create a text message - - :param text: the text - :return: the text message - """ return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.TEXT, - message=text, - save_as=save_as + type=ToolInvokeMessage.MessageType.FILE_VAR, message="", meta={"file_var": file_var}, save_as="" ) - def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage: + def create_link_message(self, link: str, save_as: str = "") -> ToolInvokeMessage: """ - create a blob message + create a link message - :param blob: the blob - :return: the blob message + :param link: the url of the link + :return: the link message """ - return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.BLOB, - message=blob, meta=meta, - save_as=save_as - ) + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, message=link, save_as=save_as) + + def create_text_message(self, text: str, save_as: str = "") -> ToolInvokeMessage: + """ + create a text message + + :param text: the text + :return: the text message + """ + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.TEXT, message=text, save_as=save_as) + + def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = "") -> ToolInvokeMessage: + """ + create a blob message + + :param blob: the blob + :return: the blob message + """ + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.BLOB, message=blob, meta=meta, save_as=save_as) def create_json_message(self, object: dict) -> ToolInvokeMessage: """ - create a json message + create a json message """ - return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.JSON, - message=object - ) + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=object) diff --git a/api/core/tools/tool/workflow_tool.py b/api/core/tools/tool/workflow_tool.py index 12e498e76d..ad0c7fc631 100644 --- a/api/core/tools/tool/workflow_tool.py +++ b/api/core/tools/tool/workflow_tool.py @@ -1,7 +1,7 @@ import json import logging from copy import deepcopy -from typing import Any, Union +from typing import Any, Optional, Union from core.file.file_obj import FileTransferMethod, FileVar from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType @@ -13,22 +13,25 @@ from models.workflow import Workflow logger = logging.getLogger(__name__) + class WorkflowTool(Tool): workflow_app_id: str version: str workflow_entities: dict[str, Any] workflow_call_depth: int + thread_pool_id: Optional[str] = None label: str """ Workflow tool. """ + def tool_provider_type(self) -> ToolProviderType: """ - get the tool provider type + get the tool provider type - :return: the tool provider type + :return: the tool provider type """ return ToolProviderType.WORKFLOW @@ -36,7 +39,7 @@ class WorkflowTool(Tool): self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke the tool + invoke the tool """ app = self._get_app(app_id=self.workflow_app_id) workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version) @@ -45,32 +48,31 @@ class WorkflowTool(Tool): tool_parameters, files = self._transform_args(tool_parameters) from core.app.apps.workflow.app_generator import WorkflowAppGenerator + generator = WorkflowAppGenerator() result = generator.generate( - app_model=app, - workflow=workflow, - user=self._get_user(user_id), - args={ - 'inputs': tool_parameters, - 'files': files - }, + app_model=app, + workflow=workflow, + user=self._get_user(user_id), + args={"inputs": tool_parameters, "files": files}, invoke_from=self.runtime.invoke_from, stream=False, call_depth=self.workflow_call_depth + 1, + workflow_thread_pool_id=self.thread_pool_id, ) - data = result.get('data', {}) + data = result.get("data", {}) + + if data.get("error"): + raise Exception(data.get("error")) - if data.get('error'): - raise Exception(data.get('error')) - result = [] - outputs = data.get('outputs', {}) + outputs = data.get("outputs", {}) outputs, files = self._extract_files(outputs) for file in files: result.append(self.create_file_var_message(file)) - + result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False))) result.append(self.create_json_message(outputs)) @@ -78,7 +80,7 @@ class WorkflowTool(Tool): def _get_user(self, user_id: str) -> Union[EndUser, Account]: """ - get the user by user id + get the user by user id """ user = db.session.query(EndUser).filter(EndUser.id == user_id).first() @@ -86,16 +88,16 @@ class WorkflowTool(Tool): user = db.session.query(Account).filter(Account.id == user_id).first() if not user: - raise ValueError('user not found') + raise ValueError("user not found") return user - def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'WorkflowTool': + def fork_tool_runtime(self, runtime: dict[str, Any]) -> "WorkflowTool": """ - fork a new tool with meta data + fork a new tool with meta data - :param meta: the meta data of a tool call processing, tenant_id is required - :return: the new tool + :param meta: the meta data of a tool call processing, tenant_id is required + :return: the new tool """ return self.__class__( identity=deepcopy(self.identity), @@ -106,45 +108,44 @@ class WorkflowTool(Tool): workflow_entities=self.workflow_entities, workflow_call_depth=self.workflow_call_depth, version=self.version, - label=self.label + label=self.label, ) - + def _get_workflow(self, app_id: str, version: str) -> Workflow: """ - get the workflow by app id and version + get the workflow by app id and version """ if not version: - workflow = db.session.query(Workflow).filter( - Workflow.app_id == app_id, - Workflow.version != 'draft' - ).order_by(Workflow.created_at.desc()).first() + workflow = ( + db.session.query(Workflow) + .filter(Workflow.app_id == app_id, Workflow.version != "draft") + .order_by(Workflow.created_at.desc()) + .first() + ) else: - workflow = db.session.query(Workflow).filter( - Workflow.app_id == app_id, - Workflow.version == version - ).first() + workflow = db.session.query(Workflow).filter(Workflow.app_id == app_id, Workflow.version == version).first() if not workflow: - raise ValueError('workflow not found or not published') + raise ValueError("workflow not found or not published") return workflow - + def _get_app(self, app_id: str) -> App: """ - get the app by app id + get the app by app id """ app = db.session.query(App).filter(App.id == app_id).first() if not app: - raise ValueError('app not found') + raise ValueError("app not found") return app - + def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]: """ - transform the tool parameters + transform the tool parameters - :param tool_parameters: the tool parameters - :return: tool_parameters, files + :param tool_parameters: the tool parameters + :return: tool_parameters, files """ parameter_rules = self.get_all_runtime_parameters() parameters_result = {} @@ -157,15 +158,15 @@ class WorkflowTool(Tool): file_var_list = [FileVar(**f) for f in file] for file_var in file_var_list: file_dict = { - 'transfer_method': file_var.transfer_method.value, - 'type': file_var.type.value, + "transfer_method": file_var.transfer_method.value, + "type": file_var.type.value, } if file_var.transfer_method == FileTransferMethod.TOOL_FILE: - file_dict['tool_file_id'] = file_var.related_id + file_dict["tool_file_id"] = file_var.related_id elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE: - file_dict['upload_file_id'] = file_var.related_id + file_dict["upload_file_id"] = file_var.related_id elif file_var.transfer_method == FileTransferMethod.REMOTE_URL: - file_dict['url'] = file_var.preview_url + file_dict["url"] = file_var.preview_url files.append(file_dict) except Exception as e: @@ -174,13 +175,13 @@ class WorkflowTool(Tool): parameters_result[parameter.name] = tool_parameters.get(parameter.name) return parameters_result, files - + def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]: """ - extract files from the result + extract files from the result - :param result: the result - :return: the result, files + :param result: the result + :return: the result, files """ files = [] result = {} @@ -188,7 +189,7 @@ class WorkflowTool(Tool): if isinstance(value, list): has_file = False for item in value: - if isinstance(item, dict) and item.get('__variant') == 'FileVar': + if isinstance(item, dict) and item.get("__variant") == "FileVar": try: files.append(FileVar(**item)) has_file = True @@ -199,4 +200,4 @@ class WorkflowTool(Tool): result[key] = value - return result, files \ No newline at end of file + return result, files diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 0e15151aa4..9a6a49d8f4 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -33,12 +33,17 @@ class ToolEngine: """ Tool runtime engine take care of the tool executions. """ + @staticmethod def agent_invoke( - tool: Tool, tool_parameters: Union[str, dict], - user_id: str, tenant_id: str, message: Message, invoke_from: InvokeFrom, + tool: Tool, + tool_parameters: Union[str, dict], + user_id: str, + tenant_id: str, + message: Message, + invoke_from: InvokeFrom, agent_tool_callback: DifyAgentCallbackHandler, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]: """ Agent invokes the tool with the given arguments. @@ -47,40 +52,30 @@ class ToolEngine: if isinstance(tool_parameters, str): # check if this tool has only one parameter parameters = [ - parameter for parameter in tool.get_runtime_parameters() or [] + parameter + for parameter in tool.get_runtime_parameters() or [] if parameter.form == ToolParameter.ToolParameterForm.LLM ] if parameters and len(parameters) == 1: - tool_parameters = { - parameters[0].name: tool_parameters - } + tool_parameters = {parameters[0].name: tool_parameters} else: raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") # invoke the tool try: # hit the callback handler - agent_tool_callback.on_tool_start( - tool_name=tool.identity.name, - tool_inputs=tool_parameters - ) + agent_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters) meta, response = ToolEngine._invoke(tool, tool_parameters, user_id) response = ToolFileMessageTransformer.transform_tool_invoke_messages( - messages=response, - user_id=user_id, - tenant_id=tenant_id, - conversation_id=message.conversation_id + messages=response, user_id=user_id, tenant_id=tenant_id, conversation_id=message.conversation_id ) # extract binary data from tool invoke message binary_files = ToolEngine._extract_tool_response_binary(response) # create message file message_files = ToolEngine._create_message_files( - tool_messages=binary_files, - agent_message=message, - invoke_from=invoke_from, - user_id=user_id + tool_messages=binary_files, agent_message=message, invoke_from=invoke_from, user_id=user_id ) plain_text = ToolEngine._convert_tool_response_to_str(response) @@ -91,7 +86,7 @@ class ToolEngine: tool_inputs=tool_parameters, tool_outputs=plain_text, message_id=message.id, - trace_manager=trace_manager + trace_manager=trace_manager, ) # transform tool invoke message to get LLM friendly message @@ -99,14 +94,10 @@ class ToolEngine: except ToolProviderCredentialValidationError as e: error_response = "Please check your tool provider credentials" agent_tool_callback.on_tool_error(e) - except ( - ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError - ) as e: + except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e: error_response = f"there is not a tool named {tool.identity.name}" agent_tool_callback.on_tool_error(e) - except ( - ToolParameterValidationError - ) as e: + except ToolParameterValidationError as e: error_response = f"tool parameters validation error: {e}, please check your tool parameters" agent_tool_callback.on_tool_error(e) except ToolInvokeError as e: @@ -124,23 +115,24 @@ class ToolEngine: return error_response, [], ToolInvokeMeta.error_instance(error_response) @staticmethod - def workflow_invoke(tool: Tool, tool_parameters: Mapping[str, Any], - user_id: str, - workflow_tool_callback: DifyWorkflowCallbackHandler, - workflow_call_depth: int, - ) -> list[ToolInvokeMessage]: + def workflow_invoke( + tool: Tool, + tool_parameters: Mapping[str, Any], + user_id: str, + workflow_tool_callback: DifyWorkflowCallbackHandler, + workflow_call_depth: int, + thread_pool_id: Optional[str] = None, + ) -> list[ToolInvokeMessage]: """ Workflow invokes the tool with the given arguments. """ try: # hit the callback handler - workflow_tool_callback.on_tool_start( - tool_name=tool.identity.name, - tool_inputs=tool_parameters - ) + workflow_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters) if isinstance(tool, WorkflowTool): tool.workflow_call_depth = workflow_call_depth + 1 + tool.thread_pool_id = thread_pool_id if tool.runtime and tool.runtime.runtime_parameters: tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters} @@ -157,21 +149,24 @@ class ToolEngine: except Exception as e: workflow_tool_callback.on_tool_error(e) raise e - + @staticmethod - def _invoke(tool: Tool, tool_parameters: dict, user_id: str) \ - -> tuple[ToolInvokeMeta, list[ToolInvokeMessage]]: + def _invoke(tool: Tool, tool_parameters: dict, user_id: str) -> tuple[ToolInvokeMeta, list[ToolInvokeMessage]]: """ Invoke the tool with the given arguments. """ started_at = datetime.now(timezone.utc) - meta = ToolInvokeMeta(time_cost=0.0, error=None, tool_config={ - 'tool_name': tool.identity.name, - 'tool_provider': tool.identity.provider, - 'tool_provider_type': tool.tool_provider_type().value, - 'tool_parameters': deepcopy(tool.runtime.runtime_parameters), - 'tool_icon': tool.identity.icon - }) + meta = ToolInvokeMeta( + time_cost=0.0, + error=None, + tool_config={ + "tool_name": tool.identity.name, + "tool_provider": tool.identity.provider, + "tool_provider_type": tool.tool_provider_type().value, + "tool_parameters": deepcopy(tool.runtime.runtime_parameters), + "tool_icon": tool.identity.icon, + }, + ) try: response = tool.invoke(user_id, tool_parameters) except Exception as e: @@ -182,20 +177,22 @@ class ToolEngine: meta.time_cost = (ended_at - started_at).total_seconds() return meta, response - + @staticmethod def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str: """ Handle tool response """ - result = '' + result = "" for response in tool_response: if response.type == ToolInvokeMessage.MessageType.TEXT: result += response.message elif response.type == ToolInvokeMessage.MessageType.LINK: result += f"result link: {response.message}. please tell user to check it." - elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ - response.type == ToolInvokeMessage.MessageType.IMAGE: + elif ( + response.type == ToolInvokeMessage.MessageType.IMAGE_LINK + or response.type == ToolInvokeMessage.MessageType.IMAGE + ): result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now." elif response.type == ToolInvokeMessage.MessageType.JSON: result += f"tool response: {json.dumps(response.message, ensure_ascii=False)}." @@ -203,7 +200,7 @@ class ToolEngine: result += f"tool response: {response.message}." return result - + @staticmethod def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]: """ @@ -212,52 +209,59 @@ class ToolEngine: result = [] for response in tool_response: - if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ - response.type == ToolInvokeMessage.MessageType.IMAGE: + if ( + response.type == ToolInvokeMessage.MessageType.IMAGE_LINK + or response.type == ToolInvokeMessage.MessageType.IMAGE + ): mimetype = None - if response.meta.get('mime_type'): - mimetype = response.meta.get('mime_type') + if response.meta.get("mime_type"): + mimetype = response.meta.get("mime_type") else: try: url = URL(response.message) extension = url.suffix - guess_type_result, _ = guess_type(f'a{extension}') + guess_type_result, _ = guess_type(f"a{extension}") if guess_type_result: mimetype = guess_type_result except Exception: pass - + if not mimetype: - mimetype = 'image/jpeg' - - result.append(ToolInvokeMessageBinary( - mimetype=response.meta.get('mime_type', 'image/jpeg'), - url=response.message, - save_as=response.save_as, - )) - elif response.type == ToolInvokeMessage.MessageType.BLOB: - result.append(ToolInvokeMessageBinary( - mimetype=response.meta.get('mime_type', 'octet/stream'), - url=response.message, - save_as=response.save_as, - )) - elif response.type == ToolInvokeMessage.MessageType.LINK: - # check if there is a mime type in meta - if response.meta and 'mime_type' in response.meta: - result.append(ToolInvokeMessageBinary( - mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream', + mimetype = "image/jpeg" + + result.append( + ToolInvokeMessageBinary( + mimetype=response.meta.get("mime_type", "image/jpeg"), url=response.message, save_as=response.save_as, - )) + ) + ) + elif response.type == ToolInvokeMessage.MessageType.BLOB: + result.append( + ToolInvokeMessageBinary( + mimetype=response.meta.get("mime_type", "octet/stream"), + url=response.message, + save_as=response.save_as, + ) + ) + elif response.type == ToolInvokeMessage.MessageType.LINK: + # check if there is a mime type in meta + if response.meta and "mime_type" in response.meta: + result.append( + ToolInvokeMessageBinary( + mimetype=response.meta.get("mime_type", "octet/stream") + if response.meta + else "octet/stream", + url=response.message, + save_as=response.save_as, + ) + ) return result - + @staticmethod def _create_message_files( - tool_messages: list[ToolInvokeMessageBinary], - agent_message: Message, - invoke_from: InvokeFrom, - user_id: str + tool_messages: list[ToolInvokeMessageBinary], agent_message: Message, invoke_from: InvokeFrom, user_id: str ) -> list[tuple[Any, str]]: """ Create message file @@ -268,29 +272,29 @@ class ToolEngine: result = [] for message in tool_messages: - file_type = 'bin' - if 'image' in message.mimetype: - file_type = 'image' - elif 'video' in message.mimetype: - file_type = 'video' - elif 'audio' in message.mimetype: - file_type = 'audio' - elif 'text' in message.mimetype: - file_type = 'text' - elif 'pdf' in message.mimetype: - file_type = 'pdf' - elif 'zip' in message.mimetype: - file_type = 'archive' + file_type = "bin" + if "image" in message.mimetype: + file_type = "image" + elif "video" in message.mimetype: + file_type = "video" + elif "audio" in message.mimetype: + file_type = "audio" + elif "text" in message.mimetype: + file_type = "text" + elif "pdf" in message.mimetype: + file_type = "pdf" + elif "zip" in message.mimetype: + file_type = "archive" # ... message_file = MessageFile( message_id=agent_message.id, type=file_type, transfer_method=FileTransferMethod.TOOL_FILE.value, - belongs_to='assistant', + belongs_to="assistant", url=message.url, upload_file_id=None, - created_by_role=('account'if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'), + created_by_role=("account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"), created_by=user_id, ) @@ -298,11 +302,8 @@ class ToolEngine: db.session.commit() db.session.refresh(message_file) - result.append(( - message_file.id, - message.save_as - )) + result.append((message_file.id, message.save_as)) db.session.close() - return result \ No newline at end of file + return result diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index f9f7c7d78a..ad3b9c7328 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -27,24 +27,24 @@ class ToolFileManager: sign file to get a temporary url """ base_url = dify_config.FILES_URL - file_preview_url = f'{base_url}/files/tools/{tool_file_id}{extension}' + file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}" timestamp = str(int(time.time())) nonce = os.urandom(16).hex() - data_to_sign = f'file-preview|{tool_file_id}|{timestamp}|{nonce}' - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b'' + data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() - return f'{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}' + return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" @staticmethod def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool: """ verify signature """ - data_to_sign = f'file-preview|{file_id}|{timestamp}|{nonce}' - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b'' + data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() @@ -62,9 +62,9 @@ class ToolFileManager: """ create file """ - extension = guess_extension(mimetype) or '.bin' + extension = guess_extension(mimetype) or ".bin" unique_name = uuid4().hex - filename = f'tools/{tenant_id}/{unique_name}{extension}' + filename = f"tools/{tenant_id}/{unique_name}{extension}" storage.save(filename, file_binary) tool_file = ToolFile( @@ -90,10 +90,10 @@ class ToolFileManager: response = get(file_url) response.raise_for_status() blob = response.content - mimetype = guess_type(file_url)[0] or 'octet/stream' - extension = guess_extension(mimetype) or '.bin' + mimetype = guess_type(file_url)[0] or "octet/stream" + extension = guess_extension(mimetype) or ".bin" unique_name = uuid4().hex - filename = f'tools/{tenant_id}/{unique_name}{extension}' + filename = f"tools/{tenant_id}/{unique_name}{extension}" storage.save(filename, blob) tool_file = ToolFile( @@ -166,13 +166,12 @@ class ToolFileManager: # Check if message_file is not None if message_file is not None: # get tool file id - tool_file_id = message_file.url.split('/')[-1] + tool_file_id = message_file.url.split("/")[-1] # trim extension - tool_file_id = tool_file_id.split('.')[0] + tool_file_id = tool_file_id.split(".")[0] else: tool_file_id = None - tool_file: ToolFile = ( db.session.query(ToolFile) .filter( @@ -216,4 +215,4 @@ class ToolFileManager: # init tool_file_parser from core.file.tool_file_parser import tool_file_manager -tool_file_manager['manager'] = ToolFileManager +tool_file_manager["manager"] = ToolFileManager diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 97788a7a07..2a5a2944ef 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -15,7 +15,7 @@ class ToolLabelManager: """ tool_labels = [label for label in tool_labels if label in default_tool_label_name_list] return list(set(tool_labels)) - + @classmethod def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]): """ @@ -26,20 +26,20 @@ class ToolLabelManager: if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): provider_id = controller.provider_id else: - raise ValueError('Unsupported tool type') + raise ValueError("Unsupported tool type") # delete old labels - db.session.query(ToolLabelBinding).filter( - ToolLabelBinding.tool_id == provider_id - ).delete() + db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete() # insert new labels for label in labels: - db.session.add(ToolLabelBinding( - tool_id=provider_id, - tool_type=controller.provider_type.value, - label_name=label, - )) + db.session.add( + ToolLabelBinding( + tool_id=provider_id, + tool_type=controller.provider_type.value, + label_name=label, + ) + ) db.session.commit() @@ -53,12 +53,16 @@ class ToolLabelManager: elif isinstance(controller, BuiltinToolProviderController): return controller.tool_labels else: - raise ValueError('Unsupported tool type') + raise ValueError("Unsupported tool type") - labels: list[ToolLabelBinding] = db.session.query(ToolLabelBinding.label_name).filter( - ToolLabelBinding.tool_id == provider_id, - ToolLabelBinding.tool_type == controller.provider_type.value, - ).all() + labels: list[ToolLabelBinding] = ( + db.session.query(ToolLabelBinding.label_name) + .filter( + ToolLabelBinding.tool_id == provider_id, + ToolLabelBinding.tool_type == controller.provider_type.value, + ) + .all() + ) return [label.label_name for label in labels] @@ -75,22 +79,20 @@ class ToolLabelManager: """ if not tool_providers: return {} - + for controller in tool_providers: if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): - raise ValueError('Unsupported tool type') - + raise ValueError("Unsupported tool type") + provider_ids = [controller.provider_id for controller in tool_providers] - labels: list[ToolLabelBinding] = db.session.query(ToolLabelBinding).filter( - ToolLabelBinding.tool_id.in_(provider_ids) - ).all() + labels: list[ToolLabelBinding] = ( + db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all() + ) - tool_labels = { - label.tool_id: [] for label in labels - } + tool_labels = {label.tool_id: [] for label in labels} for label in labels: tool_labels[label.tool_id].append(label.label_name) - return tool_labels \ No newline at end of file + return tool_labels diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 4a0188af49..a3303797e1 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -25,7 +25,6 @@ from core.tools.tool.tool import Tool from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager from core.tools.utils.tool_parameter_converter import ToolParameterConverter -from core.workflow.nodes.tool.entities import ToolEntity from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -42,29 +41,29 @@ class ToolManager: @classmethod def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController: """ - get the builtin provider + get the builtin provider - :param provider: the name of the provider - :return: the provider + :param provider: the name of the provider + :return: the provider """ if len(cls._builtin_providers) == 0: # init the builtin providers cls.load_builtin_providers_cache() if provider not in cls._builtin_providers: - raise ToolProviderNotFoundError(f'builtin provider {provider} not found') + raise ToolProviderNotFoundError(f"builtin provider {provider} not found") return cls._builtin_providers[provider] @classmethod def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool: """ - get the builtin tool + get the builtin tool - :param provider: the name of the provider - :param tool_name: the name of the tool + :param provider: the name of the provider + :param tool_name: the name of the tool - :return: the provider, the tool + :return: the provider, the tool """ provider_controller = cls.get_builtin_provider(provider) tool = provider_controller.get_tool(tool_name) @@ -72,67 +71,76 @@ class ToolManager: return tool @classmethod - def get_tool(cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \ - -> Union[BuiltinTool, ApiTool]: + def get_tool( + cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None + ) -> Union[BuiltinTool, ApiTool]: """ - get the tool + get the tool - :param provider_type: the type of the provider - :param provider_name: the name of the provider - :param tool_name: the name of the tool + :param provider_type: the type of the provider + :param provider_name: the name of the provider + :param tool_name: the name of the tool - :return: the tool + :return: the tool """ - if provider_type == 'builtin': + if provider_type == "builtin": return cls.get_builtin_tool(provider_id, tool_name) - elif provider_type == 'api': + elif provider_type == "api": if tenant_id is None: - raise ValueError('tenant id is required for api provider') + raise ValueError("tenant id is required for api provider") api_provider, _ = cls.get_api_provider_controller(tenant_id, provider_id) return api_provider.get_tool(tool_name) - elif provider_type == 'app': - raise NotImplementedError('app provider not implemented') + elif provider_type == "app": + raise NotImplementedError("app provider not implemented") else: - raise ToolProviderNotFoundError(f'provider type {provider_type} not found') + raise ToolProviderNotFoundError(f"provider type {provider_type} not found") @classmethod - def get_tool_runtime(cls, provider_type: str, - provider_id: str, - tool_name: str, - tenant_id: str, - invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \ - -> Union[BuiltinTool, ApiTool]: + def get_tool_runtime( + cls, + provider_type: str, + provider_id: str, + tool_name: str, + tenant_id: str, + invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, + tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, + ) -> Union[BuiltinTool, ApiTool]: """ - get the tool runtime + get the tool runtime - :param provider_type: the type of the provider - :param provider_name: the name of the provider - :param tool_name: the name of the tool + :param provider_type: the type of the provider + :param provider_name: the name of the provider + :param tool_name: the name of the tool - :return: the tool + :return: the tool """ - if provider_type == 'builtin': + if provider_type == "builtin": builtin_tool = cls.get_builtin_tool(provider_id, tool_name) # check if the builtin tool need credentials provider_controller = cls.get_builtin_provider(provider_id) if not provider_controller.need_credentials: - return builtin_tool.fork_tool_runtime(runtime={ - 'tenant_id': tenant_id, - 'credentials': {}, - 'invoke_from': invoke_from, - 'tool_invoke_from': tool_invoke_from, - }) + return builtin_tool.fork_tool_runtime( + runtime={ + "tenant_id": tenant_id, + "credentials": {}, + "invoke_from": invoke_from, + "tool_invoke_from": tool_invoke_from, + } + ) # get credentials - builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_id, - ).first() + builtin_provider: BuiltinToolProvider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_id, + ) + .first() + ) if builtin_provider is None: - raise ToolProviderNotFoundError(f'builtin provider {provider_id} not found') + raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") # decrypt the credentials credentials = builtin_provider.credentials @@ -141,17 +149,19 @@ class ToolManager: decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) - return builtin_tool.fork_tool_runtime(runtime={ - 'tenant_id': tenant_id, - 'credentials': decrypted_credentials, - 'runtime_parameters': {}, - 'invoke_from': invoke_from, - 'tool_invoke_from': tool_invoke_from, - }) + return builtin_tool.fork_tool_runtime( + runtime={ + "tenant_id": tenant_id, + "credentials": decrypted_credentials, + "runtime_parameters": {}, + "invoke_from": invoke_from, + "tool_invoke_from": tool_invoke_from, + } + ) - elif provider_type == 'api': + elif provider_type == "api": if tenant_id is None: - raise ValueError('tenant id is required for api provider') + raise ValueError("tenant id is required for api provider") api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) @@ -159,40 +169,43 @@ class ToolManager: tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) - return api_provider.get_tool(tool_name).fork_tool_runtime(runtime={ - 'tenant_id': tenant_id, - 'credentials': decrypted_credentials, - 'invoke_from': invoke_from, - 'tool_invoke_from': tool_invoke_from, - }) - elif provider_type == 'workflow': - workflow_provider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == provider_id - ).first() - - if workflow_provider is None: - raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found') - - controller = ToolTransformService.workflow_provider_to_controller( - db_provider=workflow_provider + return api_provider.get_tool(tool_name).fork_tool_runtime( + runtime={ + "tenant_id": tenant_id, + "credentials": decrypted_credentials, + "invoke_from": invoke_from, + "tool_invoke_from": tool_invoke_from, + } + ) + elif provider_type == "workflow": + workflow_provider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) + .first() ) - return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(runtime={ - 'tenant_id': tenant_id, - 'credentials': {}, - 'invoke_from': invoke_from, - 'tool_invoke_from': tool_invoke_from, - }) - elif provider_type == 'app': - raise NotImplementedError('app provider not implemented') + if workflow_provider is None: + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") + + controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) + + return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( + runtime={ + "tenant_id": tenant_id, + "credentials": {}, + "invoke_from": invoke_from, + "tool_invoke_from": tool_invoke_from, + } + ) + elif provider_type == "app": + raise NotImplementedError("app provider not implemented") else: - raise ToolProviderNotFoundError(f'provider type {provider_type} not found') + raise ToolProviderNotFoundError(f"provider type {provider_type} not found") @classmethod def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]: """ - init runtime parameter + init runtime parameter """ parameter_value = parameters.get(parameter_rule.name) if not parameter_value and parameter_value != 0: @@ -206,14 +219,17 @@ class ToolManager: options = [x.value for x in parameter_rule.options] if parameter_value is not None and parameter_value not in options: raise ValueError( - f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}") + f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}" + ) return ToolParameterConverter.cast_parameter_by_type(parameter_value, parameter_rule.type) @classmethod - def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool: + def get_agent_tool_runtime( + cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER + ) -> Tool: """ - get the agent tool runtime + get the agent tool runtime """ tool_entity = cls.get_tool_runtime( provider_type=agent_tool.provider_type, @@ -221,7 +237,7 @@ class ToolManager: tool_name=agent_tool.tool_name, tenant_id=tenant_id, invoke_from=invoke_from, - tool_invoke_from=ToolInvokeFrom.AGENT + tool_invoke_from=ToolInvokeFrom.AGENT, ) runtime_parameters = {} parameters = tool_entity.get_all_runtime_parameters() @@ -241,7 +257,7 @@ class ToolManager: tool_runtime=tool_entity, provider_name=agent_tool.provider_id, provider_type=agent_tool.provider_type, - identity_id=f'AGENT.{app_id}' + identity_id=f"AGENT.{app_id}", ) runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) @@ -249,9 +265,16 @@ class ToolManager: return tool_entity @classmethod - def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool: + def get_workflow_tool_runtime( + cls, + tenant_id: str, + app_id: str, + node_id: str, + workflow_tool: "ToolEntity", + invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, + ) -> Tool: """ - get the workflow tool runtime + get the workflow tool runtime """ tool_entity = cls.get_tool_runtime( provider_type=workflow_tool.provider_type, @@ -259,7 +282,7 @@ class ToolManager: tool_name=workflow_tool.tool_name, tenant_id=tenant_id, invoke_from=invoke_from, - tool_invoke_from=ToolInvokeFrom.WORKFLOW + tool_invoke_from=ToolInvokeFrom.WORKFLOW, ) runtime_parameters = {} parameters = tool_entity.get_all_runtime_parameters() @@ -276,7 +299,7 @@ class ToolManager: tool_runtime=tool_entity, provider_name=workflow_tool.provider_id, provider_type=workflow_tool.provider_type, - identity_id=f'WORKFLOW.{app_id}.{node_id}' + identity_id=f"WORKFLOW.{app_id}.{node_id}", ) if runtime_parameters: @@ -288,24 +311,30 @@ class ToolManager: @classmethod def get_builtin_provider_icon(cls, provider: str) -> tuple[str, str]: """ - get the absolute path of the icon of the builtin provider + get the absolute path of the icon of the builtin provider - :param provider: the name of the provider + :param provider: the name of the provider - :return: the absolute path of the icon, the mime type of the icon + :return: the absolute path of the icon, the mime type of the icon """ # get provider provider_controller = cls.get_builtin_provider(provider) - absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets', - provider_controller.identity.icon) + absolute_path = path.join( + path.dirname(path.realpath(__file__)), + "provider", + "builtin", + provider, + "_assets", + provider_controller.identity.icon, + ) # check if the icon exists if not path.exists(absolute_path): - raise ToolProviderNotFoundError(f'builtin provider {provider} icon not found') + raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found") # get the mime type mime_type, _ = mimetypes.guess_type(absolute_path) - mime_type = mime_type or 'application/octet-stream' + mime_type = mime_type or "application/octet-stream" return absolute_path, mime_type @@ -326,23 +355,25 @@ class ToolManager: @classmethod def _list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]: """ - list all the builtin providers + list all the builtin providers """ - for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')): - if provider.startswith('__'): + for provider in listdir(path.join(path.dirname(path.realpath(__file__)), "provider", "builtin")): + if provider.startswith("__"): continue - if path.isdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider)): - if provider.startswith('__'): + if path.isdir(path.join(path.dirname(path.realpath(__file__)), "provider", "builtin", provider)): + if provider.startswith("__"): continue # init provider try: provider_class = load_single_subclass_from_source( - module_name=f'core.tools.provider.builtin.{provider}.{provider}', - script_path=path.join(path.dirname(path.realpath(__file__)), - 'provider', 'builtin', provider, f'{provider}.py'), - parent_type=BuiltinToolProviderController) + module_name=f"core.tools.provider.builtin.{provider}.{provider}", + script_path=path.join( + path.dirname(path.realpath(__file__)), "provider", "builtin", provider, f"{provider}.py" + ), + parent_type=BuiltinToolProviderController, + ) provider: BuiltinToolProviderController = provider_class() cls._builtin_providers[provider.identity.name] = provider for tool in provider.get_tools(): @@ -350,7 +381,7 @@ class ToolManager: yield provider except Exception as e: - logger.error(f'load builtin provider {provider} error: {e}') + logger.error(f"load builtin provider {provider} error: {e}") continue # set builtin providers loaded cls._builtin_providers_loaded = True @@ -368,11 +399,11 @@ class ToolManager: @classmethod def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]: """ - get the tool label + get the tool label - :param tool_name: the name of the tool + :param tool_name: the name of the tool - :return: the label of the tool + :return: the label of the tool """ if len(cls._builtin_tools_labels) == 0: # init the builtin providers @@ -384,75 +415,78 @@ class ToolManager: return cls._builtin_tools_labels[tool_name] @classmethod - def user_list_providers(cls, user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral) -> list[UserToolProvider]: + def user_list_providers( + cls, user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral + ) -> list[UserToolProvider]: result_providers: dict[str, UserToolProvider] = {} filters = [] if not typ: - filters.extend(['builtin', 'api', 'workflow']) + filters.extend(["builtin", "api", "workflow"]) else: filters.append(typ) - if 'builtin' in filters: - + if "builtin" in filters: # get builtin providers builtin_providers = cls.list_builtin_providers() # get db builtin providers - db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \ - filter(BuiltinToolProvider.tenant_id == tenant_id).all() + db_builtin_providers: list[BuiltinToolProvider] = ( + db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() + ) find_db_builtin_provider = lambda provider: next( - (x for x in db_builtin_providers if x.provider == provider), - None + (x for x in db_builtin_providers if x.provider == provider), None ) # append builtin providers for provider in builtin_providers: # handle include, exclude if is_filtered( - include_set=dify_config.POSITION_TOOL_INCLUDES_SET, - exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, - data=provider, - name_func=lambda x: x.identity.name + include_set=dify_config.POSITION_TOOL_INCLUDES_SET, + exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, + data=provider, + name_func=lambda x: x.identity.name, ): continue user_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider, db_provider=find_db_builtin_provider(provider.identity.name), - decrypt_credentials=False + decrypt_credentials=False, ) result_providers[provider.identity.name] = user_provider # get db api providers - if 'api' in filters: - db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \ - filter(ApiToolProvider.tenant_id == tenant_id).all() + if "api" in filters: + db_api_providers: list[ApiToolProvider] = ( + db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() + ) - api_provider_controllers = [{ - 'provider': provider, - 'controller': ToolTransformService.api_provider_to_controller(provider) - } for provider in db_api_providers] + api_provider_controllers = [ + {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} + for provider in db_api_providers + ] # get labels - labels = ToolLabelManager.get_tools_labels([x['controller'] for x in api_provider_controllers]) + labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers]) for api_provider_controller in api_provider_controllers: user_provider = ToolTransformService.api_provider_to_user_provider( - provider_controller=api_provider_controller['controller'], - db_provider=api_provider_controller['provider'], + provider_controller=api_provider_controller["controller"], + db_provider=api_provider_controller["provider"], decrypt_credentials=False, - labels=labels.get(api_provider_controller['controller'].provider_id, []) + labels=labels.get(api_provider_controller["controller"].provider_id, []), ) - result_providers[f'api_provider.{user_provider.name}'] = user_provider + result_providers[f"api_provider.{user_provider.name}"] = user_provider - if 'workflow' in filters: + if "workflow" in filters: # get workflow providers - workflow_providers: list[WorkflowToolProvider] = db.session.query(WorkflowToolProvider). \ - filter(WorkflowToolProvider.tenant_id == tenant_id).all() + workflow_providers: list[WorkflowToolProvider] = ( + db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() + ) workflow_provider_controllers = [] for provider in workflow_providers: @@ -471,32 +505,36 @@ class ToolManager: provider_controller=provider_controller, labels=labels.get(provider_controller.provider_id, []), ) - result_providers[f'workflow_provider.{user_provider.name}'] = user_provider + result_providers[f"workflow_provider.{user_provider.name}"] = user_provider return BuiltinToolProviderSort.sort(list(result_providers.values())) @classmethod - def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[ - ApiToolProviderController, dict[str, Any]]: + def get_api_provider_controller( + cls, tenant_id: str, provider_id: str + ) -> tuple[ApiToolProviderController, dict[str, Any]]: """ - get the api provider + get the api provider - :param provider_name: the name of the provider + :param provider_name: the name of the provider - :return: the provider controller, the credentials + :return: the provider controller, the credentials """ - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.id == provider_id, - ApiToolProvider.tenant_id == tenant_id, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.id == provider_id, + ApiToolProvider.tenant_id == tenant_id, + ) + .first() + ) if provider is None: - raise ToolProviderNotFoundError(f'api provider {provider_id} not found') + raise ToolProviderNotFoundError(f"api provider {provider_id} not found") controller = ApiToolProviderController.from_db( provider, - ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else - ApiProviderAuthType.NONE + ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, ) controller.load_bundled_tools(provider.tools) @@ -505,18 +543,22 @@ class ToolManager: @classmethod def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: """ - get api provider + get api provider """ """ get tool provider """ - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider, + ) + .first() + ) if provider is None: - raise ValueError(f'you have not added provider {provider}') + raise ValueError(f"you have not added provider {provider}") try: credentials = json.loads(provider.credentials_str) or {} @@ -525,7 +567,7 @@ class ToolManager: # package tool provider controller controller = ApiToolProviderController.from_db( - provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE + provider, ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE ) # init tool configuration tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) @@ -536,62 +578,62 @@ class ToolManager: try: icon = json.loads(provider.icon) except: - icon = { - "background": "#252525", - "content": "\ud83d\ude01" - } + icon = {"background": "#252525", "content": "\ud83d\ude01"} # add tool labels labels = ToolLabelManager.get_tool_labels(controller) - return jsonable_encoder({ - 'schema_type': provider.schema_type, - 'schema': provider.schema, - 'tools': provider.tools, - 'icon': icon, - 'description': provider.description, - 'credentials': masked_credentials, - 'privacy_policy': provider.privacy_policy, - 'custom_disclaimer': provider.custom_disclaimer, - 'labels': labels, - }) + return jsonable_encoder( + { + "schema_type": provider.schema_type, + "schema": provider.schema, + "tools": provider.tools, + "icon": icon, + "description": provider.description, + "credentials": masked_credentials, + "privacy_policy": provider.privacy_policy, + "custom_disclaimer": provider.custom_disclaimer, + "labels": labels, + } + ) @classmethod def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> Union[str, dict]: """ - get the tool icon + get the tool icon - :param tenant_id: the id of the tenant - :param provider_type: the type of the provider - :param provider_id: the id of the provider - :return: + :param tenant_id: the id of the tenant + :param provider_type: the type of the provider + :param provider_id: the id of the provider + :return: """ provider_type = provider_type provider_id = provider_id - if provider_type == 'builtin': - return (dify_config.CONSOLE_API_URL - + "/console/api/workspaces/current/tool-provider/builtin/" - + provider_id - + "/icon") - elif provider_type == 'api': + if provider_type == "builtin": + return ( + dify_config.CONSOLE_API_URL + + "/console/api/workspaces/current/tool-provider/builtin/" + + provider_id + + "/icon" + ) + elif provider_type == "api": try: - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.id == provider_id - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) + .first() + ) return json.loads(provider.icon) except: - return { - "background": "#252525", - "content": "\ud83d\ude01" - } - elif provider_type == 'workflow': - provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == provider_id - ).first() + return {"background": "#252525", "content": "\ud83d\ude01"} + elif provider_type == "workflow": + provider: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) + .first() + ) if provider is None: - raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found') + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") return json.loads(provider.icon) else: diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index b213879e96..83600d21c1 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -56,12 +56,13 @@ class ToolConfigurationManager(BaseModel): if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT: if field_name in credentials: if len(credentials[field_name]) > 6: - credentials[field_name] = \ - credentials[field_name][:2] + \ - '*' * (len(credentials[field_name]) - 4) + \ - credentials[field_name][-2:] + credentials[field_name] = ( + credentials[field_name][:2] + + "*" * (len(credentials[field_name]) - 4) + + credentials[field_name][-2:] + ) else: - credentials[field_name] = '*' * len(credentials[field_name]) + credentials[field_name] = "*" * len(credentials[field_name]) return credentials @@ -72,9 +73,9 @@ class ToolConfigurationManager(BaseModel): return a deep copy of credentials with decrypted values """ cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}', - cache_type=ToolProviderCredentialsCacheType.PROVIDER + tenant_id=self.tenant_id, + identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}", + cache_type=ToolProviderCredentialsCacheType.PROVIDER, ) cached_credentials = cache.get() if cached_credentials: @@ -95,16 +96,18 @@ class ToolConfigurationManager(BaseModel): def delete_tool_credentials_cache(self): cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}', - cache_type=ToolProviderCredentialsCacheType.PROVIDER + tenant_id=self.tenant_id, + identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}", + cache_type=ToolProviderCredentialsCacheType.PROVIDER, ) cache.delete() + class ToolParameterConfigurationManager(BaseModel): """ Tool parameter configuration manager """ + tenant_id: str tool_runtime: Tool provider_name: str @@ -152,15 +155,19 @@ class ToolParameterConfigurationManager(BaseModel): current_parameters = self._merge_parameters() for parameter in current_parameters: - if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT: + if ( + parameter.form == ToolParameter.ToolParameterForm.FORM + and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT + ): if parameter.name in parameters: if len(parameters[parameter.name]) > 6: - parameters[parameter.name] = \ - parameters[parameter.name][:2] + \ - '*' * (len(parameters[parameter.name]) - 4) + \ - parameters[parameter.name][-2:] + parameters[parameter.name] = ( + parameters[parameter.name][:2] + + "*" * (len(parameters[parameter.name]) - 4) + + parameters[parameter.name][-2:] + ) else: - parameters[parameter.name] = '*' * len(parameters[parameter.name]) + parameters[parameter.name] = "*" * len(parameters[parameter.name]) return parameters @@ -176,7 +183,10 @@ class ToolParameterConfigurationManager(BaseModel): parameters = self._deep_copy(parameters) for parameter in current_parameters: - if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT: + if ( + parameter.form == ToolParameter.ToolParameterForm.FORM + and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT + ): if parameter.name in parameters: encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name]) parameters[parameter.name] = encrypted @@ -191,10 +201,10 @@ class ToolParameterConfigurationManager(BaseModel): """ cache = ToolParameterCache( tenant_id=self.tenant_id, - provider=f'{self.provider_type}.{self.provider_name}', + provider=f"{self.provider_type}.{self.provider_name}", tool_name=self.tool_runtime.identity.name, cache_type=ToolParameterCacheType.PARAMETER, - identity_id=self.identity_id + identity_id=self.identity_id, ) cached_parameters = cache.get() if cached_parameters: @@ -205,7 +215,10 @@ class ToolParameterConfigurationManager(BaseModel): has_secret_input = False for parameter in current_parameters: - if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT: + if ( + parameter.form == ToolParameter.ToolParameterForm.FORM + and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT + ): if parameter.name in parameters: try: has_secret_input = True @@ -221,9 +234,9 @@ class ToolParameterConfigurationManager(BaseModel): def delete_tool_parameters_cache(self): cache = ToolParameterCache( tenant_id=self.tenant_id, - provider=f'{self.provider_type}.{self.provider_name}', + provider=f"{self.provider_type}.{self.provider_name}", tool_name=self.tool_runtime.identity.name, cache_type=ToolParameterCacheType.PARAMETER, - identity_id=self.identity_id + identity_id=self.identity_id, ) cache.delete() diff --git a/api/core/tools/utils/feishu_api_utils.py b/api/core/tools/utils/feishu_api_utils.py index e6b288868f..7bb026a383 100644 --- a/api/core/tools/utils/feishu_api_utils.py +++ b/api/core/tools/utils/feishu_api_utils.py @@ -17,8 +17,9 @@ class FeishuRequest: redis_client.setex(feishu_tenant_access_token, res.get("expire"), res.get("tenant_access_token")) return res.get("tenant_access_token") - def _send_request(self, url: str, method: str = "post", require_token: bool = True, payload: dict = None, - params: dict = None): + def _send_request( + self, url: str, method: str = "post", require_token: bool = True, payload: dict = None, params: dict = None + ): headers = { "Content-Type": "application/json", "user-agent": "Dify", @@ -42,10 +43,7 @@ class FeishuRequest: } """ url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/access_token/get_tenant_access_token" - payload = { - "app_id": app_id, - "app_secret": app_secret - } + payload = {"app_id": app_id, "app_secret": app_secret} res = self._send_request(url, require_token=False, payload=payload) return res @@ -76,11 +74,7 @@ class FeishuRequest: def write_document(self, document_id: str, content: str, position: str = "start") -> dict: url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/document/write_document" - payload = { - "document_id": document_id, - "content": content, - "position": position - } + payload = {"document_id": document_id, "content": content, "position": position} res = self._send_request(url, payload=payload) return res.get("data") diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 564b9d3e14..c4983ebc65 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -7,12 +7,12 @@ from core.tools.tool_file_manager import ToolFileManager logger = logging.getLogger(__name__) + class ToolFileMessageTransformer: @classmethod - def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage], - user_id: str, - tenant_id: str, - conversation_id: str) -> list[ToolInvokeMessage]: + def transform_tool_invoke_messages( + cls, messages: list[ToolInvokeMessage], user_id: str, tenant_id: str, conversation_id: str + ) -> list[ToolInvokeMessage]: """ Transform tool message and handle file download """ @@ -27,78 +27,88 @@ class ToolFileMessageTransformer: # try to download image try: file = ToolFileManager.create_file_by_url( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=conversation_id, - file_url=message.message + user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_url=message.message ) url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}' - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) + result.append( + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + ) + ) except Exception as e: logger.exception(e) - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.TEXT, - message=f"Failed to download image: {message.message}, you can try to download it yourself.", - meta=message.meta.copy() if message.meta is not None else {}, - save_as=message.save_as, - )) + result.append( + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.TEXT, + message=f"Failed to download image: {message.message}, you can try to download it yourself.", + meta=message.meta.copy() if message.meta is not None else {}, + save_as=message.save_as, + ) + ) elif message.type == ToolInvokeMessage.MessageType.BLOB: # get mime type and save blob to storage - mimetype = message.meta.get('mime_type', 'octet/stream') + mimetype = message.meta.get("mime_type", "octet/stream") # if message is str, encode it to bytes if isinstance(message.message, str): - message.message = message.message.encode('utf-8') + message.message = message.message.encode("utf-8") file = ToolFileManager.create_file_by_raw( - user_id=user_id, tenant_id=tenant_id, + user_id=user_id, + tenant_id=tenant_id, conversation_id=conversation_id, file_binary=message.message, - mimetype=mimetype + mimetype=mimetype, ) url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype)) # check if file is image - if 'image' in mimetype: - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) + if "image" in mimetype: + result.append( + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + ) + ) else: - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) + result.append( + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + ) + ) elif message.type == ToolInvokeMessage.MessageType.FILE_VAR: - file_var = message.meta.get('file_var') + file_var = message.meta.get("file_var") if file_var: if file_var.transfer_method == FileTransferMethod.TOOL_FILE: url = cls.get_tool_file_url(file_var.related_id, file_var.extension) if file_var.type == FileType.IMAGE: - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) + result.append( + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + ) + ) else: - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) + result.append( + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + ) + ) else: result.append(message) diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 9e8ef47823..4e226810d6 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -1,7 +1,7 @@ """ - For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc. +For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc. - Therefore, a model manager is needed to list/invoke/validate models. +Therefore, a model manager is needed to list/invoke/validate models. """ import json @@ -27,52 +27,49 @@ from models.tools import ToolModelInvoke class InvokeModelError(Exception): pass + class ModelInvocationUtils: @staticmethod def get_max_llm_context_tokens( tenant_id: str, ) -> int: """ - get max llm context tokens of the model + get max llm context tokens of the model """ model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, model_type=ModelType.LLM, + tenant_id=tenant_id, + model_type=ModelType.LLM, ) if not model_instance: - raise InvokeModelError('Model not found') - + raise InvokeModelError("Model not found") + llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) if not schema: - raise InvokeModelError('No model schema found') + raise InvokeModelError("No model schema found") max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) if max_tokens is None: return 2048 - + return max_tokens @staticmethod - def calculate_tokens( - tenant_id: str, - prompt_messages: list[PromptMessage] - ) -> int: + def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int: """ - calculate tokens from prompt messages and model parameters + calculate tokens from prompt messages and model parameters """ # get model instance model_manager = ModelManager() - model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, model_type=ModelType.LLM - ) + model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM) if not model_instance: - raise InvokeModelError('Model not found') - + raise InvokeModelError("Model not found") + # get tokens tokens = model_instance.get_llm_num_tokens(prompt_messages) @@ -80,9 +77,7 @@ class ModelInvocationUtils: @staticmethod def invoke( - user_id: str, tenant_id: str, - tool_type: str, tool_name: str, - prompt_messages: list[PromptMessage] + user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage] ) -> LLMResult: """ invoke model with parameters in user's own context @@ -103,15 +98,16 @@ class ModelInvocationUtils: model_manager = ModelManager() # get model instance model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, model_type=ModelType.LLM, + tenant_id=tenant_id, + model_type=ModelType.LLM, ) # get prompt tokens prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) model_parameters = { - 'temperature': 0.8, - 'top_p': 0.8, + "temperature": 0.8, + "top_p": 0.8, } # create tool model invoke @@ -123,14 +119,14 @@ class ModelInvocationUtils: tool_name=tool_name, model_parameters=json.dumps(model_parameters), prompt_messages=json.dumps(jsonable_encoder(prompt_messages)), - model_response='', + model_response="", prompt_tokens=prompt_tokens, answer_tokens=0, answer_unit_price=0, answer_price_unit=0, provider_response_latency=0, total_price=0, - currency='USD', + currency="USD", ) db.session.add(tool_model_invoke) @@ -140,20 +136,24 @@ class ModelInvocationUtils: response: LLMResult = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=model_parameters, - tools=[], stop=[], stream=False, user=user_id, callbacks=[] + tools=[], + stop=[], + stream=False, + user=user_id, + callbacks=[], ) except InvokeRateLimitError as e: - raise InvokeModelError(f'Invoke rate limit error: {e}') + raise InvokeModelError(f"Invoke rate limit error: {e}") except InvokeBadRequestError as e: - raise InvokeModelError(f'Invoke bad request error: {e}') + raise InvokeModelError(f"Invoke bad request error: {e}") except InvokeConnectionError as e: - raise InvokeModelError(f'Invoke connection error: {e}') + raise InvokeModelError(f"Invoke connection error: {e}") except InvokeAuthorizationError as e: - raise InvokeModelError('Invoke authorization error') + raise InvokeModelError("Invoke authorization error") except InvokeServerUnavailableError as e: - raise InvokeModelError(f'Invoke server unavailable error: {e}') + raise InvokeModelError(f"Invoke server unavailable error: {e}") except Exception as e: - raise InvokeModelError(f'Invoke error: {e}') + raise InvokeModelError(f"Invoke error: {e}") # update tool model invoke tool_model_invoke.model_response = response.message.content diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index f711f7c9f3..654c9acaf9 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -1,4 +1,3 @@ - import re import uuid from json import dumps as json_dumps @@ -16,54 +15,56 @@ from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolPro class ApiBasedToolSchemaParser: @staticmethod - def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]: + def parse_openapi_to_tool_bundle( + openapi: dict, extra_info: dict = None, warning: dict = None + ) -> list[ApiToolBundle]: warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} # set description to extra_info - extra_info['description'] = openapi['info'].get('description', '') + extra_info["description"] = openapi["info"].get("description", "") - if len(openapi['servers']) == 0: - raise ToolProviderNotFoundError('No server found in the openapi yaml.') + if len(openapi["servers"]) == 0: + raise ToolProviderNotFoundError("No server found in the openapi yaml.") - server_url = openapi['servers'][0]['url'] + server_url = openapi["servers"][0]["url"] # list all interfaces interfaces = [] - for path, path_item in openapi['paths'].items(): - methods = ['get', 'post', 'put', 'delete', 'patch', 'head', 'options', 'trace'] + for path, path_item in openapi["paths"].items(): + methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"] for method in methods: if method in path_item: - interfaces.append({ - 'path': path, - 'method': method, - 'operation': path_item[method], - }) + interfaces.append( + { + "path": path, + "method": method, + "operation": path_item[method], + } + ) # get all parameters bundles = [] for interface in interfaces: # convert parameters parameters = [] - if 'parameters' in interface['operation']: - for parameter in interface['operation']['parameters']: + if "parameters" in interface["operation"]: + for parameter in interface["operation"]["parameters"]: tool_parameter = ToolParameter( - name=parameter['name'], - label=I18nObject( - en_US=parameter['name'], - zh_Hans=parameter['name'] - ), + name=parameter["name"], + label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]), human_description=I18nObject( - en_US=parameter.get('description', ''), - zh_Hans=parameter.get('description', '') + en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "") ), type=ToolParameter.ToolParameterType.STRING, - required=parameter.get('required', False), + required=parameter.get("required", False), form=ToolParameter.ToolParameterForm.LLM, - llm_description=parameter.get('description'), - default=parameter['schema']['default'] if 'schema' in parameter and 'default' in parameter['schema'] else None, + llm_description=parameter.get("description"), + default=parameter["schema"]["default"] + if "schema" in parameter and "default" in parameter["schema"] + else None, ) - + # check if there is a type typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter) if typ: @@ -72,44 +73,40 @@ class ApiBasedToolSchemaParser: parameters.append(tool_parameter) # create tool bundle # check if there is a request body - if 'requestBody' in interface['operation']: - request_body = interface['operation']['requestBody'] - if 'content' in request_body: - for content_type, content in request_body['content'].items(): + if "requestBody" in interface["operation"]: + request_body = interface["operation"]["requestBody"] + if "content" in request_body: + for content_type, content in request_body["content"].items(): # if there is a reference, get the reference and overwrite the content - if 'schema' not in content: + if "schema" not in content: continue - if '$ref' in content['schema']: + if "$ref" in content["schema"]: # get the reference root = openapi - reference = content['schema']['$ref'].split('/')[1:] + reference = content["schema"]["$ref"].split("/")[1:] for ref in reference: root = root[ref] # overwrite the content - interface['operation']['requestBody']['content'][content_type]['schema'] = root + interface["operation"]["requestBody"]["content"][content_type]["schema"] = root # parse body parameters - if 'schema' in interface['operation']['requestBody']['content'][content_type]: - body_schema = interface['operation']['requestBody']['content'][content_type]['schema'] - required = body_schema.get('required', []) - properties = body_schema.get('properties', {}) + if "schema" in interface["operation"]["requestBody"]["content"][content_type]: + body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"] + required = body_schema.get("required", []) + properties = body_schema.get("properties", {}) for name, property in properties.items(): tool = ToolParameter( name=name, - label=I18nObject( - en_US=name, - zh_Hans=name - ), + label=I18nObject(en_US=name, zh_Hans=name), human_description=I18nObject( - en_US=property.get('description', ''), - zh_Hans=property.get('description', '') + en_US=property.get("description", ""), zh_Hans=property.get("description", "") ), type=ToolParameter.ToolParameterType.STRING, required=name in required, form=ToolParameter.ToolParameterForm.LLM, - llm_description=property.get('description', ''), - default=property.get('default', None), + llm_description=property.get("description", ""), + default=property.get("default", None), ) # check if there is a type @@ -127,172 +124,176 @@ class ApiBasedToolSchemaParser: parameters_count[parameter.name] += 1 for name, count in parameters_count.items(): if count > 1: - warning['duplicated_parameter'] = f'Parameter {name} is duplicated.' + warning["duplicated_parameter"] = f"Parameter {name} is duplicated." # check if there is a operation id, use $path_$method as operation id if not - if 'operationId' not in interface['operation']: + if "operationId" not in interface["operation"]: # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ - path = interface['path'] - if interface['path'].startswith('/'): - path = interface['path'][1:] + path = interface["path"] + if interface["path"].startswith("/"): + path = interface["path"][1:] # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ - path = re.sub(r'[^a-zA-Z0-9_-]', '', path) + path = re.sub(r"[^a-zA-Z0-9_-]", "", path) if not path: path = str(uuid.uuid4()) - - interface['operation']['operationId'] = f'{path}_{interface["method"]}' - bundles.append(ApiToolBundle( - server_url=server_url + interface['path'], - method=interface['method'], - summary=interface['operation']['description'] if 'description' in interface['operation'] else - interface['operation'].get('summary', None), - operation_id=interface['operation']['operationId'], - parameters=parameters, - author='', - icon=None, - openapi=interface['operation'], - )) + interface["operation"]["operationId"] = f'{path}_{interface["method"]}' + + bundles.append( + ApiToolBundle( + server_url=server_url + interface["path"], + method=interface["method"], + summary=interface["operation"]["description"] + if "description" in interface["operation"] + else interface["operation"].get("summary", None), + operation_id=interface["operation"]["operationId"], + parameters=parameters, + author="", + icon=None, + openapi=interface["operation"], + ) + ) return bundles - + @staticmethod def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType: parameter = parameter or {} typ = None - if 'type' in parameter: - typ = parameter['type'] - elif 'schema' in parameter and 'type' in parameter['schema']: - typ = parameter['schema']['type'] - - if typ == 'integer' or typ == 'number': + if "type" in parameter: + typ = parameter["type"] + elif "schema" in parameter and "type" in parameter["schema"]: + typ = parameter["schema"]["type"] + + if typ == "integer" or typ == "number": return ToolParameter.ToolParameterType.NUMBER - elif typ == 'boolean': + elif typ == "boolean": return ToolParameter.ToolParameterType.BOOLEAN - elif typ == 'string': + elif typ == "string": return ToolParameter.ToolParameterType.STRING @staticmethod - def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]: + def parse_openapi_yaml_to_tool_bundle( + yaml: str, extra_info: dict = None, warning: dict = None + ) -> list[ApiToolBundle]: """ - parse openapi yaml to tool bundle + parse openapi yaml to tool bundle - :param yaml: the yaml string - :return: the tool bundle + :param yaml: the yaml string + :return: the tool bundle """ warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} openapi: dict = safe_load(yaml) if openapi is None: - raise ToolApiSchemaError('Invalid openapi yaml.') + raise ToolApiSchemaError("Invalid openapi yaml.") return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) - + @staticmethod def parse_swagger_to_openapi(swagger: dict, extra_info: dict = None, warning: dict = None) -> dict: """ - parse swagger to openapi + parse swagger to openapi - :param swagger: the swagger dict - :return: the openapi dict + :param swagger: the swagger dict + :return: the openapi dict """ # convert swagger to openapi - info = swagger.get('info', { - 'title': 'Swagger', - 'description': 'Swagger', - 'version': '1.0.0' - }) + info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"}) - servers = swagger.get('servers', []) + servers = swagger.get("servers", []) if len(servers) == 0: - raise ToolApiSchemaError('No server found in the swagger yaml.') + raise ToolApiSchemaError("No server found in the swagger yaml.") openapi = { - 'openapi': '3.0.0', - 'info': { - 'title': info.get('title', 'Swagger'), - 'description': info.get('description', 'Swagger'), - 'version': info.get('version', '1.0.0') + "openapi": "3.0.0", + "info": { + "title": info.get("title", "Swagger"), + "description": info.get("description", "Swagger"), + "version": info.get("version", "1.0.0"), }, - 'servers': swagger['servers'], - 'paths': {}, - 'components': { - 'schemas': {} - } + "servers": swagger["servers"], + "paths": {}, + "components": {"schemas": {}}, } # check paths - if 'paths' not in swagger or len(swagger['paths']) == 0: - raise ToolApiSchemaError('No paths found in the swagger yaml.') + if "paths" not in swagger or len(swagger["paths"]) == 0: + raise ToolApiSchemaError("No paths found in the swagger yaml.") # convert paths - for path, path_item in swagger['paths'].items(): - openapi['paths'][path] = {} + for path, path_item in swagger["paths"].items(): + openapi["paths"][path] = {} for method, operation in path_item.items(): - if 'operationId' not in operation: - raise ToolApiSchemaError(f'No operationId found in operation {method} {path}.') - - if ('summary' not in operation or len(operation['summary']) == 0) and \ - ('description' not in operation or len(operation['description']) == 0): - warning['missing_summary'] = f'No summary or description found in operation {method} {path}.' - - openapi['paths'][path][method] = { - 'operationId': operation['operationId'], - 'summary': operation.get('summary', ''), - 'description': operation.get('description', ''), - 'parameters': operation.get('parameters', []), - 'responses': operation.get('responses', {}), + if "operationId" not in operation: + raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.") + + if ("summary" not in operation or len(operation["summary"]) == 0) and ( + "description" not in operation or len(operation["description"]) == 0 + ): + warning["missing_summary"] = f"No summary or description found in operation {method} {path}." + + openapi["paths"][path][method] = { + "operationId": operation["operationId"], + "summary": operation.get("summary", ""), + "description": operation.get("description", ""), + "parameters": operation.get("parameters", []), + "responses": operation.get("responses", {}), } - if 'requestBody' in operation: - openapi['paths'][path][method]['requestBody'] = operation['requestBody'] + if "requestBody" in operation: + openapi["paths"][path][method]["requestBody"] = operation["requestBody"] # convert definitions - for name, definition in swagger['definitions'].items(): - openapi['components']['schemas'][name] = definition + for name, definition in swagger["definitions"].items(): + openapi["components"]["schemas"][name] = definition return openapi @staticmethod - def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]: + def parse_openai_plugin_json_to_tool_bundle( + json: str, extra_info: dict = None, warning: dict = None + ) -> list[ApiToolBundle]: """ - parse openapi plugin yaml to tool bundle + parse openapi plugin yaml to tool bundle - :param json: the json string - :return: the tool bundle + :param json: the json string + :return: the tool bundle """ warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} try: openai_plugin = json_loads(json) - api = openai_plugin['api'] - api_url = api['url'] - api_type = api['type'] + api = openai_plugin["api"] + api_url = api["url"] + api_type = api["type"] except: - raise ToolProviderNotFoundError('Invalid openai plugin json.') - - if api_type != 'openapi': - raise ToolNotSupportedError('Only openapi is supported now.') - + raise ToolProviderNotFoundError("Invalid openai plugin json.") + + if api_type != "openapi": + raise ToolNotSupportedError("Only openapi is supported now.") + # get openapi yaml - response = get(api_url, headers={ - 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) ' - }, timeout=5) + response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5) if response.status_code != 200: - raise ToolProviderNotFoundError('cannot get openapi yaml from url.') - - return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(response.text, extra_info=extra_info, warning=warning) - - @staticmethod - def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiToolBundle], str]: - """ - auto parse to tool bundle + raise ToolProviderNotFoundError("cannot get openapi yaml from url.") - :param content: the content - :return: tools bundle, schema_type + return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle( + response.text, extra_info=extra_info, warning=warning + ) + + @staticmethod + def auto_parse_to_tool_bundle( + content: str, extra_info: dict = None, warning: dict = None + ) -> tuple[list[ApiToolBundle], str]: + """ + auto parse to tool bundle + + :param content: the content + :return: tools bundle, schema_type """ warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} @@ -301,7 +302,7 @@ class ApiBasedToolSchemaParser: loaded_content = None json_error = None yaml_error = None - + try: loaded_content = json_loads(content) except JSONDecodeError as e: @@ -313,34 +314,46 @@ class ApiBasedToolSchemaParser: except YAMLError as e: yaml_error = e if loaded_content is None: - raise ToolApiSchemaError(f'Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)}, yaml error: {str(yaml_error)}') + raise ToolApiSchemaError( + f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)}, yaml error: {str(yaml_error)}" + ) swagger_error = None openapi_error = None openapi_plugin_error = None schema_type = None - + try: - openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(loaded_content, extra_info=extra_info, warning=warning) + openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle( + loaded_content, extra_info=extra_info, warning=warning + ) schema_type = ApiProviderSchemaType.OPENAPI.value return openapi, schema_type except ToolApiSchemaError as e: openapi_error = e - + # openai parse error, fallback to swagger try: - converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(loaded_content, extra_info=extra_info, warning=warning) + converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi( + loaded_content, extra_info=extra_info, warning=warning + ) schema_type = ApiProviderSchemaType.SWAGGER.value - return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(converted_swagger, extra_info=extra_info, warning=warning), schema_type + return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle( + converted_swagger, extra_info=extra_info, warning=warning + ), schema_type except ToolApiSchemaError as e: swagger_error = e - + # swagger parse error, fallback to openai plugin try: - openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(json_dumps(loaded_content), extra_info=extra_info, warning=warning) + openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( + json_dumps(loaded_content), extra_info=extra_info, warning=warning + ) return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value except ToolNotSupportedError as e: # maybe it's not plugin at all openapi_plugin_error = e - raise ToolApiSchemaError(f'Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)}, openapi plugin error: {str(openapi_plugin_error)}') + raise ToolApiSchemaError( + f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)}, openapi plugin error: {str(openapi_plugin_error)}" + ) diff --git a/api/core/tools/utils/tool_parameter_converter.py b/api/core/tools/utils/tool_parameter_converter.py index 6f88eeaa0a..6f7610651c 100644 --- a/api/core/tools/utils/tool_parameter_converter.py +++ b/api/core/tools/utils/tool_parameter_converter.py @@ -7,16 +7,18 @@ class ToolParameterConverter: @staticmethod def get_parameter_type(parameter_type: str | ToolParameter.ToolParameterType) -> str: match parameter_type: - case ToolParameter.ToolParameterType.STRING \ - | ToolParameter.ToolParameterType.SECRET_INPUT \ - | ToolParameter.ToolParameterType.SELECT: - return 'string' + case ( + ToolParameter.ToolParameterType.STRING + | ToolParameter.ToolParameterType.SECRET_INPUT + | ToolParameter.ToolParameterType.SELECT + ): + return "string" case ToolParameter.ToolParameterType.BOOLEAN: - return 'boolean' + return "boolean" case ToolParameter.ToolParameterType.NUMBER: - return 'number' + return "number" case _: raise ValueError(f"Unsupported parameter type {parameter_type}") @@ -26,11 +28,13 @@ class ToolParameterConverter: # convert tool parameter config to correct type try: match parameter_type: - case ToolParameter.ToolParameterType.STRING \ - | ToolParameter.ToolParameterType.SECRET_INPUT \ - | ToolParameter.ToolParameterType.SELECT: + case ( + ToolParameter.ToolParameterType.STRING + | ToolParameter.ToolParameterType.SECRET_INPUT + | ToolParameter.ToolParameterType.SELECT + ): if value is None: - return '' + return "" else: return value if isinstance(value, str) else str(value) @@ -41,9 +45,9 @@ class ToolParameterConverter: # Allowed YAML boolean value strings: https://yaml.org/type/bool.html # and also '0' for False and '1' for True match value.lower(): - case 'true' | 'yes' | 'y' | '1': + case "true" | "yes" | "y" | "1": return True - case 'false' | 'no' | 'n' | '0': + case "false" | "no" | "n" | "0": return False case _: return bool(value) @@ -53,8 +57,8 @@ class ToolParameterConverter: case ToolParameter.ToolParameterType.NUMBER: if isinstance(value, int) | isinstance(value, float): return value - elif isinstance(value, str) and value != '': - if '.' in value: + elif isinstance(value, str) and value != "": + if "." in value: return float(value) else: return int(value) diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index a461328ae6..3639b5fff7 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -32,7 +32,7 @@ TEXT: def page_result(text: str, cursor: int, max_length: int) -> str: """Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" - return text[cursor: cursor + max_length] + return text[cursor : cursor + max_length] def get_url(url: str, user_agent: str = None) -> str: @@ -49,15 +49,15 @@ def get_url(url: str, user_agent: str = None) -> str: if response.status_code == 200: # check content-type - content_type = response.headers.get('Content-Type') + content_type = response.headers.get("Content-Type") if content_type: - main_content_type = response.headers.get('Content-Type').split(';')[0].strip() + main_content_type = response.headers.get("Content-Type").split(";")[0].strip() else: - content_disposition = response.headers.get('Content-Disposition', '') + content_disposition = response.headers.get("Content-Disposition", "") filename_match = re.search(r'filename="([^"]+)"', content_disposition) if filename_match: filename = unquote(filename_match.group(1)) - extension = re.search(r'\.(\w+)$', filename) + extension = re.search(r"\.(\w+)$", filename) if extension: main_content_type = mimetypes.guess_type(filename)[0] @@ -78,7 +78,7 @@ def get_url(url: str, user_agent: str = None) -> str: # Detect encoding using chardet detected_encoding = chardet.detect(response.content) - encoding = detected_encoding['encoding'] + encoding = detected_encoding["encoding"] if encoding: try: content = response.content.decode(encoding) @@ -89,29 +89,29 @@ def get_url(url: str, user_agent: str = None) -> str: a = extract_using_readabilipy(content) - if not a['plain_text'] or not a['plain_text'].strip(): - return '' + if not a["plain_text"] or not a["plain_text"].strip(): + return "" res = FULL_TEMPLATE.format( - title=a['title'], - authors=a['byline'], - publish_date=a['date'], + title=a["title"], + authors=a["byline"], + publish_date=a["date"], top_image="", - text=a['plain_text'] if a['plain_text'] else "", + text=a["plain_text"] if a["plain_text"] else "", ) return res def extract_using_readabilipy(html): - with tempfile.NamedTemporaryFile(delete=False, mode='w+') as f_html: + with tempfile.NamedTemporaryFile(delete=False, mode="w+") as f_html: f_html.write(html) f_html.close() html_path = f_html.name # Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file article_json_path = html_path + ".json" - jsdir = os.path.join(find_module_path('readabilipy'), 'javascript') + jsdir = os.path.join(find_module_path("readabilipy"), "javascript") with chdir(jsdir): subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path]) @@ -129,7 +129,7 @@ def extract_using_readabilipy(html): "date": None, "content": None, "plain_content": None, - "plain_text": None + "plain_text": None, } # Populate article fields from readability fields where present if input_json: @@ -145,7 +145,7 @@ def extract_using_readabilipy(html): article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"]) if input_json.get("textContent"): article_json["plain_text"] = input_json["textContent"] - article_json["plain_text"] = re.sub(r'\n\s*\n', '\n', article_json["plain_text"]) + article_json["plain_text"] = re.sub(r"\n\s*\n", "\n", article_json["plain_text"]) return article_json @@ -158,6 +158,7 @@ def find_module_path(module_name): return None + @contextmanager def chdir(path): """Change directory in context and return to original on exit""" @@ -172,12 +173,14 @@ def chdir(path): def extract_text_blocks_as_plain_text(paragraph_html): # Load article as DOM - soup = BeautifulSoup(paragraph_html, 'html.parser') + soup = BeautifulSoup(paragraph_html, "html.parser") # Select all lists - list_elements = soup.find_all(['ul', 'ol']) + list_elements = soup.find_all(["ul", "ol"]) # Prefix text in all list items with "* " and make lists paragraphs for list_element in list_elements: - plain_items = "".join(list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all('li')]))) + plain_items = "".join( + list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all("li")])) + ) list_element.string = plain_items list_element.name = "p" # Select all text blocks @@ -189,8 +192,8 @@ def extract_text_blocks_as_plain_text(paragraph_html): def plain_text_leaf_node(element): - # Extract all text, stripped of any child HTML elements and normalise it - plain_text = normalise_text(element.get_text()) + # Extract all text, stripped of any child HTML elements and normalize it + plain_text = normalize_text(element.get_text()) if plain_text != "" and element.name == "li": plain_text = "* {}, ".format(plain_text) if plain_text == "": @@ -204,7 +207,7 @@ def plain_text_leaf_node(element): def plain_content(readability_content, content_digests, node_indexes): # Load article as DOM - soup = BeautifulSoup(readability_content, 'html.parser') + soup = BeautifulSoup(readability_content, "html.parser") # Make all elements plain elements = plain_elements(soup.contents, content_digests, node_indexes) if node_indexes: @@ -217,8 +220,7 @@ def plain_content(readability_content, content_digests, node_indexes): def plain_elements(elements, content_digests, node_indexes): # Get plain content versions of all elements - elements = [plain_element(element, content_digests, node_indexes) - for element in elements] + elements = [plain_element(element, content_digests, node_indexes) for element in elements] if content_digests: # Add content digest attribute to nodes elements = [add_content_digest(element) for element in elements] @@ -231,8 +233,8 @@ def plain_element(element, content_digests, node_indexes): # For leaf node elements, extract the text content, discarding any HTML tags # 1. Get element contents as text plain_text = element.get_text() - # 2. Normalise the extracted text string to a canonical representation - plain_text = normalise_text(plain_text) + # 2. Normalize the extracted text string to a canonical representation + plain_text = normalize_text(plain_text) # 3. Update element content to be plain text element.string = plain_text elif is_text(element): @@ -243,7 +245,7 @@ def plain_element(element, content_digests, node_indexes): element = type(element)("") else: plain_text = element.string - plain_text = normalise_text(plain_text) + plain_text = normalize_text(plain_text) element = type(element)(plain_text) else: # If not a leaf node or leaf type call recursively on child nodes, replacing @@ -258,21 +260,19 @@ def add_node_indexes(element, node_index="0"): # Add index to current element element["data-node-index"] = node_index # Add index to child elements - for local_idx, child in enumerate( - [c for c in element.contents if not is_text(c)], start=1): + for local_idx, child in enumerate([c for c in element.contents if not is_text(c)], start=1): # Can't add attributes to leaf string types - child_index = "{stem}.{local}".format( - stem=node_index, local=local_idx) + child_index = "{stem}.{local}".format(stem=node_index, local=local_idx) add_node_indexes(child, node_index=child_index) return element -def normalise_text(text): - """Normalise unicode and whitespace.""" - # Normalise unicode first to try and standardise whitespace characters as much as possible before normalising them +def normalize_text(text): + """Normalize unicode and whitespace.""" + # Normalize unicode first to try and standardize whitespace characters as much as possible before normalizing them text = strip_control_characters(text) - text = normalise_unicode(text) - text = normalise_whitespace(text) + text = normalize_unicode(text) + text = normalize_whitespace(text) return text @@ -284,29 +284,35 @@ def strip_control_characters(text): # [Cn]: Other, Not Assigned # [Co]: Other, Private Use # [Cs]: Other, Surrogate - control_chars = {'Cc', 'Cf', 'Cn', 'Co', 'Cs'} - retained_chars = ['\t', '\n', '\r', '\f'] + control_chars = {"Cc", "Cf", "Cn", "Co", "Cs"} + retained_chars = ["\t", "\n", "\r", "\f"] # Remove non-printing control characters - return "".join(["" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char for char in text]) + return "".join( + [ + "" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char + for char in text + ] + ) -def normalise_unicode(text): - """Normalise unicode such that things that are visually equivalent map to the same unicode string where possible.""" +def normalize_unicode(text): + """Normalize unicode such that things that are visually equivalent map to the same unicode string where possible.""" normal_form = "NFKC" text = unicodedata.normalize(normal_form, text) return text -def normalise_whitespace(text): +def normalize_whitespace(text): """Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed.""" text = regex.sub(r"\s+", " ", text) # Remove leading and trailing whitespace text = text.strip() return text + def is_leaf(element): - return (element.name in ['p', 'li']) + return element.name in ["p", "li"] def is_text(element): @@ -330,7 +336,7 @@ def content_digest(element): if trimmed_string == "": digest = "" else: - digest = hashlib.sha256(trimmed_string.encode('utf-8')).hexdigest() + digest = hashlib.sha256(trimmed_string.encode("utf-8")).hexdigest() else: contents = element.contents num_contents = len(contents) @@ -343,9 +349,8 @@ def content_digest(element): else: # Build content digest from the "non-empty" digests of child nodes digest = hashlib.sha256() - child_digests = list( - filter(lambda x: x != "", [content_digest(content) for content in contents])) + child_digests = list(filter(lambda x: x != "", [content_digest(content) for content in contents])) for child in child_digests: - digest.update(child.encode('utf-8')) + digest.update(child.encode("utf-8")) digest = digest.hexdigest() return digest diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index ff5505bbbf..94d9fd9eb9 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -10,27 +10,25 @@ class WorkflowToolConfigurationUtils: """ for configuration in configurations: if not WorkflowToolParameterConfiguration(**configuration): - raise ValueError('invalid parameter configuration') + raise ValueError("invalid parameter configuration") @classmethod def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]: """ get workflow graph variables """ - nodes = graph.get('nodes', []) - start_node = next(filter(lambda x: x.get('data', {}).get('type') == 'start', nodes), None) + nodes = graph.get("nodes", []) + start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None) if not start_node: return [] - return [ - VariableEntity(**variable) for variable in start_node.get('data', {}).get('variables', []) - ] - + return [VariableEntity(**variable) for variable in start_node.get("data", {}).get("variables", [])] + @classmethod - def check_is_synced(cls, - variables: list[VariableEntity], - tool_configurations: list[WorkflowToolParameterConfiguration]) -> None: + def check_is_synced( + cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration] + ) -> None: """ check is synced @@ -39,10 +37,10 @@ class WorkflowToolConfigurationUtils: variable_names = [variable.variable for variable in variables] if len(tool_configurations) != len(variables): - raise ValueError('parameter configuration mismatch, please republish the tool to update') - + raise ValueError("parameter configuration mismatch, please republish the tool to update") + for parameter in tool_configurations: if parameter.name not in variable_names: - raise ValueError('parameter configuration mismatch, please republish the tool to update') + raise ValueError("parameter configuration mismatch, please republish the tool to update") - return True \ No newline at end of file + return True diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py index f751c43096..bcb061376d 100644 --- a/api/core/tools/utils/yaml_utils.py +++ b/api/core/tools/utils/yaml_utils.py @@ -18,12 +18,12 @@ def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any :return: an object of the YAML content """ try: - with open(file_path, encoding='utf-8') as yaml_file: + with open(file_path, encoding="utf-8") as yaml_file: try: yaml_content = yaml.safe_load(yaml_file) return yaml_content if yaml_content else default_value except Exception as e: - raise YAMLError(f'Failed to load YAML file {file_path}: {e}') + raise YAMLError(f"Failed to load YAML file {file_path}: {e}") except Exception as e: if ignore_error: return default_value diff --git a/api/core/workflow/callbacks/base_workflow_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py index 6db8adf4c2..83086d1afc 100644 --- a/api/core/workflow/callbacks/base_workflow_callback.py +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -1,116 +1,12 @@ from abc import ABC, abstractmethod -from typing import Any, Optional -from core.app.entities.queue_entities import AppQueueEvent -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType +from core.workflow.graph_engine.entities.event import GraphEngineEvent class WorkflowCallback(ABC): @abstractmethod - def on_workflow_run_started(self) -> None: + def on_event(self, event: GraphEngineEvent) -> None: """ - Workflow run started - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_run_succeeded(self) -> None: - """ - Workflow run succeeded - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_run_failed(self, error: str) -> None: - """ - Workflow run failed - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_node_execute_started(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - node_run_index: int = 1, - predecessor_node_id: Optional[str] = None) -> None: - """ - Workflow node execute started - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_node_execute_succeeded(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None) -> None: - """ - Workflow node execute succeeded - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_node_execute_failed(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - error: str, - inputs: Optional[dict] = None, - outputs: Optional[dict] = None, - process_data: Optional[dict] = None) -> None: - """ - Workflow node execute failed - """ - raise NotImplementedError - - @abstractmethod - def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: - """ - Publish text chunk - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_iteration_started(self, - node_id: str, - node_type: NodeType, - node_run_index: int = 1, - node_data: Optional[BaseNodeData] = None, - inputs: Optional[dict] = None, - predecessor_node_id: Optional[str] = None, - metadata: Optional[dict] = None) -> None: - """ - Publish iteration started - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_iteration_next(self, node_id: str, - node_type: NodeType, - index: int, - node_run_index: int, - output: Optional[Any], - ) -> None: - """ - Publish iteration next - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_iteration_completed(self, node_id: str, - node_type: NodeType, - node_run_index: int, - outputs: dict) -> None: - """ - Publish iteration completed - """ - raise NotImplementedError - - @abstractmethod - def on_event(self, event: AppQueueEvent) -> None: - """ - Publish event + Published event """ raise NotImplementedError diff --git a/api/core/workflow/entities/base_node_data_entities.py b/api/core/workflow/entities/base_node_data_entities.py index 6bf0c11c7d..2a864dd7a8 100644 --- a/api/core/workflow/entities/base_node_data_entities.py +++ b/api/core/workflow/entities/base_node_data_entities.py @@ -8,8 +8,10 @@ class BaseNodeData(ABC, BaseModel): title: str desc: Optional[str] = None + class BaseIterationNodeData(BaseNodeData): - start_node_id: str + start_node_id: Optional[str] = None + class BaseIterationState(BaseModel): iteration_node_id: str @@ -19,4 +21,4 @@ class BaseIterationState(BaseModel): class MetaData(BaseModel): pass - metadata: MetaData \ No newline at end of file + metadata: MetaData diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 025453567b..5353b99ed3 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -1,9 +1,9 @@ -from collections.abc import Mapping from enum import Enum from typing import Any, Optional from pydantic import BaseModel +from core.model_runtime.entities.llm_entities import LLMUsage from models import WorkflowNodeExecutionStatus @@ -12,27 +12,28 @@ class NodeType(Enum): Node Types. """ - START = 'start' - END = 'end' - ANSWER = 'answer' - LLM = 'llm' - KNOWLEDGE_RETRIEVAL = 'knowledge-retrieval' - IF_ELSE = 'if-else' - CODE = 'code' - TEMPLATE_TRANSFORM = 'template-transform' - QUESTION_CLASSIFIER = 'question-classifier' - HTTP_REQUEST = 'http-request' - TOOL = 'tool' - VARIABLE_AGGREGATOR = 'variable-aggregator' + START = "start" + END = "end" + ANSWER = "answer" + LLM = "llm" + KNOWLEDGE_RETRIEVAL = "knowledge-retrieval" + IF_ELSE = "if-else" + CODE = "code" + TEMPLATE_TRANSFORM = "template-transform" + QUESTION_CLASSIFIER = "question-classifier" + HTTP_REQUEST = "http-request" + TOOL = "tool" + VARIABLE_AGGREGATOR = "variable-aggregator" # TODO: merge this into VARIABLE_AGGREGATOR - VARIABLE_ASSIGNER = 'variable-assigner' - LOOP = 'loop' - ITERATION = 'iteration' - PARAMETER_EXTRACTOR = 'parameter-extractor' - CONVERSATION_VARIABLE_ASSIGNER = 'assigner' + VARIABLE_ASSIGNER = "variable-assigner" + LOOP = "loop" + ITERATION = "iteration" + ITERATION_START = "iteration-start" # fake start node for iteration + PARAMETER_EXTRACTOR = "parameter-extractor" + CONVERSATION_VARIABLE_ASSIGNER = "assigner" @classmethod - def value_of(cls, value: str) -> 'NodeType': + def value_of(cls, value: str) -> "NodeType": """ Get value of given node type. @@ -42,7 +43,7 @@ class NodeType(Enum): for node_type in cls: if node_type.value == value: return node_type - raise ValueError(f'invalid node type value {value}') + raise ValueError(f"invalid node type value {value}") class NodeRunMetadataKey(Enum): @@ -50,12 +51,16 @@ class NodeRunMetadataKey(Enum): Node Run Metadata Key. """ - TOTAL_TOKENS = 'total_tokens' - TOTAL_PRICE = 'total_price' - CURRENCY = 'currency' - TOOL_INFO = 'tool_info' - ITERATION_ID = 'iteration_id' - ITERATION_INDEX = 'iteration_index' + TOTAL_TOKENS = "total_tokens" + TOTAL_PRICE = "total_price" + CURRENCY = "currency" + TOOL_INFO = "tool_info" + ITERATION_ID = "iteration_id" + ITERATION_INDEX = "iteration_index" + PARALLEL_ID = "parallel_id" + PARALLEL_START_NODE_ID = "parallel_start_node_id" + PARENT_PARALLEL_ID = "parent_parallel_id" + PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id" class NodeRunResult(BaseModel): @@ -65,11 +70,33 @@ class NodeRunResult(BaseModel): status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING - inputs: Optional[Mapping[str, Any]] = None # node inputs - process_data: Optional[dict] = None # process data - outputs: Optional[Mapping[str, Any]] = None # node outputs + inputs: Optional[dict[str, Any]] = None # node inputs + process_data: Optional[dict[str, Any]] = None # process data + outputs: Optional[dict[str, Any]] = None # node outputs metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata + llm_usage: Optional[LLMUsage] = None # llm usage edge_source_handle: Optional[str] = None # source handle id of node with multiple branches error: Optional[str] = None # error message if status is failed + + +class UserFrom(Enum): + """ + User from + """ + + ACCOUNT = "account" + END_USER = "end-user" + + @classmethod + def value_of(cls, value: str) -> "UserFrom": + """ + Value of + :param value: value + :return: + """ + for item in cls: + if item.value == value: + return item + raise ValueError(f"Invalid value: {value}") diff --git a/api/core/workflow/entities/variable_entities.py b/api/core/workflow/entities/variable_entities.py index 19d9af2a61..1dfb1852f8 100644 --- a/api/core/workflow/entities/variable_entities.py +++ b/api/core/workflow/entities/variable_entities.py @@ -5,5 +5,6 @@ class VariableSelector(BaseModel): """ Variable Selector. """ + variable: str value_selector: list[str] diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 8120b2ac78..b94b7f7198 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -2,6 +2,7 @@ from collections import defaultdict from collections.abc import Mapping, Sequence from typing import Any, Union +from pydantic import BaseModel, Field, model_validator from typing_extensions import deprecated from core.app.segments import Segment, Variable, factory @@ -16,43 +17,48 @@ ENVIRONMENT_VARIABLE_NODE_ID = "env" CONVERSATION_VARIABLE_NODE_ID = "conversation" -class VariablePool: - def __init__( - self, - system_variables: Mapping[SystemVariableKey, Any], - user_inputs: Mapping[str, Any], - environment_variables: Sequence[Variable], - conversation_variables: Sequence[Variable] | None = None, - ) -> None: - # system variables - # for example: - # { - # 'query': 'abc', - # 'files': [] - # } +class VariablePool(BaseModel): + # Variable dictionary is a dictionary for looking up variables by their selector. + # The first element of the selector is the node id, it's the first-level key in the dictionary. + # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the + # elements of the selector except the first one. + variable_dictionary: dict[str, dict[int, Segment]] = Field( + description="Variables mapping", default=defaultdict(dict) + ) - # Varaible dictionary is a dictionary for looking up variables by their selector. - # The first element of the selector is the node id, it's the first-level key in the dictionary. - # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the - # elements of the selector except the first one. - self._variable_dictionary: dict[str, dict[int, Segment]] = defaultdict(dict) + # TODO: This user inputs is not used for pool. + user_inputs: Mapping[str, Any] = Field( + description="User inputs", + ) - # TODO: This user inputs is not used for pool. - self.user_inputs = user_inputs + system_variables: Mapping[SystemVariableKey, Any] = Field( + description="System variables", + ) + environment_variables: Sequence[Variable] = Field(description="Environment variables.", default_factory=list) + + conversation_variables: Sequence[Variable] | None = None + + @model_validator(mode="after") + def val_model_after(self): + """ + Append system variables + :return: + """ # Add system variables to the variable pool - self.system_variables = system_variables - for key, value in system_variables.items(): + for key, value in self.system_variables.items(): self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value) # Add environment variables to the variable pool - for var in environment_variables: + for var in self.environment_variables or []: self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) # Add conversation variables to the variable pool - for var in conversation_variables or []: + for var in self.conversation_variables or []: self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var) + return self + def add(self, selector: Sequence[str], value: Any, /) -> None: """ Adds a variable to the variable pool. @@ -79,7 +85,7 @@ class VariablePool: v = factory.build_segment(value) hash_key = hash(tuple(selector[1:])) - self._variable_dictionary[selector[0]][hash_key] = v + self.variable_dictionary[selector[0]][hash_key] = v def get(self, selector: Sequence[str], /) -> Segment | None: """ @@ -97,7 +103,7 @@ class VariablePool: if len(selector) < 2: raise ValueError("Invalid selector") hash_key = hash(tuple(selector[1:])) - value = self._variable_dictionary[selector[0]].get(hash_key) + value = self.variable_dictionary[selector[0]].get(hash_key) return value @@ -118,7 +124,7 @@ class VariablePool: if len(selector) < 2: raise ValueError("Invalid selector") hash_key = hash(tuple(selector[1:])) - value = self._variable_dictionary[selector[0]].get(hash_key) + value = self.variable_dictionary[selector[0]].get(hash_key) return value.to_object() if value else None def remove(self, selector: Sequence[str], /): @@ -134,7 +140,19 @@ class VariablePool: if not selector: return if len(selector) == 1: - self._variable_dictionary[selector[0]] = {} + self.variable_dictionary[selector[0]] = {} return hash_key = hash(tuple(selector[1:])) - self._variable_dictionary[selector[0]].pop(hash_key, None) + self.variable_dictionary[selector[0]].pop(hash_key, None) + + def remove_node(self, node_id: str, /): + """ + Remove all variables associated with a given node id. + + Args: + node_id (str): The node id to remove. + + Returns: + None + """ + self.variable_dictionary.pop(node_id, None) diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index 9b35b8df8a..0a1eb57de4 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -46,13 +46,16 @@ class WorkflowRunState: current_iteration_state: Optional[BaseIterationState] - def __init__(self, workflow: Workflow, - start_at: float, - variable_pool: VariablePool, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, - workflow_call_depth: int): + def __init__( + self, + workflow: Workflow, + start_at: float, + variable_pool: VariablePool, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + workflow_call_depth: int, + ): self.workflow_id = workflow.id self.tenant_id = workflow.tenant_id self.app_id = workflow.app_id @@ -66,8 +69,7 @@ class WorkflowRunState: self.variable_pool = variable_pool self.total_tokens = 0 - self.workflow_nodes_and_results = [] - self.current_iteration_state = None self.workflow_node_steps = 1 - self.workflow_node_runs = [] \ No newline at end of file + self.workflow_node_runs = [] + self.current_iteration_state = None diff --git a/api/core/workflow/errors.py b/api/core/workflow/errors.py index fe79fadf66..07cbcd981e 100644 --- a/api/core/workflow/errors.py +++ b/api/core/workflow/errors.py @@ -1,10 +1,8 @@ -from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.base_node import BaseNode class WorkflowNodeRunFailedError(Exception): - def __init__(self, node_id: str, node_type: NodeType, node_title: str, error: str): - self.node_id = node_id - self.node_type = node_type - self.node_title = node_title + def __init__(self, node_instance: BaseNode, error: str): + self.node_instance = node_instance self.error = error - super().__init__(f"Node {node_title} run failed: {error}") + super().__init__(f"Node {node_instance.node_data.title} run failed: {error}") diff --git a/api/core/workflow/graph_engine/__init__.py b/api/core/workflow/graph_engine/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/graph_engine/condition_handlers/__init__.py b/api/core/workflow/graph_engine/condition_handlers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/graph_engine/condition_handlers/base_handler.py b/api/core/workflow/graph_engine/condition_handlers/base_handler.py new file mode 100644 index 0000000000..697392b2a3 --- /dev/null +++ b/api/core/workflow/graph_engine/condition_handlers/base_handler.py @@ -0,0 +1,25 @@ +from abc import ABC, abstractmethod + +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.graph_engine.entities.run_condition import RunCondition +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState + + +class RunConditionHandler(ABC): + def __init__(self, init_params: GraphInitParams, graph: Graph, condition: RunCondition): + self.init_params = init_params + self.graph = graph + self.condition = condition + + @abstractmethod + def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool: + """ + Check if the condition can be executed + + :param graph_runtime_state: graph runtime state + :param previous_route_node_state: previous route node state + :return: bool + """ + raise NotImplementedError diff --git a/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py b/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py new file mode 100644 index 0000000000..af695df7d8 --- /dev/null +++ b/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py @@ -0,0 +1,25 @@ +from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState + + +class BranchIdentifyRunConditionHandler(RunConditionHandler): + def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool: + """ + Check if the condition can be executed + + :param graph_runtime_state: graph runtime state + :param previous_route_node_state: previous route node state + :return: bool + """ + if not self.condition.branch_identify: + raise Exception("Branch identify is required") + + run_result = previous_route_node_state.node_run_result + if not run_result: + return False + + if not run_result.edge_source_handle: + return False + + return self.condition.branch_identify == run_result.edge_source_handle diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py new file mode 100644 index 0000000000..eda5fe079c --- /dev/null +++ b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py @@ -0,0 +1,28 @@ +from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState +from core.workflow.utils.condition.processor import ConditionProcessor + + +class ConditionRunConditionHandlerHandler(RunConditionHandler): + def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool: + """ + Check if the condition can be executed + + :param graph_runtime_state: graph runtime state + :param previous_route_node_state: previous route node state + :return: bool + """ + if not self.condition.conditions: + return True + + # process condition + condition_processor = ConditionProcessor() + input_conditions, group_result = condition_processor.process_conditions( + variable_pool=graph_runtime_state.variable_pool, conditions=self.condition.conditions + ) + + # Apply the logical operator for the current case + compare_result = all(group_result) + + return compare_result diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_manager.py b/api/core/workflow/graph_engine/condition_handlers/condition_manager.py new file mode 100644 index 0000000000..1c9237d82f --- /dev/null +++ b/api/core/workflow/graph_engine/condition_handlers/condition_manager.py @@ -0,0 +1,25 @@ +from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler +from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler +from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.run_condition import RunCondition + + +class ConditionManager: + @staticmethod + def get_condition_handler( + init_params: GraphInitParams, graph: Graph, run_condition: RunCondition + ) -> RunConditionHandler: + """ + Get condition handler + + :param init_params: init params + :param graph: graph + :param run_condition: run condition + :return: condition handler + """ + if run_condition.type == "branch_identify": + return BranchIdentifyRunConditionHandler(init_params=init_params, graph=graph, condition=run_condition) + else: + return ConditionRunConditionHandlerHandler(init_params=init_params, graph=graph, condition=run_condition) diff --git a/api/core/workflow/graph_engine/entities/__init__.py b/api/core/workflow/graph_engine/entities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py new file mode 100644 index 0000000000..06dc4cb8f4 --- /dev/null +++ b/api/core/workflow/graph_engine/entities/event.py @@ -0,0 +1,163 @@ +from datetime import datetime +from typing import Any, Optional + +from pydantic import BaseModel, Field + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeType +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState + + +class GraphEngineEvent(BaseModel): + pass + + +########################################### +# Graph Events +########################################### + + +class BaseGraphEvent(GraphEngineEvent): + pass + + +class GraphRunStartedEvent(BaseGraphEvent): + pass + + +class GraphRunSucceededEvent(BaseGraphEvent): + outputs: Optional[dict[str, Any]] = None + """outputs""" + + +class GraphRunFailedEvent(BaseGraphEvent): + error: str = Field(..., description="failed reason") + + +########################################### +# Node Events +########################################### + + +class BaseNodeEvent(GraphEngineEvent): + id: str = Field(..., description="node execution id") + node_id: str = Field(..., description="node id") + node_type: NodeType = Field(..., description="node type") + node_data: BaseNodeData = Field(..., description="node data") + route_node_state: RouteNodeState = Field(..., description="route node state") + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + + +class NodeRunStartedEvent(BaseNodeEvent): + predecessor_node_id: Optional[str] = None + """predecessor node id""" + + +class NodeRunStreamChunkEvent(BaseNodeEvent): + chunk_content: str = Field(..., description="chunk content") + from_variable_selector: Optional[list[str]] = None + """from variable selector""" + + +class NodeRunRetrieverResourceEvent(BaseNodeEvent): + retriever_resources: list[dict] = Field(..., description="retriever resources") + context: str = Field(..., description="context") + + +class NodeRunSucceededEvent(BaseNodeEvent): + pass + + +class NodeRunFailedEvent(BaseNodeEvent): + error: str = Field(..., description="error") + + +########################################### +# Parallel Branch Events +########################################### + + +class BaseParallelBranchEvent(GraphEngineEvent): + parallel_id: str = Field(..., description="parallel id") + """parallel id""" + parallel_start_node_id: str = Field(..., description="parallel start node id") + """parallel start node id""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + + +class ParallelBranchRunStartedEvent(BaseParallelBranchEvent): + pass + + +class ParallelBranchRunSucceededEvent(BaseParallelBranchEvent): + pass + + +class ParallelBranchRunFailedEvent(BaseParallelBranchEvent): + error: str = Field(..., description="failed reason") + + +########################################### +# Iteration Events +########################################### + + +class BaseIterationEvent(GraphEngineEvent): + iteration_id: str = Field(..., description="iteration node execution id") + iteration_node_id: str = Field(..., description="iteration node id") + iteration_node_type: NodeType = Field(..., description="node type, iteration or loop") + iteration_node_data: BaseNodeData = Field(..., description="node data") + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + + +class IterationRunStartedEvent(BaseIterationEvent): + start_at: datetime = Field(..., description="start at") + inputs: Optional[dict[str, Any]] = None + metadata: Optional[dict[str, Any]] = None + predecessor_node_id: Optional[str] = None + + +class IterationRunNextEvent(BaseIterationEvent): + index: int = Field(..., description="index") + pre_iteration_output: Optional[Any] = Field(None, description="pre iteration output") + + +class IterationRunSucceededEvent(BaseIterationEvent): + start_at: datetime = Field(..., description="start at") + inputs: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + metadata: Optional[dict[str, Any]] = None + steps: int = 0 + + +class IterationRunFailedEvent(BaseIterationEvent): + start_at: datetime = Field(..., description="start at") + inputs: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + metadata: Optional[dict[str, Any]] = None + steps: int = 0 + error: str = Field(..., description="failed reason") + + +InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py new file mode 100644 index 0000000000..f1f677b8c1 --- /dev/null +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -0,0 +1,729 @@ +import uuid +from collections.abc import Mapping +from typing import Any, Optional, cast + +from pydantic import BaseModel, Field + +from core.workflow.entities.node_entities import NodeType +from core.workflow.graph_engine.entities.run_condition import RunCondition +from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter +from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute +from core.workflow.nodes.end.end_stream_generate_router import EndStreamGeneratorRouter +from core.workflow.nodes.end.entities import EndStreamParam + + +class GraphEdge(BaseModel): + source_node_id: str = Field(..., description="source node id") + target_node_id: str = Field(..., description="target node id") + run_condition: Optional[RunCondition] = None + """run condition""" + + +class GraphParallel(BaseModel): + id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id") + start_from_node_id: str = Field(..., description="start from node id") + parent_parallel_id: Optional[str] = None + """parent parallel id""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id""" + end_to_node_id: Optional[str] = None + """end to node id""" + + +class Graph(BaseModel): + root_node_id: str = Field(..., description="root node id of the graph") + node_ids: list[str] = Field(default_factory=list, description="graph node ids") + node_id_config_mapping: dict[str, dict] = Field( + default_factory=list, description="node configs mapping (node id: node config)" + ) + edge_mapping: dict[str, list[GraphEdge]] = Field( + default_factory=dict, description="graph edge mapping (source node id: edges)" + ) + reverse_edge_mapping: dict[str, list[GraphEdge]] = Field( + default_factory=dict, description="reverse graph edge mapping (target node id: edges)" + ) + parallel_mapping: dict[str, GraphParallel] = Field( + default_factory=dict, description="graph parallel mapping (parallel id: parallel)" + ) + node_parallel_mapping: dict[str, str] = Field( + default_factory=dict, description="graph node parallel mapping (node id: parallel id)" + ) + answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(..., description="answer stream generate routes") + end_stream_param: EndStreamParam = Field(..., description="end stream param") + + @classmethod + def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = None) -> "Graph": + """ + Init graph + + :param graph_config: graph config + :param root_node_id: root node id + :return: graph + """ + # edge configs + edge_configs = graph_config.get("edges") + if edge_configs is None: + edge_configs = [] + + edge_configs = cast(list, edge_configs) + + # reorganize edges mapping + edge_mapping: dict[str, list[GraphEdge]] = {} + reverse_edge_mapping: dict[str, list[GraphEdge]] = {} + target_edge_ids = set() + for edge_config in edge_configs: + source_node_id = edge_config.get("source") + if not source_node_id: + continue + + if source_node_id not in edge_mapping: + edge_mapping[source_node_id] = [] + + target_node_id = edge_config.get("target") + if not target_node_id: + continue + + if target_node_id not in reverse_edge_mapping: + reverse_edge_mapping[target_node_id] = [] + + target_edge_ids.add(target_node_id) + + # parse run condition + run_condition = None + if edge_config.get("sourceHandle") and edge_config.get("sourceHandle") != "source": + run_condition = RunCondition(type="branch_identify", branch_identify=edge_config.get("sourceHandle")) + + graph_edge = GraphEdge( + source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition + ) + + edge_mapping[source_node_id].append(graph_edge) + reverse_edge_mapping[target_node_id].append(graph_edge) + + # node configs + node_configs = graph_config.get("nodes") + if not node_configs: + raise ValueError("Graph must have at least one node") + + node_configs = cast(list, node_configs) + + # fetch nodes that have no predecessor node + root_node_configs = [] + all_node_id_config_mapping: dict[str, dict] = {} + for node_config in node_configs: + node_id = node_config.get("id") + if not node_id: + continue + + if node_id not in target_edge_ids: + root_node_configs.append(node_config) + + all_node_id_config_mapping[node_id] = node_config + + root_node_ids = [node_config.get("id") for node_config in root_node_configs] + + # fetch root node + if not root_node_id: + # if no root node id, use the START type node as root node + root_node_id = next( + ( + node_config.get("id") + for node_config in root_node_configs + if node_config.get("data", {}).get("type", "") == NodeType.START.value + ), + None, + ) + + if not root_node_id or root_node_id not in root_node_ids: + raise ValueError(f"Root node id {root_node_id} not found in the graph") + + # Check whether it is connected to the previous node + cls._check_connected_to_previous_node(route=[root_node_id], edge_mapping=edge_mapping) + + # fetch all node ids from root node + node_ids = [root_node_id] + cls._recursively_add_node_ids(node_ids=node_ids, edge_mapping=edge_mapping, node_id=root_node_id) + + node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids} + + # init parallel mapping + parallel_mapping: dict[str, GraphParallel] = {} + node_parallel_mapping: dict[str, str] = {} + cls._recursively_add_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + start_node_id=root_node_id, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping, + ) + + # Check if it exceeds N layers of parallel + for parallel in parallel_mapping.values(): + if parallel.parent_parallel_id: + cls._check_exceed_parallel_limit( + parallel_mapping=parallel_mapping, level_limit=3, parent_parallel_id=parallel.parent_parallel_id + ) + + # init answer stream generate routes + answer_stream_generate_routes = AnswerStreamGeneratorRouter.init( + node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping + ) + + # init end stream param + end_stream_param = EndStreamGeneratorRouter.init( + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping, + node_parallel_mapping=node_parallel_mapping, + ) + + # init graph + graph = cls( + root_node_id=root_node_id, + node_ids=node_ids, + node_id_config_mapping=node_id_config_mapping, + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping, + answer_stream_generate_routes=answer_stream_generate_routes, + end_stream_param=end_stream_param, + ) + + return graph + + def add_extra_edge( + self, source_node_id: str, target_node_id: str, run_condition: Optional[RunCondition] = None + ) -> None: + """ + Add extra edge to the graph + + :param source_node_id: source node id + :param target_node_id: target node id + :param run_condition: run condition + """ + if source_node_id not in self.node_ids or target_node_id not in self.node_ids: + return + + if source_node_id not in self.edge_mapping: + self.edge_mapping[source_node_id] = [] + + if target_node_id in [graph_edge.target_node_id for graph_edge in self.edge_mapping[source_node_id]]: + return + + graph_edge = GraphEdge( + source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition + ) + + self.edge_mapping[source_node_id].append(graph_edge) + + def get_leaf_node_ids(self) -> list[str]: + """ + Get leaf node ids of the graph + + :return: leaf node ids + """ + leaf_node_ids = [] + for node_id in self.node_ids: + if node_id not in self.edge_mapping: + leaf_node_ids.append(node_id) + elif ( + len(self.edge_mapping[node_id]) == 1 + and self.edge_mapping[node_id][0].target_node_id == self.root_node_id + ): + leaf_node_ids.append(node_id) + + return leaf_node_ids + + @classmethod + def _recursively_add_node_ids( + cls, node_ids: list[str], edge_mapping: dict[str, list[GraphEdge]], node_id: str + ) -> None: + """ + Recursively add node ids + + :param node_ids: node ids + :param edge_mapping: edge mapping + :param node_id: node id + """ + for graph_edge in edge_mapping.get(node_id, []): + if graph_edge.target_node_id in node_ids: + continue + + node_ids.append(graph_edge.target_node_id) + cls._recursively_add_node_ids( + node_ids=node_ids, edge_mapping=edge_mapping, node_id=graph_edge.target_node_id + ) + + @classmethod + def _check_connected_to_previous_node(cls, route: list[str], edge_mapping: dict[str, list[GraphEdge]]) -> None: + """ + Check whether it is connected to the previous node + """ + last_node_id = route[-1] + + for graph_edge in edge_mapping.get(last_node_id, []): + if not graph_edge.target_node_id: + continue + + if graph_edge.target_node_id in route: + raise ValueError( + f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph." + ) + + new_route = route[:] + new_route.append(graph_edge.target_node_id) + cls._check_connected_to_previous_node( + route=new_route, + edge_mapping=edge_mapping, + ) + + @classmethod + def _recursively_add_parallels( + cls, + edge_mapping: dict[str, list[GraphEdge]], + reverse_edge_mapping: dict[str, list[GraphEdge]], + start_node_id: str, + parallel_mapping: dict[str, GraphParallel], + node_parallel_mapping: dict[str, str], + parent_parallel: Optional[GraphParallel] = None, + ) -> None: + """ + Recursively add parallel ids + + :param edge_mapping: edge mapping + :param start_node_id: start from node id + :param parallel_mapping: parallel mapping + :param node_parallel_mapping: node parallel mapping + :param parent_parallel: parent parallel + """ + target_node_edges = edge_mapping.get(start_node_id, []) + parallel = None + if len(target_node_edges) > 1: + # fetch all node ids in current parallels + parallel_branch_node_ids = {} + condition_edge_mappings = {} + for graph_edge in target_node_edges: + if graph_edge.run_condition is None: + if "default" not in parallel_branch_node_ids: + parallel_branch_node_ids["default"] = [] + + parallel_branch_node_ids["default"].append(graph_edge.target_node_id) + else: + condition_hash = graph_edge.run_condition.hash + if condition_hash not in condition_edge_mappings: + condition_edge_mappings[condition_hash] = [] + + condition_edge_mappings[condition_hash].append(graph_edge) + + for condition_hash, graph_edges in condition_edge_mappings.items(): + if len(graph_edges) > 1: + if condition_hash not in parallel_branch_node_ids: + parallel_branch_node_ids[condition_hash] = [] + + for graph_edge in graph_edges: + parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id) + + condition_parallels = {} + for condition_hash, condition_parallel_branch_node_ids in parallel_branch_node_ids.items(): + # any target node id in node_parallel_mapping + parallel = None + if condition_parallel_branch_node_ids: + parent_parallel_id = parent_parallel.id if parent_parallel else None + + parallel = GraphParallel( + start_from_node_id=start_node_id, + parent_parallel_id=parent_parallel.id if parent_parallel else None, + parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None, + ) + parallel_mapping[parallel.id] = parallel + condition_parallels[condition_hash] = parallel + + in_branch_node_ids = cls._fetch_all_node_ids_in_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + parallel_branch_node_ids=condition_parallel_branch_node_ids, + ) + + # collect all branches node ids + parallel_node_ids = [] + for _, node_ids in in_branch_node_ids.items(): + for node_id in node_ids: + in_parent_parallel = True + if parent_parallel_id: + in_parent_parallel = False + for parallel_node_id, parallel_id in node_parallel_mapping.items(): + if parallel_id == parent_parallel_id and parallel_node_id == node_id: + in_parent_parallel = True + break + + if in_parent_parallel: + parallel_node_ids.append(node_id) + node_parallel_mapping[node_id] = parallel.id + + outside_parallel_target_node_ids = set() + for node_id in parallel_node_ids: + if node_id == parallel.start_from_node_id: + continue + + node_edges = edge_mapping.get(node_id) + if not node_edges: + continue + + if len(node_edges) > 1: + continue + + target_node_id = node_edges[0].target_node_id + if target_node_id in parallel_node_ids: + continue + + if parent_parallel_id: + parent_parallel = parallel_mapping.get(parent_parallel_id) + if not parent_parallel: + continue + + if ( + ( + node_parallel_mapping.get(target_node_id) + and node_parallel_mapping.get(target_node_id) == parent_parallel_id + ) + or ( + parent_parallel + and parent_parallel.end_to_node_id + and target_node_id == parent_parallel.end_to_node_id + ) + or (not node_parallel_mapping.get(target_node_id) and not parent_parallel) + ): + outside_parallel_target_node_ids.add(target_node_id) + + if len(outside_parallel_target_node_ids) == 1: + if ( + parent_parallel + and parent_parallel.end_to_node_id + and parallel.end_to_node_id == parent_parallel.end_to_node_id + ): + parallel.end_to_node_id = None + else: + parallel.end_to_node_id = outside_parallel_target_node_ids.pop() + + if condition_edge_mappings: + for condition_hash, graph_edges in condition_edge_mappings.items(): + current_parallel = cls._get_current_parallel( + parallel_mapping=parallel_mapping, + graph_edge=graph_edge, + parallel=condition_parallels.get(condition_hash), + parent_parallel=parent_parallel, + ) + + cls._recursively_add_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + start_node_id=graph_edge.target_node_id, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping, + parent_parallel=current_parallel, + ) + else: + for graph_edge in target_node_edges: + current_parallel = cls._get_current_parallel( + parallel_mapping=parallel_mapping, + graph_edge=graph_edge, + parallel=parallel, + parent_parallel=parent_parallel, + ) + + cls._recursively_add_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + start_node_id=graph_edge.target_node_id, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping, + parent_parallel=current_parallel, + ) + else: + for graph_edge in target_node_edges: + current_parallel = cls._get_current_parallel( + parallel_mapping=parallel_mapping, + graph_edge=graph_edge, + parallel=parallel, + parent_parallel=parent_parallel, + ) + + cls._recursively_add_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + start_node_id=graph_edge.target_node_id, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping, + parent_parallel=current_parallel, + ) + + @classmethod + def _get_current_parallel( + cls, + parallel_mapping: dict[str, GraphParallel], + graph_edge: GraphEdge, + parallel: Optional[GraphParallel] = None, + parent_parallel: Optional[GraphParallel] = None, + ) -> Optional[GraphParallel]: + """ + Get current parallel + """ + current_parallel = None + if parallel: + current_parallel = parallel + elif parent_parallel: + if not parent_parallel.end_to_node_id or ( + parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id + ): + current_parallel = parent_parallel + else: + # fetch parent parallel's parent parallel + parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id + if parent_parallel_parent_parallel_id: + parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id) + if parent_parallel_parent_parallel and ( + not parent_parallel_parent_parallel.end_to_node_id + or ( + parent_parallel_parent_parallel.end_to_node_id + and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id + ) + ): + current_parallel = parent_parallel_parent_parallel + + return current_parallel + + @classmethod + def _check_exceed_parallel_limit( + cls, + parallel_mapping: dict[str, GraphParallel], + level_limit: int, + parent_parallel_id: str, + current_level: int = 1, + ) -> None: + """ + Check if it exceeds N layers of parallel + """ + parent_parallel = parallel_mapping.get(parent_parallel_id) + if not parent_parallel: + return + + current_level += 1 + if current_level > level_limit: + raise ValueError(f"Exceeds {level_limit} layers of parallel") + + if parent_parallel.parent_parallel_id: + cls._check_exceed_parallel_limit( + parallel_mapping=parallel_mapping, + level_limit=level_limit, + parent_parallel_id=parent_parallel.parent_parallel_id, + current_level=current_level, + ) + + @classmethod + def _recursively_add_parallel_node_ids( + cls, + branch_node_ids: list[str], + edge_mapping: dict[str, list[GraphEdge]], + merge_node_id: str, + start_node_id: str, + ) -> None: + """ + Recursively add node ids + + :param branch_node_ids: in branch node ids + :param edge_mapping: edge mapping + :param merge_node_id: merge node id + :param start_node_id: start node id + """ + for graph_edge in edge_mapping.get(start_node_id, []): + if graph_edge.target_node_id != merge_node_id and graph_edge.target_node_id not in branch_node_ids: + branch_node_ids.append(graph_edge.target_node_id) + cls._recursively_add_parallel_node_ids( + branch_node_ids=branch_node_ids, + edge_mapping=edge_mapping, + merge_node_id=merge_node_id, + start_node_id=graph_edge.target_node_id, + ) + + @classmethod + def _fetch_all_node_ids_in_parallels( + cls, + edge_mapping: dict[str, list[GraphEdge]], + reverse_edge_mapping: dict[str, list[GraphEdge]], + parallel_branch_node_ids: list[str], + ) -> dict[str, list[str]]: + """ + Fetch all node ids in parallels + """ + routes_node_ids: dict[str, list[str]] = {} + for parallel_branch_node_id in parallel_branch_node_ids: + routes_node_ids[parallel_branch_node_id] = [parallel_branch_node_id] + + # fetch routes node ids + cls._recursively_fetch_routes( + edge_mapping=edge_mapping, + start_node_id=parallel_branch_node_id, + routes_node_ids=routes_node_ids[parallel_branch_node_id], + ) + + # fetch leaf node ids from routes node ids + leaf_node_ids: dict[str, list[str]] = {} + merge_branch_node_ids: dict[str, list[str]] = {} + for branch_node_id, node_ids in routes_node_ids.items(): + for node_id in node_ids: + if node_id not in edge_mapping or len(edge_mapping[node_id]) == 0: + if branch_node_id not in leaf_node_ids: + leaf_node_ids[branch_node_id] = [] + + leaf_node_ids[branch_node_id].append(node_id) + + for branch_node_id2, inner_route2 in routes_node_ids.items(): + if ( + branch_node_id != branch_node_id2 + and node_id in inner_route2 + and len(reverse_edge_mapping.get(node_id, [])) > 1 + and cls._is_node_in_routes( + reverse_edge_mapping=reverse_edge_mapping, + start_node_id=node_id, + routes_node_ids=routes_node_ids, + ) + ): + if node_id not in merge_branch_node_ids: + merge_branch_node_ids[node_id] = [] + + if branch_node_id2 not in merge_branch_node_ids[node_id]: + merge_branch_node_ids[node_id].append(branch_node_id2) + + # sorted merge_branch_node_ids by branch_node_ids length desc + merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True)) + + duplicate_end_node_ids = {} + for node_id, branch_node_ids in merge_branch_node_ids.items(): + for node_id2, branch_node_ids2 in merge_branch_node_ids.items(): + if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2): + if (node_id, node_id2) not in duplicate_end_node_ids and ( + node_id2, + node_id, + ) not in duplicate_end_node_ids: + duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids + + for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items(): + # check which node is after + if cls._is_node2_after_node1(node1_id=node_id, node2_id=node_id2, edge_mapping=edge_mapping): + if node_id in merge_branch_node_ids: + del merge_branch_node_ids[node_id2] + elif cls._is_node2_after_node1(node1_id=node_id2, node2_id=node_id, edge_mapping=edge_mapping): + if node_id2 in merge_branch_node_ids: + del merge_branch_node_ids[node_id] + + branches_merge_node_ids: dict[str, str] = {} + for node_id, branch_node_ids in merge_branch_node_ids.items(): + if len(branch_node_ids) <= 1: + continue + + for branch_node_id in branch_node_ids: + if branch_node_id in branches_merge_node_ids: + continue + + branches_merge_node_ids[branch_node_id] = node_id + + in_branch_node_ids: dict[str, list[str]] = {} + for branch_node_id, node_ids in routes_node_ids.items(): + in_branch_node_ids[branch_node_id] = [] + if branch_node_id not in branches_merge_node_ids: + # all node ids in current branch is in this thread + in_branch_node_ids[branch_node_id].append(branch_node_id) + in_branch_node_ids[branch_node_id].extend(node_ids) + else: + merge_node_id = branches_merge_node_ids[branch_node_id] + if merge_node_id != branch_node_id: + in_branch_node_ids[branch_node_id].append(branch_node_id) + + # fetch all node ids from branch_node_id and merge_node_id + cls._recursively_add_parallel_node_ids( + branch_node_ids=in_branch_node_ids[branch_node_id], + edge_mapping=edge_mapping, + merge_node_id=merge_node_id, + start_node_id=branch_node_id, + ) + + return in_branch_node_ids + + @classmethod + def _recursively_fetch_routes( + cls, edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: list[str] + ) -> None: + """ + Recursively fetch route + """ + if start_node_id not in edge_mapping: + return + + for graph_edge in edge_mapping[start_node_id]: + # find next node ids + if graph_edge.target_node_id not in routes_node_ids: + routes_node_ids.append(graph_edge.target_node_id) + + cls._recursively_fetch_routes( + edge_mapping=edge_mapping, start_node_id=graph_edge.target_node_id, routes_node_ids=routes_node_ids + ) + + @classmethod + def _is_node_in_routes( + cls, reverse_edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: dict[str, list[str]] + ) -> bool: + """ + Recursively check if the node is in the routes + """ + if start_node_id not in reverse_edge_mapping: + return False + + all_routes_node_ids = set() + parallel_start_node_ids: dict[str, list[str]] = {} + for branch_node_id, node_ids in routes_node_ids.items(): + for node_id in node_ids: + all_routes_node_ids.add(node_id) + + if branch_node_id in reverse_edge_mapping: + for graph_edge in reverse_edge_mapping[branch_node_id]: + if graph_edge.source_node_id not in parallel_start_node_ids: + parallel_start_node_ids[graph_edge.source_node_id] = [] + + parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id) + + parallel_start_node_id = None + for p_start_node_id, branch_node_ids in parallel_start_node_ids.items(): + if set(branch_node_ids) == set(routes_node_ids.keys()): + parallel_start_node_id = p_start_node_id + return True + + if not parallel_start_node_id: + raise Exception("Parallel start node id not found") + + for graph_edge in reverse_edge_mapping[start_node_id]: + if ( + graph_edge.source_node_id not in all_routes_node_ids + or graph_edge.source_node_id != parallel_start_node_id + ): + return False + + return True + + @classmethod + def _is_node2_after_node1(cls, node1_id: str, node2_id: str, edge_mapping: dict[str, list[GraphEdge]]) -> bool: + """ + is node2 after node1 + """ + if node1_id not in edge_mapping: + return False + + for graph_edge in edge_mapping[node1_id]: + if graph_edge.target_node_id == node2_id: + return True + + if cls._is_node2_after_node1( + node1_id=graph_edge.target_node_id, node2_id=node2_id, edge_mapping=edge_mapping + ): + return True + + return False diff --git a/api/core/workflow/graph_engine/entities/graph_init_params.py b/api/core/workflow/graph_engine/entities/graph_init_params.py new file mode 100644 index 0000000000..1a403f3e49 --- /dev/null +++ b/api/core/workflow/graph_engine/entities/graph_init_params.py @@ -0,0 +1,21 @@ +from collections.abc import Mapping +from typing import Any + +from pydantic import BaseModel, Field + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import UserFrom +from models.workflow import WorkflowType + + +class GraphInitParams(BaseModel): + # init params + tenant_id: str = Field(..., description="tenant / workspace id") + app_id: str = Field(..., description="app id") + workflow_type: WorkflowType = Field(..., description="workflow type") + workflow_id: str = Field(..., description="workflow id") + graph_config: Mapping[str, Any] = Field(..., description="graph config") + user_id: str = Field(..., description="user id") + user_from: UserFrom = Field(..., description="user from, account or end-user") + invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger") + call_depth: int = Field(..., description="call depth") diff --git a/api/core/workflow/graph_engine/entities/graph_runtime_state.py b/api/core/workflow/graph_engine/entities/graph_runtime_state.py new file mode 100644 index 0000000000..afc09bfac5 --- /dev/null +++ b/api/core/workflow/graph_engine/entities/graph_runtime_state.py @@ -0,0 +1,27 @@ +from typing import Any + +from pydantic import BaseModel, Field + +from core.model_runtime.entities.llm_entities import LLMUsage +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState + + +class GraphRuntimeState(BaseModel): + variable_pool: VariablePool = Field(..., description="variable pool") + """variable pool""" + + start_at: float = Field(..., description="start time") + """start time""" + total_tokens: int = 0 + """total tokens""" + llm_usage: LLMUsage = LLMUsage.empty_usage() + """llm usage info""" + outputs: dict[str, Any] = {} + """outputs""" + + node_run_steps: int = 0 + """node run steps""" + + node_run_state: RuntimeRouteState = RuntimeRouteState() + """node run state""" diff --git a/api/core/workflow/graph_engine/entities/next_graph_node.py b/api/core/workflow/graph_engine/entities/next_graph_node.py new file mode 100644 index 0000000000..6aa4341ddf --- /dev/null +++ b/api/core/workflow/graph_engine/entities/next_graph_node.py @@ -0,0 +1,13 @@ +from typing import Optional + +from pydantic import BaseModel + +from core.workflow.graph_engine.entities.graph import GraphParallel + + +class NextGraphNode(BaseModel): + node_id: str + """next node id""" + + parallel: Optional[GraphParallel] = None + """parallel""" diff --git a/api/core/workflow/graph_engine/entities/run_condition.py b/api/core/workflow/graph_engine/entities/run_condition.py new file mode 100644 index 0000000000..eedce8842b --- /dev/null +++ b/api/core/workflow/graph_engine/entities/run_condition.py @@ -0,0 +1,21 @@ +import hashlib +from typing import Literal, Optional + +from pydantic import BaseModel + +from core.workflow.utils.condition.entities import Condition + + +class RunCondition(BaseModel): + type: Literal["branch_identify", "condition"] + """condition type""" + + branch_identify: Optional[str] = None + """branch identify like: sourceHandle, required when type is branch_identify""" + + conditions: Optional[list[Condition]] = None + """conditions to run the node, required when type is condition""" + + @property + def hash(self) -> str: + return hashlib.sha256(self.model_dump_json().encode()).hexdigest() diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py new file mode 100644 index 0000000000..8fc8047426 --- /dev/null +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -0,0 +1,109 @@ +import uuid +from datetime import datetime, timezone +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, Field + +from core.workflow.entities.node_entities import NodeRunResult +from models.workflow import WorkflowNodeExecutionStatus + + +class RouteNodeState(BaseModel): + class Status(Enum): + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + PAUSED = "paused" + + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + """node state id""" + + node_id: str + """node id""" + + node_run_result: Optional[NodeRunResult] = None + """node run result""" + + status: Status = Status.RUNNING + """node status""" + + start_at: datetime + """start time""" + + paused_at: Optional[datetime] = None + """paused time""" + + finished_at: Optional[datetime] = None + """finished time""" + + failed_reason: Optional[str] = None + """failed reason""" + + paused_by: Optional[str] = None + """paused by""" + + index: int = 1 + + def set_finished(self, run_result: NodeRunResult) -> None: + """ + Node finished + + :param run_result: run result + """ + if self.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]: + raise Exception(f"Route state {self.id} already finished") + + if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + self.status = RouteNodeState.Status.SUCCESS + elif run_result.status == WorkflowNodeExecutionStatus.FAILED: + self.status = RouteNodeState.Status.FAILED + self.failed_reason = run_result.error + else: + raise Exception(f"Invalid route status {run_result.status}") + + self.node_run_result = run_result + self.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + + +class RuntimeRouteState(BaseModel): + routes: dict[str, list[str]] = Field( + default_factory=dict, description="graph state routes (source_node_state_id: target_node_state_id)" + ) + + node_state_mapping: dict[str, RouteNodeState] = Field( + default_factory=dict, description="node state mapping (route_node_state_id: route_node_state)" + ) + + def create_node_state(self, node_id: str) -> RouteNodeState: + """ + Create node state + + :param node_id: node id + """ + state = RouteNodeState(node_id=node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None)) + self.node_state_mapping[state.id] = state + return state + + def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None: + """ + Add route to the graph state + + :param source_node_state_id: source node state id + :param target_node_state_id: target node state id + """ + if source_node_state_id not in self.routes: + self.routes[source_node_state_id] = [] + + self.routes[source_node_state_id].append(target_node_state_id) + + def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) -> list[RouteNodeState]: + """ + Get routes with node state by source node id + + :param source_node_state_id: source node state id + :return: routes with node state + """ + return [ + self.node_state_mapping[target_state_id] for target_state_id in self.routes.get(source_node_state_id, []) + ] diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py new file mode 100644 index 0000000000..1db9b690ab --- /dev/null +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -0,0 +1,726 @@ +import logging +import queue +import time +import uuid +from collections.abc import Generator, Mapping +from concurrent.futures import ThreadPoolExecutor, wait +from typing import Any, Optional + +from flask import Flask, current_app + +from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import ( + NodeRunMetadataKey, + NodeType, + UserFrom, +) +from core.workflow.entities.variable_pool import VariablePool, VariableValue +from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager +from core.workflow.graph_engine.entities.event import ( + BaseIterationEvent, + GraphEngineEvent, + GraphRunFailedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunFailedEvent, + NodeRunRetrieverResourceEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, + ParallelBranchRunFailedEvent, + ParallelBranchRunStartedEvent, + ParallelBranchRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph, GraphEdge +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState +from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor +from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent +from core.workflow.nodes.node_mapping import node_classes +from extensions.ext_database import db +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType + +logger = logging.getLogger(__name__) + + +class GraphEngineThreadPool(ThreadPoolExecutor): + def __init__( + self, max_workers=None, thread_name_prefix="", initializer=None, initargs=(), max_submit_count=100 + ) -> None: + super().__init__(max_workers, thread_name_prefix, initializer, initargs) + self.max_submit_count = max_submit_count + self.submit_count = 0 + + def submit(self, fn, *args, **kwargs): + self.submit_count += 1 + self.check_is_full() + + return super().submit(fn, *args, **kwargs) + + def check_is_full(self) -> None: + print(f"submit_count: {self.submit_count}, max_submit_count: {self.max_submit_count}") + if self.submit_count > self.max_submit_count: + raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.") + + +class GraphEngine: + workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {} + + def __init__( + self, + tenant_id: str, + app_id: str, + workflow_type: WorkflowType, + workflow_id: str, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + call_depth: int, + graph: Graph, + graph_config: Mapping[str, Any], + variable_pool: VariablePool, + max_execution_steps: int, + max_execution_time: int, + thread_pool_id: Optional[str] = None, + ) -> None: + thread_pool_max_submit_count = 100 + thread_pool_max_workers = 10 + + # init thread pool + if thread_pool_id: + if thread_pool_id not in GraphEngine.workflow_thread_pool_mapping: + raise ValueError(f"Max submit count {thread_pool_max_submit_count} of workflow thread pool reached.") + + self.thread_pool_id = thread_pool_id + self.thread_pool = GraphEngine.workflow_thread_pool_mapping[thread_pool_id] + self.is_main_thread_pool = False + else: + self.thread_pool = GraphEngineThreadPool( + max_workers=thread_pool_max_workers, max_submit_count=thread_pool_max_submit_count + ) + self.thread_pool_id = str(uuid.uuid4()) + self.is_main_thread_pool = True + GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool + + self.graph = graph + self.init_params = GraphInitParams( + tenant_id=tenant_id, + app_id=app_id, + workflow_type=workflow_type, + workflow_id=workflow_id, + graph_config=graph_config, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + call_depth=call_depth, + ) + + self.graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + self.max_execution_steps = max_execution_steps + self.max_execution_time = max_execution_time + + def run(self) -> Generator[GraphEngineEvent, None, None]: + # trigger graph run start event + yield GraphRunStartedEvent() + + try: + stream_processor_cls: type[AnswerStreamProcessor | EndStreamProcessor] + if self.init_params.workflow_type == WorkflowType.CHAT: + stream_processor_cls = AnswerStreamProcessor + else: + stream_processor_cls = EndStreamProcessor + + stream_processor = stream_processor_cls( + graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool + ) + + # run graph + generator = stream_processor.process(self._run(start_node_id=self.graph.root_node_id)) + + for item in generator: + try: + yield item + if isinstance(item, NodeRunFailedEvent): + yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or "Unknown error.") + return + elif isinstance(item, NodeRunSucceededEvent): + if item.node_type == NodeType.END: + self.graph_runtime_state.outputs = ( + item.route_node_state.node_run_result.outputs + if item.route_node_state.node_run_result + and item.route_node_state.node_run_result.outputs + else {} + ) + elif item.node_type == NodeType.ANSWER: + if "answer" not in self.graph_runtime_state.outputs: + self.graph_runtime_state.outputs["answer"] = "" + + self.graph_runtime_state.outputs["answer"] += "\n" + ( + item.route_node_state.node_run_result.outputs.get("answer", "") + if item.route_node_state.node_run_result + and item.route_node_state.node_run_result.outputs + else "" + ) + + self.graph_runtime_state.outputs["answer"] = self.graph_runtime_state.outputs[ + "answer" + ].strip() + except Exception as e: + logger.exception(f"Graph run failed: {str(e)}") + yield GraphRunFailedEvent(error=str(e)) + return + + # trigger graph run success event + yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs) + except GraphRunFailedError as e: + yield GraphRunFailedEvent(error=e.error) + return + except Exception as e: + logger.exception("Unknown Error when graph running") + yield GraphRunFailedEvent(error=str(e)) + raise e + finally: + if self.is_main_thread_pool and self.thread_pool_id in GraphEngine.workflow_thread_pool_mapping: + del GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] + + def _run( + self, + start_node_id: str, + in_parallel_id: Optional[str] = None, + parent_parallel_id: Optional[str] = None, + parent_parallel_start_node_id: Optional[str] = None, + ) -> Generator[GraphEngineEvent, None, None]: + parallel_start_node_id = None + if in_parallel_id: + parallel_start_node_id = start_node_id + + next_node_id = start_node_id + previous_route_node_state: Optional[RouteNodeState] = None + while True: + # max steps reached + if self.graph_runtime_state.node_run_steps > self.max_execution_steps: + raise GraphRunFailedError("Max steps {} reached.".format(self.max_execution_steps)) + + # or max execution time reached + if self._is_timed_out( + start_at=self.graph_runtime_state.start_at, max_execution_time=self.max_execution_time + ): + raise GraphRunFailedError("Max execution time {}s reached.".format(self.max_execution_time)) + + # init route node state + route_node_state = self.graph_runtime_state.node_run_state.create_node_state(node_id=next_node_id) + + # get node config + node_id = route_node_state.node_id + node_config = self.graph.node_id_config_mapping.get(node_id) + if not node_config: + raise GraphRunFailedError(f"Node {node_id} config not found.") + + # convert to specific node + node_type = NodeType.value_of(node_config.get("data", {}).get("type")) + node_cls = node_classes.get(node_type) + if not node_cls: + raise GraphRunFailedError(f"Node {node_id} type {node_type} not found.") + + previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None + + # init workflow run state + node_instance = node_cls( # type: ignore + id=route_node_state.id, + config=node_config, + graph_init_params=self.init_params, + graph=self.graph, + graph_runtime_state=self.graph_runtime_state, + previous_node_id=previous_node_id, + thread_pool_id=self.thread_pool_id, + ) + + try: + # run node + generator = self._run_node( + node_instance=node_instance, + route_node_state=route_node_state, + parallel_id=in_parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + + for item in generator: + if isinstance(item, NodeRunStartedEvent): + self.graph_runtime_state.node_run_steps += 1 + item.route_node_state.index = self.graph_runtime_state.node_run_steps + + yield item + + self.graph_runtime_state.node_run_state.node_state_mapping[route_node_state.id] = route_node_state + + # append route + if previous_route_node_state: + self.graph_runtime_state.node_run_state.add_route( + source_node_state_id=previous_route_node_state.id, target_node_state_id=route_node_state.id + ) + except Exception as e: + route_node_state.status = RouteNodeState.Status.FAILED + route_node_state.failed_reason = str(e) + yield NodeRunFailedEvent( + error=str(e), + id=node_instance.id, + node_id=next_node_id, + node_type=node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=in_parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + raise e + + # It may not be necessary, but it is necessary. :) + if ( + self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower() + == NodeType.END.value + ): + break + + previous_route_node_state = route_node_state + + # get next node ids + edge_mappings = self.graph.edge_mapping.get(next_node_id) + if not edge_mappings: + break + + if len(edge_mappings) == 1: + edge = edge_mappings[0] + + if edge.run_condition: + result = ConditionManager.get_condition_handler( + init_params=self.init_params, + graph=self.graph, + run_condition=edge.run_condition, + ).check( + graph_runtime_state=self.graph_runtime_state, + previous_route_node_state=previous_route_node_state, + ) + + if not result: + break + + next_node_id = edge.target_node_id + else: + final_node_id = None + + if any(edge.run_condition for edge in edge_mappings): + # if nodes has run conditions, get node id which branch to take based on the run condition results + condition_edge_mappings = {} + for edge in edge_mappings: + if edge.run_condition: + run_condition_hash = edge.run_condition.hash + if run_condition_hash not in condition_edge_mappings: + condition_edge_mappings[run_condition_hash] = [] + + condition_edge_mappings[run_condition_hash].append(edge) + + for _, sub_edge_mappings in condition_edge_mappings.items(): + if len(sub_edge_mappings) == 0: + continue + + edge = sub_edge_mappings[0] + + result = ConditionManager.get_condition_handler( + init_params=self.init_params, + graph=self.graph, + run_condition=edge.run_condition, + ).check( + graph_runtime_state=self.graph_runtime_state, + previous_route_node_state=previous_route_node_state, + ) + + if not result: + continue + + if len(sub_edge_mappings) == 1: + final_node_id = edge.target_node_id + else: + parallel_generator = self._run_parallel_branches( + edge_mappings=sub_edge_mappings, + in_parallel_id=in_parallel_id, + parallel_start_node_id=parallel_start_node_id, + ) + + for item in parallel_generator: + if isinstance(item, str): + final_node_id = item + else: + yield item + + break + + if not final_node_id: + break + + next_node_id = final_node_id + else: + parallel_generator = self._run_parallel_branches( + edge_mappings=edge_mappings, + in_parallel_id=in_parallel_id, + parallel_start_node_id=parallel_start_node_id, + ) + + for item in parallel_generator: + if isinstance(item, str): + final_node_id = item + else: + yield item + + if not final_node_id: + break + + next_node_id = final_node_id + + if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, "") != in_parallel_id: + break + + def _run_parallel_branches( + self, + edge_mappings: list[GraphEdge], + in_parallel_id: Optional[str] = None, + parallel_start_node_id: Optional[str] = None, + ) -> Generator[GraphEngineEvent | str, None, None]: + # if nodes has no run conditions, parallel run all nodes + parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id) + if not parallel_id: + node_id = edge_mappings[0].target_node_id + node_config = self.graph.node_id_config_mapping.get(node_id) + if not node_config: + raise GraphRunFailedError( + f"Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches." + ) + + node_title = node_config.get("data", {}).get("title") + raise GraphRunFailedError( + f"Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches." + ) + + parallel = self.graph.parallel_mapping.get(parallel_id) + if not parallel: + raise GraphRunFailedError(f"Parallel {parallel_id} not found.") + + # run parallel nodes, run in new thread and use queue to get results + q: queue.Queue = queue.Queue() + + # Create a list to store the threads + futures = [] + + # new thread + for edge in edge_mappings: + if ( + edge.target_node_id not in self.graph.node_parallel_mapping + or self.graph.node_parallel_mapping.get(edge.target_node_id, "") != parallel_id + ): + continue + + futures.append( + self.thread_pool.submit( + self._run_parallel_node, + **{ + "flask_app": current_app._get_current_object(), # type: ignore[attr-defined] + "q": q, + "parallel_id": parallel_id, + "parallel_start_node_id": edge.target_node_id, + "parent_parallel_id": in_parallel_id, + "parent_parallel_start_node_id": parallel_start_node_id, + }, + ) + ) + + succeeded_count = 0 + while True: + try: + event = q.get(timeout=1) + if event is None: + break + + yield event + if event.parallel_id == parallel_id: + if isinstance(event, ParallelBranchRunSucceededEvent): + succeeded_count += 1 + if succeeded_count == len(futures): + q.put(None) + + continue + elif isinstance(event, ParallelBranchRunFailedEvent): + raise GraphRunFailedError(event.error) + except queue.Empty: + continue + + # wait all threads + wait(futures) + + # get final node id + final_node_id = parallel.end_to_node_id + if final_node_id: + yield final_node_id + + def _run_parallel_node( + self, + flask_app: Flask, + q: queue.Queue, + parallel_id: str, + parallel_start_node_id: str, + parent_parallel_id: Optional[str] = None, + parent_parallel_start_node_id: Optional[str] = None, + ) -> None: + """ + Run parallel nodes + """ + with flask_app.app_context(): + try: + q.put( + ParallelBranchRunStartedEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + ) + + # run node + generator = self._run( + start_node_id=parallel_start_node_id, + in_parallel_id=parallel_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + + for item in generator: + q.put(item) + + # trigger graph run success event + q.put( + ParallelBranchRunSucceededEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + ) + except GraphRunFailedError as e: + q.put( + ParallelBranchRunFailedEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + error=e.error, + ) + ) + except Exception as e: + logger.exception("Unknown Error when generating in parallel") + q.put( + ParallelBranchRunFailedEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + error=str(e), + ) + ) + finally: + db.session.remove() + + def _run_node( + self, + node_instance: BaseNode, + route_node_state: RouteNodeState, + parallel_id: Optional[str] = None, + parallel_start_node_id: Optional[str] = None, + parent_parallel_id: Optional[str] = None, + parent_parallel_start_node_id: Optional[str] = None, + ) -> Generator[GraphEngineEvent, None, None]: + """ + Run node + """ + # trigger node run start event + yield NodeRunStartedEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + predecessor_node_id=node_instance.previous_node_id, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + + db.session.close() + + try: + # run node + generator = node_instance.run() + for item in generator: + if isinstance(item, GraphEngineEvent): + if isinstance(item, BaseIterationEvent): + # add parallel info to iteration event + item.parallel_id = parallel_id + item.parallel_start_node_id = parallel_start_node_id + item.parent_parallel_id = parent_parallel_id + item.parent_parallel_start_node_id = parent_parallel_start_node_id + + yield item + else: + if isinstance(item, RunCompletedEvent): + run_result = item.run_result + route_node_state.set_finished(run_result=run_result) + + if run_result.status == WorkflowNodeExecutionStatus.FAILED: + yield NodeRunFailedEvent( + error=route_node_state.failed_reason or "Unknown error.", + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): + # plus state total_tokens + self.graph_runtime_state.total_tokens += int( + run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type] + ) + + if run_result.llm_usage: + # use the latest usage + self.graph_runtime_state.llm_usage += run_result.llm_usage + + # append node output variables to variable pool + if run_result.outputs: + for variable_key, variable_value in run_result.outputs.items(): + # append variables to variable pool recursively + self._append_variables_recursively( + node_id=node_instance.node_id, + variable_key_list=[variable_key], + variable_value=variable_value, + ) + + # add parallel info to run result metadata + if parallel_id and parallel_start_node_id: + if not run_result.metadata: + run_result.metadata = {} + + run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id + run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id + if parent_parallel_id and parent_parallel_start_node_id: + run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id + run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = ( + parent_parallel_start_node_id + ) + + yield NodeRunSucceededEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + + break + elif isinstance(item, RunStreamChunkEvent): + yield NodeRunStreamChunkEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + chunk_content=item.chunk_content, + from_variable_selector=item.from_variable_selector, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + elif isinstance(item, RunRetrieverResourceEvent): + yield NodeRunRetrieverResourceEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + retriever_resources=item.retriever_resources, + context=item.context, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + except GenerateTaskStoppedError: + # trigger node run failed event + route_node_state.status = RouteNodeState.Status.FAILED + route_node_state.failed_reason = "Workflow stopped." + yield NodeRunFailedEvent( + error="Workflow stopped.", + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + return + except Exception as e: + logger.exception(f"Node {node_instance.node_data.title} run failed: {str(e)}") + raise e + finally: + db.session.close() + + def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): + """ + Append variables recursively + :param node_id: node id + :param variable_key_list: variable key list + :param variable_value: variable value + :return: + """ + self.graph_runtime_state.variable_pool.add([node_id] + variable_key_list, variable_value) + + # if variable_value is a dict, then recursively append variables + if isinstance(variable_value, dict): + for key, value in variable_value.items(): + # construct new key list + new_key_list = variable_key_list + [key] + self._append_variables_recursively( + node_id=node_id, variable_key_list=new_key_list, variable_value=value + ) + + def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: + """ + Check timeout + :param start_at: start time + :param max_execution_time: max execution time + :return: + """ + return time.perf_counter() - start_at > max_execution_time + + +class GraphRunFailedError(Exception): + def __init__(self, error: str): + self.error = error diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 5bae27092f..deacbbbbb0 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -1,9 +1,8 @@ -from typing import cast +from collections.abc import Mapping, Sequence +from typing import Any, cast -from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter from core.workflow.nodes.answer.entities import ( AnswerNodeData, GenerateRouteChunk, @@ -19,102 +18,40 @@ class AnswerNode(BaseNode): _node_data_cls = AnswerNodeData _node_type: NodeType = NodeType.ANSWER - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run node - :param variable_pool: variable pool :return: """ node_data = self.node_data node_data = cast(AnswerNodeData, node_data) # generate routes - generate_routes = self.extract_generate_route_from_node_data(node_data) + generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(node_data) - answer = '' + answer = "" for part in generate_routes: - if part.type == "var": + if part.type == GenerateRouteChunk.ChunkType.VAR: part = cast(VarGenerateRouteChunk, part) value_selector = part.value_selector - value = variable_pool.get(value_selector) + value = self.graph_runtime_state.variable_pool.get(value_selector) + if value: answer += value.markdown else: part = cast(TextGenerateRouteChunk, part) answer += part.text - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - "answer": answer - } - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer}) @classmethod - def extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]: - """ - Extract generate route selectors - :param config: node config - :return: - """ - node_data = cls._node_data_cls(**config.get("data", {})) - node_data = cast(AnswerNodeData, node_data) - - return cls.extract_generate_route_from_node_data(node_data) - - @classmethod - def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]: - """ - Extract generate route from node data - :param node_data: node data object - :return: - """ - variable_template_parser = VariableTemplateParser(template=node_data.answer) - variable_selectors = variable_template_parser.extract_variable_selectors() - - value_selector_mapping = { - variable_selector.variable: variable_selector.value_selector - for variable_selector in variable_selectors - } - - variable_keys = list(value_selector_mapping.keys()) - - # format answer template - template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True) - template_variable_keys = template_parser.variable_keys - - # Take the intersection of variable_keys and template_variable_keys - variable_keys = list(set(variable_keys) & set(template_variable_keys)) - - template = node_data.answer - for var in variable_keys: - template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω') - - generate_routes = [] - for part in template.split('Ω'): - if part: - if cls._is_variable(part, variable_keys): - var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '') - value_selector = value_selector_mapping[var_key] - generate_routes.append(VarGenerateRouteChunk( - value_selector=value_selector - )) - else: - generate_routes.append(TextGenerateRouteChunk( - text=part - )) - - return generate_routes - - @classmethod - def _is_variable(cls, part, variable_keys): - cleaned_part = part.replace('{{', '').replace('}}', '') - return part.startswith('{{') and cleaned_part in variable_keys - - @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, graph_config: Mapping[str, Any], node_id: str, node_data: AnswerNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ @@ -126,6 +63,6 @@ class AnswerNode(BaseNode): variable_mapping = {} for variable_selector in variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector + variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector return variable_mapping diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py new file mode 100644 index 0000000000..06050e1549 --- /dev/null +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -0,0 +1,164 @@ +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.answer.entities import ( + AnswerNodeData, + AnswerStreamGenerateRoute, + GenerateRouteChunk, + TextGenerateRouteChunk, + VarGenerateRouteChunk, +) +from core.workflow.utils.variable_template_parser import VariableTemplateParser + + +class AnswerStreamGeneratorRouter: + @classmethod + def init( + cls, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + ) -> AnswerStreamGenerateRoute: + """ + Get stream generate routes. + :return: + """ + # parse stream output node value selectors of answer nodes + answer_generate_route: dict[str, list[GenerateRouteChunk]] = {} + for answer_node_id, node_config in node_id_config_mapping.items(): + if not node_config.get("data", {}).get("type") == NodeType.ANSWER.value: + continue + + # get generate route for stream output + generate_route = cls._extract_generate_route_selectors(node_config) + answer_generate_route[answer_node_id] = generate_route + + # fetch answer dependencies + answer_node_ids = list(answer_generate_route.keys()) + answer_dependencies = cls._fetch_answers_dependencies( + answer_node_ids=answer_node_ids, + reverse_edge_mapping=reverse_edge_mapping, + node_id_config_mapping=node_id_config_mapping, + ) + + return AnswerStreamGenerateRoute( + answer_generate_route=answer_generate_route, answer_dependencies=answer_dependencies + ) + + @classmethod + def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]: + """ + Extract generate route from node data + :param node_data: node data object + :return: + """ + variable_template_parser = VariableTemplateParser(template=node_data.answer) + variable_selectors = variable_template_parser.extract_variable_selectors() + + value_selector_mapping = { + variable_selector.variable: variable_selector.value_selector for variable_selector in variable_selectors + } + + variable_keys = list(value_selector_mapping.keys()) + + # format answer template + template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True) + template_variable_keys = template_parser.variable_keys + + # Take the intersection of variable_keys and template_variable_keys + variable_keys = list(set(variable_keys) & set(template_variable_keys)) + + template = node_data.answer + for var in variable_keys: + template = template.replace(f"{{{{{var}}}}}", f"Ω{{{{{var}}}}}Ω") + + generate_routes: list[GenerateRouteChunk] = [] + for part in template.split("Ω"): + if part: + if cls._is_variable(part, variable_keys): + var_key = part.replace("Ω", "").replace("{{", "").replace("}}", "") + value_selector = value_selector_mapping[var_key] + generate_routes.append(VarGenerateRouteChunk(value_selector=value_selector)) + else: + generate_routes.append(TextGenerateRouteChunk(text=part)) + + return generate_routes + + @classmethod + def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]: + """ + Extract generate route selectors + :param config: node config + :return: + """ + node_data = AnswerNodeData(**config.get("data", {})) + return cls.extract_generate_route_from_node_data(node_data) + + @classmethod + def _is_variable(cls, part, variable_keys): + cleaned_part = part.replace("{{", "").replace("}}", "") + return part.startswith("{{") and cleaned_part in variable_keys + + @classmethod + def _fetch_answers_dependencies( + cls, + answer_node_ids: list[str], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + node_id_config_mapping: dict[str, dict], + ) -> dict[str, list[str]]: + """ + Fetch answer dependencies + :param answer_node_ids: answer node ids + :param reverse_edge_mapping: reverse edge mapping + :param node_id_config_mapping: node id config mapping + :return: + """ + answer_dependencies: dict[str, list[str]] = {} + for answer_node_id in answer_node_ids: + if answer_dependencies.get(answer_node_id) is None: + answer_dependencies[answer_node_id] = [] + + cls._recursive_fetch_answer_dependencies( + current_node_id=answer_node_id, + answer_node_id=answer_node_id, + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping, + answer_dependencies=answer_dependencies, + ) + + return answer_dependencies + + @classmethod + def _recursive_fetch_answer_dependencies( + cls, + current_node_id: str, + answer_node_id: str, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + answer_dependencies: dict[str, list[str]], + ) -> None: + """ + Recursive fetch answer dependencies + :param current_node_id: current node id + :param answer_node_id: answer node id + :param node_id_config_mapping: node id config mapping + :param reverse_edge_mapping: reverse edge mapping + :param answer_dependencies: answer dependencies + :return: + """ + reverse_edges = reverse_edge_mapping.get(current_node_id, []) + for edge in reverse_edges: + source_node_id = edge.source_node_id + source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") + if source_node_type in ( + NodeType.ANSWER.value, + NodeType.IF_ELSE.value, + NodeType.QUESTION_CLASSIFIER.value, + ): + answer_dependencies[answer_node_id].append(source_node_id) + else: + cls._recursive_fetch_answer_dependencies( + current_node_id=source_node_id, + answer_node_id=answer_node_id, + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping, + answer_dependencies=answer_dependencies, + ) diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py new file mode 100644 index 0000000000..32dbf436ec --- /dev/null +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -0,0 +1,221 @@ +import logging +from collections.abc import Generator +from typing import Optional, cast + +from core.file.file_obj import FileVar +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.nodes.answer.base_stream_processor import StreamProcessor +from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk + +logger = logging.getLogger(__name__) + + +class AnswerStreamProcessor(StreamProcessor): + def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: + super().__init__(graph, variable_pool) + self.generate_routes = graph.answer_stream_generate_routes + self.route_position = {} + for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items(): + self.route_position[answer_node_id] = 0 + self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} + + def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: + for event in generator: + if isinstance(event, NodeRunStartedEvent): + if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids: + self.reset() + + yield event + elif isinstance(event, NodeRunStreamChunkEvent): + if event.in_iteration_id: + yield event + continue + + if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: + stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[ + event.route_node_state.node_id + ] + else: + stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event) + self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = ( + stream_out_answer_node_ids + ) + + for _ in stream_out_answer_node_ids: + yield event + elif isinstance(event, NodeRunSucceededEvent): + yield event + if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: + # update self.route_position after all stream event finished + for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]: + self.route_position[answer_node_id] += 1 + + del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] + + # remove unreachable nodes + self._remove_unreachable_nodes(event) + + # generate stream outputs + yield from self._generate_stream_outputs_when_node_finished(event) + else: + yield event + + def reset(self) -> None: + self.route_position = {} + for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items(): + self.route_position[answer_node_id] = 0 + self.rest_node_ids = self.graph.node_ids.copy() + self.current_stream_chunk_generating_node_ids = {} + + def _generate_stream_outputs_when_node_finished( + self, event: NodeRunSucceededEvent + ) -> Generator[GraphEngineEvent, None, None]: + """ + Generate stream outputs. + :param event: node run succeeded event + :return: + """ + for answer_node_id, position in self.route_position.items(): + # all depends on answer node id not in rest node ids + if event.route_node_state.node_id != answer_node_id and ( + answer_node_id not in self.rest_node_ids + or not all( + dep_id not in self.rest_node_ids + for dep_id in self.generate_routes.answer_dependencies[answer_node_id] + ) + ): + continue + + route_position = self.route_position[answer_node_id] + route_chunks = self.generate_routes.answer_generate_route[answer_node_id][route_position:] + + for route_chunk in route_chunks: + if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT: + route_chunk = cast(TextGenerateRouteChunk, route_chunk) + yield NodeRunStreamChunkEvent( + id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + chunk_content=route_chunk.text, + route_node_state=event.route_node_state, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + from_variable_selector=[answer_node_id, "answer"], + ) + else: + route_chunk = cast(VarGenerateRouteChunk, route_chunk) + value_selector = route_chunk.value_selector + if not value_selector: + break + + value = self.variable_pool.get(value_selector) + + if value is None: + break + + text = value.markdown + + if text: + yield NodeRunStreamChunkEvent( + id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + chunk_content=text, + from_variable_selector=value_selector, + route_node_state=event.route_node_state, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + ) + + self.route_position[answer_node_id] += 1 + + def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]: + """ + Is stream out support + :param event: queue text chunk event + :return: + """ + if not event.from_variable_selector: + return [] + + stream_output_value_selector = event.from_variable_selector + if not stream_output_value_selector: + return [] + + stream_out_answer_node_ids = [] + for answer_node_id, route_position in self.route_position.items(): + if answer_node_id not in self.rest_node_ids: + continue + + # all depends on answer node id not in rest node ids + if all( + dep_id not in self.rest_node_ids for dep_id in self.generate_routes.answer_dependencies[answer_node_id] + ): + if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]): + continue + + route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position] + + if route_chunk.type != GenerateRouteChunk.ChunkType.VAR: + continue + + route_chunk = cast(VarGenerateRouteChunk, route_chunk) + value_selector = route_chunk.value_selector + + # check chunk node id is before current node id or equal to current node id + if value_selector != stream_output_value_selector: + continue + + stream_out_answer_node_ids.append(answer_node_id) + + return stream_out_answer_node_ids + + @classmethod + def _fetch_files_from_variable_value(cls, value: dict | list) -> list[dict]: + """ + Fetch files from variable value + :param value: variable value + :return: + """ + if not value: + return [] + + files = [] + if isinstance(value, list): + for item in value: + file_var = cls._get_file_var_from_value(item) + if file_var: + files.append(file_var) + elif isinstance(value, dict): + file_var = cls._get_file_var_from_value(value) + if file_var: + files.append(file_var) + + return files + + @classmethod + def _get_file_var_from_value(cls, value: dict | list) -> Optional[dict]: + """ + Get file var from value + :param value: variable value + :return: + """ + if not value: + return None + + if isinstance(value, dict): + if "__variant" in value and value["__variant"] == FileVar.__name__: + return value + elif isinstance(value, FileVar): + return value.to_dict() + + return None diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py new file mode 100644 index 0000000000..36c3fe180a --- /dev/null +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -0,0 +1,70 @@ +from abc import ABC, abstractmethod +from collections.abc import Generator + +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent +from core.workflow.graph_engine.entities.graph import Graph + + +class StreamProcessor(ABC): + def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: + self.graph = graph + self.variable_pool = variable_pool + self.rest_node_ids = graph.node_ids.copy() + + @abstractmethod + def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: + raise NotImplementedError + + def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None: + finished_node_id = event.route_node_state.node_id + if finished_node_id not in self.rest_node_ids: + return + + # remove finished node id + self.rest_node_ids.remove(finished_node_id) + + run_result = event.route_node_state.node_run_result + if not run_result: + return + + if run_result.edge_source_handle: + reachable_node_ids = [] + unreachable_first_node_ids = [] + for edge in self.graph.edge_mapping[finished_node_id]: + if ( + edge.run_condition + and edge.run_condition.branch_identify + and run_result.edge_source_handle == edge.run_condition.branch_identify + ): + reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) + continue + else: + unreachable_first_node_ids.append(edge.target_node_id) + + for node_id in unreachable_first_node_ids: + self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids) + + def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]: + node_ids = [] + for edge in self.graph.edge_mapping.get(node_id, []): + if edge.target_node_id == self.graph.root_node_id: + continue + + node_ids.append(edge.target_node_id) + node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) + return node_ids + + def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None: + """ + remove target node ids until merge + """ + if node_id not in self.rest_node_ids: + return + + self.rest_node_ids.remove(node_id) + for edge in self.graph.edge_mapping.get(node_id, []): + if edge.target_node_id in reachable_node_ids: + continue + + self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids) diff --git a/api/core/workflow/nodes/answer/entities.py b/api/core/workflow/nodes/answer/entities.py index 9effbbbe67..e356e7fd70 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -1,5 +1,6 @@ +from enum import Enum -from pydantic import BaseModel +from pydantic import BaseModel, Field from core.workflow.entities.base_node_data_entities import BaseNodeData @@ -8,27 +9,56 @@ class AnswerNodeData(BaseNodeData): """ Answer Node Data. """ - answer: str + + answer: str = Field(..., description="answer template string") class GenerateRouteChunk(BaseModel): """ Generate Route Chunk. """ - type: str + + class ChunkType(Enum): + VAR = "var" + TEXT = "text" + + type: ChunkType = Field(..., description="generate route chunk type") class VarGenerateRouteChunk(GenerateRouteChunk): """ Var Generate Route Chunk. """ - type: str = "var" - value_selector: list[str] + + type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR + """generate route chunk type""" + value_selector: list[str] = Field(..., description="value selector") class TextGenerateRouteChunk(GenerateRouteChunk): """ Text Generate Route Chunk. """ - type: str = "text" - text: str + + type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT + """generate route chunk type""" + text: str = Field(..., description="text") + + +class AnswerNodeDoubleLink(BaseModel): + node_id: str = Field(..., description="node id") + source_node_ids: list[str] = Field(..., description="source node ids") + target_node_ids: list[str] = Field(..., description="target node ids") + + +class AnswerStreamGenerateRoute(BaseModel): + """ + AnswerStreamGenerateRoute entity + """ + + answer_dependencies: dict[str, list[str]] = Field( + ..., description="answer dependencies (answer node id -> dependent answer node ids)" + ) + answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field( + ..., description="answer generate route (answer node id -> generate route chunks)" + ) diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 3d9cf52771..7bfe45a13c 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -1,142 +1,99 @@ from abc import ABC, abstractmethod -from collections.abc import Mapping, Sequence -from enum import Enum +from collections.abc import Generator, Mapping, Sequence from typing import Any, Optional -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from models import WorkflowNodeExecutionStatus - - -class UserFrom(Enum): - """ - User from - """ - ACCOUNT = "account" - END_USER = "end-user" - - @classmethod - def value_of(cls, value: str) -> "UserFrom": - """ - Value of - :param value: value - :return: - """ - for item in cls: - if item.value == value: - return item - raise ValueError(f"Invalid value: {value}") +from core.workflow.graph_engine.entities.event import InNodeEvent +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.event import RunCompletedEvent, RunEvent class BaseNode(ABC): _node_data_cls: type[BaseNodeData] _node_type: NodeType - tenant_id: str - app_id: str - workflow_id: str - user_id: str - user_from: UserFrom - invoke_from: InvokeFrom - - workflow_call_depth: int + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: GraphInitParams, + graph: Graph, + graph_runtime_state: GraphRuntimeState, + previous_node_id: Optional[str] = None, + thread_pool_id: Optional[str] = None, + ) -> None: + self.id = id + self.tenant_id = graph_init_params.tenant_id + self.app_id = graph_init_params.app_id + self.workflow_type = graph_init_params.workflow_type + self.workflow_id = graph_init_params.workflow_id + self.graph_config = graph_init_params.graph_config + self.user_id = graph_init_params.user_id + self.user_from = graph_init_params.user_from + self.invoke_from = graph_init_params.invoke_from + self.workflow_call_depth = graph_init_params.call_depth + self.graph = graph + self.graph_runtime_state = graph_runtime_state + self.previous_node_id = previous_node_id + self.thread_pool_id = thread_pool_id - node_id: str - node_data: BaseNodeData - node_run_result: Optional[NodeRunResult] = None - - callbacks: Sequence[WorkflowCallback] - - is_answer_previous_node: bool = False - - def __init__(self, tenant_id: str, - app_id: str, - workflow_id: str, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, - config: Mapping[str, Any], - callbacks: Sequence[WorkflowCallback] | None = None, - workflow_call_depth: int = 0) -> None: - self.tenant_id = tenant_id - self.app_id = app_id - self.workflow_id = workflow_id - self.user_id = user_id - self.user_from = user_from - self.invoke_from = invoke_from - self.workflow_call_depth = workflow_call_depth - - # TODO: May need to check if key exists. - self.node_id = config["id"] - if not self.node_id: + node_id = config.get("id") + if not node_id: raise ValueError("Node ID is required.") + self.node_id = node_id self.node_data = self._node_data_cls(**config.get("data", {})) - self.callbacks = callbacks or [] @abstractmethod - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]: """ Run node - :param variable_pool: variable pool :return: """ raise NotImplementedError - def run(self, variable_pool: VariablePool) -> NodeRunResult: + def run(self) -> Generator[RunEvent | InNodeEvent, None, None]: """ Run node entry - :param variable_pool: variable pool :return: """ - try: - result = self._run( - variable_pool=variable_pool - ) - self.node_run_result = result - return result - except Exception as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - ) + result = self._run() - def publish_text_chunk(self, text: str, value_selector: list[str] | None = None) -> None: - """ - Publish text chunk - :param text: chunk text - :param value_selector: value selector - :return: - """ - if self.callbacks: - for callback in self.callbacks: - callback.on_node_text_chunk( - node_id=self.node_id, - text=text, - metadata={ - "node_type": self.node_type, - "is_answer_previous_node": self.is_answer_previous_node, - "value_selector": value_selector - } - ) + if isinstance(result, NodeRunResult): + yield RunCompletedEvent(run_result=result) + else: + yield from result @classmethod - def extract_variable_selector_to_variable_mapping(cls, config: dict): + def extract_variable_selector_to_variable_mapping( + cls, graph_config: Mapping[str, Any], config: dict + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config :param config: node config :return: """ + node_id = config.get("id") + if not node_id: + raise ValueError("Node ID is required when extracting variable selector to variable mapping.") + node_data = cls._node_data_cls(**config.get("data", {})) - return cls._extract_variable_selector_to_variable_mapping(node_data) + return cls._extract_variable_selector_to_variable_mapping( + graph_config=graph_config, node_id=node_id, node_data=node_data + ) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> Mapping[str, Sequence[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, graph_config: Mapping[str, Any], node_id: str, node_data: BaseNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ @@ -158,38 +115,3 @@ class BaseNode(ABC): :return: """ return self._node_type - -class BaseIterationNode(BaseNode): - @abstractmethod - def _run(self, variable_pool: VariablePool) -> BaseIterationState: - """ - Run node - :param variable_pool: variable pool - :return: - """ - raise NotImplementedError - - def run(self, variable_pool: VariablePool) -> BaseIterationState: - """ - Run node entry - :param variable_pool: variable pool - :return: - """ - return self._run(variable_pool=variable_pool) - - def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str: - """ - Get next iteration start node id based on the graph. - :param graph: graph - :return: next node id - """ - return self._get_next_iteration(variable_pool, state) - - @abstractmethod - def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str: - """ - Get next iteration start node id based on the graph. - :param graph: graph - :return: next node id - """ - raise NotImplementedError diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 335991ae87..a07ba2f740 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,12 +1,12 @@ -from typing import Optional, Union, cast +from collections.abc import Mapping, Sequence +from typing import Any, Optional, Union, cast from configs import dify_config -from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.code.entities import CodeNodeData from models.workflow import WorkflowNodeExecutionStatus @@ -25,21 +25,20 @@ class CodeNode(BaseNode): """ code_language = CodeLanguage.PYTHON3 if filters: - code_language = (filters.get("code_language", CodeLanguage.PYTHON3)) + code_language = filters.get("code_language", CodeLanguage.PYTHON3) providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider] - code_provider: type[CodeNodeProvider] = next(p for p in providers - if p.is_accept_language(code_language)) + code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language)) return code_provider.get_default_config() - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run code - :param variable_pool: variable pool :return: """ - node_data = cast(CodeNodeData, self.node_data) + node_data = self.node_data + node_data = cast(CodeNodeData, node_data) # Get code language code_language = node_data.code_language @@ -49,7 +48,7 @@ class CodeNode(BaseNode): variables = {} for variable_selector in node_data.variables: variable = variable_selector.variable - value = variable_pool.get_any(variable_selector.value_selector) + value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) variables[variable] = value # Run code @@ -62,18 +61,10 @@ class CodeNode(BaseNode): # Transform result result = self._transform_result(result, node_data.outputs) - except (CodeExecutionException, ValueError) as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=variables, - error=str(e) - ) + except (CodeExecutionError, ValueError) as e: + return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=variables, - outputs=result - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) def _check_string(self, value: str, variable: str) -> str: """ @@ -87,12 +78,14 @@ class CodeNode(BaseNode): return None else: raise ValueError(f"Output variable `{variable}` must be a string") - - if len(value) > dify_config.CODE_MAX_STRING_LENGTH: - raise ValueError(f'The length of output variable `{variable}` must be' - f' less than {dify_config.CODE_MAX_STRING_LENGTH} characters') - return value.replace('\x00', '') + if len(value) > dify_config.CODE_MAX_STRING_LENGTH: + raise ValueError( + f"The length of output variable `{variable}` must be" + f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters" + ) + + return value.replace("\x00", "") def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]: """ @@ -108,20 +101,24 @@ class CodeNode(BaseNode): raise ValueError(f"Output variable `{variable}` must be a number") if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER: - raise ValueError(f'Output variable `{variable}` is out of range,' - f' it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}.') + raise ValueError( + f"Output variable `{variable}` is out of range," + f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}." + ) if isinstance(value, float): # raise error if precision is too high - if len(str(value).split('.')[1]) > dify_config.CODE_MAX_PRECISION: - raise ValueError(f'Output variable `{variable}` has too high precision,' - f' it must be less than {dify_config.CODE_MAX_PRECISION} digits.') + if len(str(value).split(".")[1]) > dify_config.CODE_MAX_PRECISION: + raise ValueError( + f"Output variable `{variable}` has too high precision," + f" it must be less than {dify_config.CODE_MAX_PRECISION} digits." + ) return value - def _transform_result(self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], - prefix: str = '', - depth: int = 1) -> dict: + def _transform_result( + self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], prefix: str = "", depth: int = 1 + ) -> dict: """ Transform result :param result: result @@ -139,185 +136,190 @@ class CodeNode(BaseNode): self._transform_result( result=output_value, output_schema=None, - prefix=f'{prefix}.{output_name}' if prefix else output_name, - depth=depth + 1 + prefix=f"{prefix}.{output_name}" if prefix else output_name, + depth=depth + 1, ) elif isinstance(output_value, int | float): self._check_number( - value=output_value, - variable=f'{prefix}.{output_name}' if prefix else output_name + value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name ) elif isinstance(output_value, str): self._check_string( - value=output_value, - variable=f'{prefix}.{output_name}' if prefix else output_name + value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name ) elif isinstance(output_value, list): first_element = output_value[0] if len(output_value) > 0 else None if first_element is not None: - if isinstance(first_element, int | float) and all(value is None or isinstance(value, int | float) for value in output_value): + if isinstance(first_element, int | float) and all( + value is None or isinstance(value, int | float) for value in output_value + ): for i, value in enumerate(output_value): self._check_number( value=value, - variable=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]' + variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", ) - elif isinstance(first_element, str) and all(value is None or isinstance(value, str) for value in output_value): + elif isinstance(first_element, str) and all( + value is None or isinstance(value, str) for value in output_value + ): for i, value in enumerate(output_value): self._check_string( value=value, - variable=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]' + variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", ) - elif isinstance(first_element, dict) and all(value is None or isinstance(value, dict) for value in output_value): + elif isinstance(first_element, dict) and all( + value is None or isinstance(value, dict) for value in output_value + ): for i, value in enumerate(output_value): if value is not None: self._transform_result( result=value, output_schema=None, - prefix=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]', - depth=depth + 1 + prefix=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", + depth=depth + 1, ) else: - raise ValueError(f'Output {prefix}.{output_name} is not a valid array. make sure all elements are of the same type.') + raise ValueError( + f"Output {prefix}.{output_name} is not a valid array. make sure all elements are of the same type." + ) elif isinstance(output_value, type(None)): pass else: - raise ValueError(f'Output {prefix}.{output_name} is not a valid type.') - + raise ValueError(f"Output {prefix}.{output_name} is not a valid type.") + return result parameters_validated = {} for output_name, output_config in output_schema.items(): - dot = '.' if prefix else '' + dot = "." if prefix else "" if output_name not in result: - raise ValueError(f'Output {prefix}{dot}{output_name} is missing.') - - if output_config.type == 'object': + raise ValueError(f"Output {prefix}{dot}{output_name} is missing.") + + if output_config.type == "object": # check if output is object if not isinstance(result.get(output_name), dict): if isinstance(result.get(output_name), type(None)): transformed_result[output_name] = None else: raise ValueError( - f'Output {prefix}{dot}{output_name} is not an object, got {type(result.get(output_name))} instead.' + f"Output {prefix}{dot}{output_name} is not an object, got {type(result.get(output_name))} instead." ) else: transformed_result[output_name] = self._transform_result( result=result[output_name], output_schema=output_config.children, - prefix=f'{prefix}.{output_name}', - depth=depth + 1 + prefix=f"{prefix}.{output_name}", + depth=depth + 1, ) - elif output_config.type == 'number': + elif output_config.type == "number": # check if number available transformed_result[output_name] = self._check_number( - value=result[output_name], - variable=f'{prefix}{dot}{output_name}' + value=result[output_name], variable=f"{prefix}{dot}{output_name}" ) - elif output_config.type == 'string': + elif output_config.type == "string": # check if string available transformed_result[output_name] = self._check_string( value=result[output_name], - variable=f'{prefix}{dot}{output_name}', + variable=f"{prefix}{dot}{output_name}", ) - elif output_config.type == 'array[number]': + elif output_config.type == "array[number]": # check if array of number available if not isinstance(result[output_name], list): if isinstance(result[output_name], type(None)): transformed_result[output_name] = None else: raise ValueError( - f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' + f"Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead." ) else: if len(result[output_name]) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH: raise ValueError( - f'The length of output variable `{prefix}{dot}{output_name}` must be' - f' less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements.' + f"The length of output variable `{prefix}{dot}{output_name}` must be" + f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements." ) transformed_result[output_name] = [ - self._check_number( - value=value, - variable=f'{prefix}{dot}{output_name}[{i}]' - ) + self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]") for i, value in enumerate(result[output_name]) ] - elif output_config.type == 'array[string]': + elif output_config.type == "array[string]": # check if array of string available if not isinstance(result[output_name], list): if isinstance(result[output_name], type(None)): transformed_result[output_name] = None else: raise ValueError( - f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' + f"Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead." ) else: if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH: raise ValueError( - f'The length of output variable `{prefix}{dot}{output_name}` must be' - f' less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements.' + f"The length of output variable `{prefix}{dot}{output_name}` must be" + f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements." ) transformed_result[output_name] = [ - self._check_string( - value=value, - variable=f'{prefix}{dot}{output_name}[{i}]' - ) + self._check_string(value=value, variable=f"{prefix}{dot}{output_name}[{i}]") for i, value in enumerate(result[output_name]) ] - elif output_config.type == 'array[object]': + elif output_config.type == "array[object]": # check if array of object available if not isinstance(result[output_name], list): if isinstance(result[output_name], type(None)): transformed_result[output_name] = None else: raise ValueError( - f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' + f"Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead." ) else: if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH: raise ValueError( - f'The length of output variable `{prefix}{dot}{output_name}` must be' - f' less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements.' + f"The length of output variable `{prefix}{dot}{output_name}` must be" + f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements." ) - + for i, value in enumerate(result[output_name]): if not isinstance(value, dict): if isinstance(value, type(None)): pass else: raise ValueError( - f'Output {prefix}{dot}{output_name}[{i}] is not an object, got {type(value)} instead at index {i}.' + f"Output {prefix}{dot}{output_name}[{i}] is not an object, got {type(value)} instead at index {i}." ) transformed_result[output_name] = [ - None if value is None else self._transform_result( + None + if value is None + else self._transform_result( result=value, output_schema=output_config.children, - prefix=f'{prefix}{dot}{output_name}[{i}]', - depth=depth + 1 + prefix=f"{prefix}{dot}{output_name}[{i}]", + depth=depth + 1, ) for i, value in enumerate(result[output_name]) ] else: - raise ValueError(f'Output type {output_config.type} is not supported.') - + raise ValueError(f"Output type {output_config.type} is not supported.") + parameters_validated[output_name] = True # check if all output parameters are validated if len(parameters_validated) != len(result): - raise ValueError('Not all output parameters are validated.') + raise ValueError("Not all output parameters are validated.") return transformed_result @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: CodeNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, graph_config: Mapping[str, Any], node_id: str, node_data: CodeNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ - return { - variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables + node_id + "." + variable_selector.variable: variable_selector.value_selector + for variable_selector in node_data.variables } diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index c0701ecccd..5eb0e0f63f 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -11,9 +11,10 @@ class CodeNodeData(BaseNodeData): """ Code Node Data. """ + class Output(BaseModel): - type: Literal['string', 'number', 'object', 'array[string]', 'array[number]', 'array[object]'] - children: Optional[dict[str, 'Output']] = None + type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"] + children: Optional[dict[str, "Output"]] = None class Dependency(BaseModel): name: str @@ -23,4 +24,4 @@ class CodeNodeData(BaseNodeData): code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT] code: str outputs: dict[str, Output] - dependencies: Optional[list[Dependency]] = None \ No newline at end of file + dependencies: Optional[list[Dependency]] = None diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 440dfa2f27..7b78d67be8 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,8 +1,7 @@ -from typing import cast +from collections.abc import Mapping, Sequence +from typing import Any, cast -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.end.entities import EndNodeData from models.workflow import WorkflowNodeExecutionStatus @@ -12,10 +11,9 @@ class EndNode(BaseNode): _node_data_cls = EndNodeData _node_type = NodeType.END - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run node - :param variable_pool: variable pool :return: """ node_data = self.node_data @@ -24,62 +22,19 @@ class EndNode(BaseNode): outputs = {} for variable_selector in output_variables: - value = variable_pool.get_any(variable_selector.value_selector) + value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) outputs[variable_selector.variable] = value - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=outputs, - outputs=outputs - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=outputs, outputs=outputs) @classmethod - def extract_generate_nodes(cls, graph: dict, config: dict) -> list[str]: - """ - Extract generate nodes - :param graph: graph - :param config: node config - :return: - """ - node_data = cls._node_data_cls(**config.get("data", {})) - node_data = cast(EndNodeData, node_data) - - return cls.extract_generate_nodes_from_node_data(graph, node_data) - - @classmethod - def extract_generate_nodes_from_node_data(cls, graph: dict, node_data: EndNodeData) -> list[str]: - """ - Extract generate nodes from node data - :param graph: graph - :param node_data: node data object - :return: - """ - nodes = graph.get('nodes', []) - node_mapping = {node.get('id'): node for node in nodes} - - variable_selectors = node_data.outputs - - generate_nodes = [] - for variable_selector in variable_selectors: - if not variable_selector.value_selector: - continue - - node_id = variable_selector.value_selector[0] - if node_id != 'sys' and node_id in node_mapping: - node = node_mapping[node_id] - node_type = node.get('data', {}).get('type') - if node_type == NodeType.LLM.value and variable_selector.value_selector[1] == 'text': - generate_nodes.append(node_id) - - # remove duplicates - generate_nodes = list(set(generate_nodes)) - - return generate_nodes - - @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, graph_config: Mapping[str, Any], node_id: str, node_data: EndNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ diff --git a/api/core/workflow/nodes/end/end_stream_generate_router.py b/api/core/workflow/nodes/end/end_stream_generate_router.py new file mode 100644 index 0000000000..30ce8fe018 --- /dev/null +++ b/api/core/workflow/nodes/end/end_stream_generate_router.py @@ -0,0 +1,151 @@ +from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam + + +class EndStreamGeneratorRouter: + @classmethod + def init( + cls, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + node_parallel_mapping: dict[str, str], + ) -> EndStreamParam: + """ + Get stream generate routes. + :return: + """ + # parse stream output node value selector of end nodes + end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {} + for end_node_id, node_config in node_id_config_mapping.items(): + if not node_config.get("data", {}).get("type") == NodeType.END.value: + continue + + # skip end node in parallel + if end_node_id in node_parallel_mapping: + continue + + # get generate route for stream output + stream_variable_selectors = cls._extract_stream_variable_selector(node_id_config_mapping, node_config) + end_stream_variable_selectors_mapping[end_node_id] = stream_variable_selectors + + # fetch end dependencies + end_node_ids = list(end_stream_variable_selectors_mapping.keys()) + end_dependencies = cls._fetch_ends_dependencies( + end_node_ids=end_node_ids, + reverse_edge_mapping=reverse_edge_mapping, + node_id_config_mapping=node_id_config_mapping, + ) + + return EndStreamParam( + end_stream_variable_selector_mapping=end_stream_variable_selectors_mapping, + end_dependencies=end_dependencies, + ) + + @classmethod + def extract_stream_variable_selector_from_node_data( + cls, node_id_config_mapping: dict[str, dict], node_data: EndNodeData + ) -> list[list[str]]: + """ + Extract stream variable selector from node data + :param node_id_config_mapping: node id config mapping + :param node_data: node data object + :return: + """ + variable_selectors = node_data.outputs + + value_selectors = [] + for variable_selector in variable_selectors: + if not variable_selector.value_selector: + continue + + node_id = variable_selector.value_selector[0] + if node_id != "sys" and node_id in node_id_config_mapping: + node = node_id_config_mapping[node_id] + node_type = node.get("data", {}).get("type") + if ( + variable_selector.value_selector not in value_selectors + and node_type == NodeType.LLM.value + and variable_selector.value_selector[1] == "text" + ): + value_selectors.append(variable_selector.value_selector) + + return value_selectors + + @classmethod + def _extract_stream_variable_selector( + cls, node_id_config_mapping: dict[str, dict], config: dict + ) -> list[list[str]]: + """ + Extract stream variable selector from node config + :param node_id_config_mapping: node id config mapping + :param config: node config + :return: + """ + node_data = EndNodeData(**config.get("data", {})) + return cls.extract_stream_variable_selector_from_node_data(node_id_config_mapping, node_data) + + @classmethod + def _fetch_ends_dependencies( + cls, + end_node_ids: list[str], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + node_id_config_mapping: dict[str, dict], + ) -> dict[str, list[str]]: + """ + Fetch end dependencies + :param end_node_ids: end node ids + :param reverse_edge_mapping: reverse edge mapping + :param node_id_config_mapping: node id config mapping + :return: + """ + end_dependencies: dict[str, list[str]] = {} + for end_node_id in end_node_ids: + if end_dependencies.get(end_node_id) is None: + end_dependencies[end_node_id] = [] + + cls._recursive_fetch_end_dependencies( + current_node_id=end_node_id, + end_node_id=end_node_id, + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping, + end_dependencies=end_dependencies, + ) + + return end_dependencies + + @classmethod + def _recursive_fetch_end_dependencies( + cls, + current_node_id: str, + end_node_id: str, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], + # type: ignore[name-defined] + end_dependencies: dict[str, list[str]], + ) -> None: + """ + Recursive fetch end dependencies + :param current_node_id: current node id + :param end_node_id: end node id + :param node_id_config_mapping: node id config mapping + :param reverse_edge_mapping: reverse edge mapping + :param end_dependencies: end dependencies + :return: + """ + reverse_edges = reverse_edge_mapping.get(current_node_id, []) + for edge in reverse_edges: + source_node_id = edge.source_node_id + source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") + if source_node_type in ( + NodeType.IF_ELSE.value, + NodeType.QUESTION_CLASSIFIER, + ): + end_dependencies[end_node_id].append(source_node_id) + else: + cls._recursive_fetch_end_dependencies( + current_node_id=source_node_id, + end_node_id=end_node_id, + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping, + end_dependencies=end_dependencies, + ) diff --git a/api/core/workflow/nodes/end/end_stream_processor.py b/api/core/workflow/nodes/end/end_stream_processor.py new file mode 100644 index 0000000000..0366d7965d --- /dev/null +++ b/api/core/workflow/nodes/end/end_stream_processor.py @@ -0,0 +1,187 @@ +import logging +from collections.abc import Generator + +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.nodes.answer.base_stream_processor import StreamProcessor + +logger = logging.getLogger(__name__) + + +class EndStreamProcessor(StreamProcessor): + def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: + super().__init__(graph, variable_pool) + self.end_stream_param = graph.end_stream_param + self.route_position = {} + for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items(): + self.route_position[end_node_id] = 0 + self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} + self.has_outputed = False + self.outputed_node_ids = set() + + def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: + for event in generator: + if isinstance(event, NodeRunStartedEvent): + if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids: + self.reset() + + yield event + elif isinstance(event, NodeRunStreamChunkEvent): + if event.in_iteration_id: + if self.has_outputed and event.node_id not in self.outputed_node_ids: + event.chunk_content = "\n" + event.chunk_content + + self.outputed_node_ids.add(event.node_id) + self.has_outputed = True + yield event + continue + + if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: + stream_out_end_node_ids = self.current_stream_chunk_generating_node_ids[ + event.route_node_state.node_id + ] + else: + stream_out_end_node_ids = self._get_stream_out_end_node_ids(event) + self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = ( + stream_out_end_node_ids + ) + + if stream_out_end_node_ids: + if self.has_outputed and event.node_id not in self.outputed_node_ids: + event.chunk_content = "\n" + event.chunk_content + + self.outputed_node_ids.add(event.node_id) + self.has_outputed = True + yield event + elif isinstance(event, NodeRunSucceededEvent): + yield event + if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: + # update self.route_position after all stream event finished + for end_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]: + self.route_position[end_node_id] += 1 + + del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] + + # remove unreachable nodes + self._remove_unreachable_nodes(event) + + # generate stream outputs + yield from self._generate_stream_outputs_when_node_finished(event) + else: + yield event + + def reset(self) -> None: + self.route_position = {} + for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items(): + self.route_position[end_node_id] = 0 + self.rest_node_ids = self.graph.node_ids.copy() + self.current_stream_chunk_generating_node_ids = {} + + def _generate_stream_outputs_when_node_finished( + self, event: NodeRunSucceededEvent + ) -> Generator[GraphEngineEvent, None, None]: + """ + Generate stream outputs. + :param event: node run succeeded event + :return: + """ + for end_node_id, position in self.route_position.items(): + # all depends on end node id not in rest node ids + if event.route_node_state.node_id != end_node_id and ( + end_node_id not in self.rest_node_ids + or not all( + dep_id not in self.rest_node_ids for dep_id in self.end_stream_param.end_dependencies[end_node_id] + ) + ): + continue + + route_position = self.route_position[end_node_id] + + position = 0 + value_selectors = [] + for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]: + if position >= route_position: + value_selectors.append(current_value_selectors) + + position += 1 + + for value_selector in value_selectors: + if not value_selector: + continue + + value = self.variable_pool.get(value_selector) + + if value is None: + break + + text = value.markdown + + if text: + current_node_id = value_selector[0] + if self.has_outputed and current_node_id not in self.outputed_node_ids: + text = "\n" + text + + self.outputed_node_ids.add(current_node_id) + self.has_outputed = True + yield NodeRunStreamChunkEvent( + id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + chunk_content=text, + from_variable_selector=value_selector, + route_node_state=event.route_node_state, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + ) + + self.route_position[end_node_id] += 1 + + def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]: + """ + Is stream out support + :param event: queue text chunk event + :return: + """ + if not event.from_variable_selector: + return [] + + stream_output_value_selector = event.from_variable_selector + if not stream_output_value_selector: + return [] + + stream_out_end_node_ids = [] + for end_node_id, route_position in self.route_position.items(): + if end_node_id not in self.rest_node_ids: + continue + + # all depends on end node id not in rest node ids + if all(dep_id not in self.rest_node_ids for dep_id in self.end_stream_param.end_dependencies[end_node_id]): + if route_position >= len(self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]): + continue + + position = 0 + value_selector = None + for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]: + if position == route_position: + value_selector = current_value_selectors + break + + position += 1 + + if not value_selector: + continue + + # check chunk node id is before current node id or equal to current node id + if value_selector != stream_output_value_selector: + continue + + stream_out_end_node_ids.append(end_node_id) + + return stream_out_end_node_ids diff --git a/api/core/workflow/nodes/end/entities.py b/api/core/workflow/nodes/end/entities.py index ad4fc8f04f..c3270ac22a 100644 --- a/api/core/workflow/nodes/end/entities.py +++ b/api/core/workflow/nodes/end/entities.py @@ -1,3 +1,5 @@ +from pydantic import BaseModel, Field + from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector @@ -6,4 +8,18 @@ class EndNodeData(BaseNodeData): """ END Node Data. """ + outputs: list[VariableSelector] + + +class EndStreamParam(BaseModel): + """ + EndStreamParam entity + """ + + end_dependencies: dict[str, list[str]] = Field( + ..., description="end dependencies (end node id -> dependent node ids)" + ) + end_stream_variable_selector_mapping: dict[str, list[list[str]]] = Field( + ..., description="end stream variable selector mapping (end node id -> stream variable selectors)" + ) diff --git a/api/core/workflow/nodes/event.py b/api/core/workflow/nodes/event.py new file mode 100644 index 0000000000..276c13a6d4 --- /dev/null +++ b/api/core/workflow/nodes/event.py @@ -0,0 +1,20 @@ +from pydantic import BaseModel, Field + +from core.workflow.entities.node_entities import NodeRunResult + + +class RunCompletedEvent(BaseModel): + run_result: NodeRunResult = Field(..., description="run result") + + +class RunStreamChunkEvent(BaseModel): + chunk_content: str = Field(..., description="chunk content") + from_variable_selector: list[str] = Field(..., description="from variable selector") + + +class RunRetrieverResourceEvent(BaseModel): + retriever_resources: list[dict] = Field(..., description="retriever resources") + context: str = Field(..., description="context") + + +RunEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index 90d644e0e2..66dd1f2dc6 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -5,45 +5,41 @@ from pydantic import BaseModel, ValidationInfo, field_validator from configs import dify_config from core.workflow.entities.base_node_data_entities import BaseNodeData -MAX_CONNECT_TIMEOUT = dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT -MAX_READ_TIMEOUT = dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT -MAX_WRITE_TIMEOUT = dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT - class HttpRequestNodeAuthorizationConfig(BaseModel): - type: Literal[None, 'basic', 'bearer', 'custom'] + type: Literal[None, "basic", "bearer", "custom"] api_key: Union[None, str] = None header: Union[None, str] = None class HttpRequestNodeAuthorization(BaseModel): - type: Literal['no-auth', 'api-key'] + type: Literal["no-auth", "api-key"] config: Optional[HttpRequestNodeAuthorizationConfig] = None - @field_validator('config', mode='before') + @field_validator("config", mode="before") @classmethod def check_config(cls, v: HttpRequestNodeAuthorizationConfig, values: ValidationInfo): """ Check config, if type is no-auth, config should be None, otherwise it should be a dict. """ - if values.data['type'] == 'no-auth': + if values.data["type"] == "no-auth": return None else: if not v or not isinstance(v, dict): - raise ValueError('config should be a dict') + raise ValueError("config should be a dict") return v class HttpRequestNodeBody(BaseModel): - type: Literal['none', 'form-data', 'x-www-form-urlencoded', 'raw-text', 'json'] + type: Literal["none", "form-data", "x-www-form-urlencoded", "raw-text", "json"] data: Union[None, str] = None class HttpRequestNodeTimeout(BaseModel): - connect: int = MAX_CONNECT_TIMEOUT - read: int = MAX_READ_TIMEOUT - write: int = MAX_WRITE_TIMEOUT + connect: int = dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT + read: int = dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT + write: int = dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT class HttpRequestNodeData(BaseNodeData): @@ -51,7 +47,7 @@ class HttpRequestNodeData(BaseNodeData): Code Node Data. """ - method: Literal['get', 'post', 'put', 'patch', 'delete', 'head'] + method: Literal["get", "post", "put", "patch", "delete", "head"] url: str authorization: HttpRequestNodeAuthorization headers: str diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py index d16bff58bd..49102dc3ab 100644 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -33,12 +33,12 @@ class HttpExecutorResponse: check if response is file """ content_type = self.get_content_type() - file_content_types = ['image', 'audio', 'video'] + file_content_types = ["image", "audio", "video"] return any(v in content_type for v in file_content_types) def get_content_type(self) -> str: - return self.headers.get('content-type', '') + return self.headers.get("content-type", "") def extract_file(self) -> tuple[str, bytes]: """ @@ -47,28 +47,28 @@ class HttpExecutorResponse: if self.is_file: return self.get_content_type(), self.body - return '', b'' + return "", b"" @property def content(self) -> str: if isinstance(self.response, httpx.Response): return self.response.text else: - raise ValueError(f'Invalid response type {type(self.response)}') + raise ValueError(f"Invalid response type {type(self.response)}") @property def body(self) -> bytes: if isinstance(self.response, httpx.Response): return self.response.content else: - raise ValueError(f'Invalid response type {type(self.response)}') + raise ValueError(f"Invalid response type {type(self.response)}") @property def status_code(self) -> int: if isinstance(self.response, httpx.Response): return self.response.status_code else: - raise ValueError(f'Invalid response type {type(self.response)}') + raise ValueError(f"Invalid response type {type(self.response)}") @property def size(self) -> int: @@ -77,11 +77,11 @@ class HttpExecutorResponse: @property def readable_size(self) -> str: if self.size < 1024: - return f'{self.size} bytes' + return f"{self.size} bytes" elif self.size < 1024 * 1024: - return f'{(self.size / 1024):.2f} KB' + return f"{(self.size / 1024):.2f} KB" else: - return f'{(self.size / 1024 / 1024):.2f} MB' + return f"{(self.size / 1024 / 1024):.2f} MB" class HttpExecutor: @@ -120,7 +120,7 @@ class HttpExecutor: """ check if body is json """ - if body and body.type == 'json' and body.data: + if body and body.type == "json" and body.data: try: json.loads(body.data) return True @@ -134,15 +134,15 @@ class HttpExecutor: """ Convert the string like `aa:bb\n cc:dd` to dict `{aa:bb, cc:dd}` """ - kv_paris = convert_text.split('\n') + kv_paris = convert_text.split("\n") result = {} for kv in kv_paris: if not kv.strip(): continue - kv = kv.split(':', maxsplit=1) + kv = kv.split(":", maxsplit=1) if len(kv) == 1: - k, v = kv[0], '' + k, v = kv[0], "" else: k, v = kv result[k.strip()] = v @@ -166,31 +166,31 @@ class HttpExecutor: # check if it's a valid JSON is_valid_json = self._is_json_body(node_data.body) - body_data = node_data.body.data or '' + body_data = node_data.body.data or "" if body_data: body_data, body_data_variable_selectors = self._format_template(body_data, variable_pool, is_valid_json) - content_type_is_set = any(key.lower() == 'content-type' for key in self.headers) - if node_data.body.type == 'json' and not content_type_is_set: - self.headers['Content-Type'] = 'application/json' - elif node_data.body.type == 'x-www-form-urlencoded' and not content_type_is_set: - self.headers['Content-Type'] = 'application/x-www-form-urlencoded' + content_type_is_set = any(key.lower() == "content-type" for key in self.headers) + if node_data.body.type == "json" and not content_type_is_set: + self.headers["Content-Type"] = "application/json" + elif node_data.body.type == "x-www-form-urlencoded" and not content_type_is_set: + self.headers["Content-Type"] = "application/x-www-form-urlencoded" - if node_data.body.type in ['form-data', 'x-www-form-urlencoded']: + if node_data.body.type in ["form-data", "x-www-form-urlencoded"]: body = self._to_dict(body_data) - if node_data.body.type == 'form-data': - self.files = {k: ('', v) for k, v in body.items()} - random_str = lambda n: ''.join([chr(randint(97, 122)) for _ in range(n)]) - self.boundary = f'----WebKitFormBoundary{random_str(16)}' + if node_data.body.type == "form-data": + self.files = {k: ("", v) for k, v in body.items()} + random_str = lambda n: "".join([chr(randint(97, 122)) for _ in range(n)]) + self.boundary = f"----WebKitFormBoundary{random_str(16)}" - self.headers['Content-Type'] = f'multipart/form-data; boundary={self.boundary}' + self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}" else: self.body = urlencode(body) - elif node_data.body.type in ['json', 'raw-text']: + elif node_data.body.type in ["json", "raw-text"]: self.body = body_data - elif node_data.body.type == 'none': - self.body = '' + elif node_data.body.type == "none": + self.body = "" self.variable_selectors = ( server_url_variable_selectors @@ -202,23 +202,23 @@ class HttpExecutor: def _assembling_headers(self) -> dict[str, Any]: authorization = deepcopy(self.authorization) headers = deepcopy(self.headers) or {} - if self.authorization.type == 'api-key': + if self.authorization.type == "api-key": if self.authorization.config is None: - raise ValueError('self.authorization config is required') + raise ValueError("self.authorization config is required") if authorization.config is None: - raise ValueError('authorization config is required') + raise ValueError("authorization config is required") if self.authorization.config.api_key is None: - raise ValueError('api_key is required') + raise ValueError("api_key is required") if not authorization.config.header: - authorization.config.header = 'Authorization' + authorization.config.header = "Authorization" - if self.authorization.config.type == 'bearer': - headers[authorization.config.header] = f'Bearer {authorization.config.api_key}' - elif self.authorization.config.type == 'basic': - headers[authorization.config.header] = f'Basic {authorization.config.api_key}' - elif self.authorization.config.type == 'custom': + if self.authorization.config.type == "bearer": + headers[authorization.config.header] = f"Bearer {authorization.config.api_key}" + elif self.authorization.config.type == "basic": + headers[authorization.config.header] = f"Basic {authorization.config.api_key}" + elif self.authorization.config.type == "custom": headers[authorization.config.header] = authorization.config.api_key return headers @@ -230,10 +230,13 @@ class HttpExecutor: if isinstance(response, httpx.Response): executor_response = HttpExecutorResponse(response) else: - raise ValueError(f'Invalid response type {type(response)}') + raise ValueError(f"Invalid response type {type(response)}") - threshold_size = dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE if executor_response.is_file \ + threshold_size = ( + dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE + if executor_response.is_file else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE + ) if executor_response.size > threshold_size: raise ValueError( f'{"File" if executor_response.is_file else "Text"} size is too large,' @@ -248,17 +251,17 @@ class HttpExecutor: do http request depending on api bundle """ kwargs = { - 'url': self.server_url, - 'headers': headers, - 'params': self.params, - 'timeout': (self.timeout.connect, self.timeout.read, self.timeout.write), - 'follow_redirects': True, + "url": self.server_url, + "headers": headers, + "params": self.params, + "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), + "follow_redirects": True, } - if self.method in ('get', 'head', 'post', 'put', 'delete', 'patch'): + if self.method in ("get", "head", "post", "put", "delete", "patch"): response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs) else: - raise ValueError(f'Invalid http method {self.method}') + raise ValueError(f"Invalid http method {self.method}") return response def invoke(self) -> HttpExecutorResponse: @@ -280,15 +283,15 @@ class HttpExecutor: """ server_url = self.server_url if self.params: - server_url += f'?{urlencode(self.params)}' + server_url += f"?{urlencode(self.params)}" - raw_request = f'{self.method.upper()} {server_url} HTTP/1.1\n' + raw_request = f"{self.method.upper()} {server_url} HTTP/1.1\n" headers = self._assembling_headers() for k, v in headers.items(): # get authorization header - if self.authorization.type == 'api-key': - authorization_header = 'Authorization' + if self.authorization.type == "api-key": + authorization_header = "Authorization" if self.authorization.config and self.authorization.config.header: authorization_header = self.authorization.config.header @@ -296,21 +299,21 @@ class HttpExecutor: raw_request += f'{k}: {"*" * len(v)}\n' continue - raw_request += f'{k}: {v}\n' + raw_request += f"{k}: {v}\n" - raw_request += '\n' + raw_request += "\n" # if files, use multipart/form-data with boundary if self.files: boundary = self.boundary - raw_request += f'--{boundary}' + raw_request += f"--{boundary}" for k, v in self.files.items(): raw_request += f'\nContent-Disposition: form-data; name="{k}"\n\n' - raw_request += f'{v[1]}\n' - raw_request += f'--{boundary}' - raw_request += '--' + raw_request += f"{v[1]}\n" + raw_request += f"--{boundary}" + raw_request += "--" else: - raw_request += self.body or '' + raw_request += self.body or "" return raw_request @@ -328,9 +331,9 @@ class HttpExecutor: for variable_selector in variable_selectors: variable = variable_pool.get_any(variable_selector.value_selector) if variable is None: - raise ValueError(f'Variable {variable_selector.variable} not found') + raise ValueError(f"Variable {variable_selector.variable} not found") if escape_quotes and isinstance(variable, str): - value = variable.replace('"', '\\"').replace('\n', '\\n') + value = variable.replace('"', '\\"').replace("\n", "\\n") else: value = variable variable_value_mapping[variable_selector.variable] = value diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index 1facf8a4f4..cd40819126 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -1,19 +1,16 @@ import logging +from collections.abc import Mapping, Sequence from mimetypes import guess_extension from os import path -from typing import cast +from typing import Any, cast +from configs import dify_config from core.app.segments import parser from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.tools.tool_file_manager import ToolFileManager -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.http_request.entities import ( - MAX_CONNECT_TIMEOUT, - MAX_READ_TIMEOUT, - MAX_WRITE_TIMEOUT, HttpRequestNodeData, HttpRequestNodeTimeout, ) @@ -21,9 +18,9 @@ from core.workflow.nodes.http_request.http_executor import HttpExecutor, HttpExe from models.workflow import WorkflowNodeExecutionStatus HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout( - connect=min(10, MAX_CONNECT_TIMEOUT), - read=min(60, MAX_READ_TIMEOUT), - write=min(20, MAX_WRITE_TIMEOUT), + connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, + read=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, + write=dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, ) @@ -34,33 +31,37 @@ class HttpRequestNode(BaseNode): @classmethod def get_default_config(cls, filters: dict | None = None) -> dict: return { - 'type': 'http-request', - 'config': { - 'method': 'get', - 'authorization': { - 'type': 'no-auth', + "type": "http-request", + "config": { + "method": "get", + "authorization": { + "type": "no-auth", }, - 'body': {'type': 'none'}, - 'timeout': { + "body": {"type": "none"}, + "timeout": { **HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(), - 'max_connect_timeout': MAX_CONNECT_TIMEOUT, - 'max_read_timeout': MAX_READ_TIMEOUT, - 'max_write_timeout': MAX_WRITE_TIMEOUT, + "max_connect_timeout": dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, + "max_read_timeout": dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, + "max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, }, }, } - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data) # TODO: Switch to use segment directly if node_data.authorization.config and node_data.authorization.config.api_key: - node_data.authorization.config.api_key = parser.convert_template(template=node_data.authorization.config.api_key, variable_pool=variable_pool).text + node_data.authorization.config.api_key = parser.convert_template( + template=node_data.authorization.config.api_key, variable_pool=self.graph_runtime_state.variable_pool + ).text # init http executor http_executor = None try: http_executor = HttpExecutor( - node_data=node_data, timeout=self._get_request_timeout(node_data), variable_pool=variable_pool + node_data=node_data, + timeout=self._get_request_timeout(node_data), + variable_pool=self.graph_runtime_state.variable_pool, ) # invoke http executor @@ -69,7 +70,7 @@ class HttpRequestNode(BaseNode): process_data = {} if http_executor: process_data = { - 'request': http_executor.to_raw_request(), + "request": http_executor.to_raw_request(), } return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -82,37 +83,38 @@ class HttpRequestNode(BaseNode): return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={ - 'status_code': response.status_code, - 'body': response.content if not files else '', - 'headers': response.headers, - 'files': files, + "status_code": response.status_code, + "body": response.content if not files else "", + "headers": response.headers, + "files": files, }, process_data={ - 'request': http_executor.to_raw_request(), + "request": http_executor.to_raw_request(), }, ) - def _get_request_timeout(self, node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout: + @staticmethod + def _get_request_timeout(node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout: timeout = node_data.timeout if timeout is None: return HTTP_REQUEST_DEFAULT_TIMEOUT timeout.connect = timeout.connect or HTTP_REQUEST_DEFAULT_TIMEOUT.connect - timeout.connect = min(timeout.connect, MAX_CONNECT_TIMEOUT) timeout.read = timeout.read or HTTP_REQUEST_DEFAULT_TIMEOUT.read - timeout.read = min(timeout.read, MAX_READ_TIMEOUT) timeout.write = timeout.write or HTTP_REQUEST_DEFAULT_TIMEOUT.write - timeout.write = min(timeout.write, MAX_WRITE_TIMEOUT) return timeout @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, graph_config: Mapping[str, Any], node_id: str, node_data: HttpRequestNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ - node_data = cast(HttpRequestNodeData, node_data) try: http_executor = HttpExecutor(node_data=node_data, timeout=HTTP_REQUEST_DEFAULT_TIMEOUT) @@ -120,11 +122,11 @@ class HttpRequestNode(BaseNode): variable_mapping = {} for variable_selector in variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector + variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector return variable_mapping except Exception as e: - logging.exception(f'Failed to extract variable selector to variable mapping: {e}') + logging.exception(f"Failed to extract variable selector to variable mapping: {e}") return {} def extract_files(self, url: str, response: HttpExecutorResponse) -> list[FileVar]: @@ -138,7 +140,7 @@ class HttpRequestNode(BaseNode): # extract filename from url filename = path.basename(url) # extract extension if possible - extension = guess_extension(mimetype) or '.bin' + extension = guess_extension(mimetype) or ".bin" tool_file = ToolFileManager.create_file_by_raw( user_id=self.user_id, diff --git a/api/core/workflow/nodes/if_else/entities.py b/api/core/workflow/nodes/if_else/entities.py index bc6dce0d3b..54c1081fd3 100644 --- a/api/core/workflow/nodes/if_else/entities.py +++ b/api/core/workflow/nodes/if_else/entities.py @@ -3,20 +3,7 @@ from typing import Literal, Optional from pydantic import BaseModel from core.workflow.entities.base_node_data_entities import BaseNodeData - - -class Condition(BaseModel): - """ - Condition entity - """ - variable_selector: list[str] - comparison_operator: Literal[ - # for string or array - "contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", - # for number - "=", "≠", ">", "<", "≥", "≤", "null", "not null" - ] - value: Optional[str] = None +from core.workflow.utils.condition.entities import Condition class IfElseNodeData(BaseNodeData): @@ -28,6 +15,7 @@ class IfElseNodeData(BaseNodeData): """ Case entity representing a single logical condition group """ + case_id: str logical_operator: Literal["and", "or"] conditions: list[Condition] diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index c6d235627f..5b4737c6e5 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -1,12 +1,10 @@ -from collections.abc import Sequence -from typing import Optional, cast +from collections.abc import Mapping, Sequence +from typing import Any, cast -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.if_else.entities import Condition, IfElseNodeData -from core.workflow.utils.variable_template_parser import VariableTemplateParser +from core.workflow.nodes.if_else.entities import IfElseNodeData +from core.workflow.utils.condition.processor import ConditionProcessor from models.workflow import WorkflowNodeExecutionStatus @@ -14,31 +12,30 @@ class IfElseNode(BaseNode): _node_data_cls = IfElseNodeData _node_type = NodeType.IF_ELSE - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run node - :param variable_pool: variable pool :return: """ node_data = self.node_data node_data = cast(IfElseNodeData, node_data) - node_inputs = { - "conditions": [] - } + node_inputs: dict[str, list] = {"conditions": []} - process_datas = { - "condition_results": [] - } + process_datas: dict[str, list] = {"condition_results": []} input_conditions = [] final_result = False selected_case_id = None + condition_processor = ConditionProcessor() try: # Check if the new cases structure is used if node_data.cases: for case in node_data.cases: - input_conditions, group_result = self.process_conditions(variable_pool, case.conditions) + input_conditions, group_result = condition_processor.process_conditions( + variable_pool=self.graph_runtime_state.variable_pool, conditions=case.conditions + ) + # Apply the logical operator for the current case final_result = all(group_result) if case.logical_operator == "and" else any(group_result) @@ -57,28 +54,23 @@ class IfElseNode(BaseNode): else: # Fallback to old structure if cases are not defined - input_conditions, group_result = self.process_conditions(variable_pool, node_data.conditions) + input_conditions, group_result = condition_processor.process_conditions( + variable_pool=self.graph_runtime_state.variable_pool, conditions=node_data.conditions + ) final_result = all(group_result) if node_data.logical_operator == "and" else any(group_result) selected_case_id = "true" if final_result else "false" process_datas["condition_results"].append( - { - "group": "default", - "results": group_result, - "final_result": final_result - } + {"group": "default", "results": group_result, "final_result": final_result} ) node_inputs["conditions"] = input_conditions except Exception as e: return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=node_inputs, - process_data=process_datas, - error=str(e) + status=WorkflowNodeExecutionStatus.FAILED, inputs=node_inputs, process_data=process_datas, error=str(e) ) outputs = {"result": final_result, "selected_case_id": selected_case_id} @@ -88,369 +80,19 @@ class IfElseNode(BaseNode): inputs=node_inputs, process_data=process_datas, edge_source_handle=selected_case_id if selected_case_id else "false", # Use case ID or 'default' - outputs=outputs + outputs=outputs, ) return data - def evaluate_condition( - self, actual_value: Optional[str | list], expected_value: str, comparison_operator: str - ) -> bool: - """ - Evaluate condition - :param actual_value: actual value - :param expected_value: expected value - :param comparison_operator: comparison operator - - :return: bool - """ - if comparison_operator == "contains": - return self._assert_contains(actual_value, expected_value) - elif comparison_operator == "not contains": - return self._assert_not_contains(actual_value, expected_value) - elif comparison_operator == "start with": - return self._assert_start_with(actual_value, expected_value) - elif comparison_operator == "end with": - return self._assert_end_with(actual_value, expected_value) - elif comparison_operator == "is": - return self._assert_is(actual_value, expected_value) - elif comparison_operator == "is not": - return self._assert_is_not(actual_value, expected_value) - elif comparison_operator == "empty": - return self._assert_empty(actual_value) - elif comparison_operator == "not empty": - return self._assert_not_empty(actual_value) - elif comparison_operator == "=": - return self._assert_equal(actual_value, expected_value) - elif comparison_operator == "≠": - return self._assert_not_equal(actual_value, expected_value) - elif comparison_operator == ">": - return self._assert_greater_than(actual_value, expected_value) - elif comparison_operator == "<": - return self._assert_less_than(actual_value, expected_value) - elif comparison_operator == "≥": - return self._assert_greater_than_or_equal(actual_value, expected_value) - elif comparison_operator == "≤": - return self._assert_less_than_or_equal(actual_value, expected_value) - elif comparison_operator == "null": - return self._assert_null(actual_value) - elif comparison_operator == "not null": - return self._assert_not_null(actual_value) - else: - raise ValueError(f"Invalid comparison operator: {comparison_operator}") - - def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]): - input_conditions = [] - group_result = [] - - for condition in conditions: - actual_variable = variable_pool.get_any(condition.variable_selector) - - if condition.value is not None: - variable_template_parser = VariableTemplateParser(template=condition.value) - expected_value = variable_template_parser.extract_variable_selectors() - variable_selectors = variable_template_parser.extract_variable_selectors() - if variable_selectors: - for variable_selector in variable_selectors: - value = variable_pool.get_any(variable_selector.value_selector) - expected_value = variable_template_parser.format({variable_selector.variable: value}) - else: - expected_value = condition.value - else: - expected_value = None - - comparison_operator = condition.comparison_operator - input_conditions.append( - { - "actual_value": actual_variable, - "expected_value": expected_value, - "comparison_operator": comparison_operator - } - ) - - result = self.evaluate_condition(actual_variable, expected_value, comparison_operator) - group_result.append(result) - - return input_conditions, group_result - - def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: - """ - Assert contains - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return False - - if not isinstance(actual_value, str | list): - raise ValueError('Invalid actual value type: string or array') - - if expected_value not in actual_value: - return False - return True - - def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: - """ - Assert not contains - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return True - - if not isinstance(actual_value, str | list): - raise ValueError('Invalid actual value type: string or array') - - if expected_value in actual_value: - return False - return True - - def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert start with - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return False - - if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') - - if not actual_value.startswith(expected_value): - return False - return True - - def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert end with - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return False - - if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') - - if not actual_value.endswith(expected_value): - return False - return True - - def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert is - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') - - if actual_value != expected_value: - return False - return True - - def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert is not - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') - - if actual_value == expected_value: - return False - return True - - def _assert_empty(self, actual_value: Optional[str]) -> bool: - """ - Assert empty - :param actual_value: actual value - :return: - """ - if not actual_value: - return True - return False - - def _assert_not_empty(self, actual_value: Optional[str]) -> bool: - """ - Assert not empty - :param actual_value: actual value - :return: - """ - if actual_value: - return True - return False - - def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value != expected_value: - return False - return True - - def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert not equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value == expected_value: - return False - return True - - def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert greater than - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value <= expected_value: - return False - return True - - def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert less than - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value >= expected_value: - return False - return True - - def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert greater than or equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value < expected_value: - return False - return True - - def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert less than or equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value > expected_value: - return False - return True - - def _assert_null(self, actual_value: Optional[int | float]) -> bool: - """ - Assert null - :param actual_value: actual value - :return: - """ - if actual_value is None: - return True - return False - - def _assert_not_null(self, actual_value: Optional[int | float]) -> bool: - """ - Assert not null - :param actual_value: actual value - :return: - """ - if actual_value is not None: - return True - return False - @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, graph_config: Mapping[str, Any], node_id: str, node_data: IfElseNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ diff --git a/api/core/workflow/nodes/iteration/entities.py b/api/core/workflow/nodes/iteration/entities.py index 177b47b951..3c2c189159 100644 --- a/api/core/workflow/nodes/iteration/entities.py +++ b/api/core/workflow/nodes/iteration/entities.py @@ -1,20 +1,31 @@ from typing import Any, Optional -from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState +from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState, BaseNodeData class IterationNodeData(BaseIterationNodeData): """ Iteration Node Data. """ - parent_loop_id: Optional[str] = None # redundant field, not used currently - iterator_selector: list[str] # variable selector - output_selector: list[str] # output selector + + parent_loop_id: Optional[str] = None # redundant field, not used currently + iterator_selector: list[str] # variable selector + output_selector: list[str] # output selector + + +class IterationStartNodeData(BaseNodeData): + """ + Iteration Start Node Data. + """ + + pass + class IterationState(BaseIterationState): """ Iteration State. """ + outputs: list[Any] = None current_output: Optional[Any] = None @@ -22,6 +33,7 @@ class IterationState(BaseIterationState): """ Data. """ + iterator_length: int def get_last_output(self) -> Optional[Any]: @@ -31,9 +43,9 @@ class IterationState(BaseIterationState): if self.outputs: return self.outputs[-1] return None - + def get_current_output(self) -> Optional[Any]: """ Get current output. """ - return self.current_output \ No newline at end of file + return self.current_output diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 54dfe8b7f4..77b14e36a1 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -1,124 +1,335 @@ -from typing import cast +import logging +from collections.abc import Generator, Mapping, Sequence +from datetime import datetime, timezone +from typing import Any, cast +from configs import dify_config from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities.base_node_data_entities import BaseIterationState -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseIterationNode -from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType +from core.workflow.graph_engine.entities.event import ( + BaseGraphEvent, + BaseNodeEvent, + BaseParallelBranchEvent, + GraphRunFailedEvent, + InNodeEvent, + IterationRunFailedEvent, + IterationRunNextEvent, + IterationRunStartedEvent, + IterationRunSucceededEvent, + NodeRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.run_condition import RunCondition +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.event import RunCompletedEvent, RunEvent +from core.workflow.nodes.iteration.entities import IterationNodeData +from core.workflow.utils.condition.entities import Condition from models.workflow import WorkflowNodeExecutionStatus +logger = logging.getLogger(__name__) -class IterationNode(BaseIterationNode): + +class IterationNode(BaseNode): """ Iteration Node. """ + _node_data_cls = IterationNodeData _node_type = NodeType.ITERATION - def _run(self, variable_pool: VariablePool) -> BaseIterationState: + def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: """ Run the node. """ self.node_data = cast(IterationNodeData, self.node_data) - iterator = variable_pool.get_any(self.node_data.iterator_selector) + iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) - if not isinstance(iterator, list): - raise ValueError(f"Invalid iterator value: {iterator}, please provide a list.") + if not iterator_list_segment: + raise ValueError(f"Iterator variable {self.node_data.iterator_selector} not found") - state = IterationState(iteration_node_id=self.node_id, index=-1, inputs={ - 'iterator_selector': iterator - }, outputs=[], metadata=IterationState.MetaData( - iterator_length=len(iterator) if iterator is not None else 0 - )) - - self._set_current_iteration_variable(variable_pool, state) - return state + iterator_list_value = iterator_list_segment.to_object() - def _get_next_iteration(self, variable_pool: VariablePool, state: IterationState) -> NodeRunResult | str: - """ - Get next iteration start node id based on the graph. - :param graph: graph - :return: next node id - """ - # resolve current output - self._resolve_current_output(variable_pool, state) - # move to next iteration - self._next_iteration(variable_pool, state) + if not isinstance(iterator_list_value, list): + raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") - node_data = cast(IterationNodeData, self.node_data) - if self._reached_iteration_limit(variable_pool, state): - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - 'output': jsonable_encoder(state.outputs) - } + inputs = {"iterator_selector": iterator_list_value} + + graph_config = self.graph_config + + if not self.node_data.start_node_id: + raise ValueError(f"field start_node_id in iteration {self.node_id} not found") + + root_node_id = self.node_data.start_node_id + + # init graph + iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id) + + if not iteration_graph: + raise ValueError("iteration graph not found") + + leaf_node_ids = iteration_graph.get_leaf_node_ids() + iteration_leaf_node_ids = [] + for leaf_node_id in leaf_node_ids: + node_config = iteration_graph.node_id_config_mapping.get(leaf_node_id) + if not node_config: + continue + + leaf_node_iteration_id = node_config.get("data", {}).get("iteration_id") + if not leaf_node_iteration_id: + continue + + if leaf_node_iteration_id != self.node_id: + continue + + iteration_leaf_node_ids.append(leaf_node_id) + + # add condition of end nodes to root node + iteration_graph.add_extra_edge( + source_node_id=leaf_node_id, + target_node_id=root_node_id, + run_condition=RunCondition( + type="condition", + conditions=[ + Condition( + variable_selector=[self.node_id, "index"], + comparison_operator="<", + value=str(len(iterator_list_value)), + ) + ], + ), ) - - return node_data.start_node_id - - def _set_current_iteration_variable(self, variable_pool: VariablePool, state: IterationState): - """ - Set current iteration variable. - :variable_pool: variable pool - """ - node_data = cast(IterationNodeData, self.node_data) - variable_pool.add((self.node_id, 'index'), state.index) - # get the iterator value - iterator = variable_pool.get_any(node_data.iterator_selector) + variable_pool = self.graph_runtime_state.variable_pool - if iterator is None or not isinstance(iterator, list): - return - - if state.index < len(iterator): - variable_pool.add((self.node_id, 'item'), iterator[state.index]) + # append iteration variable (item, index) to variable pool + variable_pool.add([self.node_id, "index"], 0) + variable_pool.add([self.node_id, "item"], iterator_list_value[0]) - def _next_iteration(self, variable_pool: VariablePool, state: IterationState): - """ - Move to next iteration. - :param variable_pool: variable pool - """ - state.index += 1 - self._set_current_iteration_variable(variable_pool, state) + # init graph engine + from core.workflow.graph_engine.graph_engine import GraphEngine - def _reached_iteration_limit(self, variable_pool: VariablePool, state: IterationState): - """ - Check if iteration limit is reached. - :return: True if iteration limit is reached, False otherwise - """ - node_data = cast(IterationNodeData, self.node_data) - iterator = variable_pool.get_any(node_data.iterator_selector) + graph_engine = GraphEngine( + tenant_id=self.tenant_id, + app_id=self.app_id, + workflow_type=self.workflow_type, + workflow_id=self.workflow_id, + user_id=self.user_id, + user_from=self.user_from, + invoke_from=self.invoke_from, + call_depth=self.workflow_call_depth, + graph=iteration_graph, + graph_config=graph_config, + variable_pool=variable_pool, + max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, + max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, + ) - if iterator is None or not isinstance(iterator, list): - return True + start_at = datetime.now(timezone.utc).replace(tzinfo=None) - return state.index >= len(iterator) - - def _resolve_current_output(self, variable_pool: VariablePool, state: IterationState): - """ - Resolve current output. - :param variable_pool: variable pool - """ - output_selector = cast(IterationNodeData, self.node_data).output_selector - output = variable_pool.get_any(output_selector) - # clear the output for this iteration - variable_pool.remove([self.node_id] + output_selector[1:]) - state.current_output = output - if output is not None: - # NOTE: This is a temporary patch to process double nested list (for example, DALL-E output in iteration). - if isinstance(output, list): - state.outputs.extend(output) - else: - state.outputs.append(output) + yield IterationRunStartedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + metadata={"iterator_length": len(iterator_list_value)}, + predecessor_node_id=self.previous_node_id, + ) + + yield IterationRunNextEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + index=0, + pre_iteration_output=None, + ) + + outputs: list[Any] = [] + try: + # run workflow + rst = graph_engine.run() + for event in rst: + if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: + event.in_iteration_id = self.node_id + + if isinstance(event, BaseNodeEvent) and event.node_type == NodeType.ITERATION_START: + continue + + if isinstance(event, NodeRunSucceededEvent): + if event.route_node_state.node_run_result: + metadata = event.route_node_state.node_run_result.metadata + if not metadata: + metadata = {} + + if NodeRunMetadataKey.ITERATION_ID not in metadata: + metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id + metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any( + [self.node_id, "index"] + ) + event.route_node_state.node_run_result.metadata = metadata + + yield event + + # handle iteration run result + if event.route_node_state.node_id in iteration_leaf_node_ids: + # append to iteration output variable list + current_iteration_output = variable_pool.get_any(self.node_data.output_selector) + outputs.append(current_iteration_output) + + # remove all nodes outputs from variable pool + for node_id in iteration_graph.node_ids: + variable_pool.remove_node(node_id) + + # move to next iteration + current_index = variable_pool.get([self.node_id, "index"]) + if current_index is None: + raise ValueError(f"iteration {self.node_id} current index not found") + + next_index = int(current_index.to_object()) + 1 + variable_pool.add([self.node_id, "index"], next_index) + + if next_index < len(iterator_list_value): + variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) + + yield IterationRunNextEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + index=next_index, + pre_iteration_output=jsonable_encoder(current_iteration_output) + if current_iteration_output + else None, + ) + elif isinstance(event, BaseGraphEvent): + if isinstance(event, GraphRunFailedEvent): + # iteration run failed + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": jsonable_encoder(outputs)}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=event.error, + ) + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=event.error, + ) + ) + break + else: + event = cast(InNodeEvent, event) + yield event + + yield IterationRunSucceededEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": jsonable_encoder(outputs)}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + ) + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"output": jsonable_encoder(outputs)} + ) + ) + except Exception as e: + # iteration run failed + logger.exception("Iteration run failed") + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": jsonable_encoder(outputs)}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=str(e), + ) + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + ) + ) + finally: + # remove iteration variable (item, index) from variable pool after iteration run completed + variable_pool.remove([self.node_id, "index"]) + variable_pool.remove([self.node_id, "item"]) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ - return { - 'input_selector': node_data.iterator_selector, - } \ No newline at end of file + variable_mapping = { + f"{node_id}.input_selector": node_data.iterator_selector, + } + + # init graph + iteration_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id) + + if not iteration_graph: + raise ValueError("iteration graph not found") + + for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items(): + if sub_node_config.get("data", {}).get("iteration_id") != node_id: + continue + + # variable selector to variable mapping + try: + # Get node class + from core.workflow.nodes.node_mapping import node_classes + + node_type = NodeType.value_of(sub_node_config.get("data", {}).get("type")) + node_cls = node_classes.get(node_type) + if not node_cls: + continue + + node_cls = cast(BaseNode, node_cls) + + sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( + graph_config=graph_config, config=sub_node_config + ) + sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping) + except NotImplementedError: + sub_node_variable_mapping = {} + + # remove iteration variables + sub_node_variable_mapping = { + sub_node_id + "." + key: value + for key, value in sub_node_variable_mapping.items() + if value[0] != node_id + } + + variable_mapping.update(sub_node_variable_mapping) + + # remove variable out from iteration + variable_mapping = { + key: value for key, value in variable_mapping.items() if value[0] not in iteration_graph.node_ids + } + + return variable_mapping diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py new file mode 100644 index 0000000000..88b9665ac6 --- /dev/null +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -0,0 +1,35 @@ +from collections.abc import Mapping, Sequence +from typing import Any + +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.iteration.entities import IterationNodeData, IterationStartNodeData +from models.workflow import WorkflowNodeExecutionStatus + + +class IterationStartNode(BaseNode): + """ + Iteration Start Node. + """ + + _node_data_cls = IterationStartNodeData + _node_type = NodeType.ITERATION_START + + def _run(self) -> NodeRunResult: + """ + Run the node. + """ + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED) + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ + return {} diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 7cf392277c..1cd88039b1 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -9,6 +9,7 @@ class RerankingModelConfig(BaseModel): """ Reranking Model Config. """ + provider: str model: str @@ -17,6 +18,7 @@ class VectorSetting(BaseModel): """ Vector Setting. """ + vector_weight: float embedding_provider_name: str embedding_model_name: str @@ -26,6 +28,7 @@ class KeywordSetting(BaseModel): """ Keyword Setting. """ + keyword_weight: float @@ -33,6 +36,7 @@ class WeightedScoreConfig(BaseModel): """ Weighted score Config. """ + vector_setting: VectorSetting keyword_setting: KeywordSetting @@ -41,17 +45,20 @@ class MultipleRetrievalConfig(BaseModel): """ Multiple Retrieval Config. """ + top_k: int score_threshold: Optional[float] = None - reranking_mode: str = 'reranking_model' + reranking_mode: str = "reranking_model" reranking_enable: bool = True reranking_model: Optional[RerankingModelConfig] = None weights: Optional[WeightedScoreConfig] = None + class ModelConfig(BaseModel): """ - Model Config. + Model Config. """ + provider: str name: str mode: str @@ -62,6 +69,7 @@ class SingleRetrievalConfig(BaseModel): """ Single Retrieval Config. """ + model: ModelConfig @@ -69,9 +77,10 @@ class KnowledgeRetrievalNodeData(BaseNodeData): """ Knowledge retrieval Node Data. """ - type: str = 'knowledge-retrieval' + + type: str = "knowledge-retrieval" query_variable_selector: list[str] dataset_ids: list[str] - retrieval_mode: Literal['single', 'multiple'] + retrieval_mode: Literal["single", "multiple"] multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None single_retrieval_config: Optional[SingleRetrievalConfig] = None diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 01bf6e16e6..53e8be6415 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -1,3 +1,5 @@ +import logging +from collections.abc import Mapping, Sequence from typing import Any, cast from sqlalchemy import func @@ -11,25 +13,22 @@ from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from core.rag.retrieval.retrival_methods import RetrievalMethod -from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment from models.workflow import WorkflowNodeExecutionStatus +logger = logging.getLogger(__name__) + default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } @@ -37,68 +36,53 @@ class KnowledgeRetrievalNode(BaseNode): _node_data_cls = KnowledgeRetrievalNodeData node_type = NodeType.KNOWLEDGE_RETRIEVAL - def _run(self, variable_pool: VariablePool) -> NodeRunResult: - node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data) + def _run(self) -> NodeRunResult: + node_data = cast(KnowledgeRetrievalNodeData, self.node_data) # extract variables - variable = variable_pool.get_any(node_data.query_variable_selector) + variable = self.graph_runtime_state.variable_pool.get_any(node_data.query_variable_selector) query = variable - variables = { - 'query': query - } + variables = {"query": query} if not query: return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=variables, - error="Query is required." + status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required." ) # retrieve knowledge try: - results = self._fetch_dataset_retriever( - node_data=node_data, query=query - ) - outputs = { - 'result': results - } + results = self._fetch_dataset_retriever(node_data=node_data, query=query) + outputs = {"result": results} return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=variables, - process_data=None, - outputs=outputs + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs ) except Exception as e: + logger.exception("Error when running knowledge retrieval node") + return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=variables, - error=str(e) - ) - - def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[ - dict[str, Any]]: + def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]: available_datasets = [] dataset_ids = node_data.dataset_ids # Subquery: Count the number of available documents for each dataset - subquery = db.session.query( - Document.dataset_id, - func.count(Document.id).label('available_document_count') - ).filter( - Document.indexing_status == 'completed', - Document.enabled == True, - Document.archived == False, - Document.dataset_id.in_(dataset_ids) - ).group_by(Document.dataset_id).having( - func.count(Document.id) > 0 - ).subquery() + subquery = ( + db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count")) + .filter( + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + Document.dataset_id.in_(dataset_ids), + ) + .group_by(Document.dataset_id) + .having(func.count(Document.id) > 0) + .subquery() + ) - results = db.session.query(Dataset).join( - subquery, Dataset.id == subquery.c.dataset_id - ).filter( - Dataset.tenant_id == self.tenant_id, - Dataset.id.in_(dataset_ids) - ).all() + results = ( + db.session.query(Dataset) + .join(subquery, Dataset.id == subquery.c.dataset_id) + .filter(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids)) + .all() + ) for dataset in results: # pass if dataset is not available @@ -115,16 +99,14 @@ class KnowledgeRetrievalNode(BaseNode): model_type_instance = cast(LargeLanguageModel, model_type_instance) # get model schema model_schema = model_type_instance.get_model_schema( - model=model_config.model, - credentials=model_config.credentials + model=model_config.model, credentials=model_config.credentials ) if model_schema: planning_strategy = PlanningStrategy.REACT_ROUTER features = model_schema.features if features: - if ModelFeature.TOOL_CALL in features \ - or ModelFeature.MULTI_TOOL_CALL in features: + if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features: planning_strategy = PlanningStrategy.ROUTER all_documents = dataset_retrieval.single_retrieve( available_datasets=available_datasets, @@ -135,110 +117,122 @@ class KnowledgeRetrievalNode(BaseNode): query=query, model_config=model_config, model_instance=model_instance, - planning_strategy=planning_strategy + planning_strategy=planning_strategy, ) elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: - if node_data.multiple_retrieval_config.reranking_mode == 'reranking_model': + if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": reranking_model = { - 'reranking_provider_name': node_data.multiple_retrieval_config.reranking_model.provider, - 'reranking_model_name': node_data.multiple_retrieval_config.reranking_model.model + "reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider, + "reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model, } weights = None - elif node_data.multiple_retrieval_config.reranking_mode == 'weighted_score': + elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score": reranking_model = None weights = { - 'vector_setting': { + "vector_setting": { "vector_weight": node_data.multiple_retrieval_config.weights.vector_setting.vector_weight, "embedding_provider_name": node_data.multiple_retrieval_config.weights.vector_setting.embedding_provider_name, "embedding_model_name": node_data.multiple_retrieval_config.weights.vector_setting.embedding_model_name, }, - 'keyword_setting': { + "keyword_setting": { "keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight - } + }, } else: reranking_model = None weights = None - all_documents = dataset_retrieval.multiple_retrieve(self.app_id, self.tenant_id, self.user_id, - self.user_from.value, - available_datasets, query, - node_data.multiple_retrieval_config.top_k, - node_data.multiple_retrieval_config.score_threshold, - node_data.multiple_retrieval_config.reranking_mode, - reranking_model, - weights, - node_data.multiple_retrieval_config.reranking_enable, - ) + all_documents = dataset_retrieval.multiple_retrieve( + self.app_id, + self.tenant_id, + self.user_id, + self.user_from.value, + available_datasets, + query, + node_data.multiple_retrieval_config.top_k, + node_data.multiple_retrieval_config.score_threshold, + node_data.multiple_retrieval_config.reranking_mode, + reranking_model, + weights, + node_data.multiple_retrieval_config.reranking_enable, + ) context_list = [] if all_documents: document_score_list = {} + page_number_list = {} for item in all_documents: - if item.metadata.get('score'): - document_score_list[item.metadata['doc_id']] = item.metadata['score'] + if item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] - index_node_ids = [document.metadata['doc_id'] for document in all_documents] + index_node_ids = [document.metadata["doc_id"] for document in all_documents] segments = DocumentSegment.query.filter( DocumentSegment.dataset_id.in_(dataset_ids), DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == 'completed', + DocumentSegment.status == "completed", DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids) + DocumentSegment.index_node_id.in_(index_node_ids), ).all() if segments: index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted(segments, - key=lambda segment: index_node_id_to_position.get(segment.index_node_id, - float('inf'))) + sorted_segments = sorted( + segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) + ) for segment in sorted_segments: - dataset = Dataset.query.filter_by( - id=segment.dataset_id + dataset = Dataset.query.filter_by(id=segment.dataset_id).first() + document = Document.query.filter( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, ).first() - document = Document.query.filter(Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ).first() + resource_number = 1 if dataset and document: - source = { - 'metadata': { - '_source': 'knowledge', - 'position': resource_number, - 'dataset_id': dataset.id, - 'dataset_name': dataset.name, - 'document_id': document.id, - 'document_name': document.name, - 'document_data_source_type': document.data_source_type, - 'segment_id': segment.id, - 'retriever_from': 'workflow', - 'score': document_score_list.get(segment.index_node_id, None), - 'segment_hit_count': segment.hit_count, - 'segment_word_count': segment.word_count, - 'segment_position': segment.position, - 'segment_index_node_hash': segment.index_node_hash, + "metadata": { + "_source": "knowledge", + "position": resource_number, + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, + "document_name": document.name, + "document_data_source_type": document.data_source_type, + "segment_id": segment.id, + "retriever_from": "workflow", + "score": document_score_list.get(segment.index_node_id, None), + "segment_hit_count": segment.hit_count, + "segment_word_count": segment.word_count, + "segment_position": segment.position, + "segment_index_node_hash": segment.index_node_hash, }, - 'title': document.name + "title": document.name, } if segment.answer: - source['content'] = f'question:{segment.get_sign_content()} \nanswer:{segment.answer}' + source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}" else: - source['content'] = segment.get_sign_content() + source["content"] = segment.get_sign_content() context_list.append(source) resource_number += 1 return context_list @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: - node_data = node_data - node_data = cast(cls._node_data_cls, node_data) + def _extract_variable_selector_to_variable_mapping( + cls, graph_config: Mapping[str, Any], node_id: str, node_data: KnowledgeRetrievalNodeData + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ variable_mapping = {} - variable_mapping['query'] = node_data.query_variable_selector + variable_mapping[node_id + ".query"] = node_data.query_variable_selector return variable_mapping - def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ - ModelInstance, ModelConfigWithCredentialsEntity]: + def _fetch_model_config( + self, node_data: KnowledgeRetrievalNodeData + ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config :param node_data: node data @@ -249,10 +243,7 @@ class KnowledgeRetrievalNode(BaseNode): model_manager = ModelManager() model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, - model_type=ModelType.LLM, - provider=provider_name, - model=model_name + tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name ) provider_model_bundle = model_instance.provider_model_bundle @@ -263,8 +254,7 @@ class KnowledgeRetrievalNode(BaseNode): # check model provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_name, - model_type=ModelType.LLM + model=model_name, model_type=ModelType.LLM ) if provider_model is None: @@ -280,19 +270,16 @@ class KnowledgeRetrievalNode(BaseNode): # model config completion_params = node_data.single_retrieval_config.model.completion_params stop = [] - if 'stop' in completion_params: - stop = completion_params['stop'] - del completion_params['stop'] + if "stop" in completion_params: + stop = completion_params["stop"] + del completion_params["stop"] # get model mode model_mode = node_data.single_retrieval_config.model.mode if not model_mode: raise ValueError("LLM mode is required.") - model_schema = model_type_instance.get_model_schema( - model_name, - model_credentials - ) + model_schema = model_type_instance.get_model_schema(model_name, model_credentials) if not model_schema: raise ValueError(f"Model {model_name} not exist.") diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 1e48a10bc7..93ee0ac250 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -11,6 +11,7 @@ class ModelConfig(BaseModel): """ Model Config. """ + provider: str name: str mode: str @@ -21,6 +22,7 @@ class ContextConfig(BaseModel): """ Context Config. """ + enabled: bool variable_selector: Optional[list[str]] = None @@ -29,37 +31,47 @@ class VisionConfig(BaseModel): """ Vision Config. """ + class Configs(BaseModel): """ Configs. """ - detail: Literal['low', 'high'] + + detail: Literal["low", "high"] enabled: bool configs: Optional[Configs] = None + class PromptConfig(BaseModel): """ Prompt Config. """ + jinja2_variables: Optional[list[VariableSelector]] = None + class LLMNodeChatModelMessage(ChatModelMessage): """ LLM Node Chat Model Message. """ + jinja2_text: Optional[str] = None + class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): """ LLM Node Chat Model Prompt Template. """ + jinja2_text: Optional[str] = None + class LLMNodeData(BaseNodeData): """ LLM Node Data. """ + model: ModelConfig prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate] prompt_config: Optional[PromptConfig] = None diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index eb8921b526..049c211488 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -1,16 +1,17 @@ import json -from collections.abc import Generator +from collections.abc import Generator, Mapping, Sequence from copy import deepcopy -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast + +from pydantic import BaseModel from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.entities.model_entities import ModelStatus from core.entities.provider_entities import QuotaUnit from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.message_entities import ( ImagePromptMessageContent, PromptMessage, @@ -25,7 +26,9 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.llm.entities import ( LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, @@ -42,18 +45,27 @@ if TYPE_CHECKING: from core.file.file_obj import FileVar +class ModelInvokeCompleted(BaseModel): + """ + Model invoke completed + """ + + text: str + usage: LLMUsage + finish_reason: Optional[str] = None + class LLMNode(BaseNode): _node_data_cls = LLMNodeData _node_type = NodeType.LLM - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: """ Run node - :param variable_pool: variable pool :return: """ node_data = cast(LLMNodeData, deepcopy(self.node_data)) + variable_pool = self.graph_runtime_state.variable_pool node_inputs = None process_data = None @@ -77,13 +89,18 @@ class LLMNode(BaseNode): files = self._fetch_files(node_data, variable_pool) if files: - node_inputs['#files#'] = [file.to_dict() for file in files] + node_inputs["#files#"] = [file.to_dict() for file in files] # fetch context value - context = self._fetch_context(node_data, variable_pool) + generator = self._fetch_context(node_data, variable_pool) + context = None + for event in generator: + if isinstance(event, RunRetrieverResourceEvent): + context = event.context + yield event if context: - node_inputs['#context#'] = context + node_inputs["#context#"] = context # type: ignore # fetch model config model_instance, model_config = self._fetch_model_config(node_data.model) @@ -94,61 +111,78 @@ class LLMNode(BaseNode): # fetch prompt messages prompt_messages, stop = self._fetch_prompt_messages( node_data=node_data, - query=variable_pool.get_any(['sys', SystemVariableKey.QUERY.value]) - if node_data.memory else None, + query=variable_pool.get_any(["sys", SystemVariableKey.QUERY.value]) if node_data.memory else None, query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None, inputs=inputs, files=files, context=context, memory=memory, - model_config=model_config + model_config=model_config, ) process_data = { - 'model_mode': model_config.mode, - 'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=model_config.mode, - prompt_messages=prompt_messages - ) + "model_mode": model_config.mode, + "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_config.mode, prompt_messages=prompt_messages + ), + "model_provider": model_config.provider, + "model_name": model_config.model, } # handle invoke result - result_text, usage, finish_reason = self._invoke_llm( + generator = self._invoke_llm( node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, - stop=stop + stop=stop, ) + + result_text = "" + usage = LLMUsage.empty_usage() + finish_reason = None + for event in generator: + if isinstance(event, RunStreamChunkEvent): + yield event + elif isinstance(event, ModelInvokeCompleted): + result_text = event.text + usage = event.usage + finish_reason = event.finish_reason + break except Exception as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - inputs=node_inputs, - process_data=process_data + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + inputs=node_inputs, + process_data=process_data, + ) ) + return - outputs = { - 'text': result_text, - 'usage': jsonable_encoder(usage), - 'finish_reason': finish_reason - } + outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=node_inputs, - process_data=process_data, - outputs=outputs, - metadata={ - NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, - NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency - } + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=node_inputs, + process_data=process_data, + outputs=outputs, + metadata={ + NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, + NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, + NodeRunMetadataKey.CURRENCY: usage.currency, + }, + llm_usage=usage, + ) ) - def _invoke_llm(self, node_data_model: ModelConfig, - model_instance: ModelInstance, - prompt_messages: list[PromptMessage], - stop: list[str]) -> tuple[str, LLMUsage]: + def _invoke_llm( + self, + node_data_model: ModelConfig, + model_instance: ModelInstance, + prompt_messages: list[PromptMessage], + stop: Optional[list[str]] = None, + ) -> Generator[RunEvent | ModelInvokeCompleted, None, None]: """ Invoke large language model :param node_data_model: node data model @@ -168,31 +202,38 @@ class LLMNode(BaseNode): ) # handle invoke result - text, usage, finish_reason = self._handle_invoke_result( - invoke_result=invoke_result - ) + generator = self._handle_invoke_result(invoke_result=invoke_result) + + usage = LLMUsage.empty_usage() + for event in generator: + yield event + if isinstance(event, ModelInvokeCompleted): + usage = event.usage # deduct quota self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) - return text, usage, finish_reason - - def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]: + def _handle_invoke_result( + self, invoke_result: LLMResult | Generator + ) -> Generator[RunEvent | ModelInvokeCompleted, None, None]: """ Handle invoke result :param invoke_result: invoke result :return: """ + if isinstance(invoke_result, LLMResult): + return + model = None - prompt_messages = [] - full_text = '' + prompt_messages: list[PromptMessage] = [] + full_text = "" usage = None finish_reason = None for result in invoke_result: text = result.delta.message.content full_text += text - self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text']) + yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"]) if not model: model = result.model @@ -209,10 +250,10 @@ class LLMNode(BaseNode): if not usage: usage = LLMUsage.empty_usage() - return full_text, usage, finish_reason + yield ModelInvokeCompleted(text=full_text, usage=usage, finish_reason=finish_reason) - def _transform_chat_messages(self, - messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate + def _transform_chat_messages( + self, messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: """ Transform chat messages @@ -222,13 +263,13 @@ class LLMNode(BaseNode): """ if isinstance(messages, LLMNodeCompletionModelPromptTemplate): - if messages.edition_type == 'jinja2': + if messages.edition_type == "jinja2" and messages.jinja2_text: messages.text = messages.jinja2_text return messages for message in messages: - if message.edition_type == 'jinja2': + if message.edition_type == "jinja2" and message.jinja2_text: message.text = message.jinja2_text return messages @@ -247,17 +288,15 @@ class LLMNode(BaseNode): for variable_selector in node_data.prompt_config.jinja2_variables or []: variable = variable_selector.variable - value = variable_pool.get_any( - variable_selector.value_selector - ) + value = variable_pool.get_any(variable_selector.value_selector) def parse_dict(d: dict) -> str: """ Parse dict into string """ # check if it's a context structure - if 'metadata' in d and '_source' in d['metadata'] and 'content' in d: - return d['content'] + if "metadata" in d and "_source" in d["metadata"] and "content" in d: + return d["content"] # else, parse the dict try: @@ -268,7 +307,7 @@ class LLMNode(BaseNode): if isinstance(value, str): value = value elif isinstance(value, list): - result = '' + result = "" for item in value: if isinstance(item, dict): result += parse_dict(item) @@ -278,7 +317,7 @@ class LLMNode(BaseNode): result += str(item) else: result += str(item) - result += '\n' + result += "\n" value = result.strip() elif isinstance(value, dict): value = parse_dict(value) @@ -313,18 +352,19 @@ class LLMNode(BaseNode): for variable_selector in variable_selectors: variable_value = variable_pool.get_any(variable_selector.value_selector) if variable_value is None: - raise ValueError(f'Variable {variable_selector.variable} not found') + raise ValueError(f"Variable {variable_selector.variable} not found") inputs[variable_selector.variable] = variable_value memory = node_data.memory if memory and memory.query_prompt_template: - query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template) - .extract_variable_selectors()) + query_variable_selectors = VariableTemplateParser( + template=memory.query_prompt_template + ).extract_variable_selectors() for variable_selector in query_variable_selectors: variable_value = variable_pool.get_any(variable_selector.value_selector) if variable_value is None: - raise ValueError(f'Variable {variable_selector.variable} not found') + raise ValueError(f"Variable {variable_selector.variable} not found") inputs[variable_selector.variable] = variable_value @@ -340,13 +380,13 @@ class LLMNode(BaseNode): if not node_data.vision.enabled: return [] - files = variable_pool.get_any(['sys', SystemVariableKey.FILES.value]) + files = variable_pool.get_any(["sys", SystemVariableKey.FILES.value]) if not files: return [] return files - def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Optional[str]: + def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Generator[RunEvent, None, None]: """ Fetch context :param node_data: node data @@ -354,42 +394,34 @@ class LLMNode(BaseNode): :return: """ if not node_data.context.enabled: - return None + return if not node_data.context.variable_selector: - return None + return context_value = variable_pool.get_any(node_data.context.variable_selector) if context_value: if isinstance(context_value, str): - return context_value + yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value) elif isinstance(context_value, list): - context_str = '' + context_str = "" original_retriever_resource = [] for item in context_value: if isinstance(item, str): - context_str += item + '\n' + context_str += item + "\n" else: - if 'content' not in item: - raise ValueError(f'Invalid context structure: {item}') + if "content" not in item: + raise ValueError(f"Invalid context structure: {item}") - context_str += item['content'] + '\n' + context_str += item["content"] + "\n" retriever_resource = self._convert_to_original_retriever_resource(item) if retriever_resource: original_retriever_resource.append(retriever_resource) - if self.callbacks and original_retriever_resource: - for callback in self.callbacks: - callback.on_event( - event=QueueRetrieverResourcesEvent( - retriever_resources=original_retriever_resource - ) - ) - - return context_str.strip() - - return None + yield RunRetrieverResourceEvent( + retriever_resources=original_retriever_resource, context=context_str.strip() + ) def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]: """ @@ -397,32 +429,37 @@ class LLMNode(BaseNode): :param context_dict: context dict :return: """ - if ('metadata' in context_dict and '_source' in context_dict['metadata'] - and context_dict['metadata']['_source'] == 'knowledge'): - metadata = context_dict.get('metadata', {}) + if ( + "metadata" in context_dict + and "_source" in context_dict["metadata"] + and context_dict["metadata"]["_source"] == "knowledge" + ): + metadata = context_dict.get("metadata", {}) + source = { - 'position': metadata.get('position'), - 'dataset_id': metadata.get('dataset_id'), - 'dataset_name': metadata.get('dataset_name'), - 'document_id': metadata.get('document_id'), - 'document_name': metadata.get('document_name'), - 'data_source_type': metadata.get('document_data_source_type'), - 'segment_id': metadata.get('segment_id'), - 'retriever_from': metadata.get('retriever_from'), - 'score': metadata.get('score'), - 'hit_count': metadata.get('segment_hit_count'), - 'word_count': metadata.get('segment_word_count'), - 'segment_position': metadata.get('segment_position'), - 'index_node_hash': metadata.get('segment_index_node_hash'), - 'content': context_dict.get('content'), + "position": metadata.get("position"), + "dataset_id": metadata.get("dataset_id"), + "dataset_name": metadata.get("dataset_name"), + "document_id": metadata.get("document_id"), + "document_name": metadata.get("document_name"), + "data_source_type": metadata.get("document_data_source_type"), + "segment_id": metadata.get("segment_id"), + "retriever_from": metadata.get("retriever_from"), + "score": metadata.get("score"), + "hit_count": metadata.get("segment_hit_count"), + "word_count": metadata.get("segment_word_count"), + "segment_position": metadata.get("segment_position"), + "index_node_hash": metadata.get("segment_index_node_hash"), + "content": context_dict.get("content"), } return source return None - def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ - ModelInstance, ModelConfigWithCredentialsEntity]: + def _fetch_model_config( + self, node_data_model: ModelConfig + ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config :param node_data_model: node data model @@ -433,10 +470,7 @@ class LLMNode(BaseNode): model_manager = ModelManager() model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, - model_type=ModelType.LLM, - provider=provider_name, - model=model_name + tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name ) provider_model_bundle = model_instance.provider_model_bundle @@ -447,8 +481,7 @@ class LLMNode(BaseNode): # check model provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_name, - model_type=ModelType.LLM + model=model_name, model_type=ModelType.LLM ) if provider_model is None: @@ -464,19 +497,16 @@ class LLMNode(BaseNode): # model config completion_params = node_data_model.completion_params stop = [] - if 'stop' in completion_params: - stop = completion_params['stop'] - del completion_params['stop'] + if "stop" in completion_params: + stop = completion_params["stop"] + del completion_params["stop"] # get model mode model_mode = node_data_model.mode if not model_mode: raise ValueError("LLM mode is required.") - model_schema = model_type_instance.get_model_schema( - model_name, - model_credentials - ) + model_schema = model_type_instance.get_model_schema(model_name, model_credentials) if not model_schema: raise ValueError(f"Model {model_name} not exist.") @@ -492,9 +522,9 @@ class LLMNode(BaseNode): stop=stop, ) - def _fetch_memory(self, node_data_memory: Optional[MemoryConfig], - variable_pool: VariablePool, - model_instance: ModelInstance) -> Optional[TokenBufferMemory]: + def _fetch_memory( + self, node_data_memory: Optional[MemoryConfig], variable_pool: VariablePool, model_instance: ModelInstance + ) -> Optional[TokenBufferMemory]: """ Fetch memory :param node_data_memory: node data memory @@ -505,35 +535,35 @@ class LLMNode(BaseNode): return None # get conversation id - conversation_id = variable_pool.get_any(['sys', SystemVariableKey.CONVERSATION_ID.value]) + conversation_id = variable_pool.get_any(["sys", SystemVariableKey.CONVERSATION_ID.value]) if conversation_id is None: return None # get conversation - conversation = db.session.query(Conversation).filter( - Conversation.app_id == self.app_id, - Conversation.id == conversation_id - ).first() + conversation = ( + db.session.query(Conversation) + .filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id) + .first() + ) if not conversation: return None - memory = TokenBufferMemory( - conversation=conversation, - model_instance=model_instance - ) + memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) return memory - def _fetch_prompt_messages(self, node_data: LLMNodeData, - query: Optional[str], - query_prompt_template: Optional[str], - inputs: dict[str, str], - files: list["FileVar"], - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) \ - -> tuple[list[PromptMessage], Optional[list[str]]]: + def _fetch_prompt_messages( + self, + node_data: LLMNodeData, + query: Optional[str], + query_prompt_template: Optional[str], + inputs: dict[str, str], + files: list["FileVar"], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: """ Fetch prompt messages :param node_data: node data @@ -550,7 +580,7 @@ class LLMNode(BaseNode): prompt_messages = prompt_transform.get_prompt( prompt_template=node_data.prompt_template, inputs=inputs, - query=query if query else '', + query=query if query else "", files=files, context=context, memory_config=node_data.memory, @@ -570,7 +600,11 @@ class LLMNode(BaseNode): if not isinstance(prompt_message.content, str): prompt_message_content = [] for content_item in prompt_message.content: - if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(content_item, ImagePromptMessageContent): + if ( + vision_enabled + and content_item.type == PromptMessageContentType.IMAGE + and isinstance(content_item, ImagePromptMessageContent) + ): # Override vision config if LLM node has vision config if vision_detail: content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail) @@ -580,15 +614,18 @@ class LLMNode(BaseNode): if len(prompt_message_content) > 1: prompt_message.content = prompt_message_content - elif (len(prompt_message_content) == 1 - and prompt_message_content[0].type == PromptMessageContentType.TEXT): + elif ( + len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT + ): prompt_message.content = prompt_message_content[0].data filtered_prompt_messages.append(prompt_message) if not filtered_prompt_messages: - raise ValueError("No prompt found in the LLM configuration. " - "Please ensure a prompt is properly configured before proceeding.") + raise ValueError( + "No prompt found in the LLM configuration. " + "Please ensure a prompt is properly configured before proceeding." + ) return filtered_prompt_messages, stop @@ -626,7 +663,7 @@ class LLMNode(BaseNode): elif quota_unit == QuotaUnit.CREDITS: used_quota = 1 - if 'gpt-4' in model_instance.model: + if "gpt-4" in model_instance.model: used_quota = 20 else: used_quota = 1 @@ -637,28 +674,31 @@ class LLMNode(BaseNode): Provider.provider_name == model_instance.provider, Provider.provider_type == ProviderType.SYSTEM.value, Provider.quota_type == system_configuration.current_quota_type.value, - Provider.quota_limit > Provider.quota_used - ).update({'quota_used': Provider.quota_used + used_quota}) + Provider.quota_limit > Provider.quota_used, + ).update({"quota_used": Provider.quota_used + used_quota}) db.session.commit() @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, graph_config: Mapping[str, Any], node_id: str, node_data: LLMNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ - prompt_template = node_data.prompt_template variable_selectors = [] if isinstance(prompt_template, list): for prompt in prompt_template: - if prompt.edition_type != 'jinja2': + if prompt.edition_type != "jinja2": variable_template_parser = VariableTemplateParser(template=prompt.text) variable_selectors.extend(variable_template_parser.extract_variable_selectors()) else: - if prompt_template.edition_type != 'jinja2': + if prompt_template.edition_type != "jinja2": variable_template_parser = VariableTemplateParser(template=prompt_template.text) variable_selectors = variable_template_parser.extract_variable_selectors() @@ -668,36 +708,39 @@ class LLMNode(BaseNode): memory = node_data.memory if memory and memory.query_prompt_template: - query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template) - .extract_variable_selectors()) + query_variable_selectors = VariableTemplateParser( + template=memory.query_prompt_template + ).extract_variable_selectors() for variable_selector in query_variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector if node_data.context.enabled: - variable_mapping['#context#'] = node_data.context.variable_selector + variable_mapping["#context#"] = node_data.context.variable_selector if node_data.vision.enabled: - variable_mapping['#files#'] = ['sys', SystemVariableKey.FILES.value] + variable_mapping["#files#"] = ["sys", SystemVariableKey.FILES.value] if node_data.memory: - variable_mapping['#sys.query#'] = ['sys', SystemVariableKey.QUERY.value] + variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value] if node_data.prompt_config: enable_jinja = False if isinstance(prompt_template, list): for prompt in prompt_template: - if prompt.edition_type == 'jinja2': + if prompt.edition_type == "jinja2": enable_jinja = True break else: - if prompt_template.edition_type == 'jinja2': + if prompt_template.edition_type == "jinja2": enable_jinja = True if enable_jinja: for variable_selector in node_data.prompt_config.jinja2_variables or []: variable_mapping[variable_selector.variable] = variable_selector.value_selector + variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} + return variable_mapping @classmethod @@ -713,26 +756,19 @@ class LLMNode(BaseNode): "prompt_templates": { "chat_model": { "prompts": [ - { - "role": "system", - "text": "You are a helpful AI assistant.", - "edition_type": "basic" - } + {"role": "system", "text": "You are a helpful AI assistant.", "edition_type": "basic"} ] }, "completion_model": { - "conversation_histories_role": { - "user_prefix": "Human", - "assistant_prefix": "Assistant" - }, + "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, "prompt": { "text": "Here is the chat histories between human and assistant, inside " - " XML tags.\n\n\n{{" - "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:", - "edition_type": "basic" + " XML tags.\n\n\n{{" + "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:", + "edition_type": "basic", }, - "stop": ["Human:"] - } + "stop": ["Human:"], + }, } - } + }, } diff --git a/api/core/workflow/nodes/loop/entities.py b/api/core/workflow/nodes/loop/entities.py index 8a5684551e..a8a0debe64 100644 --- a/api/core/workflow/nodes/loop/entities.py +++ b/api/core/workflow/nodes/loop/entities.py @@ -1,4 +1,3 @@ - from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState @@ -7,7 +6,8 @@ class LoopNodeData(BaseIterationNodeData): Loop Node Data. """ + class LoopState(BaseIterationState): """ Loop State. - """ \ No newline at end of file + """ diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 7d53c6f5f2..fbc68b79cb 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -1,20 +1,37 @@ -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseIterationNode +from typing import Any + +from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.loop.entities import LoopNodeData, LoopState +from core.workflow.utils.condition.entities import Condition -class LoopNode(BaseIterationNode): +class LoopNode(BaseNode): """ Loop Node. """ + _node_data_cls = LoopNodeData _node_type = NodeType.LOOP - def _run(self, variable_pool: VariablePool) -> LoopState: - return super()._run(variable_pool) + def _run(self) -> LoopState: + return super()._run() - def _get_next_iteration(self, variable_loop: VariablePool) -> NodeRunResult | str: + @classmethod + def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]: """ - Get next iteration start node id based on the graph. + Get conditions. """ + node_id = node_config.get("id") + if not node_id: + return [] + + # TODO waiting for implementation + return [ + Condition( + variable_selector=[node_id, "index"], + comparison_operator="≤", + value_type="value_selector", + value_selector=[], + ) + ] diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py new file mode 100644 index 0000000000..b98525e86e --- /dev/null +++ b/api/core/workflow/nodes/node_mapping.py @@ -0,0 +1,37 @@ +from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.answer.answer_node import AnswerNode +from core.workflow.nodes.code.code_node import CodeNode +from core.workflow.nodes.end.end_node import EndNode +from core.workflow.nodes.http_request.http_request_node import HttpRequestNode +from core.workflow.nodes.if_else.if_else_node import IfElseNode +from core.workflow.nodes.iteration.iteration_node import IterationNode +from core.workflow.nodes.iteration.iteration_start_node import IterationStartNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from core.workflow.nodes.llm.llm_node import LLMNode +from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from core.workflow.nodes.tool.tool_node import ToolNode +from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode +from core.workflow.nodes.variable_assigner import VariableAssignerNode + +node_classes = { + NodeType.START: StartNode, + NodeType.END: EndNode, + NodeType.ANSWER: AnswerNode, + NodeType.LLM: LLMNode, + NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, + NodeType.IF_ELSE: IfElseNode, + NodeType.CODE: CodeNode, + NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode, + NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode, + NodeType.HTTP_REQUEST: HttpRequestNode, + NodeType.TOOL: ToolNode, + NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode, + NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR + NodeType.ITERATION: IterationNode, + NodeType.ITERATION_START: IterationStartNode, + NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode, + NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode, +} diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index 7bb123b126..802ed31e27 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -8,47 +8,52 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData class ModelConfig(BaseModel): """ - Model Config. + Model Config. """ + provider: str name: str mode: str completion_params: dict[str, Any] = {} + class ParameterConfig(BaseModel): """ Parameter Config. """ + name: str - type: Literal['string', 'number', 'bool', 'select', 'array[string]', 'array[number]', 'array[object]'] + type: Literal["string", "number", "bool", "select", "array[string]", "array[number]", "array[object]"] options: Optional[list[str]] = None description: str required: bool - @field_validator('name', mode='before') + @field_validator("name", mode="before") @classmethod def validate_name(cls, value) -> str: if not value: - raise ValueError('Parameter name is required') - if value in ['__reason', '__is_success']: - raise ValueError('Invalid parameter name, __reason and __is_success are reserved') + raise ValueError("Parameter name is required") + if value in ["__reason", "__is_success"]: + raise ValueError("Invalid parameter name, __reason and __is_success are reserved") return value + class ParameterExtractorNodeData(BaseNodeData): """ Parameter Extractor Node Data. """ + model: ModelConfig query: list[str] parameters: list[ParameterConfig] instruction: Optional[str] = None memory: Optional[MemoryConfig] = None - reasoning_mode: Literal['function_call', 'prompt'] + reasoning_mode: Literal["function_call", "prompt"] - @field_validator('reasoning_mode', mode='before') + @field_validator("reasoning_mode", mode="before") @classmethod def set_reasoning_mode(cls, v) -> str: - return v or 'function_call' + return v or "function_call" def get_parameter_json_schema(self) -> dict: """ @@ -56,32 +61,26 @@ class ParameterExtractorNodeData(BaseNodeData): :return: parameter json schema """ - parameters = { - 'type': 'object', - 'properties': {}, - 'required': [] - } + parameters = {"type": "object", "properties": {}, "required": []} for parameter in self.parameters: - parameter_schema = { - 'description': parameter.description - } + parameter_schema = {"description": parameter.description} - if parameter.type in ['string', 'select']: - parameter_schema['type'] = 'string' - elif parameter.type.startswith('array'): - parameter_schema['type'] = 'array' + if parameter.type in ["string", "select"]: + parameter_schema["type"] = "string" + elif parameter.type.startswith("array"): + parameter_schema["type"] = "array" nested_type = parameter.type[6:-1] - parameter_schema['items'] = {'type': nested_type} + parameter_schema["items"] = {"type": nested_type} else: - parameter_schema['type'] = parameter.type + parameter_schema["type"] = parameter.type - if parameter.type == 'select': - parameter_schema['enum'] = parameter.options + if parameter.type == "select": + parameter_schema["enum"] = parameter.options + + parameters["properties"][parameter.name] = parameter_schema - parameters['properties'][parameter.name] = parameter_schema - if parameter.required: - parameters['required'].append(parameter.name) + parameters["required"].append(parameter.name) - return parameters \ No newline at end of file + return parameters diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 2876695a82..131d26b19e 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -1,6 +1,7 @@ import json import uuid -from typing import Optional, cast +from collections.abc import Mapping, Sequence +from typing import Any, Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -44,6 +45,7 @@ class ParameterExtractorNode(LLMNode): """ Parameter Extractor Node. """ + _node_data_cls = ParameterExtractorNodeData _node_type = NodeType.PARAMETER_EXTRACTOR @@ -56,30 +58,27 @@ class ParameterExtractorNode(LLMNode): "model": { "prompt_templates": { "completion_model": { - "conversation_histories_role": { - "user_prefix": "Human", - "assistant_prefix": "Assistant" - }, - "stop": ["Human:"] + "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, + "stop": ["Human:"], } } } } - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run the node. """ node_data = cast(ParameterExtractorNodeData, self.node_data) - variable = variable_pool.get_any(node_data.query) + variable = self.graph_runtime_state.variable_pool.get_any(node_data.query) if not variable: raise ValueError("Input variable content not found or is empty") query = variable inputs = { - 'query': query, - 'parameters': jsonable_encoder(node_data.parameters), - 'instruction': jsonable_encoder(node_data.instruction), + "query": query, + "parameters": jsonable_encoder(node_data.parameters), + "instruction": jsonable_encoder(node_data.instruction), } model_instance, model_config = self._fetch_model_config(node_data.model) @@ -92,29 +91,31 @@ class ParameterExtractorNode(LLMNode): raise ValueError("Model schema not found") # fetch memory - memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) + memory = self._fetch_memory(node_data.memory, self.graph_runtime_state.variable_pool, model_instance) - if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \ - and node_data.reasoning_mode == 'function_call': - # use function call + if ( + set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} + and node_data.reasoning_mode == "function_call" + ): + # use function call prompt_messages, prompt_message_tools = self._generate_function_call_prompt( - node_data, query, variable_pool, model_config, memory + node_data, query, self.graph_runtime_state.variable_pool, model_config, memory ) else: # use prompt engineering - prompt_messages = self._generate_prompt_engineering_prompt(node_data, query, variable_pool, model_config, - memory) + prompt_messages = self._generate_prompt_engineering_prompt( + node_data, query, self.graph_runtime_state.variable_pool, model_config, memory + ) prompt_message_tools = [] process_data = { - 'model_mode': model_config.mode, - 'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=model_config.mode, - prompt_messages=prompt_messages + "model_mode": model_config.mode, + "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_config.mode, prompt_messages=prompt_messages ), - 'usage': None, - 'function': {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]), - 'tool_call': None, + "usage": None, + "function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]), + "tool_call": None, } try: @@ -125,20 +126,17 @@ class ParameterExtractorNode(LLMNode): tools=prompt_message_tools, stop=model_config.stop, ) - process_data['usage'] = jsonable_encoder(usage) - process_data['tool_call'] = jsonable_encoder(tool_call) - process_data['llm_text'] = text + process_data["usage"] = jsonable_encoder(usage) + process_data["tool_call"] = jsonable_encoder(tool_call) + process_data["llm_text"] = text except Exception as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=inputs, process_data=process_data, - outputs={ - '__is_success': 0, - '__reason': str(e) - }, + outputs={"__is_success": 0, "__reason": str(e)}, error=str(e), - metadata={} + metadata={}, ) error = None @@ -163,23 +161,23 @@ class ParameterExtractorNode(LLMNode): status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, process_data=process_data, - outputs={ - '__is_success': 1 if not error else 0, - '__reason': error, - **result - }, + outputs={"__is_success": 1 if not error else 0, "__reason": error, **result}, metadata={ NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency - } + NodeRunMetadataKey.CURRENCY: usage.currency, + }, + llm_usage=usage, ) - def _invoke_llm(self, node_data_model: ModelConfig, - model_instance: ModelInstance, - prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - stop: list[str]) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]: + def _invoke_llm( + self, + node_data_model: ModelConfig, + model_instance: ModelInstance, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + stop: list[str], + ) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]: """ Invoke large language model :param node_data_model: node data model @@ -212,32 +210,35 @@ class ParameterExtractorNode(LLMNode): return text, usage, tool_call - def _generate_function_call_prompt(self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], - ) -> tuple[list[PromptMessage], list[PromptMessageTool]]: + def _generate_function_call_prompt( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + ) -> tuple[list[PromptMessage], list[PromptMessageTool]]: """ Generate function call prompt. """ - query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format(content=query, structure=json.dumps( - node_data.get_parameter_json_schema())) + query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format( + content=query, structure=json.dumps(node_data.get_parameter_json_schema()) + ) prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '') - prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, memory, - rest_token) + rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "") + prompt_template = self._get_function_calling_prompt_template( + node_data, query, variable_pool, memory, rest_token + ) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, - query='', + query="", files=[], - context='', + context="", memory_config=node_data.memory, memory=None, - model_config=model_config + model_config=model_config, ) # find last user message @@ -250,124 +251,125 @@ class ParameterExtractorNode(LLMNode): example_messages = [] for example in FUNCTION_CALLING_EXTRACTOR_EXAMPLE: id = uuid.uuid4().hex - example_messages.extend([ - UserPromptMessage(content=example['user']['query']), - AssistantPromptMessage( - content=example['assistant']['text'], - tool_calls=[ - AssistantPromptMessage.ToolCall( - id=id, - type='function', - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=example['assistant']['function_call']['name'], - arguments=json.dumps(example['assistant']['function_call']['parameters'] - ) - )) - ] - ), - ToolPromptMessage( - content='Great! You have called the function with the correct parameters.', - tool_call_id=id - ), - AssistantPromptMessage( - content='I have extracted the parameters, let\'s move on.', - ) - ]) + example_messages.extend( + [ + UserPromptMessage(content=example["user"]["query"]), + AssistantPromptMessage( + content=example["assistant"]["text"], + tool_calls=[ + AssistantPromptMessage.ToolCall( + id=id, + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=example["assistant"]["function_call"]["name"], + arguments=json.dumps(example["assistant"]["function_call"]["parameters"]), + ), + ) + ], + ), + ToolPromptMessage( + content="Great! You have called the function with the correct parameters.", tool_call_id=id + ), + AssistantPromptMessage( + content="I have extracted the parameters, let's move on.", + ), + ] + ) - prompt_messages = prompt_messages[:last_user_message_idx] + \ - example_messages + prompt_messages[last_user_message_idx:] + prompt_messages = ( + prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:] + ) # generate tool tool = PromptMessageTool( name=FUNCTION_CALLING_EXTRACTOR_NAME, - description='Extract parameters from the natural language text', + description="Extract parameters from the natural language text", parameters=node_data.get_parameter_json_schema(), ) return prompt_messages, [tool] - def _generate_prompt_engineering_prompt(self, - data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], - ) -> list[PromptMessage]: + def _generate_prompt_engineering_prompt( + self, + data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + ) -> list[PromptMessage]: """ Generate prompt engineering prompt. """ model_mode = ModelMode.value_of(data.model.mode) if model_mode == ModelMode.COMPLETION: - return self._generate_prompt_engineering_completion_prompt( - data, query, variable_pool, model_config, memory - ) + return self._generate_prompt_engineering_completion_prompt(data, query, variable_pool, model_config, memory) elif model_mode == ModelMode.CHAT: - return self._generate_prompt_engineering_chat_prompt( - data, query, variable_pool, model_config, memory - ) + return self._generate_prompt_engineering_chat_prompt(data, query, variable_pool, model_config, memory) else: raise ValueError(f"Invalid model mode: {model_mode}") - def _generate_prompt_engineering_completion_prompt(self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], - ) -> list[PromptMessage]: + def _generate_prompt_engineering_completion_prompt( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + ) -> list[PromptMessage]: """ Generate completion prompt. """ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '') - prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, memory, - rest_token) + rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "") + prompt_template = self._get_prompt_engineering_prompt_template( + node_data, query, variable_pool, memory, rest_token + ) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, - inputs={ - 'structure': json.dumps(node_data.get_parameter_json_schema()) - }, - query='', + inputs={"structure": json.dumps(node_data.get_parameter_json_schema())}, + query="", files=[], - context='', + context="", memory_config=node_data.memory, memory=memory, - model_config=model_config + model_config=model_config, ) return prompt_messages - def _generate_prompt_engineering_chat_prompt(self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], - ) -> list[PromptMessage]: + def _generate_prompt_engineering_chat_prompt( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + ) -> list[PromptMessage]: """ Generate chat prompt. """ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '') + rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "") prompt_template = self._get_prompt_engineering_prompt_template( node_data, CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( - structure=json.dumps(node_data.get_parameter_json_schema()), - text=query + structure=json.dumps(node_data.get_parameter_json_schema()), text=query ), - variable_pool, memory, rest_token + variable_pool, + memory, + rest_token, ) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, - query='', + query="", files=[], - context='', + context="", memory_config=node_data.memory, memory=None, - model_config=model_config + model_config=model_config, ) # find last user message @@ -379,18 +381,23 @@ class ParameterExtractorNode(LLMNode): # add example messages before last user message example_messages = [] for example in CHAT_EXAMPLE: - example_messages.extend([ - UserPromptMessage(content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( - structure=json.dumps(example['user']['json']), - text=example['user']['query'], - )), - AssistantPromptMessage( - content=json.dumps(example['assistant']['json']), - ) - ]) + example_messages.extend( + [ + UserPromptMessage( + content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( + structure=json.dumps(example["user"]["json"]), + text=example["user"]["query"], + ) + ), + AssistantPromptMessage( + content=json.dumps(example["assistant"]["json"]), + ), + ] + ) - prompt_messages = prompt_messages[:last_user_message_idx] + \ - example_messages + prompt_messages[last_user_message_idx:] + prompt_messages = ( + prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:] + ) return prompt_messages @@ -405,28 +412,28 @@ class ParameterExtractorNode(LLMNode): if parameter.required and parameter.name not in result: raise ValueError(f"Parameter {parameter.name} is required") - if parameter.type == 'select' and parameter.options and result.get(parameter.name) not in parameter.options: + if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options: raise ValueError(f"Invalid `select` value for parameter {parameter.name}") - if parameter.type == 'number' and not isinstance(result.get(parameter.name), int | float): + if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float): raise ValueError(f"Invalid `number` value for parameter {parameter.name}") - if parameter.type == 'bool' and not isinstance(result.get(parameter.name), bool): + if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool): raise ValueError(f"Invalid `bool` value for parameter {parameter.name}") - if parameter.type == 'string' and not isinstance(result.get(parameter.name), str): + if parameter.type == "string" and not isinstance(result.get(parameter.name), str): raise ValueError(f"Invalid `string` value for parameter {parameter.name}") - if parameter.type.startswith('array'): + if parameter.type.startswith("array"): if not isinstance(result.get(parameter.name), list): raise ValueError(f"Invalid `array` value for parameter {parameter.name}") nested_type = parameter.type[6:-1] for item in result.get(parameter.name): - if nested_type == 'number' and not isinstance(item, int | float): + if nested_type == "number" and not isinstance(item, int | float): raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}") - if nested_type == 'string' and not isinstance(item, str): + if nested_type == "string" and not isinstance(item, str): raise ValueError(f"Invalid `array[string]` value for parameter {parameter.name}") - if nested_type == 'object' and not isinstance(item, dict): + if nested_type == "object" and not isinstance(item, dict): raise ValueError(f"Invalid `array[object]` value for parameter {parameter.name}") return result @@ -438,12 +445,12 @@ class ParameterExtractorNode(LLMNode): for parameter in data.parameters: if parameter.name in result: # transform value - if parameter.type == 'number': + if parameter.type == "number": if isinstance(result[parameter.name], int | float): transformed_result[parameter.name] = result[parameter.name] elif isinstance(result[parameter.name], str): try: - if '.' in result[parameter.name]: + if "." in result[parameter.name]: result[parameter.name] = float(result[parameter.name]) else: result[parameter.name] = int(result[parameter.name]) @@ -460,40 +467,40 @@ class ParameterExtractorNode(LLMNode): # transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true') # elif isinstance(result[parameter.name], int): # transformed_result[parameter.name] = bool(result[parameter.name]) - elif parameter.type in ['string', 'select']: + elif parameter.type in ["string", "select"]: if isinstance(result[parameter.name], str): transformed_result[parameter.name] = result[parameter.name] - elif parameter.type.startswith('array'): + elif parameter.type.startswith("array"): if isinstance(result[parameter.name], list): nested_type = parameter.type[6:-1] transformed_result[parameter.name] = [] for item in result[parameter.name]: - if nested_type == 'number': + if nested_type == "number": if isinstance(item, int | float): transformed_result[parameter.name].append(item) elif isinstance(item, str): try: - if '.' in item: + if "." in item: transformed_result[parameter.name].append(float(item)) else: transformed_result[parameter.name].append(int(item)) except ValueError: pass - elif nested_type == 'string': + elif nested_type == "string": if isinstance(item, str): transformed_result[parameter.name].append(item) - elif nested_type == 'object': + elif nested_type == "object": if isinstance(item, dict): transformed_result[parameter.name].append(item) if parameter.name not in transformed_result: - if parameter.type == 'number': + if parameter.type == "number": transformed_result[parameter.name] = 0 - elif parameter.type == 'bool': + elif parameter.type == "bool": transformed_result[parameter.name] = False - elif parameter.type in ['string', 'select']: - transformed_result[parameter.name] = '' - elif parameter.type.startswith('array'): + elif parameter.type in ["string", "select"]: + transformed_result[parameter.name] = "" + elif parameter.type.startswith("array"): transformed_result[parameter.name] = [] return transformed_result @@ -509,24 +516,24 @@ class ParameterExtractorNode(LLMNode): """ stack = [] for i, c in enumerate(text): - if c == '{' or c == '[': + if c == "{" or c == "[": stack.append(c) - elif c == '}' or c == ']': + elif c == "}" or c == "]": # check if stack is empty if not stack: return text[:i] # check if the last element in stack is matching - if (c == '}' and stack[-1] == '{') or (c == ']' and stack[-1] == '['): + if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["): stack.pop() if not stack: - return text[:i + 1] + return text[: i + 1] else: return text[:i] return None # extract json from the text for idx in range(len(result)): - if result[idx] == '{' or result[idx] == '[': + if result[idx] == "{" or result[idx] == "[": json_str = extract_json(result[idx:]) if json_str: try: @@ -549,12 +556,12 @@ class ParameterExtractorNode(LLMNode): """ result = {} for parameter in data.parameters: - if parameter.type == 'number': + if parameter.type == "number": result[parameter.name] = 0 - elif parameter.type == 'bool': + elif parameter.type == "bool": result[parameter.name] = False - elif parameter.type in ['string', 'select']: - result[parameter.name] = '' + elif parameter.type in ["string", "select"]: + result[parameter.name] = "" return result @@ -570,71 +577,76 @@ class ParameterExtractorNode(LLMNode): return variable_template_parser.format(inputs) - def _get_function_calling_prompt_template(self, node_data: ParameterExtractorNodeData, query: str, - variable_pool: VariablePool, - memory: Optional[TokenBufferMemory], - max_token_limit: int = 2000) \ - -> list[ChatModelMessage]: + def _get_function_calling_prompt_template( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + memory: Optional[TokenBufferMemory], + max_token_limit: int = 2000, + ) -> list[ChatModelMessage]: model_mode = ModelMode.value_of(node_data.model.mode) input_text = query - memory_str = '' - instruction = self._render_instruction(node_data.instruction or '', variable_pool) + memory_str = "" + instruction = self._render_instruction(node_data.instruction or "", variable_pool) if memory: - memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit, - message_limit=node_data.memory.window.size) + memory_str = memory.get_history_prompt_text( + max_token_limit=max_token_limit, message_limit=node_data.memory.window.size + ) if model_mode == ModelMode.CHAT: system_prompt_messages = ChatModelMessage( role=PromptMessageRole.SYSTEM, - text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction) - ) - user_prompt_message = ChatModelMessage( - role=PromptMessageRole.USER, - text=input_text + text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction), ) + user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) return [system_prompt_messages, user_prompt_message] else: raise ValueError(f"Model mode {model_mode} not support.") - def _get_prompt_engineering_prompt_template(self, node_data: ParameterExtractorNodeData, query: str, - variable_pool: VariablePool, - memory: Optional[TokenBufferMemory], - max_token_limit: int = 2000) \ - -> list[ChatModelMessage]: - + def _get_prompt_engineering_prompt_template( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + memory: Optional[TokenBufferMemory], + max_token_limit: int = 2000, + ) -> list[ChatModelMessage]: model_mode = ModelMode.value_of(node_data.model.mode) input_text = query - memory_str = '' - instruction = self._render_instruction(node_data.instruction or '', variable_pool) + memory_str = "" + instruction = self._render_instruction(node_data.instruction or "", variable_pool) if memory: - memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit, - message_limit=node_data.memory.window.size) + memory_str = memory.get_history_prompt_text( + max_token_limit=max_token_limit, message_limit=node_data.memory.window.size + ) if model_mode == ModelMode.CHAT: system_prompt_messages = ChatModelMessage( role=PromptMessageRole.SYSTEM, - text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction) - ) - user_prompt_message = ChatModelMessage( - role=PromptMessageRole.USER, - text=input_text + text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction), ) + user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) return [system_prompt_messages, user_prompt_message] elif model_mode == ModelMode.COMPLETION: return CompletionModelPromptTemplate( - text=COMPLETION_GENERATE_JSON_PROMPT.format(histories=memory_str, - text=input_text, - instruction=instruction) - .replace('{γγγ', '') - .replace('}γγγ', '') + text=COMPLETION_GENERATE_JSON_PROMPT.format( + histories=memory_str, text=input_text, instruction=instruction + ) + .replace("{γγγ", "") + .replace("}γγγ", "") ) else: raise ValueError(f"Model mode {model_mode} not support.") - def _calculate_rest_token(self, node_data: ParameterExtractorNodeData, query: str, - variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, - context: Optional[str]) -> int: + def _calculate_rest_token( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + context: Optional[str], + ) -> int: prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) model_instance, model_config = self._fetch_model_config(node_data.model) @@ -654,12 +666,12 @@ class ParameterExtractorNode(LLMNode): prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, - query='', + query="", files=[], context=context, memory_config=node_data.memory, memory=None, - model_config=model_config + model_config=model_config, ) rest_tokens = 2000 @@ -668,26 +680,28 @@ class ParameterExtractorNode(LLMNode): model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) - curr_message_tokens = model_type_instance.get_num_tokens( - model_config.model, - model_config.credentials, - prompt_messages - ) + 1000 # add 1000 to ensure tool call messages + curr_message_tokens = ( + model_type_instance.get_num_tokens(model_config.model, model_config.credentials, prompt_messages) + 1000 + ) # add 1000 to ensure tool call messages max_tokens = 0 for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template) + ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens rest_tokens = max(rest_tokens, 0) return rest_tokens - def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ - ModelInstance, ModelConfigWithCredentialsEntity]: + def _fetch_model_config( + self, node_data_model: ModelConfig + ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config. """ @@ -697,22 +711,23 @@ class ParameterExtractorNode(LLMNode): return self._model_instance, self._model_config @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: ParameterExtractorNodeData) -> dict[ - str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, graph_config: Mapping[str, Any], node_id: str, node_data: ParameterExtractorNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ - node_data = node_data - - variable_mapping = { - 'query': node_data.query - } + variable_mapping = {"query": node_data.query} if node_data.instruction: variable_template_parser = VariableTemplateParser(template=node_data.instruction) for selector in variable_template_parser.extract_variable_selectors(): variable_mapping[selector.variable] = selector.value_selector + variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} + return variable_mapping diff --git a/api/core/workflow/nodes/parameter_extractor/prompts.py b/api/core/workflow/nodes/parameter_extractor/prompts.py index 499c58d505..c63fded4d0 100644 --- a/api/core/workflow/nodes/parameter_extractor/prompts.py +++ b/api/core/workflow/nodes/parameter_extractor/prompts.py @@ -1,4 +1,4 @@ -FUNCTION_CALLING_EXTRACTOR_NAME = 'extract_parameters' +FUNCTION_CALLING_EXTRACTOR_NAME = "extract_parameters" FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy. ### Task @@ -35,61 +35,48 @@ FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information fr """ -FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [{ - 'user': { - 'query': 'What is the weather today in SF?', - 'function': { - 'name': FUNCTION_CALLING_EXTRACTOR_NAME, - 'parameters': { - 'type': 'object', - 'properties': { - 'location': { - 'type': 'string', - 'description': 'The location to get the weather information', - 'required': True +FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [ + { + "user": { + "query": "What is the weather today in SF?", + "function": { + "name": FUNCTION_CALLING_EXTRACTOR_NAME, + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the weather information", + "required": True, + }, }, + "required": ["location"], }, - 'required': ['location'] - } - } + }, + }, + "assistant": { + "text": "I need always call the function with the correct parameters. in this case, I need to call the function with the location parameter.", + "function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"location": "San Francisco"}}, + }, }, - 'assistant': { - 'text': 'I need always call the function with the correct parameters. in this case, I need to call the function with the location parameter.', - 'function_call' : { - 'name': FUNCTION_CALLING_EXTRACTOR_NAME, - 'parameters': { - 'location': 'San Francisco' - } - } - } -}, { - 'user': { - 'query': 'I want to eat some apple pie.', - 'function': { - 'name': FUNCTION_CALLING_EXTRACTOR_NAME, - 'parameters': { - 'type': 'object', - 'properties': { - 'food': { - 'type': 'string', - 'description': 'The food to eat', - 'required': True - } + { + "user": { + "query": "I want to eat some apple pie.", + "function": { + "name": FUNCTION_CALLING_EXTRACTOR_NAME, + "parameters": { + "type": "object", + "properties": {"food": {"type": "string", "description": "The food to eat", "required": True}}, + "required": ["food"], }, - 'required': ['food'] - } - } + }, + }, + "assistant": { + "text": "I need always call the function with the correct parameters. in this case, I need to call the function with the food parameter.", + "function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"food": "apple pie"}}, + }, }, - 'assistant': { - 'text': 'I need always call the function with the correct parameters. in this case, I need to call the function with the food parameter.', - 'function_call' : { - 'name': FUNCTION_CALLING_EXTRACTOR_NAME, - 'parameters': { - 'food': 'apple pie' - } - } - } -}] +] COMPLETION_GENERATE_JSON_PROMPT = """### Instructions: Some extra information are provided below, I should always follow the instructions as possible as I can. @@ -161,46 +148,33 @@ Inside XML tags, there is a text that you should convert to a JSON """ -CHAT_EXAMPLE = [{ - 'user': { - 'query': 'What is the weather today in SF?', - 'json': { - 'type': 'object', - 'properties': { - 'location': { - 'type': 'string', - 'description': 'The location to get the weather information', - 'required': True - } +CHAT_EXAMPLE = [ + { + "user": { + "query": "What is the weather today in SF?", + "json": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the weather information", + "required": True, + } + }, + "required": ["location"], }, - 'required': ['location'] - } + }, + "assistant": {"text": "I need to output a valid JSON object.", "json": {"location": "San Francisco"}}, }, - 'assistant': { - 'text': 'I need to output a valid JSON object.', - 'json': { - 'location': 'San Francisco' - } - } -}, { - 'user': { - 'query': 'I want to eat some apple pie.', - 'json': { - 'type': 'object', - 'properties': { - 'food': { - 'type': 'string', - 'description': 'The food to eat', - 'required': True - } + { + "user": { + "query": "I want to eat some apple pie.", + "json": { + "type": "object", + "properties": {"food": {"type": "string", "description": "The food to eat", "required": True}}, + "required": ["food"], }, - 'required': ['food'] - } + }, + "assistant": {"text": "I need to output a valid JSON object.", "json": {"result": "apple pie"}}, }, - 'assistant': { - 'text': 'I need to output a valid JSON object.', - 'json': { - 'result': 'apple pie' - } - } -}] \ No newline at end of file +] diff --git a/api/core/workflow/nodes/question_classifier/entities.py b/api/core/workflow/nodes/question_classifier/entities.py index c0b0a8b696..40f7ce7582 100644 --- a/api/core/workflow/nodes/question_classifier/entities.py +++ b/api/core/workflow/nodes/question_classifier/entities.py @@ -8,8 +8,9 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData class ModelConfig(BaseModel): """ - Model Config. + Model Config. """ + provider: str name: str mode: str @@ -20,6 +21,7 @@ class ClassConfig(BaseModel): """ Class Config. """ + id: str name: str @@ -28,8 +30,9 @@ class QuestionClassifierNodeData(BaseNodeData): """ Knowledge retrieval Node Data. """ + query_variable_selector: list[str] - type: str = 'question-classifier' + type: str = "question-classifier" model: ModelConfig classes: list[ClassConfig] instruction: Optional[str] = None diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index f4057d50f3..d860f848ec 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,10 +1,12 @@ import json import logging -from typing import Optional, Union, cast +from collections.abc import Mapping, Sequence +from typing import Any, Optional, Union, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance +from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.utils.encoders import jsonable_encoder @@ -13,10 +15,9 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.llm.llm_node import LLMNode +from core.workflow.nodes.llm.llm_node import LLMNode, ModelInvokeCompleted from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData from core.workflow.nodes.question_classifier.template_prompts import ( QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1, @@ -36,46 +37,49 @@ class QuestionClassifierNode(LLMNode): _node_data_cls = QuestionClassifierNodeData node_type = NodeType.QUESTION_CLASSIFIER - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data) node_data = cast(QuestionClassifierNodeData, node_data) + variable_pool = self.graph_runtime_state.variable_pool # extract variables variable = variable_pool.get(node_data.query_variable_selector) query = variable.value if variable else None - variables = { - 'query': query - } + variables = {"query": query} # fetch model config model_instance, model_config = self._fetch_model_config(node_data.model) # fetch memory memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) # fetch instruction - instruction = self._format_instruction(node_data.instruction, variable_pool) if node_data.instruction else '' + instruction = self._format_instruction(node_data.instruction, variable_pool) if node_data.instruction else "" node_data.instruction = instruction # fetch prompt messages prompt_messages, stop = self._fetch_prompt( - node_data=node_data, - context='', - query=query, - memory=memory, - model_config=model_config + node_data=node_data, context="", query=query, memory=memory, model_config=model_config ) # handle invoke result - result_text, usage, finish_reason = self._invoke_llm( - node_data_model=node_data.model, - model_instance=model_instance, - prompt_messages=prompt_messages, - stop=stop + generator = self._invoke_llm( + node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop ) + + result_text = "" + usage = LLMUsage.empty_usage() + finish_reason = None + for event in generator: + if isinstance(event, ModelInvokeCompleted): + result_text = event.text + usage = event.usage + finish_reason = event.finish_reason + break + category_name = node_data.classes[0].name category_id = node_data.classes[0].id try: result_text_json = parse_and_check_json_markdown(result_text, []) # result_text_json = json.loads(result_text.strip('```JSON\n')) - if 'category_name' in result_text_json and 'category_id' in result_text_json: - category_id_result = result_text_json['category_id'] + if "category_name" in result_text_json and "category_id" in result_text_json: + category_id_result = result_text_json["category_id"] classes = node_data.classes classes_map = {class_.id: class_.name for class_ in classes} category_ids = [_class.id for _class in classes] @@ -87,17 +91,14 @@ class QuestionClassifierNode(LLMNode): logging.error(f"Failed to parse result text: {result_text}") try: process_data = { - 'model_mode': model_config.mode, - 'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=model_config.mode, - prompt_messages=prompt_messages + "model_mode": model_config.mode, + "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_config.mode, prompt_messages=prompt_messages ), - 'usage': jsonable_encoder(usage), - 'finish_reason': finish_reason - } - outputs = { - 'class_name': category_name + "usage": jsonable_encoder(usage), + "finish_reason": finish_reason, } + outputs = {"class_name": category_name} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -108,8 +109,9 @@ class QuestionClassifierNode(LLMNode): metadata={ NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency - } + NodeRunMetadataKey.CURRENCY: usage.currency, + }, + llm_usage=usage, ) except ValueError as e: @@ -120,21 +122,32 @@ class QuestionClassifierNode(LLMNode): metadata={ NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency - } + NodeRunMetadataKey.CURRENCY: usage.currency, + }, + llm_usage=usage, ) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: - node_data = node_data - node_data = cast(cls._node_data_cls, node_data) - variable_mapping = {'query': node_data.query_variable_selector} + def _extract_variable_selector_to_variable_mapping( + cls, graph_config: Mapping[str, Any], node_id: str, node_data: QuestionClassifierNodeData + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ + variable_mapping = {"query": node_data.query_variable_selector} variable_selectors = [] if node_data.instruction: variable_template_parser = VariableTemplateParser(template=node_data.instruction) variable_selectors.extend(variable_template_parser.extract_variable_selectors()) for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector + + variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} + return variable_mapping @classmethod @@ -144,19 +157,16 @@ class QuestionClassifierNode(LLMNode): :param filters: filter by node config parameters. :return: """ - return { - "type": "question-classifier", - "config": { - "instructions": "" - } - } + return {"type": "question-classifier", "config": {"instructions": ""}} - def _fetch_prompt(self, node_data: QuestionClassifierNodeData, - query: str, - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) \ - -> tuple[list[PromptMessage], Optional[list[str]]]: + def _fetch_prompt( + self, + node_data: QuestionClassifierNodeData, + query: str, + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: """ Fetch prompt :param node_data: node data @@ -172,118 +182,122 @@ class QuestionClassifierNode(LLMNode): prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, - query='', + query="", files=[], context=context, memory_config=node_data.memory, memory=None, - model_config=model_config + model_config=model_config, ) stop = model_config.stop return prompt_messages, stop - def _calculate_rest_token(self, node_data: QuestionClassifierNodeData, query: str, - model_config: ModelConfigWithCredentialsEntity, - context: Optional[str]) -> int: + def _calculate_rest_token( + self, + node_data: QuestionClassifierNodeData, + query: str, + model_config: ModelConfigWithCredentialsEntity, + context: Optional[str], + ) -> int: prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) prompt_template = self._get_prompt_template(node_data, query, None, 2000) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, - query='', + query="", files=[], context=context, memory_config=node_data.memory, memory=None, - model_config=model_config + model_config=model_config, ) rest_tokens = 2000 model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, - model=model_config.model + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) - curr_message_tokens = model_instance.get_llm_num_tokens( - prompt_messages - ) + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) max_tokens = 0 for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template) + ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens rest_tokens = max(rest_tokens, 0) return rest_tokens - def _get_prompt_template(self, node_data: QuestionClassifierNodeData, query: str, - memory: Optional[TokenBufferMemory], - max_token_limit: int = 2000) \ - -> Union[list[ChatModelMessage], CompletionModelPromptTemplate]: + def _get_prompt_template( + self, + node_data: QuestionClassifierNodeData, + query: str, + memory: Optional[TokenBufferMemory], + max_token_limit: int = 2000, + ) -> Union[list[ChatModelMessage], CompletionModelPromptTemplate]: model_mode = ModelMode.value_of(node_data.model.mode) classes = node_data.classes categories = [] for class_ in classes: - category = { - 'category_id': class_.id, - 'category_name': class_.name - } + category = {"category_id": class_.id, "category_name": class_.name} categories.append(category) - instruction = node_data.instruction if node_data.instruction else '' + instruction = node_data.instruction if node_data.instruction else "" input_text = query - memory_str = '' + memory_str = "" if memory: - memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit, - message_limit=node_data.memory.window.size) + memory_str = memory.get_history_prompt_text( + max_token_limit=max_token_limit, message_limit=node_data.memory.window.size + ) prompt_messages = [] if model_mode == ModelMode.CHAT: system_prompt_messages = ChatModelMessage( - role=PromptMessageRole.SYSTEM, - text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str) + role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str) ) prompt_messages.append(system_prompt_messages) user_prompt_message_1 = ChatModelMessage( - role=PromptMessageRole.USER, - text=QUESTION_CLASSIFIER_USER_PROMPT_1 + role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_1 ) prompt_messages.append(user_prompt_message_1) assistant_prompt_message_1 = ChatModelMessage( - role=PromptMessageRole.ASSISTANT, - text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 + role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 ) prompt_messages.append(assistant_prompt_message_1) user_prompt_message_2 = ChatModelMessage( - role=PromptMessageRole.USER, - text=QUESTION_CLASSIFIER_USER_PROMPT_2 + role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2 ) prompt_messages.append(user_prompt_message_2) assistant_prompt_message_2 = ChatModelMessage( - role=PromptMessageRole.ASSISTANT, - text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 + role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 ) prompt_messages.append(assistant_prompt_message_2) user_prompt_message_3 = ChatModelMessage( role=PromptMessageRole.USER, - text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(input_text=input_text, - categories=json.dumps(categories, ensure_ascii=False), - classification_instructions=instruction) + text=QUESTION_CLASSIFIER_USER_PROMPT_3.format( + input_text=input_text, + categories=json.dumps(categories, ensure_ascii=False), + classification_instructions=instruction, + ), ) prompt_messages.append(user_prompt_message_3) return prompt_messages elif model_mode == ModelMode.COMPLETION: return CompletionModelPromptTemplate( - text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(histories=memory_str, - input_text=input_text, - categories=json.dumps(categories), - classification_instructions=instruction, - ensure_ascii=False) + text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format( + histories=memory_str, + input_text=input_text, + categories=json.dumps(categories), + classification_instructions=instruction, + ensure_ascii=False, + ) ) else: @@ -299,14 +313,12 @@ class QuestionClassifierNode(LLMNode): variable = variable_pool.get(variable_selector.value_selector) variable_value = variable.value if variable else None if variable_value is None: - raise ValueError(f'Variable {variable_selector.variable} not found') + raise ValueError(f"Variable {variable_selector.variable} not found") inputs[variable_selector.variable] = variable_value prompt_template = PromptTemplateParser(template=instruction, with_variable_tmpl=True) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - instruction = prompt_template.format( - prompt_inputs - ) + instruction = prompt_template.format(prompt_inputs) return instruction diff --git a/api/core/workflow/nodes/question_classifier/template_prompts.py b/api/core/workflow/nodes/question_classifier/template_prompts.py index e0de148cc2..581f986922 100644 --- a/api/core/workflow/nodes/question_classifier/template_prompts.py +++ b/api/core/workflow/nodes/question_classifier/template_prompts.py @@ -1,5 +1,3 @@ - - QUESTION_CLASSIFIER_SYSTEM_PROMPT = """ ### Job Description', You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories. diff --git a/api/core/workflow/nodes/start/entities.py b/api/core/workflow/nodes/start/entities.py index b81ce15bd7..11d2ebe5dd 100644 --- a/api/core/workflow/nodes/start/entities.py +++ b/api/core/workflow/nodes/start/entities.py @@ -10,4 +10,5 @@ class StartNodeData(BaseNodeData): """ Start Node Data """ + variables: Sequence[VariableEntity] = Field(default_factory=list) diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 54e66bd671..96c887c58d 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,7 +1,8 @@ +from collections.abc import Mapping, Sequence +from typing import Any -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID, VariablePool +from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.start.entities import StartNodeData from models.workflow import WorkflowNodeExecutionStatus @@ -11,28 +12,27 @@ class StartNode(BaseNode): _node_data_cls = StartNodeData _node_type = NodeType.START - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run node - :param variable_pool: variable pool :return: """ - node_inputs = dict(variable_pool.user_inputs) - system_inputs = variable_pool.system_variables + node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) + system_inputs = self.graph_runtime_state.variable_pool.system_variables for var in system_inputs: - node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var] + node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=node_inputs, - outputs=node_inputs - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=node_inputs) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, graph_config: Mapping[str, Any], node_id: str, node_data: StartNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ diff --git a/api/core/workflow/nodes/template_transform/entities.py b/api/core/workflow/nodes/template_transform/entities.py index d9099a8118..e934d69fa3 100644 --- a/api/core/workflow/nodes/template_transform/entities.py +++ b/api/core/workflow/nodes/template_transform/entities.py @@ -1,5 +1,3 @@ - - from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector @@ -8,5 +6,6 @@ class TemplateTransformNodeData(BaseNodeData): """ Code Node Data. """ + variables: list[VariableSelector] - template: str \ No newline at end of file + template: str diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 21f71db6c5..32c99e0d1c 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -1,14 +1,15 @@ import os -from typing import Optional, cast +from collections.abc import Mapping, Sequence +from typing import Any, Optional, cast -from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData from models.workflow import WorkflowNodeExecutionStatus -MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get('TEMPLATE_TRANSFORM_MAX_LENGTH', '80000')) +MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000")) + class TemplateTransformNode(BaseNode): _node_data_cls = TemplateTransformNodeData @@ -23,18 +24,10 @@ class TemplateTransformNode(BaseNode): """ return { "type": "template-transform", - "config": { - "variables": [ - { - "variable": "arg1", - "value_selector": [] - } - ], - "template": "{{ arg1 }}" - } + "config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"}, } - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run node """ @@ -45,44 +38,39 @@ class TemplateTransformNode(BaseNode): variables = {} for variable_selector in node_data.variables: variable_name = variable_selector.variable - value = variable_pool.get_any(variable_selector.value_selector) + value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) variables[variable_name] = value # Run code try: result = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, - code=node_data.template, - inputs=variables + language=CodeLanguage.JINJA2, code=node_data.template, inputs=variables ) - except CodeExecutionException as e: + except CodeExecutionError as e: + return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) + + if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: return NodeRunResult( inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, - error=str(e) - ) - - if len(result['result']) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: - return NodeRunResult( - inputs=variables, - status=WorkflowNodeExecutionStatus.FAILED, - error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters" + error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters", ) return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=variables, - outputs={ - 'output': result['result'] - } + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]} ) - + @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ return { - variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables - } \ No newline at end of file + node_id + "." + variable_selector.variable: variable_selector.value_selector + for variable_selector in node_data.variables + } diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 5da5cd0727..28fbf789fd 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -8,46 +8,47 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData class ToolEntity(BaseModel): provider_id: str - provider_type: Literal['builtin', 'api', 'workflow'] - provider_name: str # redundancy + provider_type: Literal["builtin", "api", "workflow"] + provider_name: str # redundancy tool_name: str - tool_label: str # redundancy + tool_label: str # redundancy tool_configurations: dict[str, Any] - @field_validator('tool_configurations', mode='before') + @field_validator("tool_configurations", mode="before") @classmethod def validate_tool_configurations(cls, value, values: ValidationInfo): if not isinstance(value, dict): - raise ValueError('tool_configurations must be a dictionary') - - for key in values.data.get('tool_configurations', {}).keys(): - value = values.data.get('tool_configurations', {}).get(key) + raise ValueError("tool_configurations must be a dictionary") + + for key in values.data.get("tool_configurations", {}).keys(): + value = values.data.get("tool_configurations", {}).get(key) if not isinstance(value, str | int | float | bool): - raise ValueError(f'{key} must be a string') - + raise ValueError(f"{key} must be a string") + return value + class ToolNodeData(BaseNodeData, ToolEntity): class ToolInput(BaseModel): # TODO: check this type value: Union[Any, list[str]] - type: Literal['mixed', 'variable', 'constant'] + type: Literal["mixed", "variable", "constant"] - @field_validator('type', mode='before') + @field_validator("type", mode="before") @classmethod def check_type(cls, value, validation_info: ValidationInfo): typ = value - value = validation_info.data.get('value') - if typ == 'mixed' and not isinstance(value, str): - raise ValueError('value must be a string') - elif typ == 'variable': + value = validation_info.data.get("value") + if typ == "mixed" and not isinstance(value, str): + raise ValueError("value must be a string") + elif typ == "variable": if not isinstance(value, list): - raise ValueError('value must be a list') + raise ValueError("value must be a list") for val in value: if not isinstance(val, str): - raise ValueError('value must be a list of strings') - elif typ == 'constant' and not isinstance(value, str | int | float | bool): - raise ValueError('value must be a string, int, float, or bool') + raise ValueError("value must be a list of strings") + elif typ == "constant" and not isinstance(value, str | int | float | bool): + raise ValueError("value must be a string, int, float, or bool") return typ """ diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index ccce9ef360..e55adfc1f4 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -26,7 +26,7 @@ class ToolNode(BaseNode): _node_data_cls = ToolNodeData _node_type = NodeType.TOOL - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run the tool node """ @@ -34,10 +34,7 @@ class ToolNode(BaseNode): node_data = cast(ToolNodeData, self.node_data) # fetch tool icon - tool_info = { - 'provider_type': node_data.provider_type, - 'provider_id': node_data.provider_id - } + tool_info = {"provider_type": node_data.provider_type, "provider_id": node_data.provider_id} # get tool runtime try: @@ -48,16 +45,21 @@ class ToolNode(BaseNode): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs={}, - metadata={ - NodeRunMetadataKey.TOOL_INFO: tool_info - }, - error=f'Failed to get tool runtime: {str(e)}' + metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, + error=f"Failed to get tool runtime: {str(e)}", ) # get parameters tool_parameters = tool_runtime.get_runtime_parameters() or [] - parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=variable_pool, node_data=node_data) - parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=variable_pool, node_data=node_data, for_log=True) + parameters = self._generate_parameters( + tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data + ) + parameters_for_log = self._generate_parameters( + tool_parameters=tool_parameters, + variable_pool=self.graph_runtime_state.variable_pool, + node_data=node_data, + for_log=True, + ) try: messages = ToolEngine.workflow_invoke( @@ -66,15 +68,14 @@ class ToolNode(BaseNode): user_id=self.user_id, workflow_tool_callback=DifyWorkflowCallbackHandler(), workflow_call_depth=self.workflow_call_depth, + thread_pool_id=self.thread_pool_id, ) except Exception as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, - metadata={ - NodeRunMetadataKey.TOOL_INFO: tool_info - }, - error=f'Failed to invoke tool: {str(e)}', + metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, + error=f"Failed to invoke tool: {str(e)}", ) # convert tool messages @@ -82,15 +83,9 @@ class ToolNode(BaseNode): return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - 'text': plain_text, - 'files': files, - 'json': json - }, - metadata={ - NodeRunMetadataKey.TOOL_INFO: tool_info - }, - inputs=parameters_for_log + outputs={"text": plain_text, "files": files, "json": json}, + metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, + inputs=parameters_for_log, ) def _generate_parameters( @@ -122,12 +117,10 @@ class ToolNode(BaseNode): result[parameter_name] = None continue if parameter.type == ToolParameter.ToolParameterType.FILE: - result[parameter_name] = [ - v.to_dict() for v in self._fetch_files(variable_pool) - ] + result[parameter_name] = [v.to_dict() for v in self._fetch_files(variable_pool)] else: tool_input = node_data.tool_parameters[parameter_name] - if tool_input.type == 'variable': + if tool_input.type == "variable": # TODO: check if the variable exists in the variable pool parameter_value = variable_pool.get(tool_input.value).value else: @@ -141,11 +134,11 @@ class ToolNode(BaseNode): return result def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]: - variable = variable_pool.get(['sys', SystemVariableKey.FILES.value]) + variable = variable_pool.get(["sys", SystemVariableKey.FILES.value]) assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) return list(variable.value) if variable else [] - def _convert_tool_messages(self, messages: list[ToolInvokeMessage]): + def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar], list[dict]]: """ Convert ToolInvokeMessages into tuple[plain_text, files] """ @@ -170,38 +163,44 @@ class ToolNode(BaseNode): result = [] for response in tool_response: - if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ - response.type == ToolInvokeMessage.MessageType.IMAGE: + if ( + response.type == ToolInvokeMessage.MessageType.IMAGE_LINK + or response.type == ToolInvokeMessage.MessageType.IMAGE + ): url = response.message ext = path.splitext(url)[1] - mimetype = response.meta.get('mime_type', 'image/jpeg') - filename = response.save_as or url.split('/')[-1] - transfer_method = response.meta.get('transfer_method', FileTransferMethod.TOOL_FILE) + mimetype = response.meta.get("mime_type", "image/jpeg") + filename = response.save_as or url.split("/")[-1] + transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) # get tool file id - tool_file_id = url.split('/')[-1].split('.')[0] - result.append(FileVar( - tenant_id=self.tenant_id, - type=FileType.IMAGE, - transfer_method=transfer_method, - url=url, - related_id=tool_file_id, - filename=filename, - extension=ext, - mime_type=mimetype, - )) + tool_file_id = url.split("/")[-1].split(".")[0] + result.append( + FileVar( + tenant_id=self.tenant_id, + type=FileType.IMAGE, + transfer_method=transfer_method, + url=url, + related_id=tool_file_id, + filename=filename, + extension=ext, + mime_type=mimetype, + ) + ) elif response.type == ToolInvokeMessage.MessageType.BLOB: # get tool file id - tool_file_id = response.message.split('/')[-1].split('.')[0] - result.append(FileVar( - tenant_id=self.tenant_id, - type=FileType.IMAGE, - transfer_method=FileTransferMethod.TOOL_FILE, - related_id=tool_file_id, - filename=response.save_as, - extension=path.splitext(response.save_as)[1], - mime_type=response.meta.get('mime_type', 'application/octet-stream'), - )) + tool_file_id = response.message.split("/")[-1].split(".")[0] + result.append( + FileVar( + tenant_id=self.tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id=tool_file_id, + filename=response.save_as, + extension=path.splitext(response.save_as)[1], + mime_type=response.meta.get("mime_type", "application/octet-stream"), + ) + ) elif response.type == ToolInvokeMessage.MessageType.LINK: pass # TODO: @@ -211,32 +210,43 @@ class ToolNode(BaseNode): """ Extract tool response text """ - return '\n'.join([ - f'{message.message}' if message.type == ToolInvokeMessage.MessageType.TEXT else - f'Link: {message.message}' if message.type == ToolInvokeMessage.MessageType.LINK else '' - for message in tool_response - ]) + return "\n".join( + [ + f"{message.message}" + if message.type == ToolInvokeMessage.MessageType.TEXT + else f"Link: {message.message}" + if message.type == ToolInvokeMessage.MessageType.LINK + else "" + for message in tool_response + ] + ) def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]) -> list[dict]: return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON] @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, graph_config: Mapping[str, Any], node_id: str, node_data: ToolNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ result = {} for parameter_name in node_data.tool_parameters: input = node_data.tool_parameters[parameter_name] - if input.type == 'mixed': + if input.type == "mixed": selectors = VariableTemplateParser(input.value).extract_variable_selectors() for selector in selectors: result[selector.variable] = selector.value_selector - elif input.type == 'variable': + elif input.type == "variable": result[parameter_name] = input.value - elif input.type == 'constant': + elif input.type == "constant": pass + result = {node_id + "." + key: value for key, value in result.items()} + return result diff --git a/api/core/workflow/nodes/variable_aggregator/entities.py b/api/core/workflow/nodes/variable_aggregator/entities.py index cea88334b9..eb893a04e3 100644 --- a/api/core/workflow/nodes/variable_aggregator/entities.py +++ b/api/core/workflow/nodes/variable_aggregator/entities.py @@ -1,5 +1,3 @@ - - from typing import Literal, Optional from pydantic import BaseModel @@ -11,23 +9,27 @@ class AdvancedSettings(BaseModel): """ Advanced setting. """ + group_enabled: bool class Group(BaseModel): """ Group. """ - output_type: Literal['string', 'number', 'array', 'object'] + + output_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"] variables: list[list[str]] group_name: str groups: list[Group] + class VariableAssignerNodeData(BaseNodeData): """ Knowledge retrieval Node Data. """ - type: str = 'variable-assigner' + + type: str = "variable-assigner" output_type: str variables: list[list[str]] - advanced_settings: Optional[AdvancedSettings] = None \ No newline at end of file + advanced_settings: Optional[AdvancedSettings] = None diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index 885f7d7617..f03eae257a 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,8 +1,7 @@ -from typing import cast +from collections.abc import Mapping, Sequence +from typing import Any, cast -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData from models.workflow import WorkflowNodeExecutionStatus @@ -12,7 +11,7 @@ class VariableAggregatorNode(BaseNode): _node_data_cls = VariableAssignerNodeData _node_type = NodeType.VARIABLE_AGGREGATOR - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: node_data = cast(VariableAssignerNodeData, self.node_data) # Get variables outputs = {} @@ -20,34 +19,33 @@ class VariableAggregatorNode(BaseNode): if not node_data.advanced_settings or not node_data.advanced_settings.group_enabled: for selector in node_data.variables: - variable = variable_pool.get_any(selector) + variable = self.graph_runtime_state.variable_pool.get_any(selector) if variable is not None: - outputs = { - "output": variable - } + outputs = {"output": variable} - inputs = { - '.'.join(selector[1:]): variable - } + inputs = {".".join(selector[1:]): variable} break else: for group in node_data.advanced_settings.groups: for selector in group.variables: - variable = variable_pool.get_any(selector) + variable = self.graph_runtime_state.variable_pool.get_any(selector) if variable is not None: - outputs[group.group_name] = { - 'output': variable - } - inputs['.'.join(selector[1:])] = variable + outputs[group.group_name] = {"output": variable} + inputs[".".join(selector[1:])] = variable break - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs=outputs, - inputs=inputs - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, inputs=inputs) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, graph_config: Mapping[str, Any], node_id: str, node_data: VariableAssignerNodeData + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ return {} diff --git a/api/core/workflow/nodes/variable_assigner/__init__.py b/api/core/workflow/nodes/variable_assigner/__init__.py index d791d51523..83da4bdc79 100644 --- a/api/core/workflow/nodes/variable_assigner/__init__.py +++ b/api/core/workflow/nodes/variable_assigner/__init__.py @@ -2,7 +2,7 @@ from .node import VariableAssignerNode from .node_data import VariableAssignerData, WriteMode __all__ = [ - 'VariableAssignerNode', - 'VariableAssignerData', - 'WriteMode', + "VariableAssignerNode", + "VariableAssignerData", + "WriteMode", ] diff --git a/api/core/workflow/nodes/variable_assigner/node.py b/api/core/workflow/nodes/variable_assigner/node.py index 8c2adcabb9..3969299795 100644 --- a/api/core/workflow/nodes/variable_assigner/node.py +++ b/api/core/workflow/nodes/variable_assigner/node.py @@ -6,7 +6,6 @@ from sqlalchemy.orm import Session from core.app.segments import SegmentType, Variable, factory from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from extensions.ext_database import db from models import ConversationVariable, WorkflowNodeExecutionStatus @@ -19,49 +18,49 @@ class VariableAssignerNode(BaseNode): _node_data_cls: type[BaseNodeData] = VariableAssignerData _node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: data = cast(VariableAssignerData, self.node_data) # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject - original_variable = variable_pool.get(data.assigned_variable_selector) + original_variable = self.graph_runtime_state.variable_pool.get(data.assigned_variable_selector) if not isinstance(original_variable, Variable): - raise VariableAssignerNodeError('assigned variable not found') + raise VariableAssignerNodeError("assigned variable not found") match data.write_mode: case WriteMode.OVER_WRITE: - income_value = variable_pool.get(data.input_variable_selector) + income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector) if not income_value: - raise VariableAssignerNodeError('input value not found') - updated_variable = original_variable.model_copy(update={'value': income_value.value}) + raise VariableAssignerNodeError("input value not found") + updated_variable = original_variable.model_copy(update={"value": income_value.value}) case WriteMode.APPEND: - income_value = variable_pool.get(data.input_variable_selector) + income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector) if not income_value: - raise VariableAssignerNodeError('input value not found') + raise VariableAssignerNodeError("input value not found") updated_value = original_variable.value + [income_value.value] - updated_variable = original_variable.model_copy(update={'value': updated_value}) + updated_variable = original_variable.model_copy(update={"value": updated_value}) case WriteMode.CLEAR: income_value = get_zero_value(original_variable.value_type) - updated_variable = original_variable.model_copy(update={'value': income_value.to_object()}) + updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) case _: - raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}') + raise VariableAssignerNodeError(f"unsupported write mode: {data.write_mode}") # Over write the variable. - variable_pool.add(data.assigned_variable_selector, updated_variable) + self.graph_runtime_state.variable_pool.add(data.assigned_variable_selector, updated_variable) # TODO: Move database operation to the pipeline. # Update conversation variable. - conversation_id = variable_pool.get(['sys', 'conversation_id']) + conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"]) if not conversation_id: - raise VariableAssignerNodeError('conversation_id not found') + raise VariableAssignerNodeError("conversation_id not found") update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={ - 'value': income_value.to_object(), + "value": income_value.to_object(), }, ) @@ -73,7 +72,7 @@ def update_conversation_variable(conversation_id: str, variable: Variable): with Session(db.engine) as session: row = session.scalar(stmt) if not row: - raise VariableAssignerNodeError('conversation variable not found in the database') + raise VariableAssignerNodeError("conversation variable not found in the database") row.data = variable.model_dump_json() session.commit() @@ -85,8 +84,8 @@ def get_zero_value(t: SegmentType): case SegmentType.OBJECT: return factory.build_segment({}) case SegmentType.STRING: - return factory.build_segment('') + return factory.build_segment("") case SegmentType.NUMBER: return factory.build_segment(0) case _: - raise VariableAssignerNodeError(f'unsupported variable type: {t}') + raise VariableAssignerNodeError(f"unsupported variable type: {t}") diff --git a/api/core/workflow/nodes/variable_assigner/node_data.py b/api/core/workflow/nodes/variable_assigner/node_data.py index b3652b6802..8ac8eadf7c 100644 --- a/api/core/workflow/nodes/variable_assigner/node_data.py +++ b/api/core/workflow/nodes/variable_assigner/node_data.py @@ -6,14 +6,14 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData class WriteMode(str, Enum): - OVER_WRITE = 'over-write' - APPEND = 'append' - CLEAR = 'clear' + OVER_WRITE = "over-write" + APPEND = "append" + CLEAR = "clear" class VariableAssignerData(BaseNodeData): - title: str = 'Variable Assigner' - desc: Optional[str] = 'Assign a value to a variable' + title: str = "Variable Assigner" + desc: Optional[str] = "Assign a value to a variable" assigned_variable_selector: Sequence[str] write_mode: WriteMode input_variable_selector: Sequence[str] diff --git a/api/core/workflow/utils/condition/__init__.py b/api/core/workflow/utils/condition/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/utils/condition/entities.py b/api/core/workflow/utils/condition/entities.py new file mode 100644 index 0000000000..b8e8b881a5 --- /dev/null +++ b/api/core/workflow/utils/condition/entities.py @@ -0,0 +1,32 @@ +from typing import Literal, Optional + +from pydantic import BaseModel + + +class Condition(BaseModel): + """ + Condition entity + """ + + variable_selector: list[str] + comparison_operator: Literal[ + # for string or array + "contains", + "not contains", + "start with", + "end with", + "is", + "is not", + "empty", + "not empty", + # for number + "=", + "≠", + ">", + "<", + "≥", + "≤", + "null", + "not null", + ] + value: Optional[str] = None diff --git a/api/core/workflow/utils/condition/processor.py b/api/core/workflow/utils/condition/processor.py new file mode 100644 index 0000000000..395ee82478 --- /dev/null +++ b/api/core/workflow/utils/condition/processor.py @@ -0,0 +1,381 @@ +from collections.abc import Sequence +from typing import Any, Optional + +from core.file.file_obj import FileVar +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.utils.condition.entities import Condition +from core.workflow.utils.variable_template_parser import VariableTemplateParser + + +class ConditionProcessor: + def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]): + input_conditions = [] + group_result = [] + + index = 0 + for condition in conditions: + index += 1 + actual_value = variable_pool.get_any(condition.variable_selector) + + expected_value = None + if condition.value is not None: + variable_template_parser = VariableTemplateParser(template=condition.value) + variable_selectors = variable_template_parser.extract_variable_selectors() + if variable_selectors: + for variable_selector in variable_selectors: + value = variable_pool.get_any(variable_selector.value_selector) + expected_value = variable_template_parser.format({variable_selector.variable: value}) + + if expected_value is None: + expected_value = condition.value + else: + expected_value = condition.value + + comparison_operator = condition.comparison_operator + input_conditions.append( + { + "actual_value": actual_value, + "expected_value": expected_value, + "comparison_operator": comparison_operator, + } + ) + + result = self.evaluate_condition(actual_value, comparison_operator, expected_value) + group_result.append(result) + + return input_conditions, group_result + + def evaluate_condition( + self, + actual_value: Optional[str | int | float | dict[Any, Any] | list[Any] | FileVar | None], + comparison_operator: str, + expected_value: Optional[str] = None, + ) -> bool: + """ + Evaluate condition + :param actual_value: actual value + :param expected_value: expected value + :param comparison_operator: comparison operator + + :return: bool + """ + if comparison_operator == "contains": + return self._assert_contains(actual_value, expected_value) + elif comparison_operator == "not contains": + return self._assert_not_contains(actual_value, expected_value) + elif comparison_operator == "start with": + return self._assert_start_with(actual_value, expected_value) + elif comparison_operator == "end with": + return self._assert_end_with(actual_value, expected_value) + elif comparison_operator == "is": + return self._assert_is(actual_value, expected_value) + elif comparison_operator == "is not": + return self._assert_is_not(actual_value, expected_value) + elif comparison_operator == "empty": + return self._assert_empty(actual_value) + elif comparison_operator == "not empty": + return self._assert_not_empty(actual_value) + elif comparison_operator == "=": + return self._assert_equal(actual_value, expected_value) + elif comparison_operator == "≠": + return self._assert_not_equal(actual_value, expected_value) + elif comparison_operator == ">": + return self._assert_greater_than(actual_value, expected_value) + elif comparison_operator == "<": + return self._assert_less_than(actual_value, expected_value) + elif comparison_operator == "≥": + return self._assert_greater_than_or_equal(actual_value, expected_value) + elif comparison_operator == "≤": + return self._assert_less_than_or_equal(actual_value, expected_value) + elif comparison_operator == "null": + return self._assert_null(actual_value) + elif comparison_operator == "not null": + return self._assert_not_null(actual_value) + else: + raise ValueError(f"Invalid comparison operator: {comparison_operator}") + + def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: + """ + Assert contains + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return False + + if not isinstance(actual_value, str | list): + raise ValueError("Invalid actual value type: string or array") + + if expected_value not in actual_value: + return False + return True + + def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: + """ + Assert not contains + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return True + + if not isinstance(actual_value, str | list): + raise ValueError("Invalid actual value type: string or array") + + if expected_value in actual_value: + return False + return True + + def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert start with + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return False + + if not isinstance(actual_value, str): + raise ValueError("Invalid actual value type: string") + + if not actual_value.startswith(expected_value): + return False + return True + + def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert end with + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return False + + if not isinstance(actual_value, str): + raise ValueError("Invalid actual value type: string") + + if not actual_value.endswith(expected_value): + return False + return True + + def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert is + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, str): + raise ValueError("Invalid actual value type: string") + + if actual_value != expected_value: + return False + return True + + def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert is not + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, str): + raise ValueError("Invalid actual value type: string") + + if actual_value == expected_value: + return False + return True + + def _assert_empty(self, actual_value: Optional[str]) -> bool: + """ + Assert empty + :param actual_value: actual value + :return: + """ + if not actual_value: + return True + return False + + def _assert_not_empty(self, actual_value: Optional[str]) -> bool: + """ + Assert not empty + :param actual_value: actual value + :return: + """ + if actual_value: + return True + return False + + def _assert_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: + """ + Assert equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError("Invalid actual value type: number") + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value != expected_value: + return False + return True + + def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: + """ + Assert not equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError("Invalid actual value type: number") + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value == expected_value: + return False + return True + + def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: + """ + Assert greater than + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError("Invalid actual value type: number") + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value <= expected_value: + return False + return True + + def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: + """ + Assert less than + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError("Invalid actual value type: number") + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value >= expected_value: + return False + return True + + def _assert_greater_than_or_equal( + self, actual_value: Optional[int | float], expected_value: str | int | float + ) -> bool: + """ + Assert greater than or equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError("Invalid actual value type: number") + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value < expected_value: + return False + return True + + def _assert_less_than_or_equal( + self, actual_value: Optional[int | float], expected_value: str | int | float + ) -> bool: + """ + Assert less than or equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError("Invalid actual value type: number") + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value > expected_value: + return False + return True + + def _assert_null(self, actual_value: Optional[int | float]) -> bool: + """ + Assert null + :param actual_value: actual value + :return: + """ + if actual_value is None: + return True + return False + + def _assert_not_null(self, actual_value: Optional[int | float]) -> bool: + """ + Assert not null + :param actual_value: actual value + :return: + """ + if actual_value is not None: + return True + return False + + +class ConditionAssertionError(Exception): + def __init__(self, message: str, conditions: list[dict], sub_condition_compare_results: list[dict]) -> None: + self.message = message + self.conditions = conditions + self.sub_condition_compare_results = sub_condition_compare_results + super().__init__(self.message) diff --git a/api/core/workflow/utils/variable_template_parser.py b/api/core/workflow/utils/variable_template_parser.py index c43fde172c..fd0e48b862 100644 --- a/api/core/workflow/utils/variable_template_parser.py +++ b/api/core/workflow/utils/variable_template_parser.py @@ -5,7 +5,7 @@ from typing import Any from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_pool import VariablePool -REGEX = re.compile(r'\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}') +REGEX = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}") def parse_mixed_template(*, template: str, variable_pool: VariablePool) -> str: @@ -20,7 +20,7 @@ def parse_mixed_template(*, template: str, variable_pool: VariablePool) -> str: # e.g. ('#node_id.query.name#', ['node_id', 'query', 'name']) key_selectors = filter( lambda t: len(t[1]) >= 2, - ((key, selector.replace('#', '').split('.')) for key, selector in zip(variable_keys, variable_keys)), + ((key, selector.replace("#", "").split(".")) for key, selector in zip(variable_keys, variable_keys)), ) inputs = {key: variable_pool.get_any(selector) for key, selector in key_selectors} @@ -29,13 +29,13 @@ def parse_mixed_template(*, template: str, variable_pool: VariablePool) -> str: # return original matched string if key not found value = inputs.get(key, match.group(0)) if value is None: - value = '' + value = "" value = str(value) # remove template variables if required - return re.sub(REGEX, r'{\1}', value) + return re.sub(REGEX, r"{\1}", value) result = re.sub(REGEX, replacer, template) - result = re.sub(r'<\|.*?\|>', '', result) + result = re.sub(r"<\|.*?\|>", "", result) return result @@ -101,8 +101,8 @@ class VariableTemplateParser: """ variable_selectors = [] for variable_key in self.variable_keys: - remove_hash = variable_key.replace('#', '') - split_result = remove_hash.split('.') + remove_hash = variable_key.replace("#", "") + split_result = remove_hash.split(".") if len(split_result) < 2: continue @@ -127,7 +127,7 @@ class VariableTemplateParser: value = inputs.get(key, match.group(0)) # return original matched string if key not found if value is None: - value = '' + value = "" # convert the value to string if isinstance(value, list | dict | bool | int | float): value = str(value) @@ -136,7 +136,7 @@ class VariableTemplateParser: return VariableTemplateParser.remove_template_variables(value) prompt = re.sub(REGEX, replacer, self.template) - return re.sub(r'<\|.*?\|>', '', prompt) + return re.sub(r"<\|.*?\|>", "", prompt) @classmethod def remove_template_variables(cls, text: str): @@ -149,4 +149,4 @@ class VariableTemplateParser: Returns: The text with template variables removed. """ - return re.sub(REGEX, r'{\1}', text) + return re.sub(REGEX, r"{\1}", text) diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 3157eedfee..e69de29bb2 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -1,1005 +0,0 @@ -import logging -import time -from collections.abc import Mapping, Sequence -from typing import Any, Optional, cast - -import contexts -from configs import dify_config -from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException -from core.app.entities.app_invoke_entities import InvokeFrom -from core.file.file_obj import FileExtraConfig, FileTransferMethod, FileType, FileVar -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool, VariableValue -from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState -from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom -from core.workflow.nodes.code.code_node import CodeNode -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.http_request.http_request_node import HttpRequestNode -from core.workflow.nodes.if_else.if_else_node import IfElseNode -from core.workflow.nodes.iteration.entities import IterationState -from core.workflow.nodes.iteration.iteration_node import IterationNode -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode -from core.workflow.nodes.llm.entities import LLMNodeData -from core.workflow.nodes.llm.llm_node import LLMNode -from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode -from core.workflow.nodes.tool.tool_node import ToolNode -from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode -from core.workflow.nodes.variable_assigner import VariableAssignerNode -from extensions.ext_database import db -from models.workflow import ( - Workflow, - WorkflowNodeExecutionStatus, -) - -node_classes: Mapping[NodeType, type[BaseNode]] = { - NodeType.START: StartNode, - NodeType.END: EndNode, - NodeType.ANSWER: AnswerNode, - NodeType.LLM: LLMNode, - NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, - NodeType.IF_ELSE: IfElseNode, - NodeType.CODE: CodeNode, - NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode, - NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode, - NodeType.HTTP_REQUEST: HttpRequestNode, - NodeType.TOOL: ToolNode, - NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode, - NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, - NodeType.ITERATION: IterationNode, - NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode, - NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode, -} - -logger = logging.getLogger(__name__) - - -class WorkflowEngineManager: - def get_default_configs(self) -> list[dict]: - """ - Get default block configs - """ - default_block_configs = [] - for node_type, node_class in node_classes.items(): - default_config = node_class.get_default_config() - if default_config: - default_block_configs.append(default_config) - - return default_block_configs - - def get_default_config(self, node_type: NodeType, filters: Optional[dict] = None) -> Optional[dict]: - """ - Get default config of node. - :param node_type: node type - :param filters: filter by node config parameters. - :return: - """ - node_class = node_classes.get(node_type) - if not node_class: - return None - - default_config = node_class.get_default_config(filters=filters) - if not default_config: - return None - - return default_config - - def run_workflow( - self, - *, - workflow: Workflow, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, - callbacks: Sequence[WorkflowCallback], - call_depth: int = 0, - variable_pool: VariablePool | None = None, - ) -> None: - """ - :param workflow: Workflow instance - :param user_id: user id - :param user_from: user from - :param invoke_from: invoke from - :param callbacks: workflow callbacks - :param call_depth: call depth - :param variable_pool: variable pool - """ - # fetch workflow graph - graph = workflow.graph_dict - if not graph: - raise ValueError('workflow graph not found') - - if 'nodes' not in graph or 'edges' not in graph: - raise ValueError('nodes or edges not found in workflow graph') - - if not isinstance(graph.get('nodes'), list): - raise ValueError('nodes in workflow graph must be a list') - - if not isinstance(graph.get('edges'), list): - raise ValueError('edges in workflow graph must be a list') - - - workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH - if call_depth > workflow_call_max_depth: - raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) - - # init workflow run state - if not variable_pool: - variable_pool = contexts.workflow_variable_pool.get() - workflow_run_state = WorkflowRunState( - workflow=workflow, - start_at=time.perf_counter(), - variable_pool=variable_pool, - user_id=user_id, - user_from=user_from, - invoke_from=invoke_from, - workflow_call_depth=call_depth - ) - - # init workflow run - if callbacks: - for callback in callbacks: - callback.on_workflow_run_started() - - # run workflow - self._run_workflow( - workflow=workflow, - workflow_run_state=workflow_run_state, - callbacks=callbacks, - ) - - def _run_workflow(self, workflow: Workflow, - workflow_run_state: WorkflowRunState, - callbacks: Sequence[WorkflowCallback], - start_at: Optional[str] = None, - end_at: Optional[str] = None) -> None: - """ - Run workflow - :param workflow: Workflow instance - :param user_id: user id - :param user_from: user from - :param user_inputs: user variables inputs - :param system_inputs: system inputs, like: query, files - :param callbacks: workflow callbacks - :param call_depth: call depth - :param start_at: force specific start node - :param end_at: force specific end node - :return: - """ - graph = workflow.graph_dict - - try: - answer_prov_node_ids = [] - for node in graph.get('nodes', []): - if node.get('id', '') == 'answer': - try: - answer_prov_node_ids.append(node.get('data', {}) - .get('answer', '') - .replace('#', '') - .replace('.text', '') - .replace('{{', '') - .replace('}}', '').split('.')[0]) - except Exception as e: - logger.error(e) - - predecessor_node: BaseNode | None = None - current_iteration_node: BaseIterationNode | None = None - has_entry_node = False - max_execution_steps = dify_config.WORKFLOW_MAX_EXECUTION_STEPS - max_execution_time = dify_config.WORKFLOW_MAX_EXECUTION_TIME - while True: - # get next node, multiple target nodes in the future - next_node = self._get_next_overall_node( - workflow_run_state=workflow_run_state, - graph=graph, - predecessor_node=predecessor_node, - callbacks=callbacks, - start_at=start_at, - end_at=end_at - ) - - if not next_node: - # reached loop/iteration end or overall end - if current_iteration_node and workflow_run_state.current_iteration_state: - # reached loop/iteration end - # get next iteration - next_iteration = current_iteration_node.get_next_iteration( - variable_pool=workflow_run_state.variable_pool, - state=workflow_run_state.current_iteration_state - ) - self._workflow_iteration_next( - graph=graph, - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - callbacks=callbacks - ) - if isinstance(next_iteration, NodeRunResult): - if next_iteration.outputs: - for variable_key, variable_value in next_iteration.outputs.items(): - # append variables to variable pool recursively - self._append_variables_recursively( - variable_pool=workflow_run_state.variable_pool, - node_id=current_iteration_node.node_id, - variable_key_list=[variable_key], - variable_value=variable_value - ) - self._workflow_iteration_completed( - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - callbacks=callbacks - ) - # iteration has ended - next_node = self._get_next_overall_node( - workflow_run_state=workflow_run_state, - graph=graph, - predecessor_node=current_iteration_node, - callbacks=callbacks, - start_at=start_at, - end_at=end_at - ) - current_iteration_node = None - workflow_run_state.current_iteration_state = None - # continue overall process - elif isinstance(next_iteration, str): - # move to next iteration - next_node_id = next_iteration - # get next id - next_node = self._get_node(workflow_run_state=workflow_run_state, graph=graph, node_id=next_node_id, callbacks=callbacks) - - if not next_node: - break - - # check is already ran - if self._check_node_has_ran(workflow_run_state, next_node.node_id): - predecessor_node = next_node - continue - - has_entry_node = True - - # max steps reached - if workflow_run_state.workflow_node_steps > max_execution_steps: - raise ValueError('Max steps {} reached.'.format(max_execution_steps)) - - # or max execution time reached - if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=max_execution_time): - raise ValueError('Max execution time {}s reached.'.format(max_execution_time)) - - # handle iteration nodes - if isinstance(next_node, BaseIterationNode): - current_iteration_node = next_node - workflow_run_state.current_iteration_state = next_node.run( - variable_pool=workflow_run_state.variable_pool - ) - self._workflow_iteration_started( - graph=graph, - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - predecessor_node_id=predecessor_node.node_id if predecessor_node else None, - callbacks=callbacks - ) - predecessor_node = next_node - # move to start node of iteration - next_node_id = next_node.get_next_iteration( - variable_pool=workflow_run_state.variable_pool, - state=workflow_run_state.current_iteration_state - ) - self._workflow_iteration_next( - graph=graph, - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - callbacks=callbacks - ) - if isinstance(next_node_id, NodeRunResult): - # iteration has ended - current_iteration_node.set_output( - variable_pool=workflow_run_state.variable_pool, - state=workflow_run_state.current_iteration_state - ) - self._workflow_iteration_completed( - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - callbacks=callbacks - ) - current_iteration_node = None - workflow_run_state.current_iteration_state = None - continue - else: - next_node = self._get_node(workflow_run_state=workflow_run_state, graph=graph, node_id=next_node_id, callbacks=callbacks) - - if next_node and next_node.node_id in answer_prov_node_ids: - next_node.is_answer_previous_node = True - - # run workflow, run multiple target nodes in the future - self._run_workflow_node( - workflow_run_state=workflow_run_state, - node=next_node, - predecessor_node=predecessor_node, - callbacks=callbacks - ) - - if next_node.node_type in [NodeType.END]: - break - - predecessor_node = next_node - - if not has_entry_node: - self._workflow_run_failed( - error='Start node not found in workflow graph.', - callbacks=callbacks - ) - return - except GenerateTaskStoppedException as e: - return - except Exception as e: - self._workflow_run_failed( - error=str(e), - callbacks=callbacks - ) - return - - # workflow run success - self._workflow_run_success( - callbacks=callbacks - ) - - def single_step_run_workflow_node(self, workflow: Workflow, - node_id: str, - user_id: str, - user_inputs: dict) -> tuple[BaseNode, NodeRunResult]: - """ - Single step run workflow node - :param workflow: Workflow instance - :param node_id: node id - :param user_id: user id - :param user_inputs: user inputs - :return: - """ - # fetch node info from workflow graph - graph = workflow.graph_dict - if not graph: - raise ValueError('workflow graph not found') - - nodes = graph.get('nodes') - if not nodes: - raise ValueError('nodes not found in workflow graph') - - # fetch node config from node id - node_config = None - for node in nodes: - if node.get('id') == node_id: - node_config = node - break - - if not node_config: - raise ValueError('node id not found in workflow graph') - - # Get node class - node_type = NodeType.value_of(node_config.get('data', {}).get('type')) - node_cls = node_classes.get(node_type) - - # init workflow run state - node_instance = node_cls( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - workflow_id=workflow.id, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - config=node_config, - workflow_call_depth=0 - ) - - try: - # init variable pool - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - environment_variables=workflow.environment_variables, - conversation_variables=workflow.conversation_variables, - ) - - if node_cls is None: - raise ValueError('Node class not found') - # variable selector to variable mapping - variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(node_config) - - self._mapping_user_inputs_to_variable_pool( - variable_mapping=variable_mapping, - user_inputs=user_inputs, - variable_pool=variable_pool, - tenant_id=workflow.tenant_id, - node_instance=node_instance - ) - - # run node - node_run_result = node_instance.run( - variable_pool=variable_pool - ) - - # sign output files - node_run_result.outputs = self.handle_special_values(node_run_result.outputs) - except Exception as e: - raise WorkflowNodeRunFailedError( - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_title=node_instance.node_data.title, - error=str(e) - ) - - return node_instance, node_run_result - - def single_step_run_iteration_workflow_node(self, workflow: Workflow, - node_id: str, - user_id: str, - user_inputs: dict, - callbacks: Sequence[WorkflowCallback], - ) -> None: - """ - Single iteration run workflow node - """ - # fetch node info from workflow graph - graph = workflow.graph_dict - if not graph: - raise ValueError('workflow graph not found') - - nodes = graph.get('nodes') - if not nodes: - raise ValueError('nodes not found in workflow graph') - - for node in nodes: - if node.get('id') == node_id: - if node.get('data', {}).get('type') in [ - NodeType.ITERATION.value, - NodeType.LOOP.value, - ]: - node_config = node - else: - raise ValueError('node id is not an iteration node') - - # init variable pool - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - environment_variables=workflow.environment_variables, - conversation_variables=workflow.conversation_variables, - ) - - # variable selector to variable mapping - iteration_nested_nodes = [ - node for node in nodes - if node.get('data', {}).get('iteration_id') == node_id or node.get('id') == node_id - ] - iteration_nested_node_ids = [node.get('id') for node in iteration_nested_nodes] - - if not iteration_nested_nodes: - raise ValueError('iteration has no nested nodes') - - # init workflow run - if callbacks: - for callback in callbacks: - callback.on_workflow_run_started() - - for node_config in iteration_nested_nodes: - # mapping user inputs to variable pool - node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type'))) - if node_cls is None: - raise ValueError('Node class not found') - variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(node_config) - - # remove iteration variables - variable_mapping = { - f'{node_config.get("id")}.{key}': value for key, value in variable_mapping.items() - if value[0] != node_id - } - - # remove variable out from iteration - variable_mapping = { - key: value for key, value in variable_mapping.items() - if value[0] not in iteration_nested_node_ids - } - - # append variables to variable pool - node_instance = node_cls( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - workflow_id=workflow.id, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - config=node_config, - callbacks=callbacks, - workflow_call_depth=0 - ) - - self._mapping_user_inputs_to_variable_pool( - variable_mapping=variable_mapping, - user_inputs=user_inputs, - variable_pool=variable_pool, - tenant_id=workflow.tenant_id, - node_instance=node_instance - ) - - # fetch end node of iteration - end_node_id = None - for edge in graph.get('edges'): - if edge.get('source') == node_id: - end_node_id = edge.get('target') - break - - if not end_node_id: - raise ValueError('end node of iteration not found') - - # init workflow run state - workflow_run_state = WorkflowRunState( - workflow=workflow, - start_at=time.perf_counter(), - variable_pool=variable_pool, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - workflow_call_depth=0 - ) - - # run workflow - self._run_workflow( - workflow=workflow, - workflow_run_state=workflow_run_state, - callbacks=callbacks, - start_at=node_id, - end_at=end_node_id - ) - - def _workflow_run_success(self, callbacks: Sequence[WorkflowCallback]) -> None: - """ - Workflow run success - :param callbacks: workflow callbacks - :return: - """ - - if callbacks: - for callback in callbacks: - callback.on_workflow_run_succeeded() - - def _workflow_run_failed(self, error: str, - callbacks: Sequence[WorkflowCallback]) -> None: - """ - Workflow run failed - :param error: error message - :param callbacks: workflow callbacks - :return: - """ - if callbacks: - for callback in callbacks: - callback.on_workflow_run_failed( - error=error - ) - - def _workflow_iteration_started(self, *, graph: Mapping[str, Any], - current_iteration_node: BaseIterationNode, - workflow_run_state: WorkflowRunState, - predecessor_node_id: Optional[str] = None, - callbacks: Sequence[WorkflowCallback]) -> None: - """ - Workflow iteration started - :param current_iteration_node: current iteration node - :param workflow_run_state: workflow run state - :param callbacks: workflow callbacks - :return: - """ - # get nested nodes - iteration_nested_nodes = [ - node for node in graph.get('nodes') - if node.get('data', {}).get('iteration_id') == current_iteration_node.node_id - ] - - if not iteration_nested_nodes: - raise ValueError('iteration has no nested nodes') - - if callbacks: - if isinstance(workflow_run_state.current_iteration_state, IterationState): - for callback in callbacks: - callback.on_workflow_iteration_started( - node_id=current_iteration_node.node_id, - node_type=NodeType.ITERATION, - node_run_index=workflow_run_state.workflow_node_steps, - node_data=current_iteration_node.node_data, - inputs=workflow_run_state.current_iteration_state.inputs, - predecessor_node_id=predecessor_node_id, - metadata=workflow_run_state.current_iteration_state.metadata.model_dump() - ) - - # add steps - workflow_run_state.workflow_node_steps += 1 - - def _workflow_iteration_next(self, *, graph: Mapping[str, Any], - current_iteration_node: BaseIterationNode, - workflow_run_state: WorkflowRunState, - callbacks: Sequence[WorkflowCallback]) -> None: - """ - Workflow iteration next - :param workflow_run_state: workflow run state - :return: - """ - if callbacks: - if isinstance(workflow_run_state.current_iteration_state, IterationState): - for callback in callbacks: - callback.on_workflow_iteration_next( - node_id=current_iteration_node.node_id, - node_type=NodeType.ITERATION, - index=workflow_run_state.current_iteration_state.index, - node_run_index=workflow_run_state.workflow_node_steps, - output=workflow_run_state.current_iteration_state.get_current_output() - ) - # clear ran nodes - workflow_run_state.workflow_node_runs = [ - node_run for node_run in workflow_run_state.workflow_node_runs - if node_run.iteration_node_id != current_iteration_node.node_id - ] - - # clear variables in current iteration - nodes = graph.get('nodes') - nodes = [node for node in nodes if node.get('data', {}).get('iteration_id') == current_iteration_node.node_id] - - for node in nodes: - workflow_run_state.variable_pool.remove((node.get('id'),)) - - def _workflow_iteration_completed(self, *, current_iteration_node: BaseIterationNode, - workflow_run_state: WorkflowRunState, - callbacks: Sequence[WorkflowCallback]) -> None: - if callbacks: - if isinstance(workflow_run_state.current_iteration_state, IterationState): - for callback in callbacks: - callback.on_workflow_iteration_completed( - node_id=current_iteration_node.node_id, - node_type=NodeType.ITERATION, - node_run_index=workflow_run_state.workflow_node_steps, - outputs={ - 'output': workflow_run_state.current_iteration_state.outputs - } - ) - - def _get_next_overall_node(self, *, workflow_run_state: WorkflowRunState, - graph: Mapping[str, Any], - predecessor_node: Optional[BaseNode] = None, - callbacks: Sequence[WorkflowCallback], - start_at: Optional[str] = None, - end_at: Optional[str] = None) -> Optional[BaseNode]: - """ - Get next node - multiple target nodes in the future. - :param graph: workflow graph - :param predecessor_node: predecessor node - :param callbacks: workflow callbacks - :return: - """ - nodes = graph.get('nodes') - if not nodes: - return None - - if not predecessor_node: - for node_config in nodes: - node_cls = None - if start_at: - if node_config.get('id') == start_at: - node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type'))) - else: - if node_config.get('data', {}).get('type', '') == NodeType.START.value: - node_cls = StartNode - if node_cls: - return node_cls( - tenant_id=workflow_run_state.tenant_id, - app_id=workflow_run_state.app_id, - workflow_id=workflow_run_state.workflow_id, - user_id=workflow_run_state.user_id, - user_from=workflow_run_state.user_from, - invoke_from=workflow_run_state.invoke_from, - config=node_config, - callbacks=callbacks, - workflow_call_depth=workflow_run_state.workflow_call_depth - ) - - else: - edges = graph.get('edges') - source_node_id = predecessor_node.node_id - - # fetch all outgoing edges from source node - outgoing_edges = [edge for edge in edges if edge.get('source') == source_node_id] - if not outgoing_edges: - return None - - # fetch target node id from outgoing edges - outgoing_edge = None - source_handle = predecessor_node.node_run_result.edge_source_handle \ - if predecessor_node.node_run_result else None - if source_handle: - for edge in outgoing_edges: - if edge.get('sourceHandle') and edge.get('sourceHandle') == source_handle: - outgoing_edge = edge - break - else: - outgoing_edge = outgoing_edges[0] - - if not outgoing_edge: - return None - - target_node_id = outgoing_edge.get('target') - - if end_at and target_node_id == end_at: - return None - - # fetch target node from target node id - target_node_config = None - for node in nodes: - if node.get('id') == target_node_id: - target_node_config = node - break - - if not target_node_config: - return None - - # get next node - target_node = node_classes.get(NodeType.value_of(target_node_config.get('data', {}).get('type'))) - - return target_node( - tenant_id=workflow_run_state.tenant_id, - app_id=workflow_run_state.app_id, - workflow_id=workflow_run_state.workflow_id, - user_id=workflow_run_state.user_id, - user_from=workflow_run_state.user_from, - invoke_from=workflow_run_state.invoke_from, - config=target_node_config, - callbacks=callbacks, - workflow_call_depth=workflow_run_state.workflow_call_depth - ) - - def _get_node(self, workflow_run_state: WorkflowRunState, - graph: Mapping[str, Any], - node_id: str, - callbacks: Sequence[WorkflowCallback]): - """ - Get node from graph by node id - """ - nodes = graph.get('nodes') - if not nodes: - return None - - for node_config in nodes: - if node_config.get('id') == node_id: - node_type = NodeType.value_of(node_config.get('data', {}).get('type')) - node_cls = node_classes[node_type] - return node_cls( - tenant_id=workflow_run_state.tenant_id, - app_id=workflow_run_state.app_id, - workflow_id=workflow_run_state.workflow_id, - user_id=workflow_run_state.user_id, - user_from=workflow_run_state.user_from, - invoke_from=workflow_run_state.invoke_from, - config=node_config, - callbacks=callbacks, - workflow_call_depth=workflow_run_state.workflow_call_depth - ) - - def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: - """ - Check timeout - :param start_at: start time - :param max_execution_time: max execution time - :return: - """ - return time.perf_counter() - start_at > max_execution_time - - def _check_node_has_ran(self, workflow_run_state: WorkflowRunState, node_id: str) -> bool: - """ - Check node has ran - """ - return bool([ - node_and_result for node_and_result in workflow_run_state.workflow_node_runs - if node_and_result.node_id == node_id - ]) - - def _run_workflow_node(self, *, workflow_run_state: WorkflowRunState, - node: BaseNode, - predecessor_node: Optional[BaseNode] = None, - callbacks: Sequence[WorkflowCallback]) -> None: - if callbacks: - for callback in callbacks: - callback.on_workflow_node_execute_started( - node_id=node.node_id, - node_type=node.node_type, - node_data=node.node_data, - node_run_index=workflow_run_state.workflow_node_steps, - predecessor_node_id=predecessor_node.node_id if predecessor_node else None - ) - - db.session.close() - - workflow_nodes_and_result = WorkflowNodeAndResult( - node=node, - result=None - ) - - # add to workflow_nodes_and_results - workflow_run_state.workflow_nodes_and_results.append(workflow_nodes_and_result) - - # add steps - workflow_run_state.workflow_node_steps += 1 - - # mark node as running - if workflow_run_state.current_iteration_state: - workflow_run_state.workflow_node_runs.append(WorkflowRunState.NodeRun( - node_id=node.node_id, - iteration_node_id=workflow_run_state.current_iteration_state.iteration_node_id - )) - - try: - # run node, result must have inputs, process_data, outputs, execution_metadata - node_run_result = node.run( - variable_pool=workflow_run_state.variable_pool - ) - except GenerateTaskStoppedException as e: - node_run_result = NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error='Workflow stopped.' - ) - except Exception as e: - logger.exception(f"Node {node.node_data.title} run failed: {str(e)}") - node_run_result = NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e) - ) - - if node_run_result.status == WorkflowNodeExecutionStatus.FAILED: - # node run failed - if callbacks: - for callback in callbacks: - callback.on_workflow_node_execute_failed( - node_id=node.node_id, - node_type=node.node_type, - node_data=node.node_data, - error=node_run_result.error, - inputs=node_run_result.inputs, - outputs=node_run_result.outputs, - process_data=node_run_result.process_data, - ) - - raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}") - - if node.is_answer_previous_node and not isinstance(node, LLMNode): - if not node_run_result.metadata: - node_run_result.metadata = {} - node_run_result.metadata["is_answer_previous_node"]=True - workflow_nodes_and_result.result = node_run_result - - # node run success - if callbacks: - for callback in callbacks: - callback.on_workflow_node_execute_succeeded( - node_id=node.node_id, - node_type=node.node_type, - node_data=node.node_data, - inputs=node_run_result.inputs, - process_data=node_run_result.process_data, - outputs=node_run_result.outputs, - execution_metadata=node_run_result.metadata - ) - - if node_run_result.outputs: - for variable_key, variable_value in node_run_result.outputs.items(): - # append variables to variable pool recursively - self._append_variables_recursively( - variable_pool=workflow_run_state.variable_pool, - node_id=node.node_id, - variable_key_list=[variable_key], - variable_value=variable_value - ) - - if node_run_result.metadata and node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): - workflow_run_state.total_tokens += int(node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)) - - db.session.close() - - def _append_variables_recursively(self, variable_pool: VariablePool, - node_id: str, - variable_key_list: list[str], - variable_value: VariableValue): - """ - Append variables recursively - :param variable_pool: variable pool - :param node_id: node id - :param variable_key_list: variable key list - :param variable_value: variable value - :return: - """ - variable_pool.add( - [node_id] + variable_key_list, variable_value - ) - - # if variable_value is a dict, then recursively append variables - if isinstance(variable_value, dict): - for key, value in variable_value.items(): - # construct new key list - new_key_list = variable_key_list + [key] - self._append_variables_recursively( - variable_pool=variable_pool, - node_id=node_id, - variable_key_list=new_key_list, - variable_value=value - ) - - @classmethod - def handle_special_values(cls, value: Optional[dict]) -> Optional[dict]: - """ - Handle special values - :param value: value - :return: - """ - if not value: - return None - - new_value = value.copy() - if isinstance(new_value, dict): - for key, val in new_value.items(): - if isinstance(val, FileVar): - new_value[key] = val.to_dict() - elif isinstance(val, list): - new_val = [] - for v in val: - if isinstance(v, FileVar): - new_val.append(v.to_dict()) - else: - new_val.append(v) - - new_value[key] = new_val - - return new_value - - def _mapping_user_inputs_to_variable_pool(self, - variable_mapping: Mapping[str, Sequence[str]], - user_inputs: dict, - variable_pool: VariablePool, - tenant_id: str, - node_instance: BaseNode): - for variable_key, variable_selector in variable_mapping.items(): - if variable_key not in user_inputs and not variable_pool.get(variable_selector): - raise ValueError(f'Variable key {variable_key} not found in user inputs.') - - # fetch variable node id from variable selector - variable_node_id = variable_selector[0] - variable_key_list = variable_selector[1:] - - # get value - value = user_inputs.get(variable_key) - - # FIXME: temp fix for image type - if node_instance.node_type == NodeType.LLM: - new_value = [] - if isinstance(value, list): - node_data = node_instance.node_data - node_data = cast(LLMNodeData, node_data) - - detail = node_data.vision.configs.detail if node_data.vision.configs else None - - for item in value: - if isinstance(item, dict) and 'type' in item and item['type'] == 'image': - transfer_method = FileTransferMethod.value_of(item.get('transfer_method')) - file = FileVar( - tenant_id=tenant_id, - type=FileType.IMAGE, - transfer_method=transfer_method, - url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, - related_id=item.get( - 'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, - extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None), - ) - new_value.append(file) - - if new_value: - value = new_value - - # append variable and value to variable pool - variable_pool.add([variable_node_id]+variable_key_list, value) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py new file mode 100644 index 0000000000..74a598ada5 --- /dev/null +++ b/api/core/workflow/workflow_entry.py @@ -0,0 +1,295 @@ +import logging +import time +import uuid +from collections.abc import Generator, Mapping, Sequence +from typing import Any, Optional, cast + +from configs import dify_config +from core.app.app_config.entities import FileExtraConfig +from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import InvokeFrom +from core.file.file_obj import FileTransferMethod, FileType, FileVar +from core.workflow.callbacks.base_workflow_callback import WorkflowCallback +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeType, UserFrom +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.errors import WorkflowNodeRunFailedError +from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent, InNodeEvent +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.graph_engine.graph_engine import GraphEngine +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.event import RunEvent +from core.workflow.nodes.llm.entities import LLMNodeData +from core.workflow.nodes.node_mapping import node_classes +from models.workflow import ( + Workflow, + WorkflowType, +) + +logger = logging.getLogger(__name__) + + +class WorkflowEntry: + def __init__( + self, + tenant_id: str, + app_id: str, + workflow_id: str, + workflow_type: WorkflowType, + graph_config: Mapping[str, Any], + graph: Graph, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + call_depth: int, + variable_pool: VariablePool, + thread_pool_id: Optional[str] = None, + ) -> None: + """ + Init workflow entry + :param tenant_id: tenant id + :param app_id: app id + :param workflow_id: workflow id + :param workflow_type: workflow type + :param graph_config: workflow graph config + :param graph: workflow graph + :param user_id: user id + :param user_from: user from + :param invoke_from: invoke from + :param call_depth: call depth + :param variable_pool: variable pool + :param thread_pool_id: thread pool id + """ + # check call depth + workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH + if call_depth > workflow_call_max_depth: + raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth)) + + # init workflow run state + self.graph_engine = GraphEngine( + tenant_id=tenant_id, + app_id=app_id, + workflow_type=workflow_type, + workflow_id=workflow_id, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + call_depth=call_depth, + graph=graph, + graph_config=graph_config, + variable_pool=variable_pool, + max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, + max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, + thread_pool_id=thread_pool_id, + ) + + def run( + self, + *, + callbacks: Sequence[WorkflowCallback], + ) -> Generator[GraphEngineEvent, None, None]: + """ + :param callbacks: workflow callbacks + """ + graph_engine = self.graph_engine + + try: + # run workflow + generator = graph_engine.run() + for event in generator: + if callbacks: + for callback in callbacks: + callback.on_event(event=event) + yield event + except GenerateTaskStoppedError: + pass + except Exception as e: + logger.exception("Unknown Error when workflow entry running") + if callbacks: + for callback in callbacks: + callback.on_event(event=GraphRunFailedEvent(error=str(e))) + return + + @classmethod + def single_step_run( + cls, workflow: Workflow, node_id: str, user_id: str, user_inputs: dict + ) -> tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]: + """ + Single step run workflow node + :param workflow: Workflow instance + :param node_id: node id + :param user_id: user id + :param user_inputs: user inputs + :return: + """ + # fetch node info from workflow graph + graph = workflow.graph_dict + if not graph: + raise ValueError("workflow graph not found") + + nodes = graph.get("nodes") + if not nodes: + raise ValueError("nodes not found in workflow graph") + + # fetch node config from node id + node_config = None + for node in nodes: + if node.get("id") == node_id: + node_config = node + break + + if not node_config: + raise ValueError("node id not found in workflow graph") + + # Get node class + node_type = NodeType.value_of(node_config.get("data", {}).get("type")) + node_cls = node_classes.get(node_type) + node_cls = cast(type[BaseNode], node_cls) + + if not node_cls: + raise ValueError(f"Node class not found for node type {node_type}") + + # init variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + environment_variables=workflow.environment_variables, + ) + + # init graph + graph = Graph.init(graph_config=workflow.graph_dict) + + # init workflow run state + node_instance: BaseNode = node_cls( + id=str(uuid.uuid4()), + config=node_config, + graph_init_params=GraphInitParams( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_type=WorkflowType.value_of(workflow.type), + workflow_id=workflow.id, + graph_config=workflow.graph_dict, + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + ) + + try: + # variable selector to variable mapping + try: + variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( + graph_config=workflow.graph_dict, config=node_config + ) + except NotImplementedError: + variable_mapping = {} + + cls.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id=workflow.tenant_id, + node_type=node_type, + node_data=node_instance.node_data, + ) + + # run node + generator = node_instance.run() + + return node_instance, generator + except Exception as e: + raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) + + @classmethod + def handle_special_values(cls, value: Optional[Mapping[str, Any]]) -> Optional[dict]: + """ + Handle special values + :param value: value + :return: + """ + if not value: + return None + + new_value = dict(value) if value else {} + if isinstance(new_value, dict): + for key, val in new_value.items(): + if isinstance(val, FileVar): + new_value[key] = val.to_dict() + elif isinstance(val, list): + new_val = [] + for v in val: + if isinstance(v, FileVar): + new_val.append(v.to_dict()) + else: + new_val.append(v) + + new_value[key] = new_val + + return new_value + + @classmethod + def mapping_user_inputs_to_variable_pool( + cls, + variable_mapping: Mapping[str, Sequence[str]], + user_inputs: dict, + variable_pool: VariablePool, + tenant_id: str, + node_type: NodeType, + node_data: BaseNodeData, + ) -> None: + for node_variable, variable_selector in variable_mapping.items(): + # fetch node id and variable key from node_variable + node_variable_list = node_variable.split(".") + if len(node_variable_list) < 1: + raise ValueError(f"Invalid node variable {node_variable}") + + node_variable_key = ".".join(node_variable_list[1:]) + + if (node_variable_key not in user_inputs and node_variable not in user_inputs) and not variable_pool.get( + variable_selector + ): + raise ValueError(f"Variable key {node_variable} not found in user inputs.") + + # fetch variable node id from variable selector + variable_node_id = variable_selector[0] + variable_key_list = variable_selector[1:] + variable_key_list = cast(list[str], variable_key_list) + + # get input value + input_value = user_inputs.get(node_variable) + if not input_value: + input_value = user_inputs.get(node_variable_key) + + # FIXME: temp fix for image type + if node_type == NodeType.LLM: + new_value = [] + if isinstance(input_value, list): + node_data = cast(LLMNodeData, node_data) + + detail = node_data.vision.configs.detail if node_data.vision.configs else None + + for item in input_value: + if isinstance(item, dict) and "type" in item and item["type"] == "image": + transfer_method = FileTransferMethod.value_of(item.get("transfer_method")) + file = FileVar( + tenant_id=tenant_id, + type=FileType.IMAGE, + transfer_method=transfer_method, + url=item.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None, + related_id=item.get("upload_file_id") + if transfer_method == FileTransferMethod.LOCAL_FILE + else None, + extra_config=FileExtraConfig(image_config={"detail": detail} if detail else None), + ) + new_value.append(file) + + if new_value: + value = new_value + + # append variable and value to variable pool + variable_pool.add([variable_node_id] + variable_key_list, input_value) diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py index 7ee7146d09..1d6ad35333 100644 --- a/api/events/event_handlers/__init__.py +++ b/api/events/event_handlers/__init__.py @@ -3,8 +3,8 @@ from .clean_when_document_deleted import handle from .create_document_index import handle from .create_installed_app_when_app_created import handle from .create_site_record_when_app_created import handle -from .deduct_quota_when_messaeg_created import handle +from .deduct_quota_when_message_created import handle from .delete_tool_parameters_cache_when_sync_draft_workflow import handle from .update_app_dataset_join_when_app_model_config_updated import handle from .update_app_dataset_join_when_app_published_workflow_updated import handle -from .update_provider_last_used_at_when_messaeg_created import handle +from .update_provider_last_used_at_when_message_created import handle diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 72a135e73d..54f6a76e16 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -5,7 +5,7 @@ import time import click from werkzeug.exceptions import NotFound -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from events.event_handlers.document_index_event import document_index_created from extensions.ext_database import db from models.dataset import Document @@ -43,7 +43,7 @@ def handle(sender, **kwargs): indexing_runner.run(documents) end_at = time.perf_counter() logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) - except DocumentIsPausedException as ex: + except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/events/event_handlers/deduct_quota_when_messaeg_created.py b/api/events/event_handlers/deduct_quota_when_message_created.py similarity index 100% rename from api/events/event_handlers/deduct_quota_when_messaeg_created.py rename to api/events/event_handlers/deduct_quota_when_message_created.py diff --git a/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py b/api/events/event_handlers/update_provider_last_used_at_when_message_created.py similarity index 100% rename from api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py rename to api/events/event_handlers/update_provider_last_used_at_when_message_created.py diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index f5ec7c1759..0ff9f90847 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -10,11 +10,21 @@ def init_app(app: Flask) -> Celery: with app.app_context(): return self.run(*args, **kwargs) + broker_transport_options = {} + + if app.config.get("CELERY_USE_SENTINEL"): + broker_transport_options = { + "master_name": app.config.get("CELERY_SENTINEL_MASTER_NAME"), + "sentinel_kwargs": { + "socket_timeout": app.config.get("CELERY_SENTINEL_SOCKET_TIMEOUT", 0.1), + }, + } + celery_app = Celery( app.name, task_cls=FlaskTask, - broker=app.config["CELERY_BROKER_URL"], - backend=app.config["CELERY_BACKEND"], + broker=app.config.get("CELERY_BROKER_URL"), + backend=app.config.get("CELERY_BACKEND"), task_ignore_result=True, ) @@ -27,11 +37,12 @@ def init_app(app: Flask) -> Celery: } celery_app.conf.update( - result_backend=app.config["CELERY_RESULT_BACKEND"], + result_backend=app.config.get("CELERY_RESULT_BACKEND"), + broker_transport_options=broker_transport_options, broker_connection_retry_on_startup=True, ) - if app.config["BROKER_USE_SSL"]: + if app.config.get("BROKER_USE_SSL"): celery_app.conf.update( broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration ) @@ -43,7 +54,7 @@ def init_app(app: Flask) -> Celery: "schedule.clean_embedding_cache_task", "schedule.clean_unused_datasets_task", ] - day = app.config["CELERY_BEAT_SCHEDULER_TIME"] + day = app.config.get("CELERY_BEAT_SCHEDULER_TIME") beat_schedule = { "clean_embedding_cache_task": { "task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task", diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index d5fb162fd8..054769e7ff 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -1,26 +1,83 @@ import redis from redis.connection import Connection, SSLConnection +from redis.sentinel import Sentinel -redis_client = redis.Redis() + +class RedisClientWrapper(redis.Redis): + """ + A wrapper class for the Redis client that addresses the issue where the global + `redis_client` variable cannot be updated when a new Redis instance is returned + by Sentinel. + + This class allows for deferred initialization of the Redis client, enabling the + client to be re-initialized with a new instance when necessary. This is particularly + useful in scenarios where the Redis instance may change dynamically, such as during + a failover in a Sentinel-managed Redis setup. + + Attributes: + _client (redis.Redis): The actual Redis client instance. It remains None until + initialized with the `initialize` method. + + Methods: + initialize(client): Initializes the Redis client if it hasn't been initialized already. + __getattr__(item): Delegates attribute access to the Redis client, raising an error + if the client is not initialized. + """ + + def __init__(self): + self._client = None + + def initialize(self, client): + if self._client is None: + self._client = client + + def __getattr__(self, item): + if self._client is None: + raise RuntimeError("Redis client is not initialized. Call init_app first.") + return getattr(self._client, item) + + +redis_client = RedisClientWrapper() def init_app(app): + global redis_client connection_class = Connection if app.config.get("REDIS_USE_SSL"): connection_class = SSLConnection - redis_client.connection_pool = redis.ConnectionPool( - **{ - "host": app.config.get("REDIS_HOST"), - "port": app.config.get("REDIS_PORT"), - "username": app.config.get("REDIS_USERNAME"), - "password": app.config.get("REDIS_PASSWORD"), - "db": app.config.get("REDIS_DB"), - "encoding": "utf-8", - "encoding_errors": "strict", - "decode_responses": False, - }, - connection_class=connection_class, - ) + redis_params = { + "username": app.config.get("REDIS_USERNAME"), + "password": app.config.get("REDIS_PASSWORD"), + "db": app.config.get("REDIS_DB"), + "encoding": "utf-8", + "encoding_errors": "strict", + "decode_responses": False, + } + + if app.config.get("REDIS_USE_SENTINEL"): + sentinel_hosts = [ + (node.split(":")[0], int(node.split(":")[1])) for node in app.config.get("REDIS_SENTINELS").split(",") + ] + sentinel = Sentinel( + sentinel_hosts, + sentinel_kwargs={ + "socket_timeout": app.config.get("REDIS_SENTINEL_SOCKET_TIMEOUT", 0.1), + "username": app.config.get("REDIS_SENTINEL_USERNAME"), + "password": app.config.get("REDIS_SENTINEL_PASSWORD"), + }, + ) + master = sentinel.master_for(app.config.get("REDIS_SENTINEL_SERVICE_NAME"), **redis_params) + redis_client.initialize(master) + else: + redis_params.update( + { + "host": app.config.get("REDIS_HOST"), + "port": app.config.get("REDIS_PORT"), + "connection_class": connection_class, + } + ) + pool = redis.ConnectionPool(**redis_params) + redis_client.initialize(redis.Redis(connection_pool=pool)) app.extensions["redis"] = redis_client diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 227c6635f0..3b7b0a37f4 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -1,3 +1,4 @@ +import openai import sentry_sdk from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.flask import FlaskIntegration @@ -9,7 +10,7 @@ def init_app(app): sentry_sdk.init( dsn=app.config.get("SENTRY_DSN"), integrations=[FlaskIntegration(), CeleryIntegration()], - ignore_errors=[HTTPException, ValueError], + ignore_errors=[HTTPException, ValueError, openai.APIStatusError], traces_sample_rate=app.config.get("SENTRY_TRACES_SAMPLE_RATE", 1.0), profiles_sample_rate=app.config.get("SENTRY_PROFILES_SAMPLE_RATE", 1.0), environment=app.config.get("DEPLOY_ENV"), diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index e6c4352577..5ce18b7292 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -6,10 +6,12 @@ from flask import Flask from extensions.storage.aliyun_storage import AliyunStorage from extensions.storage.azure_storage import AzureStorage from extensions.storage.google_storage import GoogleStorage +from extensions.storage.huawei_storage import HuaweiStorage from extensions.storage.local_storage import LocalStorage from extensions.storage.oci_storage import OCIStorage from extensions.storage.s3_storage import S3Storage from extensions.storage.tencent_storage import TencentStorage +from extensions.storage.volcengine_storage import VolcengineStorage class Storage: @@ -30,6 +32,10 @@ class Storage: self.storage_runner = TencentStorage(app=app) elif storage_type == "oci-storage": self.storage_runner = OCIStorage(app=app) + elif storage_type == "huawei-obs": + self.storage_runner = HuaweiStorage(app=app) + elif storage_type == "volcengine-tos": + self.storage_runner = VolcengineStorage(app=app) else: self.storage_runner = LocalStorage(app=app) diff --git a/api/extensions/storage/aliyun_storage.py b/api/extensions/storage/aliyun_storage.py index b962cedc55..bee237fc17 100644 --- a/api/extensions/storage/aliyun_storage.py +++ b/api/extensions/storage/aliyun_storage.py @@ -15,6 +15,7 @@ class AliyunStorage(BaseStorage): app_config = self.app.config self.bucket_name = app_config.get("ALIYUN_OSS_BUCKET_NAME") + self.folder = app.config.get("ALIYUN_OSS_PATH") oss_auth_method = aliyun_s3.Auth region = None if app_config.get("ALIYUN_OSS_AUTH_VERSION") == "v4": @@ -30,15 +31,29 @@ class AliyunStorage(BaseStorage): ) def save(self, filename, data): + if not self.folder or self.folder.endswith("/"): + filename = self.folder + filename + else: + filename = self.folder + "/" + filename self.client.put_object(filename, data) def load_once(self, filename: str) -> bytes: + if not self.folder or self.folder.endswith("/"): + filename = self.folder + filename + else: + filename = self.folder + "/" + filename + with closing(self.client.get_object(filename)) as obj: data = obj.read() return data def load_stream(self, filename: str) -> Generator: def generate(filename: str = filename) -> Generator: + if not self.folder or self.folder.endswith("/"): + filename = self.folder + filename + else: + filename = self.folder + "/" + filename + with closing(self.client.get_object(filename)) as obj: while chunk := obj.read(4096): yield chunk @@ -46,10 +61,24 @@ class AliyunStorage(BaseStorage): return generate() def download(self, filename, target_filepath): + if not self.folder or self.folder.endswith("/"): + filename = self.folder + filename + else: + filename = self.folder + "/" + filename + self.client.get_object_to_file(filename, target_filepath) def exists(self, filename): + if not self.folder or self.folder.endswith("/"): + filename = self.folder + filename + else: + filename = self.folder + "/" + filename + return self.client.object_exists(filename) def delete(self, filename): + if not self.folder or self.folder.endswith("/"): + filename = self.folder + filename + else: + filename = self.folder + "/" + filename self.client.delete_object(filename) diff --git a/api/extensions/storage/google_storage.py b/api/extensions/storage/google_storage.py index 9ed1fcf0b4..c42f946fa8 100644 --- a/api/extensions/storage/google_storage.py +++ b/api/extensions/storage/google_storage.py @@ -5,7 +5,7 @@ from collections.abc import Generator from contextlib import closing from flask import Flask -from google.cloud import storage as GoogleCloudStorage +from google.cloud import storage as google_cloud_storage from extensions.storage.base_storage import BaseStorage @@ -23,9 +23,9 @@ class GoogleStorage(BaseStorage): service_account_json = base64.b64decode(service_account_json_str).decode("utf-8") # convert str to object service_account_obj = json.loads(service_account_json) - self.client = GoogleCloudStorage.Client.from_service_account_info(service_account_obj) + self.client = google_cloud_storage.Client.from_service_account_info(service_account_obj) else: - self.client = GoogleCloudStorage.Client() + self.client = google_cloud_storage.Client() def save(self, filename, data): bucket = self.client.get_bucket(self.bucket_name) diff --git a/api/extensions/storage/huawei_storage.py b/api/extensions/storage/huawei_storage.py new file mode 100644 index 0000000000..269a008fba --- /dev/null +++ b/api/extensions/storage/huawei_storage.py @@ -0,0 +1,53 @@ +from collections.abc import Generator + +from flask import Flask +from obs import ObsClient + +from extensions.storage.base_storage import BaseStorage + + +class HuaweiStorage(BaseStorage): + """Implementation for huawei obs storage.""" + + def __init__(self, app: Flask): + super().__init__(app) + app_config = self.app.config + self.bucket_name = app_config.get("HUAWEI_OBS_BUCKET_NAME") + self.client = ObsClient( + access_key_id=app_config.get("HUAWEI_OBS_ACCESS_KEY"), + secret_access_key=app_config.get("HUAWEI_OBS_SECRET_KEY"), + server=app_config.get("HUAWEI_OBS_SERVER"), + ) + + def save(self, filename, data): + self.client.putObject(bucketName=self.bucket_name, objectKey=filename, content=data) + + def load_once(self, filename: str) -> bytes: + data = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read() + return data + + def load_stream(self, filename: str) -> Generator: + def generate(filename: str = filename) -> Generator: + response = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response + yield from response.read(4096) + + return generate() + + def download(self, filename, target_filepath): + self.client.getObject(bucketName=self.bucket_name, objectKey=filename, downloadPath=target_filepath) + + def exists(self, filename): + res = self._get_meta(filename) + if res is None: + return False + return True + + def delete(self, filename): + self.client.deleteObject(bucketName=self.bucket_name, objectKey=filename) + + def _get_meta(self, filename): + res = self.client.getObjectMetadata(bucketName=self.bucket_name, objectKey=filename) + if res.status < 300: + return res + else: + return None diff --git a/api/extensions/storage/s3_storage.py b/api/extensions/storage/s3_storage.py index 424d441cdc..0858be3af6 100644 --- a/api/extensions/storage/s3_storage.py +++ b/api/extensions/storage/s3_storage.py @@ -35,6 +35,9 @@ class S3Storage(BaseStorage): # if bucket not exists, create it if e.response["Error"]["Code"] == "404": self.client.create_bucket(Bucket=self.bucket_name) + # if bucket is not accessible, pass, maybe the bucket is existing but not accessible + elif e.response["Error"]["Code"] == "403": + pass else: # other error, raise exception raise diff --git a/api/extensions/storage/volcengine_storage.py b/api/extensions/storage/volcengine_storage.py new file mode 100644 index 0000000000..f74ad2ee6d --- /dev/null +++ b/api/extensions/storage/volcengine_storage.py @@ -0,0 +1,48 @@ +from collections.abc import Generator + +import tos +from flask import Flask + +from extensions.storage.base_storage import BaseStorage + + +class VolcengineStorage(BaseStorage): + """Implementation for Volcengine TOS storage.""" + + def __init__(self, app: Flask): + super().__init__(app) + app_config = self.app.config + self.bucket_name = app_config.get("VOLCENGINE_TOS_BUCKET_NAME") + self.client = tos.TosClientV2( + ak=app_config.get("VOLCENGINE_TOS_ACCESS_KEY"), + sk=app_config.get("VOLCENGINE_TOS_SECRET_KEY"), + endpoint=app_config.get("VOLCENGINE_TOS_ENDPOINT"), + region=app_config.get("VOLCENGINE_TOS_REGION"), + ) + + def save(self, filename, data): + self.client.put_object(bucket=self.bucket_name, key=filename, content=data) + + def load_once(self, filename: str) -> bytes: + data = self.client.get_object(bucket=self.bucket_name, key=filename).read() + return data + + def load_stream(self, filename: str) -> Generator: + def generate(filename: str = filename) -> Generator: + response = self.client.get_object(bucket=self.bucket_name, key=filename) + while chunk := response.read(4096): + yield chunk + + return generate() + + def download(self, filename, target_filepath): + self.client.get_object_to_file(bucket=self.bucket_name, key=filename, file_path=target_filepath) + + def exists(self, filename): + res = self.client.head_object(bucket=self.bucket_name, key=filename) + if res.status_code != 200: + return False + return True + + def delete(self, filename): + self.client.delete_object(bucket=self.bucket_name, key=filename) diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index 45fcb128ce..aa353a3cc1 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -58,6 +58,7 @@ app_detail_fields = { "model_config": fields.Nested(model_config_fields, attribute="app_model_config", allow_null=True), "workflow": fields.Nested(workflow_partial_fields, allow_null=True), "tracing": fields.Raw, + "use_icon_as_answer_icon": fields.Boolean, "created_by": fields.String, "created_at": TimestampField, "updated_by": fields.String, @@ -91,6 +92,7 @@ app_partial_fields = { "icon_url": AppIconUrlField, "model_config": fields.Nested(model_config_partial_fields, attribute="app_model_config", allow_null=True), "workflow": fields.Nested(workflow_partial_fields, allow_null=True), + "use_icon_as_answer_icon": fields.Boolean, "created_by": fields.String, "created_at": TimestampField, "updated_by": fields.String, @@ -140,6 +142,7 @@ site_fields = { "prompt_public": fields.Boolean, "app_base_url": fields.String, "show_workflow_steps": fields.Boolean, + "use_icon_as_answer_icon": fields.Boolean, "created_by": fields.String, "created_at": TimestampField, "updated_by": fields.String, @@ -161,6 +164,7 @@ app_detail_fields_with_site = { "workflow": fields.Nested(workflow_partial_fields, allow_null=True), "site": fields.Nested(site_fields), "api_base_url": fields.String, + "use_icon_as_answer_icon": fields.Boolean, "created_by": fields.String, "created_at": TimestampField, "updated_by": fields.String, @@ -184,4 +188,5 @@ app_site_fields = { "customize_token_strategy": fields.String, "prompt_public": fields.Boolean, "show_workflow_steps": fields.Boolean, + "use_icon_as_answer_icon": fields.Boolean, } diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py index 9afc1b1a4a..e0b3e340f6 100644 --- a/api/fields/installed_app_fields.py +++ b/api/fields/installed_app_fields.py @@ -10,6 +10,7 @@ app_fields = { "icon": fields.String, "icon_background": fields.String, "icon_url": AppIconUrlField, + "use_icon_as_answer_icon": fields.Boolean, } installed_app_fields = { diff --git a/api/libs/gmpy2_pkcs10aep_cipher.py b/api/libs/gmpy2_pkcs10aep_cipher.py index 2d306edb40..f89902c5e8 100644 --- a/api/libs/gmpy2_pkcs10aep_cipher.py +++ b/api/libs/gmpy2_pkcs10aep_cipher.py @@ -31,7 +31,7 @@ from Crypto.Util.py3compat import _copy_bytes, bord from Crypto.Util.strxor import strxor -class PKCS1OAEP_Cipher: +class PKCS1OAepCipher: """Cipher object for PKCS#1 v1.5 OAEP. Do not create directly: use :func:`new` instead.""" @@ -237,4 +237,4 @@ def new(key, hashAlgo=None, mgfunc=None, label=b"", randfunc=None): if randfunc is None: randfunc = Random.get_random_bytes - return PKCS1OAEP_Cipher(key, hashAlgo, mgfunc, label, randfunc) + return PKCS1OAepCipher(key, hashAlgo, mgfunc, label, randfunc) diff --git a/api/libs/helper.py b/api/libs/helper.py index 7e3c269e3f..5adce452ef 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -84,7 +84,7 @@ def timestamp_value(timestamp): raise ValueError(error) -class str_len: +class StrLen: """Restrict input to an integer in a range (inclusive)""" def __init__(self, max_length, argument="argument"): @@ -102,7 +102,7 @@ class str_len: return value -class float_range: +class FloatRange: """Restrict input to an float in a range (inclusive)""" def __init__(self, low, high, argument="argument"): @@ -121,7 +121,7 @@ class float_range: return value -class datetime_string: +class DatetimeString: def __init__(self, format, argument="argument"): self.format = format self.argument = argument diff --git a/api/libs/json_in_md_parser.py b/api/libs/json_in_md_parser.py index 41d6905899..39c17534e7 100644 --- a/api/libs/json_in_md_parser.py +++ b/api/libs/json_in_md_parser.py @@ -1,6 +1,6 @@ import json -from core.llm_generator.output_parser.errors import OutputParserException +from core.llm_generator.output_parser.errors import OutputParserError def parse_json_markdown(json_string: str) -> dict: @@ -33,10 +33,10 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: try: json_obj = parse_json_markdown(text) except json.JSONDecodeError as e: - raise OutputParserException(f"Got invalid JSON object. Error: {e}") + raise OutputParserError(f"Got invalid JSON object. Error: {e}") for key in expected_keys: if key not in json_obj: - raise OutputParserException( + raise OutputParserError( f"Got invalid return object. Expected key `{key}` " f"to be present, but got {json_obj}" ) return json_obj diff --git a/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py b/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py index 0fba6a87eb..8cd4ec552b 100644 --- a/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py +++ b/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py @@ -24,6 +24,7 @@ def upgrade(): with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: batch_op.add_column(sa.Column('label', sa.String(length=255), server_default='', nullable=False)) + def downgrade(): with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: batch_op.drop_column('label') diff --git a/api/migrations/versions/2024_09_01_1255-030f4915f36a_add_use_icon_as_answer_icon_fields_for_.py b/api/migrations/versions/2024_09_01_1255-030f4915f36a_add_use_icon_as_answer_icon_fields_for_.py new file mode 100644 index 0000000000..4406d51ed0 --- /dev/null +++ b/api/migrations/versions/2024_09_01_1255-030f4915f36a_add_use_icon_as_answer_icon_fields_for_.py @@ -0,0 +1,45 @@ +"""add use_icon_as_answer_icon fields for app and site + +Revision ID: 030f4915f36a +Revises: d0187d6a88dd +Create Date: 2024-09-01 12:55:45.129687 + +""" + +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = "030f4915f36a" +down_revision = "d0187d6a88dd" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("apps", schema=None) as batch_op: + batch_op.add_column( + sa.Column("use_icon_as_answer_icon", sa.Boolean(), server_default=sa.text("false"), nullable=False) + ) + + with op.batch_alter_table("sites", schema=None) as batch_op: + batch_op.add_column( + sa.Column("use_icon_as_answer_icon", sa.Boolean(), server_default=sa.text("false"), nullable=False) + ) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table("sites", schema=None) as batch_op: + batch_op.drop_column("use_icon_as_answer_icon") + + with op.batch_alter_table("apps", schema=None) as batch_op: + batch_op.drop_column("use_icon_as_answer_icon") + + # ### end Alembic commands ### diff --git a/api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py b/api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py new file mode 100644 index 0000000000..55824945da --- /dev/null +++ b/api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py @@ -0,0 +1,35 @@ +"""add node_execution_id into node_executions + +Revision ID: 675b5321501b +Revises: 030f4915f36a +Create Date: 2024-08-12 10:54:02.259331 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '675b5321501b' +down_revision = '030f4915f36a' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.add_column(sa.Column('node_execution_id', sa.String(length=255), nullable=True)) + batch_op.create_index('workflow_node_execution_id_idx', ['tenant_id', 'app_id', 'workflow_id', 'triggered_from', 'node_execution_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.drop_index('workflow_node_execution_id_idx') + batch_op.drop_column('node_execution_id') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py b/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py index bfda7d619c..92f41f0abd 100644 --- a/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py +++ b/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py @@ -21,6 +21,7 @@ def upgrade(): with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: batch_op.add_column(sa.Column('version', sa.String(length=255), server_default='', nullable=False)) + def downgrade(): with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: batch_op.drop_column('version') diff --git a/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py b/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py index 2365766837..fcca705d21 100644 --- a/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py +++ b/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py @@ -99,7 +99,7 @@ def upgrade(): id=id, tenant_id=tenant_id, user_id=user_id, - provider='google', + provider='google', encrypted_credentials=encrypted_credentials, created_at=created_at, updated_at=updated_at diff --git a/api/models/__init__.py b/api/models/__init__.py index 4012611471..30ceef057e 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -4,7 +4,7 @@ from .model import App, AppMode, Message from .types import StringUUID from .workflow import ConversationVariable, Workflow, WorkflowNodeExecutionStatus -__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus', 'Workflow', 'App', 'Message'] +__all__ = ["ConversationVariable", "StringUUID", "AppMode", "WorkflowNodeExecutionStatus", "Workflow", "App", "Message"] class CreatedByRole(Enum): @@ -12,11 +12,11 @@ class CreatedByRole(Enum): Enum class for createdByRole """ - ACCOUNT = 'account' - END_USER = 'end_user' + ACCOUNT = "account" + END_USER = "end_user" @classmethod - def value_of(cls, value: str) -> 'CreatedByRole': + def value_of(cls, value: str) -> "CreatedByRole": """ Get value of given mode. @@ -26,4 +26,4 @@ class CreatedByRole(Enum): for role in cls: if role.value == value: return role - raise ValueError(f'invalid createdByRole value {value}') + raise ValueError(f"invalid createdByRole value {value}") diff --git a/api/models/account.py b/api/models/account.py index 67d940b7b7..60b4f11aad 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -9,21 +9,18 @@ from .types import StringUUID class AccountStatus(str, enum.Enum): - PENDING = 'pending' - UNINITIALIZED = 'uninitialized' - ACTIVE = 'active' - BANNED = 'banned' - CLOSED = 'closed' + PENDING = "pending" + UNINITIALIZED = "uninitialized" + ACTIVE = "active" + BANNED = "banned" + CLOSED = "closed" class Account(UserMixin, db.Model): - __tablename__ = 'accounts' - __table_args__ = ( - db.PrimaryKeyConstraint('id', name='account_pkey'), - db.Index('account_email_idx', 'email') - ) + __tablename__ = "accounts" + __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) name = db.Column(db.String(255), nullable=False) email = db.Column(db.String(255), nullable=False) password = db.Column(db.String(255), nullable=True) @@ -34,11 +31,11 @@ class Account(UserMixin, db.Model): timezone = db.Column(db.String(255)) last_login_at = db.Column(db.DateTime) last_login_ip = db.Column(db.String(255)) - last_active_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + last_active_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) status = db.Column(db.String(16), nullable=False, server_default=db.text("'active'::character varying")) initialized_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def is_password_set(self): @@ -65,11 +62,13 @@ class Account(UserMixin, db.Model): @current_tenant_id.setter def current_tenant_id(self, value: str): try: - tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \ - .filter(Tenant.id == value) \ - .filter(TenantAccountJoin.tenant_id == Tenant.id) \ - .filter(TenantAccountJoin.account_id == self.id) \ + tenant_account_join = ( + db.session.query(Tenant, TenantAccountJoin) + .filter(Tenant.id == value) + .filter(TenantAccountJoin.tenant_id == Tenant.id) + .filter(TenantAccountJoin.account_id == self.id) .one_or_none() + ) if tenant_account_join: tenant, ta = tenant_account_join @@ -91,20 +90,18 @@ class Account(UserMixin, db.Model): @classmethod def get_by_openid(cls, provider: str, open_id: str) -> db.Model: - account_integrate = db.session.query(AccountIntegrate). \ - filter(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id). \ - one_or_none() + account_integrate = ( + db.session.query(AccountIntegrate) + .filter(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id) + .one_or_none() + ) if account_integrate: - return db.session.query(Account). \ - filter(Account.id == account_integrate.account_id). \ - one_or_none() + return db.session.query(Account).filter(Account.id == account_integrate.account_id).one_or_none() return None def get_integrates(self) -> list[db.Model]: ai = db.Model - return db.session.query(ai).filter( - ai.account_id == self.id - ).all() + return db.session.query(ai).filter(ai.account_id == self.id).all() # check current_user.current_tenant.current_role in ['admin', 'owner'] @property @@ -123,61 +120,75 @@ class Account(UserMixin, db.Model): def is_dataset_operator(self): return self._current_tenant.current_role == TenantAccountRole.DATASET_OPERATOR + class TenantStatus(str, enum.Enum): - NORMAL = 'normal' - ARCHIVE = 'archive' + NORMAL = "normal" + ARCHIVE = "archive" class TenantAccountRole(str, enum.Enum): - OWNER = 'owner' - ADMIN = 'admin' - EDITOR = 'editor' - NORMAL = 'normal' - DATASET_OPERATOR = 'dataset_operator' + OWNER = "owner" + ADMIN = "admin" + EDITOR = "editor" + NORMAL = "normal" + DATASET_OPERATOR = "dataset_operator" @staticmethod def is_valid_role(role: str) -> bool: - return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, - TenantAccountRole.NORMAL, TenantAccountRole.DATASET_OPERATOR} + return role and role in { + TenantAccountRole.OWNER, + TenantAccountRole.ADMIN, + TenantAccountRole.EDITOR, + TenantAccountRole.NORMAL, + TenantAccountRole.DATASET_OPERATOR, + } @staticmethod def is_privileged_role(role: str) -> bool: return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN} - + @staticmethod def is_non_owner_role(role: str) -> bool: - return role and role in {TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, TenantAccountRole.NORMAL, - TenantAccountRole.DATASET_OPERATOR} - + return role and role in { + TenantAccountRole.ADMIN, + TenantAccountRole.EDITOR, + TenantAccountRole.NORMAL, + TenantAccountRole.DATASET_OPERATOR, + } + @staticmethod def is_editing_role(role: str) -> bool: return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR} @staticmethod def is_dataset_edit_role(role: str) -> bool: - return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, - TenantAccountRole.DATASET_OPERATOR} + return role and role in { + TenantAccountRole.OWNER, + TenantAccountRole.ADMIN, + TenantAccountRole.EDITOR, + TenantAccountRole.DATASET_OPERATOR, + } + class Tenant(db.Model): - __tablename__ = 'tenants' - __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tenant_pkey'), - ) + __tablename__ = "tenants" + __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) name = db.Column(db.String(255), nullable=False) encrypt_public_key = db.Column(db.Text) plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying")) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) custom_config = db.Column(db.Text) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) def get_accounts(self) -> list[Account]: - return db.session.query(Account).filter( - Account.id == TenantAccountJoin.account_id, - TenantAccountJoin.tenant_id == self.id - ).all() + return ( + db.session.query(Account) + .filter(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id) + .all() + ) @property def custom_config_dict(self) -> dict: @@ -189,54 +200,54 @@ class Tenant(db.Model): class TenantAccountJoinRole(enum.Enum): - OWNER = 'owner' - ADMIN = 'admin' - NORMAL = 'normal' - DATASET_OPERATOR = 'dataset_operator' + OWNER = "owner" + ADMIN = "admin" + NORMAL = "normal" + DATASET_OPERATOR = "dataset_operator" class TenantAccountJoin(db.Model): - __tablename__ = 'tenant_account_joins' + __tablename__ = "tenant_account_joins" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tenant_account_join_pkey'), - db.Index('tenant_account_join_account_id_idx', 'account_id'), - db.Index('tenant_account_join_tenant_id_idx', 'tenant_id'), - db.UniqueConstraint('tenant_id', 'account_id', name='unique_tenant_account_join') + db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"), + db.Index("tenant_account_join_account_id_idx", "account_id"), + db.Index("tenant_account_join_tenant_id_idx", "tenant_id"), + db.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) account_id = db.Column(StringUUID, nullable=False) - current = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - role = db.Column(db.String(16), nullable=False, server_default='normal') + current = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + role = db.Column(db.String(16), nullable=False, server_default="normal") invited_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class AccountIntegrate(db.Model): - __tablename__ = 'account_integrates' + __tablename__ = "account_integrates" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='account_integrate_pkey'), - db.UniqueConstraint('account_id', 'provider', name='unique_account_provider'), - db.UniqueConstraint('provider', 'open_id', name='unique_provider_open_id') + db.PrimaryKeyConstraint("id", name="account_integrate_pkey"), + db.UniqueConstraint("account_id", "provider", name="unique_account_provider"), + db.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) account_id = db.Column(StringUUID, nullable=False) provider = db.Column(db.String(16), nullable=False) open_id = db.Column(db.String(255), nullable=False) encrypted_token = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class InvitationCode(db.Model): - __tablename__ = 'invitation_codes' + __tablename__ = "invitation_codes" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='invitation_code_pkey'), - db.Index('invitation_codes_batch_idx', 'batch'), - db.Index('invitation_codes_code_idx', 'code', 'status') + db.PrimaryKeyConstraint("id", name="invitation_code_pkey"), + db.Index("invitation_codes_batch_idx", "batch"), + db.Index("invitation_codes_code_idx", "code", "status"), ) id = db.Column(db.Integer, nullable=False) @@ -247,4 +258,4 @@ class InvitationCode(db.Model): used_by_tenant_id = db.Column(StringUUID) used_by_account_id = db.Column(StringUUID) deprecated_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index 7f69323628..97173747af 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -6,22 +6,22 @@ from .types import StringUUID class APIBasedExtensionPoint(enum.Enum): - APP_EXTERNAL_DATA_TOOL_QUERY = 'app.external_data_tool.query' - PING = 'ping' - APP_MODERATION_INPUT = 'app.moderation.input' - APP_MODERATION_OUTPUT = 'app.moderation.output' + APP_EXTERNAL_DATA_TOOL_QUERY = "app.external_data_tool.query" + PING = "ping" + APP_MODERATION_INPUT = "app.moderation.input" + APP_MODERATION_OUTPUT = "app.moderation.output" class APIBasedExtension(db.Model): - __tablename__ = 'api_based_extensions' + __tablename__ = "api_based_extensions" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='api_based_extension_pkey'), - db.Index('api_based_extension_tenant_idx', 'tenant_id'), + db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"), + db.Index("api_based_extension_tenant_idx", "tenant_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) api_endpoint = db.Column(db.String(255), nullable=False) api_key = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/dataset.py b/api/models/dataset.py index 203031c7b9..55f6ed3180 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -14,7 +14,7 @@ from sqlalchemy import func from sqlalchemy.dialects.postgresql import JSONB from configs import dify_config -from core.rag.retrieval.retrival_methods import RetrievalMethod +from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from extensions.ext_storage import storage @@ -24,37 +24,34 @@ from .types import StringUUID class DatasetPermissionEnum(str, enum.Enum): - ONLY_ME = 'only_me' - ALL_TEAM = 'all_team_members' - PARTIAL_TEAM = 'partial_members' + ONLY_ME = "only_me" + ALL_TEAM = "all_team_members" + PARTIAL_TEAM = "partial_members" + class Dataset(db.Model): - __tablename__ = 'datasets' + __tablename__ = "datasets" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_pkey'), - db.Index('dataset_tenant_idx', 'tenant_id'), - db.Index('retrieval_model_idx', "retrieval_model", postgresql_using='gin') + db.PrimaryKeyConstraint("id", name="dataset_pkey"), + db.Index("dataset_tenant_idx", "tenant_id"), + db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"), ) - INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy', None] + INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None] - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=True) - provider = db.Column(db.String(255), nullable=False, - server_default=db.text("'vendor'::character varying")) - permission = db.Column(db.String(255), nullable=False, - server_default=db.text("'only_me'::character varying")) + provider = db.Column(db.String(255), nullable=False, server_default=db.text("'vendor'::character varying")) + permission = db.Column(db.String(255), nullable=False, server_default=db.text("'only_me'::character varying")) data_source_type = db.Column(db.String(255)) indexing_technique = db.Column(db.String(255), nullable=True) index_struct = db.Column(db.Text, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) embedding_model = db.Column(db.String(255), nullable=True) embedding_model_provider = db.Column(db.String(255), nullable=True) collection_binding_id = db.Column(StringUUID, nullable=True) @@ -62,8 +59,9 @@ class Dataset(db.Model): @property def dataset_keyword_table(self): - dataset_keyword_table = db.session.query(DatasetKeywordTable).filter( - DatasetKeywordTable.dataset_id == self.id).first() + dataset_keyword_table = ( + db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == self.id).first() + ) if dataset_keyword_table: return dataset_keyword_table @@ -79,13 +77,19 @@ class Dataset(db.Model): @property def latest_process_rule(self): - return DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id) \ - .order_by(DatasetProcessRule.created_at.desc()).first() + return ( + DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id) + .order_by(DatasetProcessRule.created_at.desc()) + .first() + ) @property def app_count(self): - return db.session.query(func.count(AppDatasetJoin.id)).filter(AppDatasetJoin.dataset_id == self.id, - App.id == AppDatasetJoin.app_id).scalar() + return ( + db.session.query(func.count(AppDatasetJoin.id)) + .filter(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id) + .scalar() + ) @property def document_count(self): @@ -93,30 +97,40 @@ class Dataset(db.Model): @property def available_document_count(self): - return db.session.query(func.count(Document.id)).filter( - Document.dataset_id == self.id, - Document.indexing_status == 'completed', - Document.enabled == True, - Document.archived == False - ).scalar() + return ( + db.session.query(func.count(Document.id)) + .filter( + Document.dataset_id == self.id, + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + ) + .scalar() + ) @property def available_segment_count(self): - return db.session.query(func.count(DocumentSegment.id)).filter( - DocumentSegment.dataset_id == self.id, - DocumentSegment.status == 'completed', - DocumentSegment.enabled == True - ).scalar() + return ( + db.session.query(func.count(DocumentSegment.id)) + .filter( + DocumentSegment.dataset_id == self.id, + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + ) + .scalar() + ) @property def word_count(self): - return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \ - .filter(Document.dataset_id == self.id).scalar() + return ( + Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) + .filter(Document.dataset_id == self.id) + .scalar() + ) @property def doc_form(self): - document = db.session.query(Document).filter( - Document.dataset_id == self.id).first() + document = db.session.query(Document).filter(Document.dataset_id == self.id).first() if document: return document.doc_form return None @@ -124,76 +138,68 @@ class Dataset(db.Model): @property def retrieval_model_dict(self): default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } return self.retrieval_model if self.retrieval_model else default_retrieval_model @property def tags(self): - tags = db.session.query(Tag).join( - TagBinding, - Tag.id == TagBinding.tag_id - ).filter( - TagBinding.target_id == self.id, - TagBinding.tenant_id == self.tenant_id, - Tag.tenant_id == self.tenant_id, - Tag.type == 'knowledge' - ).all() + tags = ( + db.session.query(Tag) + .join(TagBinding, Tag.id == TagBinding.tag_id) + .filter( + TagBinding.target_id == self.id, + TagBinding.tenant_id == self.tenant_id, + Tag.tenant_id == self.tenant_id, + Tag.type == "knowledge", + ) + .all() + ) return tags if tags else [] @staticmethod def gen_collection_name_by_id(dataset_id: str) -> str: normalized_dataset_id = dataset_id.replace("-", "_") - return f'Vector_index_{normalized_dataset_id}_Node' + return f"Vector_index_{normalized_dataset_id}_Node" class DatasetProcessRule(db.Model): - __tablename__ = 'dataset_process_rules' + __tablename__ = "dataset_process_rules" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_process_rule_pkey'), - db.Index('dataset_process_rule_dataset_id_idx', 'dataset_id'), + db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), + db.Index("dataset_process_rule_dataset_id_idx", "dataset_id"), ) - id = db.Column(StringUUID, nullable=False, - server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) dataset_id = db.Column(StringUUID, nullable=False) - mode = db.Column(db.String(255), nullable=False, - server_default=db.text("'automatic'::character varying")) + mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) rules = db.Column(db.Text, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - MODES = ['automatic', 'custom'] - PRE_PROCESSING_RULES = ['remove_stopwords', 'remove_extra_spaces', 'remove_urls_emails'] + MODES = ["automatic", "custom"] + PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] AUTOMATIC_RULES = { - 'pre_processing_rules': [ - {'id': 'remove_extra_spaces', 'enabled': True}, - {'id': 'remove_urls_emails', 'enabled': False} + "pre_processing_rules": [ + {"id": "remove_extra_spaces", "enabled": True}, + {"id": "remove_urls_emails", "enabled": False}, ], - 'segmentation': { - 'delimiter': '\n', - 'max_tokens': 500, - 'chunk_overlap': 50 - } + "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, } def to_dict(self): return { - 'id': self.id, - 'dataset_id': self.dataset_id, - 'mode': self.mode, - 'rules': self.rules_dict, - 'created_by': self.created_by, - 'created_at': self.created_at, + "id": self.id, + "dataset_id": self.dataset_id, + "mode": self.mode, + "rules": self.rules_dict, + "created_by": self.created_by, + "created_at": self.created_at, } @property @@ -205,17 +211,16 @@ class DatasetProcessRule(db.Model): class Document(db.Model): - __tablename__ = 'documents' + __tablename__ = "documents" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='document_pkey'), - db.Index('document_dataset_id_idx', 'dataset_id'), - db.Index('document_is_paused_idx', 'is_paused'), - db.Index('document_tenant_idx', 'tenant_id'), + db.PrimaryKeyConstraint("id", name="document_pkey"), + db.Index("document_dataset_id_idx", "dataset_id"), + db.Index("document_is_paused_idx", "is_paused"), + db.Index("document_tenant_idx", "tenant_id"), ) # initial fields - id = db.Column(StringUUID, nullable=False, - server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) position = db.Column(db.Integer, nullable=False) @@ -227,8 +232,7 @@ class Document(db.Model): created_from = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) created_api_request_id = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) # start processing processing_started_at = db.Column(db.DateTime, nullable=True) @@ -250,7 +254,7 @@ class Document(db.Model): completed_at = db.Column(db.DateTime, nullable=True) # pause - is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text('false')) + is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) paused_by = db.Column(StringUUID, nullable=True) paused_at = db.Column(db.DateTime, nullable=True) @@ -259,44 +263,39 @@ class Document(db.Model): stopped_at = db.Column(db.DateTime, nullable=True) # basic fields - indexing_status = db.Column(db.String( - 255), nullable=False, server_default=db.text("'waiting'::character varying")) - enabled = db.Column(db.Boolean, nullable=False, - server_default=db.text('true')) + indexing_status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) + enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) disabled_at = db.Column(db.DateTime, nullable=True) disabled_by = db.Column(StringUUID, nullable=True) - archived = db.Column(db.Boolean, nullable=False, - server_default=db.text('false')) + archived = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) archived_reason = db.Column(db.String(255), nullable=True) archived_by = db.Column(StringUUID, nullable=True) archived_at = db.Column(db.DateTime, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) doc_type = db.Column(db.String(40), nullable=True) doc_metadata = db.Column(db.JSON, nullable=True) - doc_form = db.Column(db.String( - 255), nullable=False, server_default=db.text("'text_model'::character varying")) + doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) doc_language = db.Column(db.String(255), nullable=True) - DATA_SOURCES = ['upload_file', 'notion_import', 'website_crawl'] + DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"] @property def display_status(self): status = None - if self.indexing_status == 'waiting': - status = 'queuing' - elif self.indexing_status not in ['completed', 'error', 'waiting'] and self.is_paused: - status = 'paused' - elif self.indexing_status in ['parsing', 'cleaning', 'splitting', 'indexing']: - status = 'indexing' - elif self.indexing_status == 'error': - status = 'error' - elif self.indexing_status == 'completed' and not self.archived and self.enabled: - status = 'available' - elif self.indexing_status == 'completed' and not self.archived and not self.enabled: - status = 'disabled' - elif self.indexing_status == 'completed' and self.archived: - status = 'archived' + if self.indexing_status == "waiting": + status = "queuing" + elif self.indexing_status not in ["completed", "error", "waiting"] and self.is_paused: + status = "paused" + elif self.indexing_status in ["parsing", "cleaning", "splitting", "indexing"]: + status = "indexing" + elif self.indexing_status == "error": + status = "error" + elif self.indexing_status == "completed" and not self.archived and self.enabled: + status = "available" + elif self.indexing_status == "completed" and not self.archived and not self.enabled: + status = "disabled" + elif self.indexing_status == "completed" and self.archived: + status = "archived" return status @property @@ -313,24 +312,26 @@ class Document(db.Model): @property def data_source_detail_dict(self): if self.data_source_info: - if self.data_source_type == 'upload_file': + if self.data_source_type == "upload_file": data_source_info_dict = json.loads(self.data_source_info) - file_detail = db.session.query(UploadFile). \ - filter(UploadFile.id == data_source_info_dict['upload_file_id']). \ - one_or_none() + file_detail = ( + db.session.query(UploadFile) + .filter(UploadFile.id == data_source_info_dict["upload_file_id"]) + .one_or_none() + ) if file_detail: return { - 'upload_file': { - 'id': file_detail.id, - 'name': file_detail.name, - 'size': file_detail.size, - 'extension': file_detail.extension, - 'mime_type': file_detail.mime_type, - 'created_by': file_detail.created_by, - 'created_at': file_detail.created_at.timestamp() + "upload_file": { + "id": file_detail.id, + "name": file_detail.name, + "size": file_detail.size, + "extension": file_detail.extension, + "mime_type": file_detail.mime_type, + "created_by": file_detail.created_by, + "created_at": file_detail.created_at.timestamp(), } } - elif self.data_source_type == 'notion_import' or self.data_source_type == 'website_crawl': + elif self.data_source_type == "notion_import" or self.data_source_type == "website_crawl": return json.loads(self.data_source_info) return {} @@ -356,120 +357,123 @@ class Document(db.Model): @property def hit_count(self): - return DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) \ - .filter(DocumentSegment.document_id == self.id).scalar() + return ( + DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) + .filter(DocumentSegment.document_id == self.id) + .scalar() + ) def to_dict(self): return { - 'id': self.id, - 'tenant_id': self.tenant_id, - 'dataset_id': self.dataset_id, - 'position': self.position, - 'data_source_type': self.data_source_type, - 'data_source_info': self.data_source_info, - 'dataset_process_rule_id': self.dataset_process_rule_id, - 'batch': self.batch, - 'name': self.name, - 'created_from': self.created_from, - 'created_by': self.created_by, - 'created_api_request_id': self.created_api_request_id, - 'created_at': self.created_at, - 'processing_started_at': self.processing_started_at, - 'file_id': self.file_id, - 'word_count': self.word_count, - 'parsing_completed_at': self.parsing_completed_at, - 'cleaning_completed_at': self.cleaning_completed_at, - 'splitting_completed_at': self.splitting_completed_at, - 'tokens': self.tokens, - 'indexing_latency': self.indexing_latency, - 'completed_at': self.completed_at, - 'is_paused': self.is_paused, - 'paused_by': self.paused_by, - 'paused_at': self.paused_at, - 'error': self.error, - 'stopped_at': self.stopped_at, - 'indexing_status': self.indexing_status, - 'enabled': self.enabled, - 'disabled_at': self.disabled_at, - 'disabled_by': self.disabled_by, - 'archived': self.archived, - 'archived_reason': self.archived_reason, - 'archived_by': self.archived_by, - 'archived_at': self.archived_at, - 'updated_at': self.updated_at, - 'doc_type': self.doc_type, - 'doc_metadata': self.doc_metadata, - 'doc_form': self.doc_form, - 'doc_language': self.doc_language, - 'display_status': self.display_status, - 'data_source_info_dict': self.data_source_info_dict, - 'average_segment_length': self.average_segment_length, - 'dataset_process_rule': self.dataset_process_rule.to_dict() if self.dataset_process_rule else None, - 'dataset': self.dataset.to_dict() if self.dataset else None, - 'segment_count': self.segment_count, - 'hit_count': self.hit_count + "id": self.id, + "tenant_id": self.tenant_id, + "dataset_id": self.dataset_id, + "position": self.position, + "data_source_type": self.data_source_type, + "data_source_info": self.data_source_info, + "dataset_process_rule_id": self.dataset_process_rule_id, + "batch": self.batch, + "name": self.name, + "created_from": self.created_from, + "created_by": self.created_by, + "created_api_request_id": self.created_api_request_id, + "created_at": self.created_at, + "processing_started_at": self.processing_started_at, + "file_id": self.file_id, + "word_count": self.word_count, + "parsing_completed_at": self.parsing_completed_at, + "cleaning_completed_at": self.cleaning_completed_at, + "splitting_completed_at": self.splitting_completed_at, + "tokens": self.tokens, + "indexing_latency": self.indexing_latency, + "completed_at": self.completed_at, + "is_paused": self.is_paused, + "paused_by": self.paused_by, + "paused_at": self.paused_at, + "error": self.error, + "stopped_at": self.stopped_at, + "indexing_status": self.indexing_status, + "enabled": self.enabled, + "disabled_at": self.disabled_at, + "disabled_by": self.disabled_by, + "archived": self.archived, + "archived_reason": self.archived_reason, + "archived_by": self.archived_by, + "archived_at": self.archived_at, + "updated_at": self.updated_at, + "doc_type": self.doc_type, + "doc_metadata": self.doc_metadata, + "doc_form": self.doc_form, + "doc_language": self.doc_language, + "display_status": self.display_status, + "data_source_info_dict": self.data_source_info_dict, + "average_segment_length": self.average_segment_length, + "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None, + "dataset": self.dataset.to_dict() if self.dataset else None, + "segment_count": self.segment_count, + "hit_count": self.hit_count, } @classmethod def from_dict(cls, data: dict): return cls( - id=data.get('id'), - tenant_id=data.get('tenant_id'), - dataset_id=data.get('dataset_id'), - position=data.get('position'), - data_source_type=data.get('data_source_type'), - data_source_info=data.get('data_source_info'), - dataset_process_rule_id=data.get('dataset_process_rule_id'), - batch=data.get('batch'), - name=data.get('name'), - created_from=data.get('created_from'), - created_by=data.get('created_by'), - created_api_request_id=data.get('created_api_request_id'), - created_at=data.get('created_at'), - processing_started_at=data.get('processing_started_at'), - file_id=data.get('file_id'), - word_count=data.get('word_count'), - parsing_completed_at=data.get('parsing_completed_at'), - cleaning_completed_at=data.get('cleaning_completed_at'), - splitting_completed_at=data.get('splitting_completed_at'), - tokens=data.get('tokens'), - indexing_latency=data.get('indexing_latency'), - completed_at=data.get('completed_at'), - is_paused=data.get('is_paused'), - paused_by=data.get('paused_by'), - paused_at=data.get('paused_at'), - error=data.get('error'), - stopped_at=data.get('stopped_at'), - indexing_status=data.get('indexing_status'), - enabled=data.get('enabled'), - disabled_at=data.get('disabled_at'), - disabled_by=data.get('disabled_by'), - archived=data.get('archived'), - archived_reason=data.get('archived_reason'), - archived_by=data.get('archived_by'), - archived_at=data.get('archived_at'), - updated_at=data.get('updated_at'), - doc_type=data.get('doc_type'), - doc_metadata=data.get('doc_metadata'), - doc_form=data.get('doc_form'), - doc_language=data.get('doc_language') + id=data.get("id"), + tenant_id=data.get("tenant_id"), + dataset_id=data.get("dataset_id"), + position=data.get("position"), + data_source_type=data.get("data_source_type"), + data_source_info=data.get("data_source_info"), + dataset_process_rule_id=data.get("dataset_process_rule_id"), + batch=data.get("batch"), + name=data.get("name"), + created_from=data.get("created_from"), + created_by=data.get("created_by"), + created_api_request_id=data.get("created_api_request_id"), + created_at=data.get("created_at"), + processing_started_at=data.get("processing_started_at"), + file_id=data.get("file_id"), + word_count=data.get("word_count"), + parsing_completed_at=data.get("parsing_completed_at"), + cleaning_completed_at=data.get("cleaning_completed_at"), + splitting_completed_at=data.get("splitting_completed_at"), + tokens=data.get("tokens"), + indexing_latency=data.get("indexing_latency"), + completed_at=data.get("completed_at"), + is_paused=data.get("is_paused"), + paused_by=data.get("paused_by"), + paused_at=data.get("paused_at"), + error=data.get("error"), + stopped_at=data.get("stopped_at"), + indexing_status=data.get("indexing_status"), + enabled=data.get("enabled"), + disabled_at=data.get("disabled_at"), + disabled_by=data.get("disabled_by"), + archived=data.get("archived"), + archived_reason=data.get("archived_reason"), + archived_by=data.get("archived_by"), + archived_at=data.get("archived_at"), + updated_at=data.get("updated_at"), + doc_type=data.get("doc_type"), + doc_metadata=data.get("doc_metadata"), + doc_form=data.get("doc_form"), + doc_language=data.get("doc_language"), ) + class DocumentSegment(db.Model): - __tablename__ = 'document_segments' + __tablename__ = "document_segments" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='document_segment_pkey'), - db.Index('document_segment_dataset_id_idx', 'dataset_id'), - db.Index('document_segment_document_id_idx', 'document_id'), - db.Index('document_segment_tenant_dataset_idx', 'dataset_id', 'tenant_id'), - db.Index('document_segment_tenant_document_idx', 'document_id', 'tenant_id'), - db.Index('document_segment_dataset_node_idx', 'dataset_id', 'index_node_id'), - db.Index('document_segment_tenant_idx', 'tenant_id'), + db.PrimaryKeyConstraint("id", name="document_segment_pkey"), + db.Index("document_segment_dataset_id_idx", "dataset_id"), + db.Index("document_segment_document_id_idx", "document_id"), + db.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"), + db.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"), + db.Index("document_segment_dataset_node_idx", "dataset_id", "index_node_id"), + db.Index("document_segment_tenant_idx", "tenant_id"), ) # initial fields - id = db.Column(StringUUID, nullable=False, - server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) document_id = db.Column(StringUUID, nullable=False) @@ -486,18 +490,14 @@ class DocumentSegment(db.Model): # basic fields hit_count = db.Column(db.Integer, nullable=False, default=0) - enabled = db.Column(db.Boolean, nullable=False, - server_default=db.text('true')) + enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) disabled_at = db.Column(db.DateTime, nullable=True) disabled_by = db.Column(StringUUID, nullable=True) - status = db.Column(db.String(255), nullable=False, - server_default=db.text("'waiting'::character varying")) + status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) indexing_at = db.Column(db.DateTime, nullable=True) completed_at = db.Column(db.DateTime, nullable=True) error = db.Column(db.Text, nullable=True) @@ -513,17 +513,19 @@ class DocumentSegment(db.Model): @property def previous_segment(self): - return db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == self.document_id, - DocumentSegment.position == self.position - 1 - ).first() + return ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1) + .first() + ) @property def next_segment(self): - return db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == self.document_id, - DocumentSegment.position == self.position + 1 - ).first() + return ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1) + .first() + ) def get_sign_content(self): pattern = r"/files/([a-f0-9\-]+)/image-preview" @@ -535,7 +537,7 @@ class DocumentSegment(db.Model): nonce = os.urandom(16).hex() timestamp = str(int(time.time())) data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b'' + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() @@ -546,21 +548,20 @@ class DocumentSegment(db.Model): # Reconstruct the text with signed URLs offset = 0 for start, end, signed_url in signed_urls: - text = text[:start + offset] + signed_url + text[end + offset:] + text = text[: start + offset] + signed_url + text[end + offset :] offset += len(signed_url) - (end - start) return text - class AppDatasetJoin(db.Model): - __tablename__ = 'app_dataset_joins' + __tablename__ = "app_dataset_joins" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='app_dataset_join_pkey'), - db.Index('app_dataset_join_app_dataset_idx', 'dataset_id', 'app_id'), + db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"), + db.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"), ) - id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) @@ -571,13 +572,13 @@ class AppDatasetJoin(db.Model): class DatasetQuery(db.Model): - __tablename__ = 'dataset_queries' + __tablename__ = "dataset_queries" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_query_pkey'), - db.Index('dataset_query_dataset_id_idx', 'dataset_id'), + db.PrimaryKeyConstraint("id", name="dataset_query_pkey"), + db.Index("dataset_query_dataset_id_idx", "dataset_id"), ) - id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) dataset_id = db.Column(StringUUID, nullable=False) content = db.Column(db.Text, nullable=False) source = db.Column(db.String(255), nullable=False) @@ -588,17 +589,18 @@ class DatasetQuery(db.Model): class DatasetKeywordTable(db.Model): - __tablename__ = 'dataset_keyword_tables' + __tablename__ = "dataset_keyword_tables" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_keyword_table_pkey'), - db.Index('dataset_keyword_table_dataset_id_idx', 'dataset_id'), + db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"), + db.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) dataset_id = db.Column(StringUUID, nullable=False, unique=True) keyword_table = db.Column(db.Text, nullable=False) - data_source_type = db.Column(db.String(255), nullable=False, - server_default=db.text("'database'::character varying")) + data_source_type = db.Column( + db.String(255), nullable=False, server_default=db.text("'database'::character varying") + ) @property def keyword_table_dict(self): @@ -614,19 +616,17 @@ class DatasetKeywordTable(db.Model): return dct # get dataset - dataset = Dataset.query.filter_by( - id=self.dataset_id - ).first() + dataset = Dataset.query.filter_by(id=self.dataset_id).first() if not dataset: return None - if self.data_source_type == 'database': + if self.data_source_type == "database": return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None else: - file_key = 'keyword_files/' + dataset.tenant_id + '/' + self.dataset_id + '.txt' + file_key = "keyword_files/" + dataset.tenant_id + "/" + self.dataset_id + ".txt" try: keyword_table_text = storage.load_once(file_key) if keyword_table_text: - return json.loads(keyword_table_text.decode('utf-8'), cls=SetDecoder) + return json.loads(keyword_table_text.decode("utf-8"), cls=SetDecoder) return None except Exception as e: logging.exception(str(e)) @@ -634,21 +634,21 @@ class DatasetKeywordTable(db.Model): class Embedding(db.Model): - __tablename__ = 'embeddings' + __tablename__ = "embeddings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='embedding_pkey'), - db.UniqueConstraint('model_name', 'hash', 'provider_name', name='embedding_hash_idx'), - db.Index('created_at_idx', 'created_at') + db.PrimaryKeyConstraint("id", name="embedding_pkey"), + db.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"), + db.Index("created_at_idx", "created_at"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) - model_name = db.Column(db.String(255), nullable=False, - server_default=db.text("'text-embedding-ada-002'::character varying")) + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + model_name = db.Column( + db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying") + ) hash = db.Column(db.String(64), nullable=False) embedding = db.Column(db.LargeBinary, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - provider_name = db.Column(db.String(255), nullable=False, - server_default=db.text("''::character varying")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying")) def set_embedding(self, embedding_data: list[float]): self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) @@ -658,33 +658,32 @@ class Embedding(db.Model): class DatasetCollectionBinding(db.Model): - __tablename__ = 'dataset_collection_bindings' + __tablename__ = "dataset_collection_bindings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_collection_bindings_pkey'), - db.Index('provider_model_name_idx', 'provider_name', 'model_name') - + db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"), + db.Index("provider_model_name_idx", "provider_name", "model_name"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) provider_name = db.Column(db.String(40), nullable=False) model_name = db.Column(db.String(255), nullable=False) type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) collection_name = db.Column(db.String(64), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class DatasetPermission(db.Model): - __tablename__ = 'dataset_permissions' + __tablename__ = "dataset_permissions" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_permission_pkey'), - db.Index('idx_dataset_permissions_dataset_id', 'dataset_id'), - db.Index('idx_dataset_permissions_account_id', 'account_id'), - db.Index('idx_dataset_permissions_tenant_id', 'tenant_id') + db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"), + db.Index("idx_dataset_permissions_dataset_id", "dataset_id"), + db.Index("idx_dataset_permissions_account_id", "account_id"), + db.Index("idx_dataset_permissions_tenant_id", "tenant_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'), primary_key=True) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True) dataset_id = db.Column(StringUUID, nullable=False) account_id = db.Column(StringUUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False) - has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/model.py b/api/models/model.py index e2d1fcfc23..8ab3026522 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -20,25 +20,23 @@ from .types import StringUUID class DifySetup(db.Model): - __tablename__ = 'dify_setups' - __table_args__ = ( - db.PrimaryKeyConstraint('version', name='dify_setup_pkey'), - ) + __tablename__ = "dify_setups" + __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) version = db.Column(db.String(255), nullable=False) - setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class AppMode(Enum): - COMPLETION = 'completion' - WORKFLOW = 'workflow' - CHAT = 'chat' - ADVANCED_CHAT = 'advanced-chat' - AGENT_CHAT = 'agent-chat' - CHANNEL = 'channel' + COMPLETION = "completion" + WORKFLOW = "workflow" + CHAT = "chat" + ADVANCED_CHAT = "advanced-chat" + AGENT_CHAT = "agent-chat" + CHANNEL = "channel" @classmethod - def value_of(cls, value: str) -> 'AppMode': + def value_of(cls, value: str) -> "AppMode": """ Get value of given mode. @@ -48,21 +46,19 @@ class AppMode(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") class IconType(Enum): IMAGE = "image" EMOJI = "emoji" -class App(db.Model): - __tablename__ = 'apps' - __table_args__ = ( - db.PrimaryKeyConstraint('id', name='app_pkey'), - db.Index('app_tenant_id_idx', 'tenant_id') - ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) +class App(db.Model): + __tablename__ = "apps" + __table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id")) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) @@ -75,17 +71,18 @@ class App(db.Model): status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) enable_site = db.Column(db.Boolean, nullable=False) enable_api = db.Column(db.Boolean, nullable=False) - api_rpm = db.Column(db.Integer, nullable=False, server_default=db.text('0')) - api_rph = db.Column(db.Integer, nullable=False, server_default=db.text('0')) - is_demo = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - is_public = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + api_rpm = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + api_rph = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + is_demo = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) tracing = db.Column(db.Text, nullable=True) max_active_requests = db.Column(db.Integer, nullable=True) created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) @property def desc_or_prompt(self): @@ -96,7 +93,7 @@ class App(db.Model): if app_model_config: return app_model_config.pre_prompt else: - return '' + return "" @property def site(self): @@ -104,24 +101,24 @@ class App(db.Model): return site @property - def app_model_config(self) -> Optional['AppModelConfig']: + def app_model_config(self) -> Optional["AppModelConfig"]: if self.app_model_config_id: return db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first() return None @property - def workflow(self) -> Optional['Workflow']: + def workflow(self) -> Optional["Workflow"]: if self.workflow_id: from .workflow import Workflow + return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() return None @property def api_base_url(self): - return (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL - else request.host_url.rstrip('/')) + '/v1' + return (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/")) + "/v1" @property def tenant(self): @@ -135,8 +132,9 @@ class App(db.Model): return False if not app_model_config.agent_mode: return False - if self.app_model_config.agent_mode_dict.get('enabled', False) \ - and self.app_model_config.agent_mode_dict.get('strategy', '') in ['function_call', 'react']: + if self.app_model_config.agent_mode_dict.get("enabled", False) and self.app_model_config.agent_mode_dict.get( + "strategy", "" + ) in ["function_call", "react"]: self.mode = AppMode.AGENT_CHAT.value db.session.commit() return True @@ -158,16 +156,16 @@ class App(db.Model): if not app_model_config.agent_mode: return [] agent_mode = app_model_config.agent_mode_dict - tools = agent_mode.get('tools', []) + tools = agent_mode.get("tools", []) provider_ids = [] for tool in tools: keys = list(tool.keys()) if len(keys) >= 4: - provider_type = tool.get('provider_type', '') - provider_id = tool.get('provider_id', '') - if provider_type == 'api': + provider_type = tool.get("provider_type", "") + provider_id = tool.get("provider_id", "") + if provider_type == "api": # check if provider id is a uuid string, if not, skip try: uuid.UUID(provider_id) @@ -179,8 +177,7 @@ class App(db.Model): return [] api_providers = db.session.execute( - text('SELECT id FROM tool_api_providers WHERE id IN :provider_ids'), - {'provider_ids': tuple(provider_ids)} + text("SELECT id FROM tool_api_providers WHERE id IN :provider_ids"), {"provider_ids": tuple(provider_ids)} ).fetchall() deleted_tools = [] @@ -189,44 +186,43 @@ class App(db.Model): for tool in tools: keys = list(tool.keys()) if len(keys) >= 4: - provider_type = tool.get('provider_type', '') - provider_id = tool.get('provider_id', '') - if provider_type == 'api' and provider_id not in current_api_provider_ids: - deleted_tools.append(tool['tool_name']) + provider_type = tool.get("provider_type", "") + provider_id = tool.get("provider_id", "") + if provider_type == "api" and provider_id not in current_api_provider_ids: + deleted_tools.append(tool["tool_name"]) return deleted_tools @property def tags(self): - tags = db.session.query(Tag).join( - TagBinding, - Tag.id == TagBinding.tag_id - ).filter( - TagBinding.target_id == self.id, - TagBinding.tenant_id == self.tenant_id, - Tag.tenant_id == self.tenant_id, - Tag.type == 'app' - ).all() + tags = ( + db.session.query(Tag) + .join(TagBinding, Tag.id == TagBinding.tag_id) + .filter( + TagBinding.target_id == self.id, + TagBinding.tenant_id == self.tenant_id, + Tag.tenant_id == self.tenant_id, + Tag.type == "app", + ) + .all() + ) return tags if tags else [] class AppModelConfig(db.Model): - __tablename__ = 'app_model_configs' - __table_args__ = ( - db.PrimaryKeyConstraint('id', name='app_model_config_pkey'), - db.Index('app_app_id_idx', 'app_id') - ) + __tablename__ = "app_model_configs" + __table_args__ = (db.PrimaryKeyConstraint("id", name="app_model_config_pkey"), db.Index("app_app_id_idx", "app_id")) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) provider = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True) configs = db.Column(db.JSON, nullable=True) created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) opening_statement = db.Column(db.Text) suggested_questions = db.Column(db.Text) suggested_questions_after_answer = db.Column(db.Text) @@ -262,28 +258,29 @@ class AppModelConfig(db.Model): @property def suggested_questions_after_answer_dict(self) -> dict: - return json.loads(self.suggested_questions_after_answer) if self.suggested_questions_after_answer \ + return ( + json.loads(self.suggested_questions_after_answer) + if self.suggested_questions_after_answer else {"enabled": False} + ) @property def speech_to_text_dict(self) -> dict: - return json.loads(self.speech_to_text) if self.speech_to_text \ - else {"enabled": False} + return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False} @property def text_to_speech_dict(self) -> dict: - return json.loads(self.text_to_speech) if self.text_to_speech \ - else {"enabled": False} + return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False} @property def retriever_resource_dict(self) -> dict: - return json.loads(self.retriever_resource) if self.retriever_resource \ - else {"enabled": True} + return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} @property def annotation_reply_dict(self) -> dict: - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == self.app_id).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == self.app_id).first() + ) if annotation_setting: collection_binding_detail = annotation_setting.collection_binding_detail return { @@ -292,8 +289,8 @@ class AppModelConfig(db.Model): "score_threshold": annotation_setting.score_threshold, "embedding_model": { "embedding_provider_name": collection_binding_detail.provider_name, - "embedding_model_name": collection_binding_detail.model_name - } + "embedding_model_name": collection_binding_detail.model_name, + }, } else: @@ -305,13 +302,15 @@ class AppModelConfig(db.Model): @property def sensitive_word_avoidance_dict(self) -> dict: - return json.loads(self.sensitive_word_avoidance) if self.sensitive_word_avoidance \ + return ( + json.loads(self.sensitive_word_avoidance) + if self.sensitive_word_avoidance else {"enabled": False, "type": "", "configs": []} + ) @property def external_data_tools_list(self) -> list[dict]: - return json.loads(self.external_data_tools) if self.external_data_tools \ - else [] + return json.loads(self.external_data_tools) if self.external_data_tools else [] @property def user_input_form_list(self) -> dict: @@ -319,8 +318,11 @@ class AppModelConfig(db.Model): @property def agent_mode_dict(self) -> dict: - return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "strategy": None, "tools": [], - "prompt": None} + return ( + json.loads(self.agent_mode) + if self.agent_mode + else {"enabled": False, "strategy": None, "tools": [], "prompt": None} + ) @property def chat_prompt_config_dict(self) -> dict: @@ -334,19 +336,28 @@ class AppModelConfig(db.Model): def dataset_configs_dict(self) -> dict: if self.dataset_configs: dataset_configs = json.loads(self.dataset_configs) - if 'retrieval_model' not in dataset_configs: - return {'retrieval_model': 'single'} + if "retrieval_model" not in dataset_configs: + return {"retrieval_model": "single"} else: return dataset_configs return { - 'retrieval_model': 'multiple', - } + "retrieval_model": "multiple", + } @property def file_upload_dict(self) -> dict: - return json.loads(self.file_upload) if self.file_upload else { - "image": {"enabled": False, "number_limits": 3, "detail": "high", - "transfer_methods": ["remote_url", "local_file"]}} + return ( + json.loads(self.file_upload) + if self.file_upload + else { + "image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"], + } + } + ) def to_dict(self) -> dict: return { @@ -369,44 +380,53 @@ class AppModelConfig(db.Model): "chat_prompt_config": self.chat_prompt_config_dict, "completion_prompt_config": self.completion_prompt_config_dict, "dataset_configs": self.dataset_configs_dict, - "file_upload": self.file_upload_dict + "file_upload": self.file_upload_dict, } def from_model_config_dict(self, model_config: dict): - self.opening_statement = model_config.get('opening_statement') - self.suggested_questions = json.dumps(model_config['suggested_questions']) \ - if model_config.get('suggested_questions') else None - self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer']) \ - if model_config.get('suggested_questions_after_answer') else None - self.speech_to_text = json.dumps(model_config['speech_to_text']) \ - if model_config.get('speech_to_text') else None - self.text_to_speech = json.dumps(model_config['text_to_speech']) \ - if model_config.get('text_to_speech') else None - self.more_like_this = json.dumps(model_config['more_like_this']) \ - if model_config.get('more_like_this') else None - self.sensitive_word_avoidance = json.dumps(model_config['sensitive_word_avoidance']) \ - if model_config.get('sensitive_word_avoidance') else None - self.external_data_tools = json.dumps(model_config['external_data_tools']) \ - if model_config.get('external_data_tools') else None - self.model = json.dumps(model_config['model']) \ - if model_config.get('model') else None - self.user_input_form = json.dumps(model_config['user_input_form']) \ - if model_config.get('user_input_form') else None - self.dataset_query_variable = model_config.get('dataset_query_variable') - self.pre_prompt = model_config['pre_prompt'] - self.agent_mode = json.dumps(model_config['agent_mode']) \ - if model_config.get('agent_mode') else None - self.retriever_resource = json.dumps(model_config['retriever_resource']) \ - if model_config.get('retriever_resource') else None - self.prompt_type = model_config.get('prompt_type', 'simple') - self.chat_prompt_config = json.dumps(model_config.get('chat_prompt_config')) \ - if model_config.get('chat_prompt_config') else None - self.completion_prompt_config = json.dumps(model_config.get('completion_prompt_config')) \ - if model_config.get('completion_prompt_config') else None - self.dataset_configs = json.dumps(model_config.get('dataset_configs')) \ - if model_config.get('dataset_configs') else None - self.file_upload = json.dumps(model_config.get('file_upload')) \ - if model_config.get('file_upload') else None + self.opening_statement = model_config.get("opening_statement") + self.suggested_questions = ( + json.dumps(model_config["suggested_questions"]) if model_config.get("suggested_questions") else None + ) + self.suggested_questions_after_answer = ( + json.dumps(model_config["suggested_questions_after_answer"]) + if model_config.get("suggested_questions_after_answer") + else None + ) + self.speech_to_text = json.dumps(model_config["speech_to_text"]) if model_config.get("speech_to_text") else None + self.text_to_speech = json.dumps(model_config["text_to_speech"]) if model_config.get("text_to_speech") else None + self.more_like_this = json.dumps(model_config["more_like_this"]) if model_config.get("more_like_this") else None + self.sensitive_word_avoidance = ( + json.dumps(model_config["sensitive_word_avoidance"]) + if model_config.get("sensitive_word_avoidance") + else None + ) + self.external_data_tools = ( + json.dumps(model_config["external_data_tools"]) if model_config.get("external_data_tools") else None + ) + self.model = json.dumps(model_config["model"]) if model_config.get("model") else None + self.user_input_form = ( + json.dumps(model_config["user_input_form"]) if model_config.get("user_input_form") else None + ) + self.dataset_query_variable = model_config.get("dataset_query_variable") + self.pre_prompt = model_config["pre_prompt"] + self.agent_mode = json.dumps(model_config["agent_mode"]) if model_config.get("agent_mode") else None + self.retriever_resource = ( + json.dumps(model_config["retriever_resource"]) if model_config.get("retriever_resource") else None + ) + self.prompt_type = model_config.get("prompt_type", "simple") + self.chat_prompt_config = ( + json.dumps(model_config.get("chat_prompt_config")) if model_config.get("chat_prompt_config") else None + ) + self.completion_prompt_config = ( + json.dumps(model_config.get("completion_prompt_config")) + if model_config.get("completion_prompt_config") + else None + ) + self.dataset_configs = ( + json.dumps(model_config.get("dataset_configs")) if model_config.get("dataset_configs") else None + ) + self.file_upload = json.dumps(model_config.get("file_upload")) if model_config.get("file_upload") else None return self def copy(self): @@ -431,21 +451,21 @@ class AppModelConfig(db.Model): chat_prompt_config=self.chat_prompt_config, completion_prompt_config=self.completion_prompt_config, dataset_configs=self.dataset_configs, - file_upload=self.file_upload + file_upload=self.file_upload, ) return new_app_model_config class RecommendedApp(db.Model): - __tablename__ = 'recommended_apps' + __tablename__ = "recommended_apps" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='recommended_app_pkey'), - db.Index('recommended_app_app_id_idx', 'app_id'), - db.Index('recommended_app_is_listed_idx', 'is_listed', 'language') + db.PrimaryKeyConstraint("id", name="recommended_app_pkey"), + db.Index("recommended_app_app_id_idx", "app_id"), + db.Index("recommended_app_is_listed_idx", "is_listed", "language"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) description = db.Column(db.JSON, nullable=False) copyright = db.Column(db.String(255), nullable=False) @@ -456,8 +476,8 @@ class RecommendedApp(db.Model): is_listed = db.Column(db.Boolean, nullable=False, default=True) install_count = db.Column(db.Integer, nullable=False, default=0) language = db.Column(db.String(255), nullable=False, server_default=db.text("'en-US'::character varying")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def app(self): @@ -466,22 +486,22 @@ class RecommendedApp(db.Model): class InstalledApp(db.Model): - __tablename__ = 'installed_apps' + __tablename__ = "installed_apps" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='installed_app_pkey'), - db.Index('installed_app_tenant_id_idx', 'tenant_id'), - db.Index('installed_app_app_id_idx', 'app_id'), - db.UniqueConstraint('tenant_id', 'app_id', name='unique_tenant_app') + db.PrimaryKeyConstraint("id", name="installed_app_pkey"), + db.Index("installed_app_tenant_id_idx", "tenant_id"), + db.Index("installed_app_app_id_idx", "app_id"), + db.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False) app_owner_tenant_id = db.Column(StringUUID, nullable=False) position = db.Column(db.Integer, nullable=False, default=0) - is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) last_used_at = db.Column(db.DateTime, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def app(self): @@ -495,13 +515,13 @@ class InstalledApp(db.Model): class Conversation(db.Model): - __tablename__ = 'conversations' + __tablename__ = "conversations" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='conversation_pkey'), - db.Index('conversation_app_from_user_idx', 'app_id', 'from_source', 'from_end_user_id') + db.PrimaryKeyConstraint("id", name="conversation_pkey"), + db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) app_model_config_id = db.Column(StringUUID, nullable=True) model_provider = db.Column(db.String(255), nullable=True) @@ -513,7 +533,7 @@ class Conversation(db.Model): inputs = db.Column(db.JSON) introduction = db.Column(db.Text) system_instruction = db.Column(db.Text) - system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) status = db.Column(db.String(255), nullable=False) invoke_from = db.Column(db.String(255), nullable=True) from_source = db.Column(db.String(255), nullable=False) @@ -522,13 +542,15 @@ class Conversation(db.Model): read_at = db.Column(db.DateTime) read_account_id = db.Column(StringUUID) dialogue_count: Mapped[int] = mapped_column(default=0) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all") - message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all") + messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all") + message_annotations = db.relationship( + "MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all" + ) - is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) @property def model_config(self): @@ -541,20 +563,21 @@ class Conversation(db.Model): if self.override_model_configs: override_model_configs = json.loads(self.override_model_configs) - if 'model' in override_model_configs: + if "model" in override_model_configs: app_model_config = AppModelConfig() app_model_config = app_model_config.from_model_config_dict(override_model_configs) model_config = app_model_config.to_dict() else: - model_config['configs'] = override_model_configs + model_config["configs"] = override_model_configs else: - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == self.app_model_config_id).first() + app_model_config = ( + db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first() + ) model_config = app_model_config.to_dict() - model_config['model_id'] = self.model_id - model_config['provider'] = self.model_provider + model_config["model_id"] = self.model_id + model_config["provider"] = self.model_provider return model_config @@ -567,7 +590,7 @@ class Conversation(db.Model): if first_message: return first_message.query else: - return '' + return "" @property def annotated(self): @@ -583,31 +606,51 @@ class Conversation(db.Model): @property def user_feedback_stats(self): - like = db.session.query(MessageFeedback) \ - .filter(MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == 'user', - MessageFeedback.rating == 'like').count() + like = ( + db.session.query(MessageFeedback) + .filter( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "user", + MessageFeedback.rating == "like", + ) + .count() + ) - dislike = db.session.query(MessageFeedback) \ - .filter(MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == 'user', - MessageFeedback.rating == 'dislike').count() + dislike = ( + db.session.query(MessageFeedback) + .filter( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "user", + MessageFeedback.rating == "dislike", + ) + .count() + ) - return {'like': like, 'dislike': dislike} + return {"like": like, "dislike": dislike} @property def admin_feedback_stats(self): - like = db.session.query(MessageFeedback) \ - .filter(MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == 'admin', - MessageFeedback.rating == 'like').count() + like = ( + db.session.query(MessageFeedback) + .filter( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "admin", + MessageFeedback.rating == "like", + ) + .count() + ) - dislike = db.session.query(MessageFeedback) \ - .filter(MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == 'admin', - MessageFeedback.rating == 'dislike').count() + dislike = ( + db.session.query(MessageFeedback) + .filter( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "admin", + MessageFeedback.rating == "dislike", + ) + .count() + ) - return {'like': like, 'dislike': dislike} + return {"like": like, "dislike": dislike} @property def first_message(self): @@ -641,33 +684,33 @@ class Conversation(db.Model): class Message(db.Model): - __tablename__ = 'messages' + __tablename__ = "messages" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='message_pkey'), - db.Index('message_app_id_idx', 'app_id', 'created_at'), - db.Index('message_conversation_id_idx', 'conversation_id'), - db.Index('message_end_user_idx', 'app_id', 'from_source', 'from_end_user_id'), - db.Index('message_account_idx', 'app_id', 'from_source', 'from_account_id'), - db.Index('message_workflow_run_id_idx', 'conversation_id', 'workflow_run_id') + db.PrimaryKeyConstraint("id", name="message_pkey"), + db.Index("message_app_id_idx", "app_id", "created_at"), + db.Index("message_conversation_id_idx", "conversation_id"), + db.Index("message_end_user_idx", "app_id", "from_source", "from_end_user_id"), + db.Index("message_account_idx", "app_id", "from_source", "from_account_id"), + db.Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) model_provider = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True) override_model_configs = db.Column(db.Text) - conversation_id = db.Column(StringUUID, db.ForeignKey('conversations.id'), nullable=False) + conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=False) inputs = db.Column(db.JSON) query = db.Column(db.Text, nullable=False) message = db.Column(db.JSON, nullable=False) - message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) message_unit_price = db.Column(db.Numeric(10, 4), nullable=False) - message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) + message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) answer = db.Column(db.Text, nullable=False) - answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) - answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) - provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text('0')) + answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0")) total_price = db.Column(db.Numeric(10, 7)) currency = db.Column(db.String(255), nullable=False) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) @@ -677,9 +720,9 @@ class Message(db.Model): from_source = db.Column(db.String(255), nullable=False) from_end_user_id = db.Column(StringUUID) from_account_id = db.Column(StringUUID) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) workflow_run_id = db.Column(StringUUID) @property @@ -687,7 +730,7 @@ class Message(db.Model): if not self.answer: return self.answer - pattern = r'\[!?.*?\]\((((http|https):\/\/.+)?\/files\/(tools\/)?[\w-]+.*?timestamp=.*&nonce=.*&sign=.*)\)' + pattern = r"\[!?.*?\]\((((http|https):\/\/.+)?\/files\/(tools\/)?[\w-]+.*?timestamp=.*&nonce=.*&sign=.*)\)" matches = re.findall(pattern, self.answer) if not matches: @@ -703,9 +746,9 @@ class Message(db.Model): re_sign_file_url_answer = self.answer for url in urls: - if 'files/tools' in url: + if "files/tools" in url: # get tool file id - tool_file_id_pattern = r'\/files\/tools\/([\.\w-]+)?\?timestamp=' + tool_file_id_pattern = r"\/files\/tools\/([\.\w-]+)?\?timestamp=" result = re.search(tool_file_id_pattern, url) if not result: continue @@ -713,25 +756,24 @@ class Message(db.Model): tool_file_id = result.group(1) # get extension - if '.' in tool_file_id: - split_result = tool_file_id.split('.') - extension = f'.{split_result[-1]}' + if "." in tool_file_id: + split_result = tool_file_id.split(".") + extension = f".{split_result[-1]}" if len(extension) > 10: - extension = '.bin' + extension = ".bin" tool_file_id = split_result[0] else: - extension = '.bin' + extension = ".bin" if not tool_file_id: continue sign_url = ToolFileParser.get_tool_file_manager().sign_file( - tool_file_id=tool_file_id, - extension=extension + tool_file_id=tool_file_id, extension=extension ) else: # get upload file id - upload_file_id_pattern = r'\/files\/([\w-]+)\/image-preview?\?timestamp=' + upload_file_id_pattern = r"\/files\/([\w-]+)\/image-preview?\?timestamp=" result = re.search(upload_file_id_pattern, url) if not result: continue @@ -749,14 +791,20 @@ class Message(db.Model): @property def user_feedback(self): - feedback = db.session.query(MessageFeedback).filter(MessageFeedback.message_id == self.id, - MessageFeedback.from_source == 'user').first() + feedback = ( + db.session.query(MessageFeedback) + .filter(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user") + .first() + ) return feedback @property def admin_feedback(self): - feedback = db.session.query(MessageFeedback).filter(MessageFeedback.message_id == self.id, - MessageFeedback.from_source == 'admin').first() + feedback = ( + db.session.query(MessageFeedback) + .filter(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin") + .first() + ) return feedback @property @@ -771,11 +819,15 @@ class Message(db.Model): @property def annotation_hit_history(self): - annotation_history = (db.session.query(AppAnnotationHitHistory) - .filter(AppAnnotationHitHistory.message_id == self.id).first()) + annotation_history = ( + db.session.query(AppAnnotationHitHistory).filter(AppAnnotationHitHistory.message_id == self.id).first() + ) if annotation_history: - annotation = (db.session.query(MessageAnnotation). - filter(MessageAnnotation.id == annotation_history.annotation_id).first()) + annotation = ( + db.session.query(MessageAnnotation) + .filter(MessageAnnotation.id == annotation_history.annotation_id) + .first() + ) return annotation return None @@ -783,8 +835,9 @@ class Message(db.Model): def app_model_config(self): conversation = db.session.query(Conversation).filter(Conversation.id == self.conversation_id).first() if conversation: - return db.session.query(AppModelConfig).filter( - AppModelConfig.id == conversation.app_model_config_id).first() + return ( + db.session.query(AppModelConfig).filter(AppModelConfig.id == conversation.app_model_config_id).first() + ) return None @@ -798,13 +851,21 @@ class Message(db.Model): @property def agent_thoughts(self): - return db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == self.id) \ - .order_by(MessageAgentThought.position.asc()).all() + return ( + db.session.query(MessageAgentThought) + .filter(MessageAgentThought.message_id == self.id) + .order_by(MessageAgentThought.position.asc()) + .all() + ) @property def retriever_resources(self): - return db.session.query(DatasetRetrieverResource).filter(DatasetRetrieverResource.message_id == self.id) \ - .order_by(DatasetRetrieverResource.position.asc()).all() + return ( + db.session.query(DatasetRetrieverResource) + .filter(DatasetRetrieverResource.message_id == self.id) + .order_by(DatasetRetrieverResource.position.asc()) + .all() + ) @property def message_files(self): @@ -817,39 +878,39 @@ class Message(db.Model): files = [] for message_file in message_files: url = message_file.url - if message_file.type == 'image': - if message_file.transfer_method == 'local_file': - upload_file = (db.session.query(UploadFile) - .filter( - UploadFile.id == message_file.upload_file_id - ).first()) - - url = UploadFileParser.get_image_data( - upload_file=upload_file, - force_url=True + if message_file.type == "image": + if message_file.transfer_method == "local_file": + upload_file = ( + db.session.query(UploadFile).filter(UploadFile.id == message_file.upload_file_id).first() ) - if message_file.transfer_method == 'tool_file': + + url = UploadFileParser.get_image_data(upload_file=upload_file, force_url=True) + if message_file.transfer_method == "tool_file": # get tool file id - tool_file_id = message_file.url.split('/')[-1] + tool_file_id = message_file.url.split("/")[-1] # trim extension - tool_file_id = tool_file_id.split('.')[0] + tool_file_id = tool_file_id.split(".")[0] # get extension - if '.' in message_file.url: + if "." in message_file.url: extension = f'.{message_file.url.split(".")[-1]}' if len(extension) > 10: - extension = '.bin' + extension = ".bin" else: - extension = '.bin' + extension = ".bin" # add sign url - url = ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=tool_file_id, extension=extension) + url = ToolFileParser.get_tool_file_manager().sign_file( + tool_file_id=tool_file_id, extension=extension + ) - files.append({ - 'id': message_file.id, - 'type': message_file.type, - 'url': url, - 'belongs_to': message_file.belongs_to if message_file.belongs_to else 'user' - }) + files.append( + { + "id": message_file.id, + "type": message_file.type, + "url": url, + "belongs_to": message_file.belongs_to if message_file.belongs_to else "user", + } + ) return files @@ -857,64 +918,65 @@ class Message(db.Model): def workflow_run(self): if self.workflow_run_id: from .workflow import WorkflowRun + return db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first() return None def to_dict(self) -> dict: return { - 'id': self.id, - 'app_id': self.app_id, - 'conversation_id': self.conversation_id, - 'inputs': self.inputs, - 'query': self.query, - 'message': self.message, - 'answer': self.answer, - 'status': self.status, - 'error': self.error, - 'message_metadata': self.message_metadata_dict, - 'from_source': self.from_source, - 'from_end_user_id': self.from_end_user_id, - 'from_account_id': self.from_account_id, - 'created_at': self.created_at.isoformat(), - 'updated_at': self.updated_at.isoformat(), - 'agent_based': self.agent_based, - 'workflow_run_id': self.workflow_run_id + "id": self.id, + "app_id": self.app_id, + "conversation_id": self.conversation_id, + "inputs": self.inputs, + "query": self.query, + "message": self.message, + "answer": self.answer, + "status": self.status, + "error": self.error, + "message_metadata": self.message_metadata_dict, + "from_source": self.from_source, + "from_end_user_id": self.from_end_user_id, + "from_account_id": self.from_account_id, + "created_at": self.created_at.isoformat(), + "updated_at": self.updated_at.isoformat(), + "agent_based": self.agent_based, + "workflow_run_id": self.workflow_run_id, } @classmethod def from_dict(cls, data: dict): return cls( - id=data['id'], - app_id=data['app_id'], - conversation_id=data['conversation_id'], - inputs=data['inputs'], - query=data['query'], - message=data['message'], - answer=data['answer'], - status=data['status'], - error=data['error'], - message_metadata=json.dumps(data['message_metadata']), - from_source=data['from_source'], - from_end_user_id=data['from_end_user_id'], - from_account_id=data['from_account_id'], - created_at=data['created_at'], - updated_at=data['updated_at'], - agent_based=data['agent_based'], - workflow_run_id=data['workflow_run_id'] + id=data["id"], + app_id=data["app_id"], + conversation_id=data["conversation_id"], + inputs=data["inputs"], + query=data["query"], + message=data["message"], + answer=data["answer"], + status=data["status"], + error=data["error"], + message_metadata=json.dumps(data["message_metadata"]), + from_source=data["from_source"], + from_end_user_id=data["from_end_user_id"], + from_account_id=data["from_account_id"], + created_at=data["created_at"], + updated_at=data["updated_at"], + agent_based=data["agent_based"], + workflow_run_id=data["workflow_run_id"], ) class MessageFeedback(db.Model): - __tablename__ = 'message_feedbacks' + __tablename__ = "message_feedbacks" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='message_feedback_pkey'), - db.Index('message_feedback_app_idx', 'app_id'), - db.Index('message_feedback_message_idx', 'message_id', 'from_source'), - db.Index('message_feedback_conversation_idx', 'conversation_id', 'from_source', 'rating') + db.PrimaryKeyConstraint("id", name="message_feedback_pkey"), + db.Index("message_feedback_app_idx", "app_id"), + db.Index("message_feedback_message_idx", "message_id", "from_source"), + db.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) conversation_id = db.Column(StringUUID, nullable=False) message_id = db.Column(StringUUID, nullable=False) @@ -923,8 +985,8 @@ class MessageFeedback(db.Model): from_source = db.Column(db.String(255), nullable=False) from_end_user_id = db.Column(StringUUID) from_account_id = db.Column(StringUUID) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def from_account(self): @@ -933,14 +995,14 @@ class MessageFeedback(db.Model): class MessageFile(db.Model): - __tablename__ = 'message_files' + __tablename__ = "message_files" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='message_file_pkey'), - db.Index('message_file_message_idx', 'message_id'), - db.Index('message_file_created_by_idx', 'created_by') + db.PrimaryKeyConstraint("id", name="message_file_pkey"), + db.Index("message_file_message_idx", "message_id"), + db.Index("message_file_created_by_idx", "created_by"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) message_id = db.Column(StringUUID, nullable=False) type = db.Column(db.String(255), nullable=False) transfer_method = db.Column(db.String(255), nullable=False) @@ -949,28 +1011,28 @@ class MessageFile(db.Model): upload_file_id = db.Column(StringUUID, nullable=True) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class MessageAnnotation(db.Model): - __tablename__ = 'message_annotations' + __tablename__ = "message_annotations" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='message_annotation_pkey'), - db.Index('message_annotation_app_idx', 'app_id'), - db.Index('message_annotation_conversation_idx', 'conversation_id'), - db.Index('message_annotation_message_idx', 'message_id') + db.PrimaryKeyConstraint("id", name="message_annotation_pkey"), + db.Index("message_annotation_app_idx", "app_id"), + db.Index("message_annotation_conversation_idx", "conversation_id"), + db.Index("message_annotation_message_idx", "message_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) - conversation_id = db.Column(StringUUID, db.ForeignKey('conversations.id'), nullable=True) + conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=True) message_id = db.Column(StringUUID, nullable=True) question = db.Column(db.Text, nullable=True) content = db.Column(db.Text, nullable=False) - hit_count = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + hit_count = db.Column(db.Integer, nullable=False, server_default=db.text("0")) account_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def account(self): @@ -984,32 +1046,35 @@ class MessageAnnotation(db.Model): class AppAnnotationHitHistory(db.Model): - __tablename__ = 'app_annotation_hit_histories' + __tablename__ = "app_annotation_hit_histories" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='app_annotation_hit_histories_pkey'), - db.Index('app_annotation_hit_histories_app_idx', 'app_id'), - db.Index('app_annotation_hit_histories_account_idx', 'account_id'), - db.Index('app_annotation_hit_histories_annotation_idx', 'annotation_id'), - db.Index('app_annotation_hit_histories_message_idx', 'message_id'), + db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"), + db.Index("app_annotation_hit_histories_app_idx", "app_id"), + db.Index("app_annotation_hit_histories_account_idx", "account_id"), + db.Index("app_annotation_hit_histories_annotation_idx", "annotation_id"), + db.Index("app_annotation_hit_histories_message_idx", "message_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) annotation_id = db.Column(StringUUID, nullable=False) source = db.Column(db.Text, nullable=False) question = db.Column(db.Text, nullable=False) account_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - score = db.Column(Float, nullable=False, server_default=db.text('0')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + score = db.Column(Float, nullable=False, server_default=db.text("0")) message_id = db.Column(StringUUID, nullable=False) annotation_question = db.Column(db.Text, nullable=False) annotation_content = db.Column(db.Text, nullable=False) @property def account(self): - account = (db.session.query(Account) - .join(MessageAnnotation, MessageAnnotation.account_id == Account.id) - .filter(MessageAnnotation.id == self.annotation_id).first()) + account = ( + db.session.query(Account) + .join(MessageAnnotation, MessageAnnotation.account_id == Account.id) + .filter(MessageAnnotation.id == self.annotation_id) + .first() + ) return account @property @@ -1019,89 +1084,99 @@ class AppAnnotationHitHistory(db.Model): class AppAnnotationSetting(db.Model): - __tablename__ = 'app_annotation_settings' + __tablename__ = "app_annotation_settings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='app_annotation_settings_pkey'), - db.Index('app_annotation_settings_app_idx', 'app_id') + db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"), + db.Index("app_annotation_settings_app_idx", "app_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) - score_threshold = db.Column(Float, nullable=False, server_default=db.text('0')) + score_threshold = db.Column(Float, nullable=False, server_default=db.text("0")) collection_binding_id = db.Column(StringUUID, nullable=False) created_user_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_user_id = db.Column(StringUUID, nullable=False) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def created_account(self): - account = (db.session.query(Account) - .join(AppAnnotationSetting, AppAnnotationSetting.created_user_id == Account.id) - .filter(AppAnnotationSetting.id == self.annotation_id).first()) + account = ( + db.session.query(Account) + .join(AppAnnotationSetting, AppAnnotationSetting.created_user_id == Account.id) + .filter(AppAnnotationSetting.id == self.annotation_id) + .first() + ) return account @property def updated_account(self): - account = (db.session.query(Account) - .join(AppAnnotationSetting, AppAnnotationSetting.updated_user_id == Account.id) - .filter(AppAnnotationSetting.id == self.annotation_id).first()) + account = ( + db.session.query(Account) + .join(AppAnnotationSetting, AppAnnotationSetting.updated_user_id == Account.id) + .filter(AppAnnotationSetting.id == self.annotation_id) + .first() + ) return account @property def collection_binding_detail(self): from .dataset import DatasetCollectionBinding - collection_binding_detail = (db.session.query(DatasetCollectionBinding) - .filter(DatasetCollectionBinding.id == self.collection_binding_id).first()) + + collection_binding_detail = ( + db.session.query(DatasetCollectionBinding) + .filter(DatasetCollectionBinding.id == self.collection_binding_id) + .first() + ) return collection_binding_detail class OperationLog(db.Model): - __tablename__ = 'operation_logs' + __tablename__ = "operation_logs" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='operation_log_pkey'), - db.Index('operation_log_account_action_idx', 'tenant_id', 'account_id', 'action') + db.PrimaryKeyConstraint("id", name="operation_log_pkey"), + db.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) account_id = db.Column(StringUUID, nullable=False) action = db.Column(db.String(255), nullable=False) content = db.Column(db.JSON) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) created_ip = db.Column(db.String(255), nullable=False) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class EndUser(UserMixin, db.Model): - __tablename__ = 'end_users' + __tablename__ = "end_users" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='end_user_pkey'), - db.Index('end_user_session_id_idx', 'session_id', 'type'), - db.Index('end_user_tenant_session_id_idx', 'tenant_id', 'session_id', 'type'), + db.PrimaryKeyConstraint("id", name="end_user_pkey"), + db.Index("end_user_session_id_idx", "session_id", "type"), + db.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=True) type = db.Column(db.String(255), nullable=False) external_user_id = db.Column(db.String(255), nullable=True) name = db.Column(db.String(255)) - is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) + is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) session_id = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class Site(db.Model): - __tablename__ = 'sites' + __tablename__ = "sites" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='site_pkey'), - db.Index('site_app_id_idx', 'app_id'), - db.Index('site_code_idx', 'code', 'status') + db.PrimaryKeyConstraint("id", name="site_pkey"), + db.Index("site_app_id_idx", "app_id"), + db.Index("site_code_idx", "code", "status"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) title = db.Column(db.String(255), nullable=False) icon_type = db.Column(db.String(255), nullable=True) @@ -1110,19 +1185,20 @@ class Site(db.Model): description = db.Column(db.Text) default_language = db.Column(db.String(255), nullable=False) chat_color_theme = db.Column(db.String(255)) - chat_color_theme_inverted = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + chat_color_theme_inverted = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) copyright = db.Column(db.String(255)) privacy_policy = db.Column(db.String(255)) - show_workflow_steps = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) + show_workflow_steps = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) custom_disclaimer = db.Column(db.String(255), nullable=True) customize_domain = db.Column(db.String(255)) customize_token_strategy = db.Column(db.String(255), nullable=False) - prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) code = db.Column(db.String(255)) @staticmethod @@ -1136,26 +1212,25 @@ class Site(db.Model): @property def app_base_url(self): - return ( - dify_config.APP_WEB_URL if dify_config.APP_WEB_URL else request.url_root.rstrip('/')) + return dify_config.APP_WEB_URL if dify_config.APP_WEB_URL else request.url_root.rstrip("/") class ApiToken(db.Model): - __tablename__ = 'api_tokens' + __tablename__ = "api_tokens" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='api_token_pkey'), - db.Index('api_token_app_id_type_idx', 'app_id', 'type'), - db.Index('api_token_token_idx', 'token', 'type'), - db.Index('api_token_tenant_idx', 'tenant_id', 'type') + db.PrimaryKeyConstraint("id", name="api_token_pkey"), + db.Index("api_token_app_id_type_idx", "app_id", "type"), + db.Index("api_token_token_idx", "token", "type"), + db.Index("api_token_tenant_idx", "tenant_id", "type"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=True) tenant_id = db.Column(StringUUID, nullable=True) type = db.Column(db.String(16), nullable=False) token = db.Column(db.String(255), nullable=False) last_used_at = db.Column(db.DateTime, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @staticmethod def generate_api_key(prefix, n): @@ -1168,13 +1243,13 @@ class ApiToken(db.Model): class UploadFile(db.Model): - __tablename__ = 'upload_files' + __tablename__ = "upload_files" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='upload_file_pkey'), - db.Index('upload_file_tenant_idx', 'tenant_id') + db.PrimaryKeyConstraint("id", name="upload_file_pkey"), + db.Index("upload_file_tenant_idx", "tenant_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) storage_type = db.Column(db.String(255), nullable=False) key = db.Column(db.String(255), nullable=False) @@ -1184,38 +1259,38 @@ class UploadFile(db.Model): mime_type = db.Column(db.String(255), nullable=True) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'account'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - used = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + used = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) used_by = db.Column(StringUUID, nullable=True) used_at = db.Column(db.DateTime, nullable=True) hash = db.Column(db.String(255), nullable=True) class ApiRequest(db.Model): - __tablename__ = 'api_requests' + __tablename__ = "api_requests" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='api_request_pkey'), - db.Index('api_request_token_idx', 'tenant_id', 'api_token_id') + db.PrimaryKeyConstraint("id", name="api_request_pkey"), + db.Index("api_request_token_idx", "tenant_id", "api_token_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) api_token_id = db.Column(StringUUID, nullable=False) path = db.Column(db.String(255), nullable=False) request = db.Column(db.Text, nullable=True) response = db.Column(db.Text, nullable=True) ip = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class MessageChain(db.Model): - __tablename__ = 'message_chains' + __tablename__ = "message_chains" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='message_chain_pkey'), - db.Index('message_chain_message_id_idx', 'message_id') + db.PrimaryKeyConstraint("id", name="message_chain_pkey"), + db.Index("message_chain_message_id_idx", "message_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) message_id = db.Column(StringUUID, nullable=False) type = db.Column(db.String(255), nullable=False) input = db.Column(db.Text, nullable=True) @@ -1224,14 +1299,14 @@ class MessageChain(db.Model): class MessageAgentThought(db.Model): - __tablename__ = 'message_agent_thoughts' + __tablename__ = "message_agent_thoughts" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='message_agent_thought_pkey'), - db.Index('message_agent_thought_message_id_idx', 'message_id'), - db.Index('message_agent_thought_message_chain_id_idx', 'message_chain_id'), + db.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"), + db.Index("message_agent_thought_message_id_idx", "message_id"), + db.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) message_id = db.Column(StringUUID, nullable=False) message_chain_id = db.Column(StringUUID, nullable=True) position = db.Column(db.Integer, nullable=False) @@ -1246,12 +1321,12 @@ class MessageAgentThought(db.Model): message = db.Column(db.Text, nullable=True) message_token = db.Column(db.Integer, nullable=True) message_unit_price = db.Column(db.Numeric, nullable=True) - message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) + message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) message_files = db.Column(db.Text, nullable=True) answer = db.Column(db.Text, nullable=True) answer_token = db.Column(db.Integer, nullable=True) answer_unit_price = db.Column(db.Numeric, nullable=True) - answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) + answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) tokens = db.Column(db.Integer, nullable=True) total_price = db.Column(db.Numeric, nullable=True) currency = db.Column(db.String, nullable=True) @@ -1308,9 +1383,7 @@ class MessageAgentThought(db.Model): result[tool] = {} return result else: - return { - tool: {} for tool in tools - } + return {tool: {} for tool in tools} except Exception as e: return {} @@ -1331,22 +1404,20 @@ class MessageAgentThought(db.Model): result[tool] = {} return result else: - return { - tool: {} for tool in tools - } + return {tool: {} for tool in tools} except Exception as e: if self.observation: return dict.fromkeys(tools, self.observation) class DatasetRetrieverResource(db.Model): - __tablename__ = 'dataset_retriever_resources' + __tablename__ = "dataset_retriever_resources" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_retriever_resource_pkey'), - db.Index('dataset_retriever_resource_message_id_idx', 'message_id'), + db.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"), + db.Index("dataset_retriever_resource_message_id_idx", "message_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) message_id = db.Column(StringUUID, nullable=False) position = db.Column(db.Integer, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) @@ -1367,53 +1438,53 @@ class DatasetRetrieverResource(db.Model): class Tag(db.Model): - __tablename__ = 'tags' + __tablename__ = "tags" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tag_pkey'), - db.Index('tag_type_idx', 'type'), - db.Index('tag_name_idx', 'name'), + db.PrimaryKeyConstraint("id", name="tag_pkey"), + db.Index("tag_type_idx", "type"), + db.Index("tag_name_idx", "name"), ) - TAG_TYPE_LIST = ['knowledge', 'app'] + TAG_TYPE_LIST = ["knowledge", "app"] - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=True) type = db.Column(db.String(16), nullable=False) name = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class TagBinding(db.Model): - __tablename__ = 'tag_bindings' + __tablename__ = "tag_bindings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tag_binding_pkey'), - db.Index('tag_bind_target_id_idx', 'target_id'), - db.Index('tag_bind_tag_id_idx', 'tag_id'), + db.PrimaryKeyConstraint("id", name="tag_binding_pkey"), + db.Index("tag_bind_target_id_idx", "target_id"), + db.Index("tag_bind_tag_id_idx", "tag_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=True) tag_id = db.Column(StringUUID, nullable=True) target_id = db.Column(StringUUID, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class TraceAppConfig(db.Model): - __tablename__ = 'trace_app_config' + __tablename__ = "trace_app_config" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tracing_app_config_pkey'), - db.Index('trace_app_config_app_id_idx', 'app_id'), + db.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"), + db.Index("trace_app_config_app_id_idx", "app_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) tracing_provider = db.Column(db.String(255), nullable=True) tracing_config = db.Column(db.JSON, nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=func.now()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.now(), onupdate=func.now()) - is_active = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) + is_active = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) @property def tracing_config_dict(self): @@ -1425,11 +1496,11 @@ class TraceAppConfig(db.Model): def to_dict(self): return { - 'id': self.id, - 'app_id': self.app_id, - 'tracing_provider': self.tracing_provider, - 'tracing_config': self.tracing_config_dict, + "id": self.id, + "app_id": self.app_id, + "tracing_provider": self.tracing_provider, + "tracing_config": self.tracing_config_dict, "is_active": self.is_active, "created_at": self.created_at.__str__() if self.created_at else None, - 'updated_at': self.updated_at.__str__() if self.updated_at else None, + "updated_at": self.updated_at.__str__() if self.updated_at else None, } diff --git a/api/models/provider.py b/api/models/provider.py index 5d92ee6eb6..ff63e03a92 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -6,8 +6,8 @@ from .types import StringUUID class ProviderType(Enum): - CUSTOM = 'custom' - SYSTEM = 'system' + CUSTOM = "custom" + SYSTEM = "system" @staticmethod def value_of(value): @@ -18,13 +18,13 @@ class ProviderType(Enum): class ProviderQuotaType(Enum): - PAID = 'paid' + PAID = "paid" """hosted paid quota""" - FREE = 'free' + FREE = "free" """third-party free quota""" - TRIAL = 'trial' + TRIAL = "trial" """hosted trial quota""" @staticmethod @@ -39,27 +39,30 @@ class Provider(db.Model): """ Provider model representing the API providers and their configurations. """ - __tablename__ = 'providers' + + __tablename__ = "providers" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='provider_pkey'), - db.Index('provider_tenant_id_provider_idx', 'tenant_id', 'provider_name'), - db.UniqueConstraint('tenant_id', 'provider_name', 'provider_type', 'quota_type', name='unique_provider_name_type_quota') + db.PrimaryKeyConstraint("id", name="provider_pkey"), + db.Index("provider_tenant_id_provider_idx", "tenant_id", "provider_name"), + db.UniqueConstraint( + "tenant_id", "provider_name", "provider_type", "quota_type", name="unique_provider_name_type_quota" + ), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying")) encrypted_config = db.Column(db.Text, nullable=True) - is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) last_used = db.Column(db.DateTime, nullable=True) quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying")) quota_limit = db.Column(db.BigInteger, nullable=True) quota_used = db.Column(db.BigInteger, default=0) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) def __repr__(self): return f"" @@ -67,8 +70,8 @@ class Provider(db.Model): @property def token_is_set(self): """ - Returns True if the encrypted_config is not None, indicating that the token is set. - """ + Returns True if the encrypted_config is not None, indicating that the token is set. + """ return self.encrypted_config is not None @property @@ -86,118 +89,123 @@ class ProviderModel(db.Model): """ Provider model representing the API provider_models and their configurations. """ - __tablename__ = 'provider_models' + + __tablename__ = "provider_models" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='provider_model_pkey'), - db.Index('provider_model_tenant_id_provider_idx', 'tenant_id', 'provider_name'), - db.UniqueConstraint('tenant_id', 'provider_name', 'model_name', 'model_type', name='unique_provider_model_name') + db.PrimaryKeyConstraint("id", name="provider_model_pkey"), + db.Index("provider_model_tenant_id_provider_idx", "tenant_id", "provider_name"), + db.UniqueConstraint( + "tenant_id", "provider_name", "model_name", "model_type", name="unique_provider_model_name" + ), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) encrypted_config = db.Column(db.Text, nullable=True) - is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class TenantDefaultModel(db.Model): - __tablename__ = 'tenant_default_models' + __tablename__ = "tenant_default_models" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tenant_default_model_pkey'), - db.Index('tenant_default_model_tenant_id_provider_type_idx', 'tenant_id', 'provider_name', 'model_type'), + db.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"), + db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class TenantPreferredModelProvider(db.Model): - __tablename__ = 'tenant_preferred_model_providers' + __tablename__ = "tenant_preferred_model_providers" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tenant_preferred_model_provider_pkey'), - db.Index('tenant_preferred_model_provider_tenant_provider_idx', 'tenant_id', 'provider_name'), + db.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"), + db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) preferred_provider_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class ProviderOrder(db.Model): - __tablename__ = 'provider_orders' + __tablename__ = "provider_orders" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='provider_order_pkey'), - db.Index('provider_order_tenant_provider_idx', 'tenant_id', 'provider_name'), + db.PrimaryKeyConstraint("id", name="provider_order_pkey"), + db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) account_id = db.Column(StringUUID, nullable=False) payment_product_id = db.Column(db.String(191), nullable=False) payment_id = db.Column(db.String(191)) transaction_id = db.Column(db.String(191)) - quantity = db.Column(db.Integer, nullable=False, server_default=db.text('1')) + quantity = db.Column(db.Integer, nullable=False, server_default=db.text("1")) currency = db.Column(db.String(40)) total_amount = db.Column(db.Integer) payment_status = db.Column(db.String(40), nullable=False, server_default=db.text("'wait_pay'::character varying")) paid_at = db.Column(db.DateTime) pay_failed_at = db.Column(db.DateTime) refunded_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class ProviderModelSetting(db.Model): """ Provider model settings for record the model enabled status and load balancing status. """ - __tablename__ = 'provider_model_settings' + + __tablename__ = "provider_model_settings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='provider_model_setting_pkey'), - db.Index('provider_model_setting_tenant_provider_model_idx', 'tenant_id', 'provider_name', 'model_type'), + db.PrimaryKeyConstraint("id", name="provider_model_setting_pkey"), + db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) - enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) - load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class LoadBalancingModelConfig(db.Model): """ Configurations for load balancing models. """ - __tablename__ = 'load_balancing_model_configs' + + __tablename__ = "load_balancing_model_configs" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='load_balancing_model_config_pkey'), - db.Index('load_balancing_model_config_tenant_provider_model_idx', 'tenant_id', 'provider_name', 'model_type'), + db.PrimaryKeyConstraint("id", name="load_balancing_model_config_pkey"), + db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) name = db.Column(db.String(255), nullable=False) encrypted_config = db.Column(db.Text, nullable=True) - enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/source.py b/api/models/source.py index adc00028be..07695f06e6 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -8,48 +8,48 @@ from .types import StringUUID class DataSourceOauthBinding(db.Model): - __tablename__ = 'data_source_oauth_bindings' + __tablename__ = "data_source_oauth_bindings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='source_binding_pkey'), - db.Index('source_binding_tenant_id_idx', 'tenant_id'), - db.Index('source_info_idx', "source_info", postgresql_using='gin') + db.PrimaryKeyConstraint("id", name="source_binding_pkey"), + db.Index("source_binding_tenant_id_idx", "tenant_id"), + db.Index("source_info_idx", "source_info", postgresql_using="gin"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) access_token = db.Column(db.String(255), nullable=False) provider = db.Column(db.String(255), nullable=False) source_info = db.Column(JSONB, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) class DataSourceApiKeyAuthBinding(db.Model): - __tablename__ = 'data_source_api_key_auth_bindings' + __tablename__ = "data_source_api_key_auth_bindings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='data_source_api_key_auth_binding_pkey'), - db.Index('data_source_api_key_auth_binding_tenant_id_idx', 'tenant_id'), - db.Index('data_source_api_key_auth_binding_provider_idx', 'provider'), + db.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"), + db.Index("data_source_api_key_auth_binding_tenant_id_idx", "tenant_id"), + db.Index("data_source_api_key_auth_binding_provider_idx", "provider"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) category = db.Column(db.String(255), nullable=False) provider = db.Column(db.String(255), nullable=False) credentials = db.Column(db.Text, nullable=True) # JSON - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) def to_dict(self): return { - 'id': self.id, - 'tenant_id': self.tenant_id, - 'category': self.category, - 'provider': self.provider, - 'credentials': json.loads(self.credentials), - 'created_at': self.created_at.timestamp(), - 'updated_at': self.updated_at.timestamp(), - 'disabled': self.disabled + "id": self.id, + "tenant_id": self.tenant_id, + "category": self.category, + "provider": self.provider, + "credentials": json.loads(self.credentials), + "created_at": self.created_at.timestamp(), + "updated_at": self.updated_at.timestamp(), + "disabled": self.disabled, } diff --git a/api/models/task.py b/api/models/task.py index 618d831d8e..57b147c78d 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -8,15 +8,18 @@ from extensions.ext_database import db class CeleryTask(db.Model): """Task result/status.""" - __tablename__ = 'celery_taskmeta' + __tablename__ = "celery_taskmeta" - id = db.Column(db.Integer, db.Sequence('task_id_sequence'), - primary_key=True, autoincrement=True) + id = db.Column(db.Integer, db.Sequence("task_id_sequence"), primary_key=True, autoincrement=True) task_id = db.Column(db.String(155), unique=True) status = db.Column(db.String(50), default=states.PENDING) result = db.Column(db.PickleType, nullable=True) - date_done = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), - onupdate=lambda: datetime.now(timezone.utc).replace(tzinfo=None), nullable=True) + date_done = db.Column( + db.DateTime, + default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), + onupdate=lambda: datetime.now(timezone.utc).replace(tzinfo=None), + nullable=True, + ) traceback = db.Column(db.Text, nullable=True) name = db.Column(db.String(155), nullable=True) args = db.Column(db.LargeBinary, nullable=True) @@ -29,11 +32,9 @@ class CeleryTask(db.Model): class CeleryTaskSet(db.Model): """TaskSet result.""" - __tablename__ = 'celery_tasksetmeta' + __tablename__ = "celery_tasksetmeta" - id = db.Column(db.Integer, db.Sequence('taskset_id_sequence'), - autoincrement=True, primary_key=True) + id = db.Column(db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True) taskset_id = db.Column(db.String(155), unique=True) result = db.Column(db.PickleType, nullable=True) - date_done = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), - nullable=True) + date_done = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), nullable=True) diff --git a/api/models/tool.py b/api/models/tool.py index 79a70c6b1f..a81bb65174 100644 --- a/api/models/tool.py +++ b/api/models/tool.py @@ -7,7 +7,7 @@ from .types import StringUUID class ToolProviderName(Enum): - SERPAPI = 'serpapi' + SERPAPI = "serpapi" @staticmethod def value_of(value): @@ -18,25 +18,25 @@ class ToolProviderName(Enum): class ToolProvider(db.Model): - __tablename__ = 'tool_providers' + __tablename__ = "tool_providers" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_provider_pkey'), - db.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name') + db.PrimaryKeyConstraint("id", name="tool_provider_pkey"), + db.UniqueConstraint("tenant_id", "tool_name", name="unique_tool_provider_tool_name"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) tool_name = db.Column(db.String(40), nullable=False) encrypted_credentials = db.Column(db.Text, nullable=True) - is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def credentials_is_set(self): """ - Returns True if the encrypted_config is not None, indicating that the token is set. - """ + Returns True if the encrypted_config is not None, indicating that the token is set. + """ return self.encrypted_credentials is not None @property diff --git a/api/models/tools.py b/api/models/tools.py index 069dc5bad0..6b69a219b1 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -15,15 +15,16 @@ class BuiltinToolProvider(db.Model): """ This table stores the tool provider information for built-in tools for each tenant. """ - __tablename__ = 'tool_builtin_providers' + + __tablename__ = "tool_builtin_providers" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_builtin_provider_pkey'), + db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"), # one tenant can only have one tool provider with the same name - db.UniqueConstraint('tenant_id', 'provider', name='unique_builtin_tool_provider') + db.UniqueConstraint("tenant_id", "provider", name="unique_builtin_tool_provider"), ) # id of the tool provider - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # id of the tenant tenant_id = db.Column(StringUUID, nullable=True) # who created this tool provider @@ -32,27 +33,29 @@ class BuiltinToolProvider(db.Model): provider = db.Column(db.String(40), nullable=False) # credential of the tool provider encrypted_credentials = db.Column(db.Text, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def credentials(self) -> dict: return json.loads(self.encrypted_credentials) + class PublishedAppTool(db.Model): """ The table stores the apps published as a tool for each person. """ - __tablename__ = 'tool_published_apps' + + __tablename__ = "tool_published_apps" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='published_app_tool_pkey'), - db.UniqueConstraint('app_id', 'user_id', name='unique_published_app_tool') + db.PrimaryKeyConstraint("id", name="published_app_tool_pkey"), + db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"), ) # id of the tool provider - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # id of the app - app_id = db.Column(StringUUID, ForeignKey('apps.id'), nullable=False) + app_id = db.Column(StringUUID, ForeignKey("apps.id"), nullable=False) # who published this tool user_id = db.Column(StringUUID, nullable=False) # description of the tool, stored in i18n format, for human @@ -67,28 +70,30 @@ class PublishedAppTool(db.Model): tool_name = db.Column(db.String(40), nullable=False) # author author = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def description_i18n(self) -> I18nObject: return I18nObject(**json.loads(self.description)) - + @property def app(self) -> App: return db.session.query(App).filter(App.id == self.app_id).first() + class ApiToolProvider(db.Model): """ The table stores the api providers. """ - __tablename__ = 'tool_api_providers' + + __tablename__ = "tool_api_providers" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_api_provider_pkey'), - db.UniqueConstraint('name', 'tenant_id', name='unique_api_tool_provider') + db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"), + db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the api provider name = db.Column(db.String(40), nullable=False) # icon @@ -111,21 +116,21 @@ class ApiToolProvider(db.Model): # custom_disclaimer custom_disclaimer = db.Column(db.String(255), nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def schema_type(self) -> ApiProviderSchemaType: return ApiProviderSchemaType.value_of(self.schema_type_str) - + @property def tools(self) -> list[ApiToolBundle]: return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)] - + @property def credentials(self) -> dict: return json.loads(self.credentials_str) - + @property def user(self) -> Account: return db.session.query(Account).filter(Account.id == self.user_id).first() @@ -134,17 +139,19 @@ class ApiToolProvider(db.Model): def tenant(self) -> Tenant: return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() + class ToolLabelBinding(db.Model): """ The table stores the labels for tools. """ - __tablename__ = 'tool_label_bindings' + + __tablename__ = "tool_label_bindings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_label_bind_pkey'), - db.UniqueConstraint('tool_id', 'label_name', name='unique_tool_label_bind'), + db.PrimaryKeyConstraint("id", name="tool_label_bind_pkey"), + db.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # tool id tool_id = db.Column(db.String(64), nullable=False) # tool type @@ -152,28 +159,30 @@ class ToolLabelBinding(db.Model): # label name label_name = db.Column(db.String(40), nullable=False) + class WorkflowToolProvider(db.Model): """ The table stores the workflow providers. """ - __tablename__ = 'tool_workflow_providers' + + __tablename__ = "tool_workflow_providers" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_workflow_provider_pkey'), - db.UniqueConstraint('name', 'tenant_id', name='unique_workflow_tool_provider'), - db.UniqueConstraint('tenant_id', 'app_id', name='unique_workflow_tool_provider_app_id'), + db.PrimaryKeyConstraint("id", name="tool_workflow_provider_pkey"), + db.UniqueConstraint("name", "tenant_id", name="unique_workflow_tool_provider"), + db.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the workflow provider name = db.Column(db.String(40), nullable=False) # label of the workflow provider - label = db.Column(db.String(255), nullable=False, server_default='') + label = db.Column(db.String(255), nullable=False, server_default="") # icon icon = db.Column(db.String(255), nullable=False) # app id of the workflow provider app_id = db.Column(StringUUID, nullable=False) # version of the workflow provider - version = db.Column(db.String(255), nullable=False, server_default='') + version = db.Column(db.String(255), nullable=False, server_default="") # who created this tool user_id = db.Column(StringUUID, nullable=False) # tenant id @@ -181,17 +190,17 @@ class WorkflowToolProvider(db.Model): # description of the provider description = db.Column(db.Text, nullable=False) # parameter configuration - parameter_configuration = db.Column(db.Text, nullable=False, server_default='[]') + parameter_configuration = db.Column(db.Text, nullable=False, server_default="[]") # privacy policy - privacy_policy = db.Column(db.String(255), nullable=True, server_default='') + privacy_policy = db.Column(db.String(255), nullable=True, server_default="") - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def schema_type(self) -> ApiProviderSchemaType: return ApiProviderSchemaType.value_of(self.schema_type_str) - + @property def user(self) -> Account: return db.session.query(Account).filter(Account.id == self.user_id).first() @@ -199,28 +208,25 @@ class WorkflowToolProvider(db.Model): @property def tenant(self) -> Tenant: return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() - + @property def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]: - return [ - WorkflowToolParameterConfiguration(**config) - for config in json.loads(self.parameter_configuration) - ] - + return [WorkflowToolParameterConfiguration(**config) for config in json.loads(self.parameter_configuration)] + @property def app(self) -> App: return db.session.query(App).filter(App.id == self.app_id).first() + class ToolModelInvoke(db.Model): """ store the invoke logs from tool invoke """ - __tablename__ = "tool_model_invokes" - __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_model_invoke_pkey'), - ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + __tablename__ = "tool_model_invokes" + __table_args__ = (db.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # who invoke this tool user_id = db.Column(StringUUID, nullable=False) # tenant id @@ -238,29 +244,31 @@ class ToolModelInvoke(db.Model): # invoke response model_response = db.Column(db.Text, nullable=False) - prompt_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) - answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + prompt_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) - answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) - provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text('0')) + answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0")) total_price = db.Column(db.Numeric(10, 7)) currency = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + class ToolConversationVariables(db.Model): """ store the conversation variables from tool invoke """ + __tablename__ = "tool_conversation_variables" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_conversation_variables_pkey'), + db.PrimaryKeyConstraint("id", name="tool_conversation_variables_pkey"), # add index for user_id and conversation_id - db.Index('user_id_idx', 'user_id'), - db.Index('conversation_id_idx', 'conversation_id'), + db.Index("user_id_idx", "user_id"), + db.Index("conversation_id_idx", "conversation_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # conversation user id user_id = db.Column(StringUUID, nullable=False) # tenant id @@ -270,25 +278,27 @@ class ToolConversationVariables(db.Model): # variables pool variables_str = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def variables(self) -> dict: return json.loads(self.variables_str) - + + class ToolFile(db.Model): """ store the file created by agent """ + __tablename__ = "tool_files" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_file_pkey'), + db.PrimaryKeyConstraint("id", name="tool_file_pkey"), # add index for conversation_id - db.Index('tool_file_conversation_id_idx', 'conversation_id'), + db.Index("tool_file_conversation_id_idx", "conversation_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # conversation user id user_id = db.Column(StringUUID, nullable=False) # tenant id @@ -300,4 +310,4 @@ class ToolFile(db.Model): # mime type mimetype = db.Column(db.String(255), nullable=False) # original url - original_url = db.Column(db.String(2048), nullable=True) \ No newline at end of file + original_url = db.Column(db.String(2048), nullable=True) diff --git a/api/models/types.py b/api/models/types.py index 1614ec2018..cb6773e70c 100644 --- a/api/models/types.py +++ b/api/models/types.py @@ -9,13 +9,13 @@ class StringUUID(TypeDecorator): def process_bind_param(self, value, dialect): if value is None: return value - elif dialect.name == 'postgresql': + elif dialect.name == "postgresql": return str(value) else: return value.hex def load_dialect_impl(self, dialect): - if dialect.name == 'postgresql': + if dialect.name == "postgresql": return dialect.type_descriptor(UUID()) else: return dialect.type_descriptor(CHAR(36)) @@ -23,4 +23,4 @@ class StringUUID(TypeDecorator): def process_result_value(self, value, dialect): if value is None: return value - return str(value) \ No newline at end of file + return str(value) diff --git a/api/models/web.py b/api/models/web.py index 0e901d5f84..bc088c185d 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -1,4 +1,3 @@ - from extensions.ext_database import db from .model import Message @@ -6,18 +5,18 @@ from .types import StringUUID class SavedMessage(db.Model): - __tablename__ = 'saved_messages' + __tablename__ = "saved_messages" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='saved_message_pkey'), - db.Index('saved_message_message_idx', 'app_id', 'message_id', 'created_by_role', 'created_by'), + db.PrimaryKeyConstraint("id", name="saved_message_pkey"), + db.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) message_id = db.Column(StringUUID, nullable=False) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def message(self): @@ -25,15 +24,15 @@ class SavedMessage(db.Model): class PinnedConversation(db.Model): - __tablename__ = 'pinned_conversations' + __tablename__ = "pinned_conversations" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='pinned_conversation_pkey'), - db.Index('pinned_conversation_conversation_idx', 'app_id', 'conversation_id', 'created_by_role', 'created_by'), + db.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"), + db.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) conversation_id = db.Column(StringUUID, nullable=False) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/workflow.py b/api/models/workflow.py index cdd5e1992d..d52749f0ff 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -22,11 +22,12 @@ class CreatedByRole(Enum): """ Created By Role Enum """ - ACCOUNT = 'account' - END_USER = 'end_user' + + ACCOUNT = "account" + END_USER = "end_user" @classmethod - def value_of(cls, value: str) -> 'CreatedByRole': + def value_of(cls, value: str) -> "CreatedByRole": """ Get value of given mode. @@ -36,18 +37,19 @@ class CreatedByRole(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid created by role value {value}') + raise ValueError(f"invalid created by role value {value}") class WorkflowType(Enum): """ Workflow Type Enum """ - WORKFLOW = 'workflow' - CHAT = 'chat' + + WORKFLOW = "workflow" + CHAT = "chat" @classmethod - def value_of(cls, value: str) -> 'WorkflowType': + def value_of(cls, value: str) -> "WorkflowType": """ Get value of given mode. @@ -57,10 +59,10 @@ class WorkflowType(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid workflow type value {value}') + raise ValueError(f"invalid workflow type value {value}") @classmethod - def from_app_mode(cls, app_mode: Union[str, 'AppMode']) -> 'WorkflowType': + def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType": """ Get workflow type from app mode. @@ -68,6 +70,7 @@ class WorkflowType(Enum): :return: workflow type """ from models.model import AppMode + app_mode = app_mode if isinstance(app_mode, AppMode) else AppMode.value_of(app_mode) return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT @@ -105,13 +108,13 @@ class Workflow(db.Model): - updated_at (timestamp) `optional` Last update time """ - __tablename__ = 'workflows' + __tablename__ = "workflows" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='workflow_pkey'), - db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'version'), + db.PrimaryKeyConstraint("id", name="workflow_pkey"), + db.Index("workflow_version_idx", "tenant_id", "app_id", "version"), ) - id: Mapped[str] = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) app_id: Mapped[str] = db.Column(StringUUID, nullable=False) type: Mapped[str] = db.Column(db.String(255), nullable=False) @@ -119,15 +122,31 @@ class Workflow(db.Model): graph: Mapped[str] = db.Column(db.Text) features: Mapped[str] = db.Column(db.Text) created_by: Mapped[str] = db.Column(StringUUID, nullable=False) - created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at: Mapped[datetime] = db.Column( + db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) updated_by: Mapped[str] = db.Column(StringUUID) updated_at: Mapped[datetime] = db.Column(db.DateTime) - _environment_variables: Mapped[str] = db.Column('environment_variables', db.Text, nullable=False, server_default='{}') - _conversation_variables: Mapped[str] = db.Column('conversation_variables', db.Text, nullable=False, server_default='{}') + _environment_variables: Mapped[str] = db.Column( + "environment_variables", db.Text, nullable=False, server_default="{}" + ) + _conversation_variables: Mapped[str] = db.Column( + "conversation_variables", db.Text, nullable=False, server_default="{}" + ) - def __init__(self, *, tenant_id: str, app_id: str, type: str, version: str, graph: str, - features: str, created_by: str, environment_variables: Sequence[Variable], - conversation_variables: Sequence[Variable]): + def __init__( + self, + *, + tenant_id: str, + app_id: str, + type: str, + version: str, + graph: str, + features: str, + created_by: str, + environment_variables: Sequence[Variable], + conversation_variables: Sequence[Variable], + ): self.tenant_id = tenant_id self.app_id = app_id self.type = type @@ -160,22 +179,20 @@ class Workflow(db.Model): return [] graph_dict = self.graph_dict - if 'nodes' not in graph_dict: + if "nodes" not in graph_dict: return [] - start_node = next((node for node in graph_dict['nodes'] if node['data']['type'] == 'start'), None) + start_node = next((node for node in graph_dict["nodes"] if node["data"]["type"] == "start"), None) if not start_node: return [] # get user_input_form from start node - variables = start_node.get('data', {}).get('variables', []) + variables = start_node.get("data", {}).get("variables", []) if to_old_structure: old_structure_variables = [] for variable in variables: - old_structure_variables.append({ - variable['type']: variable - }) + old_structure_variables.append({variable["type"]: variable}) return old_structure_variables @@ -188,25 +205,24 @@ class Workflow(db.Model): :return: hash """ - entity = { - 'graph': self.graph_dict, - 'features': self.features_dict - } + entity = {"graph": self.graph_dict, "features": self.features_dict} return helper.generate_text_hash(json.dumps(entity, sort_keys=True)) @property def tool_published(self) -> bool: from models.tools import WorkflowToolProvider - return db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.app_id == self.app_id - ).first() is not None + + return ( + db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.app_id == self.app_id).first() + is not None + ) @property def environment_variables(self) -> Sequence[Variable]: # TODO: find some way to init `self._environment_variables` when instance created. if self._environment_variables is None: - self._environment_variables = '{}' + self._environment_variables = "{}" tenant_id = contexts.tenant_id.get() @@ -215,9 +231,7 @@ class Workflow(db.Model): # decrypt secret variables value decrypt_func = ( - lambda var: var.model_copy( - update={'value': encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)} - ) + lambda var: var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) if isinstance(var, SecretVariable) else var ) @@ -230,19 +244,17 @@ class Workflow(db.Model): value = list(value) if any(var for var in value if not var.id): - raise ValueError('environment variable require a unique id') + raise ValueError("environment variable require a unique id") # Compare inputs and origin variables, if the value is HIDDEN_VALUE, use the origin variable value (only update `name`). origin_variables_dictionary = {var.id: var for var in self.environment_variables} for i, variable in enumerate(value): if variable.id in origin_variables_dictionary and variable.value == HIDDEN_VALUE: - value[i] = origin_variables_dictionary[variable.id].model_copy(update={'name': variable.name}) + value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name}) # encrypt secret variables value encrypt_func = ( - lambda var: var.model_copy( - update={'value': encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)} - ) + lambda var: var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)}) if isinstance(var, SecretVariable) else var ) @@ -256,15 +268,15 @@ class Workflow(db.Model): def to_dict(self, *, include_secret: bool = False) -> Mapping[str, Any]: environment_variables = list(self.environment_variables) environment_variables = [ - v if not isinstance(v, SecretVariable) or include_secret else v.model_copy(update={'value': ''}) + v if not isinstance(v, SecretVariable) or include_secret else v.model_copy(update={"value": ""}) for v in environment_variables ] result = { - 'graph': self.graph_dict, - 'features': self.features_dict, - 'environment_variables': [var.model_dump(mode='json') for var in environment_variables], - 'conversation_variables': [var.model_dump(mode='json') for var in self.conversation_variables], + "graph": self.graph_dict, + "features": self.features_dict, + "environment_variables": [var.model_dump(mode="json") for var in environment_variables], + "conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables], } return result @@ -272,7 +284,7 @@ class Workflow(db.Model): def conversation_variables(self) -> Sequence[Variable]: # TODO: find some way to init `self._conversation_variables` when instance created. if self._conversation_variables is None: - self._conversation_variables = '{}' + self._conversation_variables = "{}" variables_dict: dict[str, Any] = json.loads(self._conversation_variables) results = [factory.build_variable_from_mapping(v) for v in variables_dict.values()] @@ -290,11 +302,12 @@ class WorkflowRunTriggeredFrom(Enum): """ Workflow Run Triggered From Enum """ - DEBUGGING = 'debugging' - APP_RUN = 'app-run' + + DEBUGGING = "debugging" + APP_RUN = "app-run" @classmethod - def value_of(cls, value: str) -> 'WorkflowRunTriggeredFrom': + def value_of(cls, value: str) -> "WorkflowRunTriggeredFrom": """ Get value of given mode. @@ -304,20 +317,21 @@ class WorkflowRunTriggeredFrom(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid workflow run triggered from value {value}') + raise ValueError(f"invalid workflow run triggered from value {value}") class WorkflowRunStatus(Enum): """ Workflow Run Status Enum """ - RUNNING = 'running' - SUCCEEDED = 'succeeded' - FAILED = 'failed' - STOPPED = 'stopped' + + RUNNING = "running" + SUCCEEDED = "succeeded" + FAILED = "failed" + STOPPED = "stopped" @classmethod - def value_of(cls, value: str) -> 'WorkflowRunStatus': + def value_of(cls, value: str) -> "WorkflowRunStatus": """ Get value of given mode. @@ -327,7 +341,7 @@ class WorkflowRunStatus(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid workflow run status value {value}') + raise ValueError(f"invalid workflow run status value {value}") class WorkflowRun(db.Model): @@ -368,14 +382,14 @@ class WorkflowRun(db.Model): - finished_at (timestamp) End time """ - __tablename__ = 'workflow_runs' + __tablename__ = "workflow_runs" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='workflow_run_pkey'), - db.Index('workflow_run_triggerd_from_idx', 'tenant_id', 'app_id', 'triggered_from'), - db.Index('workflow_run_tenant_app_sequence_idx', 'tenant_id', 'app_id', 'sequence_number'), + db.PrimaryKeyConstraint("id", name="workflow_run_pkey"), + db.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"), + db.Index("workflow_run_tenant_app_sequence_idx", "tenant_id", "app_id", "sequence_number"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False) sequence_number = db.Column(db.Integer, nullable=False) @@ -388,26 +402,25 @@ class WorkflowRun(db.Model): status = db.Column(db.String(255), nullable=False) outputs = db.Column(db.Text) error = db.Column(db.Text) - elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0')) - total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) - total_steps = db.Column(db.Integer, server_default=db.text('0')) + elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) + total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + total_steps = db.Column(db.Integer, server_default=db.text("0")) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) finished_at = db.Column(db.DateTime) @property def created_by_account(self): created_by_role = CreatedByRole.value_of(self.created_by_role) - return db.session.get(Account, self.created_by) \ - if created_by_role == CreatedByRole.ACCOUNT else None + return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser + created_by_role = CreatedByRole.value_of(self.created_by_role) - return db.session.get(EndUser, self.created_by) \ - if created_by_role == CreatedByRole.END_USER else None + return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None @property def graph_dict(self): @@ -422,12 +435,12 @@ class WorkflowRun(db.Model): return json.loads(self.outputs) if self.outputs else None @property - def message(self) -> Optional['Message']: + def message(self) -> Optional["Message"]: from models.model import Message - return db.session.query(Message).filter( - Message.app_id == self.app_id, - Message.workflow_run_id == self.id - ).first() + + return ( + db.session.query(Message).filter(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first() + ) @property def workflow(self): @@ -435,51 +448,51 @@ class WorkflowRun(db.Model): def to_dict(self): return { - 'id': self.id, - 'tenant_id': self.tenant_id, - 'app_id': self.app_id, - 'sequence_number': self.sequence_number, - 'workflow_id': self.workflow_id, - 'type': self.type, - 'triggered_from': self.triggered_from, - 'version': self.version, - 'graph': self.graph_dict, - 'inputs': self.inputs_dict, - 'status': self.status, - 'outputs': self.outputs_dict, - 'error': self.error, - 'elapsed_time': self.elapsed_time, - 'total_tokens': self.total_tokens, - 'total_steps': self.total_steps, - 'created_by_role': self.created_by_role, - 'created_by': self.created_by, - 'created_at': self.created_at, - 'finished_at': self.finished_at, + "id": self.id, + "tenant_id": self.tenant_id, + "app_id": self.app_id, + "sequence_number": self.sequence_number, + "workflow_id": self.workflow_id, + "type": self.type, + "triggered_from": self.triggered_from, + "version": self.version, + "graph": self.graph_dict, + "inputs": self.inputs_dict, + "status": self.status, + "outputs": self.outputs_dict, + "error": self.error, + "elapsed_time": self.elapsed_time, + "total_tokens": self.total_tokens, + "total_steps": self.total_steps, + "created_by_role": self.created_by_role, + "created_by": self.created_by, + "created_at": self.created_at, + "finished_at": self.finished_at, } @classmethod - def from_dict(cls, data: dict) -> 'WorkflowRun': + def from_dict(cls, data: dict) -> "WorkflowRun": return cls( - id=data.get('id'), - tenant_id=data.get('tenant_id'), - app_id=data.get('app_id'), - sequence_number=data.get('sequence_number'), - workflow_id=data.get('workflow_id'), - type=data.get('type'), - triggered_from=data.get('triggered_from'), - version=data.get('version'), - graph=json.dumps(data.get('graph')), - inputs=json.dumps(data.get('inputs')), - status=data.get('status'), - outputs=json.dumps(data.get('outputs')), - error=data.get('error'), - elapsed_time=data.get('elapsed_time'), - total_tokens=data.get('total_tokens'), - total_steps=data.get('total_steps'), - created_by_role=data.get('created_by_role'), - created_by=data.get('created_by'), - created_at=data.get('created_at'), - finished_at=data.get('finished_at'), + id=data.get("id"), + tenant_id=data.get("tenant_id"), + app_id=data.get("app_id"), + sequence_number=data.get("sequence_number"), + workflow_id=data.get("workflow_id"), + type=data.get("type"), + triggered_from=data.get("triggered_from"), + version=data.get("version"), + graph=json.dumps(data.get("graph")), + inputs=json.dumps(data.get("inputs")), + status=data.get("status"), + outputs=json.dumps(data.get("outputs")), + error=data.get("error"), + elapsed_time=data.get("elapsed_time"), + total_tokens=data.get("total_tokens"), + total_steps=data.get("total_steps"), + created_by_role=data.get("created_by_role"), + created_by=data.get("created_by"), + created_at=data.get("created_at"), + finished_at=data.get("finished_at"), ) @@ -487,11 +500,12 @@ class WorkflowNodeExecutionTriggeredFrom(Enum): """ Workflow Node Execution Triggered From Enum """ - SINGLE_STEP = 'single-step' - WORKFLOW_RUN = 'workflow-run' + + SINGLE_STEP = "single-step" + WORKFLOW_RUN = "workflow-run" @classmethod - def value_of(cls, value: str) -> 'WorkflowNodeExecutionTriggeredFrom': + def value_of(cls, value: str) -> "WorkflowNodeExecutionTriggeredFrom": """ Get value of given mode. @@ -501,19 +515,20 @@ class WorkflowNodeExecutionTriggeredFrom(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid workflow node execution triggered from value {value}') + raise ValueError(f"invalid workflow node execution triggered from value {value}") class WorkflowNodeExecutionStatus(Enum): """ Workflow Node Execution Status Enum """ - RUNNING = 'running' - SUCCEEDED = 'succeeded' - FAILED = 'failed' + + RUNNING = "running" + SUCCEEDED = "succeeded" + FAILED = "failed" @classmethod - def value_of(cls, value: str) -> 'WorkflowNodeExecutionStatus': + def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus": """ Get value of given mode. @@ -523,7 +538,7 @@ class WorkflowNodeExecutionStatus(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid workflow node execution status value {value}') + raise ValueError(f"invalid workflow node execution status value {value}") class WorkflowNodeExecution(db.Model): @@ -574,16 +589,31 @@ class WorkflowNodeExecution(db.Model): - finished_at (timestamp) End time """ - __tablename__ = 'workflow_node_executions' + __tablename__ = "workflow_node_executions" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='workflow_node_execution_pkey'), - db.Index('workflow_node_execution_workflow_run_idx', 'tenant_id', 'app_id', 'workflow_id', - 'triggered_from', 'workflow_run_id'), - db.Index('workflow_node_execution_node_run_idx', 'tenant_id', 'app_id', 'workflow_id', - 'triggered_from', 'node_id'), + db.PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"), + db.Index( + "workflow_node_execution_workflow_run_idx", + "tenant_id", + "app_id", + "workflow_id", + "triggered_from", + "workflow_run_id", + ), + db.Index( + "workflow_node_execution_node_run_idx", "tenant_id", "app_id", "workflow_id", "triggered_from", "node_id" + ), + db.Index( + "workflow_node_execution_id_idx", + "tenant_id", + "app_id", + "workflow_id", + "triggered_from", + "node_execution_id", + ), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False) workflow_id = db.Column(StringUUID, nullable=False) @@ -591,6 +621,7 @@ class WorkflowNodeExecution(db.Model): workflow_run_id = db.Column(StringUUID) index = db.Column(db.Integer, nullable=False) predecessor_node_id = db.Column(db.String(255)) + node_execution_id = db.Column(db.String(255), nullable=True) node_id = db.Column(db.String(255), nullable=False) node_type = db.Column(db.String(255), nullable=False) title = db.Column(db.String(255), nullable=False) @@ -599,9 +630,9 @@ class WorkflowNodeExecution(db.Model): outputs = db.Column(db.Text) status = db.Column(db.String(255), nullable=False) error = db.Column(db.Text) - elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0')) + elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) execution_metadata = db.Column(db.Text) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) finished_at = db.Column(db.DateTime) @@ -609,15 +640,14 @@ class WorkflowNodeExecution(db.Model): @property def created_by_account(self): created_by_role = CreatedByRole.value_of(self.created_by_role) - return db.session.get(Account, self.created_by) \ - if created_by_role == CreatedByRole.ACCOUNT else None + return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser + created_by_role = CreatedByRole.value_of(self.created_by_role) - return db.session.get(EndUser, self.created_by) \ - if created_by_role == CreatedByRole.END_USER else None + return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None @property def inputs_dict(self): @@ -638,15 +668,17 @@ class WorkflowNodeExecution(db.Model): @property def extras(self): from core.tools.tool_manager import ToolManager + extras = {} if self.execution_metadata_dict: from core.workflow.entities.node_entities import NodeType - if self.node_type == NodeType.TOOL.value and 'tool_info' in self.execution_metadata_dict: - tool_info = self.execution_metadata_dict['tool_info'] - extras['icon'] = ToolManager.get_tool_icon( + + if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict: + tool_info = self.execution_metadata_dict["tool_info"] + extras["icon"] = ToolManager.get_tool_icon( tenant_id=self.tenant_id, - provider_type=tool_info['provider_type'], - provider_id=tool_info['provider_id'] + provider_type=tool_info["provider_type"], + provider_id=tool_info["provider_id"], ) return extras @@ -656,12 +688,13 @@ class WorkflowAppLogCreatedFrom(Enum): """ Workflow App Log Created From Enum """ - SERVICE_API = 'service-api' - WEB_APP = 'web-app' - INSTALLED_APP = 'installed-app' + + SERVICE_API = "service-api" + WEB_APP = "web-app" + INSTALLED_APP = "installed-app" @classmethod - def value_of(cls, value: str) -> 'WorkflowAppLogCreatedFrom': + def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom": """ Get value of given mode. @@ -671,7 +704,7 @@ class WorkflowAppLogCreatedFrom(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid workflow app log created from value {value}') + raise ValueError(f"invalid workflow app log created from value {value}") class WorkflowAppLog(db.Model): @@ -703,13 +736,13 @@ class WorkflowAppLog(db.Model): - created_at (timestamp) Creation time """ - __tablename__ = 'workflow_app_logs' + __tablename__ = "workflow_app_logs" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='workflow_app_log_pkey'), - db.Index('workflow_app_log_app_idx', 'tenant_id', 'app_id'), + db.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"), + db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False) workflow_id = db.Column(StringUUID, nullable=False) @@ -717,7 +750,7 @@ class WorkflowAppLog(db.Model): created_from = db.Column(db.String(255), nullable=False) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def workflow_run(self): @@ -726,26 +759,27 @@ class WorkflowAppLog(db.Model): @property def created_by_account(self): created_by_role = CreatedByRole.value_of(self.created_by_role) - return db.session.get(Account, self.created_by) \ - if created_by_role == CreatedByRole.ACCOUNT else None + return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser + created_by_role = CreatedByRole.value_of(self.created_by_role) - return db.session.get(EndUser, self.created_by) \ - if created_by_role == CreatedByRole.END_USER else None + return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None class ConversationVariable(db.Model): - __tablename__ = 'workflow_conversation_variables' + __tablename__ = "workflow_conversation_variables" id: Mapped[str] = db.Column(StringUUID, primary_key=True) conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True) app_id: Mapped[str] = db.Column(StringUUID, nullable=False, index=True) data = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()) + created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column( + db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str) -> None: self.id = id @@ -754,7 +788,7 @@ class ConversationVariable(db.Model): self.data = data @classmethod - def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> 'ConversationVariable': + def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> "ConversationVariable": obj = cls( id=variable.id, app_id=app_id, diff --git a/api/poetry.lock b/api/poetry.lock index 7d26dbdc57..103423e5c7 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -364,27 +364,27 @@ alibabacloud-tea = ">=0.0.1" [[package]] name = "aliyun-python-sdk-core" -version = "2.15.1" +version = "2.15.2" description = "The core module of Aliyun Python SDK." optional = false python-versions = "*" files = [ - {file = "aliyun-python-sdk-core-2.15.1.tar.gz", hash = "sha256:518550d07f537cd3afac3b6c93b5c997ce3440e4d0c054e3acbdaa8261e90adf"}, + {file = "aliyun-python-sdk-core-2.15.2.tar.gz", hash = "sha256:54f66a53e193c61c5e16ea4505a0cab43543f8ad2ef22833f69c4d5e5151c17d"}, ] [package.dependencies] -cryptography = ">=2.6.0" +cryptography = ">=3.0.0" jmespath = ">=0.9.3,<1.0.0" [[package]] name = "aliyun-python-sdk-kms" -version = "2.16.4" +version = "2.16.5" description = "The kms module of Aliyun Python sdk." optional = false python-versions = "*" files = [ - {file = "aliyun-python-sdk-kms-2.16.4.tar.gz", hash = "sha256:0d5bb165c07b6a972939753a128507393f48011792ee0ec4f59b6021eabd9752"}, - {file = "aliyun_python_sdk_kms-2.16.4-py2.py3-none-any.whl", hash = "sha256:6d412663ef8c35dc3bb42be6a3ee76a9bc07acdadca6dd26815131062bedf4c5"}, + {file = "aliyun-python-sdk-kms-2.16.5.tar.gz", hash = "sha256:f328a8a19d83ecbb965ffce0ec1e9930755216d104638cd95ecd362753b813b3"}, + {file = "aliyun_python_sdk_kms-2.16.5-py2.py3-none-any.whl", hash = "sha256:24b6cdc4fd161d2942619479c8d050c63ea9cd22b044fe33b60bbb60153786f0"}, ] [package.dependencies] @@ -520,22 +520,22 @@ files = [ [[package]] name = "attrs" -version = "24.2.0" +version = "23.2.0" description = "Classes Without Boilerplate" optional = false python-versions = ">=3.7" files = [ - {file = "attrs-24.2.0-py3-none-any.whl", hash = "sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2"}, - {file = "attrs-24.2.0.tar.gz", hash = "sha256:5cfb1b9148b5b086569baec03f20d7b6bf3bcacc9a42bebf87ffaaca362f6346"}, + {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"}, + {file = "attrs-23.2.0.tar.gz", hash = "sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30"}, ] [package.extras] -benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier (<24.7)"] -tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] +cov = ["attrs[tests]", "coverage[toml] (>=5.3)"] +dev = ["attrs[tests]", "pre-commit"] +docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"] +tests = ["attrs[tests-no-zope]", "zope-interface"] +tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"] +tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"] [[package]] name = "authlib" @@ -553,13 +553,13 @@ cryptography = "*" [[package]] name = "azure-ai-inference" -version = "1.0.0b3" +version = "1.0.0b4" description = "Microsoft Azure Ai Inference Client Library for Python" optional = false python-versions = ">=3.8" files = [ - {file = "azure-ai-inference-1.0.0b3.tar.gz", hash = "sha256:1e99dc74c3b335a457500311bbbadb348f54dc4c12252a93cb8ab78d6d217ff0"}, - {file = "azure_ai_inference-1.0.0b3-py3-none-any.whl", hash = "sha256:6734ca7334c809a170beb767f1f1455724ab3f006cb60045e42a833c0e764403"}, + {file = "azure-ai-inference-1.0.0b4.tar.gz", hash = "sha256:5464404bef337338d4af6eefde3af903400ddb8e5c9e6820f902303542fa0f72"}, + {file = "azure_ai_inference-1.0.0b4-py3-none-any.whl", hash = "sha256:e2c949f91845a8cd96cb9a61ffd432b5b0f4ce236b9be8c29d10f38e0a327412"}, ] [package.dependencies] @@ -1140,89 +1140,89 @@ zstd = ["zstandard (==0.22.0)"] [[package]] name = "certifi" -version = "2024.7.4" +version = "2024.8.30" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" files = [ - {file = "certifi-2024.7.4-py3-none-any.whl", hash = "sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90"}, - {file = "certifi-2024.7.4.tar.gz", hash = "sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b"}, + {file = "certifi-2024.8.30-py3-none-any.whl", hash = "sha256:922820b53db7a7257ffbda3f597266d435245903d80737e34f8a45ff3e3230d8"}, + {file = "certifi-2024.8.30.tar.gz", hash = "sha256:bec941d2aa8195e248a60b31ff9f0558284cf01a52591ceda73ea9afffd69fd9"}, ] [[package]] name = "cffi" -version = "1.17.0" +version = "1.17.1" description = "Foreign Function Interface for Python calling C code." optional = false python-versions = ">=3.8" files = [ - {file = "cffi-1.17.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f9338cc05451f1942d0d8203ec2c346c830f8e86469903d5126c1f0a13a2bcbb"}, - {file = "cffi-1.17.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a0ce71725cacc9ebf839630772b07eeec220cbb5f03be1399e0457a1464f8e1a"}, - {file = "cffi-1.17.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c815270206f983309915a6844fe994b2fa47e5d05c4c4cef267c3b30e34dbe42"}, - {file = "cffi-1.17.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d6bdcd415ba87846fd317bee0774e412e8792832e7805938987e4ede1d13046d"}, - {file = "cffi-1.17.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8a98748ed1a1df4ee1d6f927e151ed6c1a09d5ec21684de879c7ea6aa96f58f2"}, - {file = "cffi-1.17.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0a048d4f6630113e54bb4b77e315e1ba32a5a31512c31a273807d0027a7e69ab"}, - {file = "cffi-1.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:24aa705a5f5bd3a8bcfa4d123f03413de5d86e497435693b638cbffb7d5d8a1b"}, - {file = "cffi-1.17.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:856bf0924d24e7f93b8aee12a3a1095c34085600aa805693fb7f5d1962393206"}, - {file = "cffi-1.17.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:4304d4416ff032ed50ad6bb87416d802e67139e31c0bde4628f36a47a3164bfa"}, - {file = "cffi-1.17.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:331ad15c39c9fe9186ceaf87203a9ecf5ae0ba2538c9e898e3a6967e8ad3db6f"}, - {file = "cffi-1.17.0-cp310-cp310-win32.whl", hash = "sha256:669b29a9eca6146465cc574659058ed949748f0809a2582d1f1a324eb91054dc"}, - {file = "cffi-1.17.0-cp310-cp310-win_amd64.whl", hash = "sha256:48b389b1fd5144603d61d752afd7167dfd205973a43151ae5045b35793232aa2"}, - {file = "cffi-1.17.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c5d97162c196ce54af6700949ddf9409e9833ef1003b4741c2b39ef46f1d9720"}, - {file = "cffi-1.17.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5ba5c243f4004c750836f81606a9fcb7841f8874ad8f3bf204ff5e56332b72b9"}, - {file = "cffi-1.17.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bb9333f58fc3a2296fb1d54576138d4cf5d496a2cc118422bd77835e6ae0b9cb"}, - {file = "cffi-1.17.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:435a22d00ec7d7ea533db494da8581b05977f9c37338c80bc86314bec2619424"}, - {file = "cffi-1.17.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d1df34588123fcc88c872f5acb6f74ae59e9d182a2707097f9e28275ec26a12d"}, - {file = "cffi-1.17.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:df8bb0010fdd0a743b7542589223a2816bdde4d94bb5ad67884348fa2c1c67e8"}, - {file = "cffi-1.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8b5b9712783415695663bd463990e2f00c6750562e6ad1d28e072a611c5f2a6"}, - {file = "cffi-1.17.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ffef8fd58a36fb5f1196919638f73dd3ae0db1a878982b27a9a5a176ede4ba91"}, - {file = "cffi-1.17.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4e67d26532bfd8b7f7c05d5a766d6f437b362c1bf203a3a5ce3593a645e870b8"}, - {file = "cffi-1.17.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:45f7cd36186db767d803b1473b3c659d57a23b5fa491ad83c6d40f2af58e4dbb"}, - {file = "cffi-1.17.0-cp311-cp311-win32.whl", hash = "sha256:a9015f5b8af1bb6837a3fcb0cdf3b874fe3385ff6274e8b7925d81ccaec3c5c9"}, - {file = "cffi-1.17.0-cp311-cp311-win_amd64.whl", hash = "sha256:b50aaac7d05c2c26dfd50c3321199f019ba76bb650e346a6ef3616306eed67b0"}, - {file = "cffi-1.17.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aec510255ce690d240f7cb23d7114f6b351c733a74c279a84def763660a2c3bc"}, - {file = "cffi-1.17.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2770bb0d5e3cc0e31e7318db06efcbcdb7b31bcb1a70086d3177692a02256f59"}, - {file = "cffi-1.17.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:db9a30ec064129d605d0f1aedc93e00894b9334ec74ba9c6bdd08147434b33eb"}, - {file = "cffi-1.17.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a47eef975d2b8b721775a0fa286f50eab535b9d56c70a6e62842134cf7841195"}, - {file = "cffi-1.17.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f3e0992f23bbb0be00a921eae5363329253c3b86287db27092461c887b791e5e"}, - {file = "cffi-1.17.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6107e445faf057c118d5050560695e46d272e5301feffda3c41849641222a828"}, - {file = "cffi-1.17.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eb862356ee9391dc5a0b3cbc00f416b48c1b9a52d252d898e5b7696a5f9fe150"}, - {file = "cffi-1.17.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:c1c13185b90bbd3f8b5963cd8ce7ad4ff441924c31e23c975cb150e27c2bf67a"}, - {file = "cffi-1.17.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:17c6d6d3260c7f2d94f657e6872591fe8733872a86ed1345bda872cfc8c74885"}, - {file = "cffi-1.17.0-cp312-cp312-win32.whl", hash = "sha256:c3b8bd3133cd50f6b637bb4322822c94c5ce4bf0d724ed5ae70afce62187c492"}, - {file = "cffi-1.17.0-cp312-cp312-win_amd64.whl", hash = "sha256:dca802c8db0720ce1c49cce1149ff7b06e91ba15fa84b1d59144fef1a1bc7ac2"}, - {file = "cffi-1.17.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:6ce01337d23884b21c03869d2f68c5523d43174d4fc405490eb0091057943118"}, - {file = "cffi-1.17.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:cab2eba3830bf4f6d91e2d6718e0e1c14a2f5ad1af68a89d24ace0c6b17cced7"}, - {file = "cffi-1.17.0-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:14b9cbc8f7ac98a739558eb86fabc283d4d564dafed50216e7f7ee62d0d25377"}, - {file = "cffi-1.17.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b00e7bcd71caa0282cbe3c90966f738e2db91e64092a877c3ff7f19a1628fdcb"}, - {file = "cffi-1.17.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:41f4915e09218744d8bae14759f983e466ab69b178de38066f7579892ff2a555"}, - {file = "cffi-1.17.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4760a68cab57bfaa628938e9c2971137e05ce48e762a9cb53b76c9b569f1204"}, - {file = "cffi-1.17.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:011aff3524d578a9412c8b3cfaa50f2c0bd78e03eb7af7aa5e0df59b158efb2f"}, - {file = "cffi-1.17.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:a003ac9edc22d99ae1286b0875c460351f4e101f8c9d9d2576e78d7e048f64e0"}, - {file = "cffi-1.17.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ef9528915df81b8f4c7612b19b8628214c65c9b7f74db2e34a646a0a2a0da2d4"}, - {file = "cffi-1.17.0-cp313-cp313-win32.whl", hash = "sha256:70d2aa9fb00cf52034feac4b913181a6e10356019b18ef89bc7c12a283bf5f5a"}, - {file = "cffi-1.17.0-cp313-cp313-win_amd64.whl", hash = "sha256:b7b6ea9e36d32582cda3465f54c4b454f62f23cb083ebc7a94e2ca6ef011c3a7"}, - {file = "cffi-1.17.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:964823b2fc77b55355999ade496c54dde161c621cb1f6eac61dc30ed1b63cd4c"}, - {file = "cffi-1.17.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:516a405f174fd3b88829eabfe4bb296ac602d6a0f68e0d64d5ac9456194a5b7e"}, - {file = "cffi-1.17.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dec6b307ce928e8e112a6bb9921a1cb00a0e14979bf28b98e084a4b8a742bd9b"}, - {file = "cffi-1.17.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e4094c7b464cf0a858e75cd14b03509e84789abf7b79f8537e6a72152109c76e"}, - {file = "cffi-1.17.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2404f3de742f47cb62d023f0ba7c5a916c9c653d5b368cc966382ae4e57da401"}, - {file = "cffi-1.17.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3aa9d43b02a0c681f0bfbc12d476d47b2b2b6a3f9287f11ee42989a268a1833c"}, - {file = "cffi-1.17.0-cp38-cp38-win32.whl", hash = "sha256:0bb15e7acf8ab35ca8b24b90af52c8b391690ef5c4aec3d31f38f0d37d2cc499"}, - {file = "cffi-1.17.0-cp38-cp38-win_amd64.whl", hash = "sha256:93a7350f6706b31f457c1457d3a3259ff9071a66f312ae64dc024f049055f72c"}, - {file = "cffi-1.17.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1a2ddbac59dc3716bc79f27906c010406155031a1c801410f1bafff17ea304d2"}, - {file = "cffi-1.17.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6327b572f5770293fc062a7ec04160e89741e8552bf1c358d1a23eba68166759"}, - {file = "cffi-1.17.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dbc183e7bef690c9abe5ea67b7b60fdbca81aa8da43468287dae7b5c046107d4"}, - {file = "cffi-1.17.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5bdc0f1f610d067c70aa3737ed06e2726fd9d6f7bfee4a351f4c40b6831f4e82"}, - {file = "cffi-1.17.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6d872186c1617d143969defeadac5a904e6e374183e07977eedef9c07c8953bf"}, - {file = "cffi-1.17.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0d46ee4764b88b91f16661a8befc6bfb24806d885e27436fdc292ed7e6f6d058"}, - {file = "cffi-1.17.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f76a90c345796c01d85e6332e81cab6d70de83b829cf1d9762d0a3da59c7932"}, - {file = "cffi-1.17.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0e60821d312f99d3e1569202518dddf10ae547e799d75aef3bca3a2d9e8ee693"}, - {file = "cffi-1.17.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:eb09b82377233b902d4c3fbeeb7ad731cdab579c6c6fda1f763cd779139e47c3"}, - {file = "cffi-1.17.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:24658baf6224d8f280e827f0a50c46ad819ec8ba380a42448e24459daf809cf4"}, - {file = "cffi-1.17.0-cp39-cp39-win32.whl", hash = "sha256:0fdacad9e0d9fc23e519efd5ea24a70348305e8d7d85ecbb1a5fa66dc834e7fb"}, - {file = "cffi-1.17.0-cp39-cp39-win_amd64.whl", hash = "sha256:7cbc78dc018596315d4e7841c8c3a7ae31cc4d638c9b627f87d52e8abaaf2d29"}, - {file = "cffi-1.17.0.tar.gz", hash = "sha256:f3157624b7558b914cb039fd1af735e5e8049a87c817cc215109ad1c8779df76"}, + {file = "cffi-1.17.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:df8b1c11f177bc2313ec4b2d46baec87a5f3e71fc8b45dab2ee7cae86d9aba14"}, + {file = "cffi-1.17.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f2cdc858323644ab277e9bb925ad72ae0e67f69e804f4898c070998d50b1a67"}, + {file = "cffi-1.17.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:edae79245293e15384b51f88b00613ba9f7198016a5948b5dddf4917d4d26382"}, + {file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45398b671ac6d70e67da8e4224a065cec6a93541bb7aebe1b198a61b58c7b702"}, + {file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ad9413ccdeda48c5afdae7e4fa2192157e991ff761e7ab8fdd8926f40b160cc3"}, + {file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5da5719280082ac6bd9aa7becb3938dc9f9cbd57fac7d2871717b1feb0902ab6"}, + {file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bb1a08b8008b281856e5971307cc386a8e9c5b625ac297e853d36da6efe9c17"}, + {file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:045d61c734659cc045141be4bae381a41d89b741f795af1dd018bfb532fd0df8"}, + {file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6883e737d7d9e4899a8a695e00ec36bd4e5e4f18fabe0aca0efe0a4b44cdb13e"}, + {file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6b8b4a92e1c65048ff98cfe1f735ef8f1ceb72e3d5f0c25fdb12087a23da22be"}, + {file = "cffi-1.17.1-cp310-cp310-win32.whl", hash = "sha256:c9c3d058ebabb74db66e431095118094d06abf53284d9c81f27300d0e0d8bc7c"}, + {file = "cffi-1.17.1-cp310-cp310-win_amd64.whl", hash = "sha256:0f048dcf80db46f0098ccac01132761580d28e28bc0f78ae0d58048063317e15"}, + {file = "cffi-1.17.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a45e3c6913c5b87b3ff120dcdc03f6131fa0065027d0ed7ee6190736a74cd401"}, + {file = "cffi-1.17.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:30c5e0cb5ae493c04c8b42916e52ca38079f1b235c2f8ae5f4527b963c401caf"}, + {file = "cffi-1.17.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f75c7ab1f9e4aca5414ed4d8e5c0e303a34f4421f8a0d47a4d019ceff0ab6af4"}, + {file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1ed2dd2972641495a3ec98445e09766f077aee98a1c896dcb4ad0d303628e41"}, + {file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:46bf43160c1a35f7ec506d254e5c890f3c03648a4dbac12d624e4490a7046cd1"}, + {file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a24ed04c8ffd54b0729c07cee15a81d964e6fee0e3d4d342a27b020d22959dc6"}, + {file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:610faea79c43e44c71e1ec53a554553fa22321b65fae24889706c0a84d4ad86d"}, + {file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:a9b15d491f3ad5d692e11f6b71f7857e7835eb677955c00cc0aefcd0669adaf6"}, + {file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:de2ea4b5833625383e464549fec1bc395c1bdeeb5f25c4a3a82b5a8c756ec22f"}, + {file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fc48c783f9c87e60831201f2cce7f3b2e4846bf4d8728eabe54d60700b318a0b"}, + {file = "cffi-1.17.1-cp311-cp311-win32.whl", hash = "sha256:85a950a4ac9c359340d5963966e3e0a94a676bd6245a4b55bc43949eee26a655"}, + {file = "cffi-1.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:caaf0640ef5f5517f49bc275eca1406b0ffa6aa184892812030f04c2abf589a0"}, + {file = "cffi-1.17.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:805b4371bf7197c329fcb3ead37e710d1bca9da5d583f5073b799d5c5bd1eee4"}, + {file = "cffi-1.17.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:733e99bc2df47476e3848417c5a4540522f234dfd4ef3ab7fafdf555b082ec0c"}, + {file = "cffi-1.17.1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1257bdabf294dceb59f5e70c64a3e2f462c30c7ad68092d01bbbfb1c16b1ba36"}, + {file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da95af8214998d77a98cc14e3a3bd00aa191526343078b530ceb0bd710fb48a5"}, + {file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d63afe322132c194cf832bfec0dc69a99fb9bb6bbd550f161a49e9e855cc78ff"}, + {file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f79fc4fc25f1c8698ff97788206bb3c2598949bfe0fef03d299eb1b5356ada99"}, + {file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b62ce867176a75d03a665bad002af8e6d54644fad99a3c70905c543130e39d93"}, + {file = "cffi-1.17.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:386c8bf53c502fff58903061338ce4f4950cbdcb23e2902d86c0f722b786bbe3"}, + {file = "cffi-1.17.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4ceb10419a9adf4460ea14cfd6bc43d08701f0835e979bf821052f1805850fe8"}, + {file = "cffi-1.17.1-cp312-cp312-win32.whl", hash = "sha256:a08d7e755f8ed21095a310a693525137cfe756ce62d066e53f502a83dc550f65"}, + {file = "cffi-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:51392eae71afec0d0c8fb1a53b204dbb3bcabcb3c9b807eedf3e1e6ccf2de903"}, + {file = "cffi-1.17.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f3a2b4222ce6b60e2e8b337bb9596923045681d71e5a082783484d845390938e"}, + {file = "cffi-1.17.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0984a4925a435b1da406122d4d7968dd861c1385afe3b45ba82b750f229811e2"}, + {file = "cffi-1.17.1-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d01b12eeeb4427d3110de311e1774046ad344f5b1a7403101878976ecd7a10f3"}, + {file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:706510fe141c86a69c8ddc029c7910003a17353970cff3b904ff0686a5927683"}, + {file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de55b766c7aa2e2a3092c51e0483d700341182f08e67c63630d5b6f200bb28e5"}, + {file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c59d6e989d07460165cc5ad3c61f9fd8f1b4796eacbd81cee78957842b834af4"}, + {file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd398dbc6773384a17fe0d3e7eeb8d1a21c2200473ee6806bb5e6a8e62bb73dd"}, + {file = "cffi-1.17.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:3edc8d958eb099c634dace3c7e16560ae474aa3803a5df240542b305d14e14ed"}, + {file = "cffi-1.17.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:72e72408cad3d5419375fc87d289076ee319835bdfa2caad331e377589aebba9"}, + {file = "cffi-1.17.1-cp313-cp313-win32.whl", hash = "sha256:e03eab0a8677fa80d646b5ddece1cbeaf556c313dcfac435ba11f107ba117b5d"}, + {file = "cffi-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a"}, + {file = "cffi-1.17.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:636062ea65bd0195bc012fea9321aca499c0504409f413dc88af450b57ffd03b"}, + {file = "cffi-1.17.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c7eac2ef9b63c79431bc4b25f1cd649d7f061a28808cbc6c47b534bd789ef964"}, + {file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e221cf152cff04059d011ee126477f0d9588303eb57e88923578ace7baad17f9"}, + {file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:31000ec67d4221a71bd3f67df918b1f88f676f1c3b535a7eb473255fdc0b83fc"}, + {file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6f17be4345073b0a7b8ea599688f692ac3ef23ce28e5df79c04de519dbc4912c"}, + {file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e2b1fac190ae3ebfe37b979cc1ce69c81f4e4fe5746bb401dca63a9062cdaf1"}, + {file = "cffi-1.17.1-cp38-cp38-win32.whl", hash = "sha256:7596d6620d3fa590f677e9ee430df2958d2d6d6de2feeae5b20e82c00b76fbf8"}, + {file = "cffi-1.17.1-cp38-cp38-win_amd64.whl", hash = "sha256:78122be759c3f8a014ce010908ae03364d00a1f81ab5c7f4a7a5120607ea56e1"}, + {file = "cffi-1.17.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b2ab587605f4ba0bf81dc0cb08a41bd1c0a5906bd59243d56bad7668a6fc6c16"}, + {file = "cffi-1.17.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:28b16024becceed8c6dfbc75629e27788d8a3f9030691a1dbf9821a128b22c36"}, + {file = "cffi-1.17.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1d599671f396c4723d016dbddb72fe8e0397082b0a77a4fab8028923bec050e8"}, + {file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca74b8dbe6e8e8263c0ffd60277de77dcee6c837a3d0881d8c1ead7268c9e576"}, + {file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f7f5baafcc48261359e14bcd6d9bff6d4b28d9103847c9e136694cb0501aef87"}, + {file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:98e3969bcff97cae1b2def8ba499ea3d6f31ddfdb7635374834cf89a1a08ecf0"}, + {file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cdf5ce3acdfd1661132f2a9c19cac174758dc2352bfe37d98aa7512c6b7178b3"}, + {file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9755e4345d1ec879e3849e62222a18c7174d65a6a92d5b346b1863912168b595"}, + {file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f1e22e8c4419538cb197e4dd60acc919d7696e5ef98ee4da4e01d3f8cfa4cc5a"}, + {file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c03e868a0b3bc35839ba98e74211ed2b05d2119be4e8a0f224fba9384f1fe02e"}, + {file = "cffi-1.17.1-cp39-cp39-win32.whl", hash = "sha256:e31ae45bc2e29f6b2abd0de1cc3b9d5205aa847cafaecb8af1476a609a2f6eb7"}, + {file = "cffi-1.17.1-cp39-cp39-win_amd64.whl", hash = "sha256:d016c76bdd850f3c626af19b0542c9677ba156e4ee4fccfdd7848803533ef662"}, + {file = "cffi-1.17.1.tar.gz", hash = "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824"}, ] [package.dependencies] @@ -1416,6 +1416,17 @@ typer = ">=0.9.0" typing-extensions = ">=4.5.0" uvicorn = {version = ">=0.18.3", extras = ["standard"]} +[[package]] +name = "circuitbreaker" +version = "2.0.0" +description = "Python Circuit Breaker pattern implementation" +optional = false +python-versions = "*" +files = [ + {file = "circuitbreaker-2.0.0-py2.py3-none-any.whl", hash = "sha256:c8c6f044b616cd5066368734ce4488020392c962b4bd2869d406d883c36d9859"}, + {file = "circuitbreaker-2.0.0.tar.gz", hash = "sha256:28110761ca81a2accbd6b33186bc8c433e69b0933d85e89f280028dbb8c1dd14"}, +] + [[package]] name = "click" version = "8.1.7" @@ -1708,6 +1719,17 @@ lz4 = ["clickhouse-cityhash (>=1.0.2.1)", "lz4", "lz4 (<=3.0.1)"] numpy = ["numpy (>=1.12.0)", "pandas (>=0.24.0)"] zstd = ["clickhouse-cityhash (>=1.0.2.1)", "zstd"] +[[package]] +name = "cloudpickle" +version = "2.2.1" +description = "Extended pickling support for Python objects" +optional = false +python-versions = ">=3.6" +files = [ + {file = "cloudpickle-2.2.1-py3-none-any.whl", hash = "sha256:61f594d1f4c295fa5cd9014ceb3a1fc4a70b0de1164b94fbc2d854ccba056f9f"}, + {file = "cloudpickle-2.2.1.tar.gz", hash = "sha256:d89684b8de9e34a2a43b3460fbca07d09d6e25ce858df4d5a44240403b6178f5"}, +] + [[package]] name = "cloudscraper" version = "1.2.71" @@ -1774,66 +1796,87 @@ cron = ["capturer (>=2.4)"] [[package]] name = "contourpy" -version = "1.2.1" +version = "1.3.0" description = "Python library for calculating contours of 2D quadrilateral grids" optional = false python-versions = ">=3.9" files = [ - {file = "contourpy-1.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bd7c23df857d488f418439686d3b10ae2fbf9bc256cd045b37a8c16575ea1040"}, - {file = "contourpy-1.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5b9eb0ca724a241683c9685a484da9d35c872fd42756574a7cfbf58af26677fd"}, - {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c75507d0a55378240f781599c30e7776674dbaf883a46d1c90f37e563453480"}, - {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:11959f0ce4a6f7b76ec578576a0b61a28bdc0696194b6347ba3f1c53827178b9"}, - {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eb3315a8a236ee19b6df481fc5f997436e8ade24a9f03dfdc6bd490fea20c6da"}, - {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39f3ecaf76cd98e802f094e0d4fbc6dc9c45a8d0c4d185f0f6c2234e14e5f75b"}, - {file = "contourpy-1.2.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:94b34f32646ca0414237168d68a9157cb3889f06b096612afdd296003fdd32fd"}, - {file = "contourpy-1.2.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:457499c79fa84593f22454bbd27670227874cd2ff5d6c84e60575c8b50a69619"}, - {file = "contourpy-1.2.1-cp310-cp310-win32.whl", hash = "sha256:ac58bdee53cbeba2ecad824fa8159493f0bf3b8ea4e93feb06c9a465d6c87da8"}, - {file = "contourpy-1.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:9cffe0f850e89d7c0012a1fb8730f75edd4320a0a731ed0c183904fe6ecfc3a9"}, - {file = "contourpy-1.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6022cecf8f44e36af10bd9118ca71f371078b4c168b6e0fab43d4a889985dbb5"}, - {file = "contourpy-1.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ef5adb9a3b1d0c645ff694f9bca7702ec2c70f4d734f9922ea34de02294fdf72"}, - {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6150ffa5c767bc6332df27157d95442c379b7dce3a38dff89c0f39b63275696f"}, - {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4c863140fafc615c14a4bf4efd0f4425c02230eb8ef02784c9a156461e62c965"}, - {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2"}, - {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4492d82b3bc7fbb7e3610747b159869468079fe149ec5c4d771fa1f614a14df"}, - {file = "contourpy-1.2.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:49e70d111fee47284d9dd867c9bb9a7058a3c617274900780c43e38d90fe1205"}, - {file = "contourpy-1.2.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b59c0ffceff8d4d3996a45f2bb6f4c207f94684a96bf3d9728dbb77428dd8cb8"}, - {file = "contourpy-1.2.1-cp311-cp311-win32.whl", hash = "sha256:7b4182299f251060996af5249c286bae9361fa8c6a9cda5efc29fe8bfd6062ec"}, - {file = "contourpy-1.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2855c8b0b55958265e8b5888d6a615ba02883b225f2227461aa9127c578a4922"}, - {file = "contourpy-1.2.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:62828cada4a2b850dbef89c81f5a33741898b305db244904de418cc957ff05dc"}, - {file = "contourpy-1.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:309be79c0a354afff9ff7da4aaed7c3257e77edf6c1b448a779329431ee79d7e"}, - {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e785e0f2ef0d567099b9ff92cbfb958d71c2d5b9259981cd9bee81bd194c9a4"}, - {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1cac0a8f71a041aa587410424ad46dfa6a11f6149ceb219ce7dd48f6b02b87a7"}, - {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af3f4485884750dddd9c25cb7e3915d83c2db92488b38ccb77dd594eac84c4a0"}, - {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ce6889abac9a42afd07a562c2d6d4b2b7134f83f18571d859b25624a331c90b"}, - {file = "contourpy-1.2.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a1eea9aecf761c661d096d39ed9026574de8adb2ae1c5bd7b33558af884fb2ce"}, - {file = "contourpy-1.2.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:187fa1d4c6acc06adb0fae5544c59898ad781409e61a926ac7e84b8f276dcef4"}, - {file = "contourpy-1.2.1-cp312-cp312-win32.whl", hash = "sha256:c2528d60e398c7c4c799d56f907664673a807635b857df18f7ae64d3e6ce2d9f"}, - {file = "contourpy-1.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:1a07fc092a4088ee952ddae19a2b2a85757b923217b7eed584fdf25f53a6e7ce"}, - {file = "contourpy-1.2.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bb6834cbd983b19f06908b45bfc2dad6ac9479ae04abe923a275b5f48f1a186b"}, - {file = "contourpy-1.2.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1d59e739ab0e3520e62a26c60707cc3ab0365d2f8fecea74bfe4de72dc56388f"}, - {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd3db01f59fdcbce5b22afad19e390260d6d0222f35a1023d9adc5690a889364"}, - {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a12a813949e5066148712a0626895c26b2578874e4cc63160bb007e6df3436fe"}, - {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fe0ccca550bb8e5abc22f530ec0466136379c01321fd94f30a22231e8a48d985"}, - {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1d59258c3c67c865435d8fbeb35f8c59b8bef3d6f46c1f29f6123556af28445"}, - {file = "contourpy-1.2.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f32c38afb74bd98ce26de7cc74a67b40afb7b05aae7b42924ea990d51e4dac02"}, - {file = "contourpy-1.2.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d31a63bc6e6d87f77d71e1abbd7387ab817a66733734883d1fc0021ed9bfa083"}, - {file = "contourpy-1.2.1-cp39-cp39-win32.whl", hash = "sha256:ddcb8581510311e13421b1f544403c16e901c4e8f09083c881fab2be80ee31ba"}, - {file = "contourpy-1.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9"}, - {file = "contourpy-1.2.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a31f94983fecbac95e58388210427d68cd30fe8a36927980fab9c20062645609"}, - {file = "contourpy-1.2.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef2b055471c0eb466033760a521efb9d8a32b99ab907fc8358481a1dd29e3bd3"}, - {file = "contourpy-1.2.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:b33d2bc4f69caedcd0a275329eb2198f560b325605810895627be5d4b876bf7f"}, - {file = "contourpy-1.2.1.tar.gz", hash = "sha256:4d8908b3bee1c889e547867ca4cdc54e5ab6be6d3e078556814a22457f49423c"}, + {file = "contourpy-1.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:880ea32e5c774634f9fcd46504bf9f080a41ad855f4fef54f5380f5133d343c7"}, + {file = "contourpy-1.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:76c905ef940a4474a6289c71d53122a4f77766eef23c03cd57016ce19d0f7b42"}, + {file = "contourpy-1.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:92f8557cbb07415a4d6fa191f20fd9d2d9eb9c0b61d1b2f52a8926e43c6e9af7"}, + {file = "contourpy-1.3.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:36f965570cff02b874773c49bfe85562b47030805d7d8360748f3eca570f4cab"}, + {file = "contourpy-1.3.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cacd81e2d4b6f89c9f8a5b69b86490152ff39afc58a95af002a398273e5ce589"}, + {file = "contourpy-1.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:69375194457ad0fad3a839b9e29aa0b0ed53bb54db1bfb6c3ae43d111c31ce41"}, + {file = "contourpy-1.3.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:7a52040312b1a858b5e31ef28c2e865376a386c60c0e248370bbea2d3f3b760d"}, + {file = "contourpy-1.3.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3faeb2998e4fcb256542e8a926d08da08977f7f5e62cf733f3c211c2a5586223"}, + {file = "contourpy-1.3.0-cp310-cp310-win32.whl", hash = "sha256:36e0cff201bcb17a0a8ecc7f454fe078437fa6bda730e695a92f2d9932bd507f"}, + {file = "contourpy-1.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:87ddffef1dbe5e669b5c2440b643d3fdd8622a348fe1983fad7a0f0ccb1cd67b"}, + {file = "contourpy-1.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0fa4c02abe6c446ba70d96ece336e621efa4aecae43eaa9b030ae5fb92b309ad"}, + {file = "contourpy-1.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:834e0cfe17ba12f79963861e0f908556b2cedd52e1f75e6578801febcc6a9f49"}, + {file = "contourpy-1.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dbc4c3217eee163fa3984fd1567632b48d6dfd29216da3ded3d7b844a8014a66"}, + {file = "contourpy-1.3.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4865cd1d419e0c7a7bf6de1777b185eebdc51470800a9f42b9e9decf17762081"}, + {file = "contourpy-1.3.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:303c252947ab4b14c08afeb52375b26781ccd6a5ccd81abcdfc1fafd14cf93c1"}, + {file = "contourpy-1.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:637f674226be46f6ba372fd29d9523dd977a291f66ab2a74fbeb5530bb3f445d"}, + {file = "contourpy-1.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:76a896b2f195b57db25d6b44e7e03f221d32fe318d03ede41f8b4d9ba1bff53c"}, + {file = "contourpy-1.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:e1fd23e9d01591bab45546c089ae89d926917a66dceb3abcf01f6105d927e2cb"}, + {file = "contourpy-1.3.0-cp311-cp311-win32.whl", hash = "sha256:d402880b84df3bec6eab53cd0cf802cae6a2ef9537e70cf75e91618a3801c20c"}, + {file = "contourpy-1.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:6cb6cc968059db9c62cb35fbf70248f40994dfcd7aa10444bbf8b3faeb7c2d67"}, + {file = "contourpy-1.3.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:570ef7cf892f0afbe5b2ee410c507ce12e15a5fa91017a0009f79f7d93a1268f"}, + {file = "contourpy-1.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:da84c537cb8b97d153e9fb208c221c45605f73147bd4cadd23bdae915042aad6"}, + {file = "contourpy-1.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0be4d8425bfa755e0fd76ee1e019636ccc7c29f77a7c86b4328a9eb6a26d0639"}, + {file = "contourpy-1.3.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9c0da700bf58f6e0b65312d0a5e695179a71d0163957fa381bb3c1f72972537c"}, + {file = "contourpy-1.3.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eb8b141bb00fa977d9122636b16aa67d37fd40a3d8b52dd837e536d64b9a4d06"}, + {file = "contourpy-1.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3634b5385c6716c258d0419c46d05c8aa7dc8cb70326c9a4fb66b69ad2b52e09"}, + {file = "contourpy-1.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0dce35502151b6bd35027ac39ba6e5a44be13a68f55735c3612c568cac3805fd"}, + {file = "contourpy-1.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:aea348f053c645100612b333adc5983d87be69acdc6d77d3169c090d3b01dc35"}, + {file = "contourpy-1.3.0-cp312-cp312-win32.whl", hash = "sha256:90f73a5116ad1ba7174341ef3ea5c3150ddf20b024b98fb0c3b29034752c8aeb"}, + {file = "contourpy-1.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:b11b39aea6be6764f84360fce6c82211a9db32a7c7de8fa6dd5397cf1d079c3b"}, + {file = "contourpy-1.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:3e1c7fa44aaae40a2247e2e8e0627f4bea3dd257014764aa644f319a5f8600e3"}, + {file = "contourpy-1.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:364174c2a76057feef647c802652f00953b575723062560498dc7930fc9b1cb7"}, + {file = "contourpy-1.3.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32b238b3b3b649e09ce9aaf51f0c261d38644bdfa35cbaf7b263457850957a84"}, + {file = "contourpy-1.3.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d51fca85f9f7ad0b65b4b9fe800406d0d77017d7270d31ec3fb1cc07358fdea0"}, + {file = "contourpy-1.3.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:732896af21716b29ab3e988d4ce14bc5133733b85956316fb0c56355f398099b"}, + {file = "contourpy-1.3.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d73f659398a0904e125280836ae6f88ba9b178b2fed6884f3b1f95b989d2c8da"}, + {file = "contourpy-1.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c6c7c2408b7048082932cf4e641fa3b8ca848259212f51c8c59c45aa7ac18f14"}, + {file = "contourpy-1.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f317576606de89da6b7e0861cf6061f6146ead3528acabff9236458a6ba467f8"}, + {file = "contourpy-1.3.0-cp313-cp313-win32.whl", hash = "sha256:31cd3a85dbdf1fc002280c65caa7e2b5f65e4a973fcdf70dd2fdcb9868069294"}, + {file = "contourpy-1.3.0-cp313-cp313-win_amd64.whl", hash = "sha256:4553c421929ec95fb07b3aaca0fae668b2eb5a5203d1217ca7c34c063c53d087"}, + {file = "contourpy-1.3.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:345af746d7766821d05d72cb8f3845dfd08dd137101a2cb9b24de277d716def8"}, + {file = "contourpy-1.3.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:3bb3808858a9dc68f6f03d319acd5f1b8a337e6cdda197f02f4b8ff67ad2057b"}, + {file = "contourpy-1.3.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:420d39daa61aab1221567b42eecb01112908b2cab7f1b4106a52caaec8d36973"}, + {file = "contourpy-1.3.0-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4d63ee447261e963af02642ffcb864e5a2ee4cbfd78080657a9880b8b1868e18"}, + {file = "contourpy-1.3.0-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:167d6c890815e1dac9536dca00828b445d5d0df4d6a8c6adb4a7ec3166812fa8"}, + {file = "contourpy-1.3.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:710a26b3dc80c0e4febf04555de66f5fd17e9cf7170a7b08000601a10570bda6"}, + {file = "contourpy-1.3.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:75ee7cb1a14c617f34a51d11fa7524173e56551646828353c4af859c56b766e2"}, + {file = "contourpy-1.3.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:33c92cdae89ec5135d036e7218e69b0bb2851206077251f04a6c4e0e21f03927"}, + {file = "contourpy-1.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a11077e395f67ffc2c44ec2418cfebed032cd6da3022a94fc227b6faf8e2acb8"}, + {file = "contourpy-1.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e8134301d7e204c88ed7ab50028ba06c683000040ede1d617298611f9dc6240c"}, + {file = "contourpy-1.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e12968fdfd5bb45ffdf6192a590bd8ddd3ba9e58360b29683c6bb71a7b41edca"}, + {file = "contourpy-1.3.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fd2a0fc506eccaaa7595b7e1418951f213cf8255be2600f1ea1b61e46a60c55f"}, + {file = "contourpy-1.3.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4cfb5c62ce023dfc410d6059c936dcf96442ba40814aefbfa575425a3a7f19dc"}, + {file = "contourpy-1.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:68a32389b06b82c2fdd68276148d7b9275b5f5cf13e5417e4252f6d1a34f72a2"}, + {file = "contourpy-1.3.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:94e848a6b83da10898cbf1311a815f770acc9b6a3f2d646f330d57eb4e87592e"}, + {file = "contourpy-1.3.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:d78ab28a03c854a873787a0a42254a0ccb3cb133c672f645c9f9c8f3ae9d0800"}, + {file = "contourpy-1.3.0-cp39-cp39-win32.whl", hash = "sha256:81cb5ed4952aae6014bc9d0421dec7c5835c9c8c31cdf51910b708f548cf58e5"}, + {file = "contourpy-1.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:14e262f67bd7e6eb6880bc564dcda30b15e351a594657e55b7eec94b6ef72843"}, + {file = "contourpy-1.3.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:fe41b41505a5a33aeaed2a613dccaeaa74e0e3ead6dd6fd3a118fb471644fd6c"}, + {file = "contourpy-1.3.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eca7e17a65f72a5133bdbec9ecf22401c62bcf4821361ef7811faee695799779"}, + {file = "contourpy-1.3.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:1ec4dc6bf570f5b22ed0d7efba0dfa9c5b9e0431aeea7581aa217542d9e809a4"}, + {file = "contourpy-1.3.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:00ccd0dbaad6d804ab259820fa7cb0b8036bda0686ef844d24125d8287178ce0"}, + {file = "contourpy-1.3.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ca947601224119117f7c19c9cdf6b3ab54c5726ef1d906aa4a69dfb6dd58102"}, + {file = "contourpy-1.3.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:c6ec93afeb848a0845a18989da3beca3eec2c0f852322efe21af1931147d12cb"}, + {file = "contourpy-1.3.0.tar.gz", hash = "sha256:7ffa0db17717a8ffb127efd0c95a4362d996b892c2904db72428d5b52e1938a4"}, ] [package.dependencies] -numpy = ">=1.20" +numpy = ">=1.23" [package.extras] bokeh = ["bokeh", "selenium"] docs = ["furo", "sphinx (>=7.2)", "sphinx-copybutton"] -mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.8.0)", "types-Pillow"] +mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.11.1)", "types-Pillow"] test = ["Pillow", "contourpy[test-no-images]", "matplotlib"] -test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"] +test-no-images = ["pytest", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "wurlitzer"] [[package]] name = "cos-python-sdk-v5" @@ -1936,38 +1979,43 @@ files = [ [[package]] name = "cryptography" -version = "43.0.0" +version = "42.0.8" description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." optional = false python-versions = ">=3.7" files = [ - {file = "cryptography-43.0.0-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:64c3f16e2a4fc51c0d06af28441881f98c5d91009b8caaff40cf3548089e9c74"}, - {file = "cryptography-43.0.0-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3dcdedae5c7710b9f97ac6bba7e1052b95c7083c9d0e9df96e02a1932e777895"}, - {file = "cryptography-43.0.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d9a1eca329405219b605fac09ecfc09ac09e595d6def650a437523fcd08dd22"}, - {file = "cryptography-43.0.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ea9e57f8ea880eeea38ab5abf9fbe39f923544d7884228ec67d666abd60f5a47"}, - {file = "cryptography-43.0.0-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:9a8d6802e0825767476f62aafed40532bd435e8a5f7d23bd8b4f5fd04cc80ecf"}, - {file = "cryptography-43.0.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:cc70b4b581f28d0a254d006f26949245e3657d40d8857066c2ae22a61222ef55"}, - {file = "cryptography-43.0.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:4a997df8c1c2aae1e1e5ac49c2e4f610ad037fc5a3aadc7b64e39dea42249431"}, - {file = "cryptography-43.0.0-cp37-abi3-win32.whl", hash = "sha256:6e2b11c55d260d03a8cf29ac9b5e0608d35f08077d8c087be96287f43af3ccdc"}, - {file = "cryptography-43.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:31e44a986ceccec3d0498e16f3d27b2ee5fdf69ce2ab89b52eaad1d2f33d8778"}, - {file = "cryptography-43.0.0-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:7b3f5fe74a5ca32d4d0f302ffe6680fcc5c28f8ef0dc0ae8f40c0f3a1b4fca66"}, - {file = "cryptography-43.0.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac1955ce000cb29ab40def14fd1bbfa7af2017cca696ee696925615cafd0dce5"}, - {file = "cryptography-43.0.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:299d3da8e00b7e2b54bb02ef58d73cd5f55fb31f33ebbf33bd00d9aa6807df7e"}, - {file = "cryptography-43.0.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ee0c405832ade84d4de74b9029bedb7b31200600fa524d218fc29bfa371e97f5"}, - {file = "cryptography-43.0.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:cb013933d4c127349b3948aa8aaf2f12c0353ad0eccd715ca789c8a0f671646f"}, - {file = "cryptography-43.0.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:fdcb265de28585de5b859ae13e3846a8e805268a823a12a4da2597f1f5afc9f0"}, - {file = "cryptography-43.0.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:2905ccf93a8a2a416f3ec01b1a7911c3fe4073ef35640e7ee5296754e30b762b"}, - {file = "cryptography-43.0.0-cp39-abi3-win32.whl", hash = "sha256:47ca71115e545954e6c1d207dd13461ab81f4eccfcb1345eac874828b5e3eaaf"}, - {file = "cryptography-43.0.0-cp39-abi3-win_amd64.whl", hash = "sha256:0663585d02f76929792470451a5ba64424acc3cd5227b03921dab0e2f27b1709"}, - {file = "cryptography-43.0.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:2c6d112bf61c5ef44042c253e4859b3cbbb50df2f78fa8fae6747a7814484a70"}, - {file = "cryptography-43.0.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:844b6d608374e7d08f4f6e6f9f7b951f9256db41421917dfb2d003dde4cd6b66"}, - {file = "cryptography-43.0.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:51956cf8730665e2bdf8ddb8da0056f699c1a5715648c1b0144670c1ba00b48f"}, - {file = "cryptography-43.0.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:aae4d918f6b180a8ab8bf6511a419473d107df4dbb4225c7b48c5c9602c38c7f"}, - {file = "cryptography-43.0.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:232ce02943a579095a339ac4b390fbbe97f5b5d5d107f8a08260ea2768be8cc2"}, - {file = "cryptography-43.0.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:5bcb8a5620008a8034d39bce21dc3e23735dfdb6a33a06974739bfa04f853947"}, - {file = "cryptography-43.0.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:08a24a7070b2b6804c1940ff0f910ff728932a9d0e80e7814234269f9d46d069"}, - {file = "cryptography-43.0.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:e9c5266c432a1e23738d178e51c2c7a5e2ddf790f248be939448c0ba2021f9d1"}, - {file = "cryptography-43.0.0.tar.gz", hash = "sha256:b88075ada2d51aa9f18283532c9f60e72170041bba88d7f37e49cbb10275299e"}, + {file = "cryptography-42.0.8-cp37-abi3-macosx_10_12_universal2.whl", hash = "sha256:81d8a521705787afe7a18d5bfb47ea9d9cc068206270aad0b96a725022e18d2e"}, + {file = "cryptography-42.0.8-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:961e61cefdcb06e0c6d7e3a1b22ebe8b996eb2bf50614e89384be54c48c6b63d"}, + {file = "cryptography-42.0.8-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e3ec3672626e1b9e55afd0df6d774ff0e953452886e06e0f1eb7eb0c832e8902"}, + {file = "cryptography-42.0.8-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e599b53fd95357d92304510fb7bda8523ed1f79ca98dce2f43c115950aa78801"}, + {file = "cryptography-42.0.8-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:5226d5d21ab681f432a9c1cf8b658c0cb02533eece706b155e5fbd8a0cdd3949"}, + {file = "cryptography-42.0.8-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:6b7c4f03ce01afd3b76cf69a5455caa9cfa3de8c8f493e0d3ab7d20611c8dae9"}, + {file = "cryptography-42.0.8-cp37-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:2346b911eb349ab547076f47f2e035fc8ff2c02380a7cbbf8d87114fa0f1c583"}, + {file = "cryptography-42.0.8-cp37-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:ad803773e9df0b92e0a817d22fd8a3675493f690b96130a5e24f1b8fabbea9c7"}, + {file = "cryptography-42.0.8-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:2f66d9cd9147ee495a8374a45ca445819f8929a3efcd2e3df6428e46c3cbb10b"}, + {file = "cryptography-42.0.8-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:d45b940883a03e19e944456a558b67a41160e367a719833c53de6911cabba2b7"}, + {file = "cryptography-42.0.8-cp37-abi3-win32.whl", hash = "sha256:a0c5b2b0585b6af82d7e385f55a8bc568abff8923af147ee3c07bd8b42cda8b2"}, + {file = "cryptography-42.0.8-cp37-abi3-win_amd64.whl", hash = "sha256:57080dee41209e556a9a4ce60d229244f7a66ef52750f813bfbe18959770cfba"}, + {file = "cryptography-42.0.8-cp39-abi3-macosx_10_12_universal2.whl", hash = "sha256:dea567d1b0e8bc5764b9443858b673b734100c2871dc93163f58c46a97a83d28"}, + {file = "cryptography-42.0.8-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c4783183f7cb757b73b2ae9aed6599b96338eb957233c58ca8f49a49cc32fd5e"}, + {file = "cryptography-42.0.8-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0608251135d0e03111152e41f0cc2392d1e74e35703960d4190b2e0f4ca9c70"}, + {file = "cryptography-42.0.8-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:dc0fdf6787f37b1c6b08e6dfc892d9d068b5bdb671198c72072828b80bd5fe4c"}, + {file = "cryptography-42.0.8-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:9c0c1716c8447ee7dbf08d6db2e5c41c688544c61074b54fc4564196f55c25a7"}, + {file = "cryptography-42.0.8-cp39-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:fff12c88a672ab9c9c1cf7b0c80e3ad9e2ebd9d828d955c126be4fd3e5578c9e"}, + {file = "cryptography-42.0.8-cp39-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:cafb92b2bc622cd1aa6a1dce4b93307792633f4c5fe1f46c6b97cf67073ec961"}, + {file = "cryptography-42.0.8-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:31f721658a29331f895a5a54e7e82075554ccfb8b163a18719d342f5ffe5ecb1"}, + {file = "cryptography-42.0.8-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b297f90c5723d04bcc8265fc2a0f86d4ea2e0f7ab4b6994459548d3a6b992a14"}, + {file = "cryptography-42.0.8-cp39-abi3-win32.whl", hash = "sha256:2f88d197e66c65be5e42cd72e5c18afbfae3f741742070e3019ac8f4ac57262c"}, + {file = "cryptography-42.0.8-cp39-abi3-win_amd64.whl", hash = "sha256:fa76fbb7596cc5839320000cdd5d0955313696d9511debab7ee7278fc8b5c84a"}, + {file = "cryptography-42.0.8-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:ba4f0a211697362e89ad822e667d8d340b4d8d55fae72cdd619389fb5912eefe"}, + {file = "cryptography-42.0.8-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:81884c4d096c272f00aeb1f11cf62ccd39763581645b0812e99a91505fa48e0c"}, + {file = "cryptography-42.0.8-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:c9bb2ae11bfbab395bdd072985abde58ea9860ed84e59dbc0463a5d0159f5b71"}, + {file = "cryptography-42.0.8-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:7016f837e15b0a1c119d27ecd89b3515f01f90a8615ed5e9427e30d9cdbfed3d"}, + {file = "cryptography-42.0.8-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5a94eccb2a81a309806027e1670a358b99b8fe8bfe9f8d329f27d72c094dde8c"}, + {file = "cryptography-42.0.8-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:dec9b018df185f08483f294cae6ccac29e7a6e0678996587363dc352dc65c842"}, + {file = "cryptography-42.0.8-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:343728aac38decfdeecf55ecab3264b015be68fc2816ca800db649607aeee648"}, + {file = "cryptography-42.0.8-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:013629ae70b40af70c9a7a5db40abe5d9054e6f4380e50ce769947b73bf3caad"}, + {file = "cryptography-42.0.8.tar.gz", hash = "sha256:8d09d05439ce7baa8e9e95b07ec5b6c886f548deb7e0f69ef25f64b3bce842f2"}, ] [package.dependencies] @@ -1980,7 +2028,7 @@ nox = ["nox"] pep8test = ["check-sdist", "click", "mypy", "ruff"] sdist = ["build"] ssh = ["bcrypt (>=3.1.5)"] -test = ["certifi", "cryptography-vectors (==43.0.0)", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-xdist"] +test = ["certifi", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-xdist"] test-randomorder = ["pytest-randomly"] [[package]] @@ -2114,6 +2162,21 @@ wrapt = ">=1.10,<2" [package.extras] dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] +[[package]] +name = "dill" +version = "0.3.8" +description = "serialize all of Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7"}, + {file = "dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca"}, +] + +[package.extras] +graph = ["objgraph (>=1.7.2)"] +profile = ["gprof2dot (>=2022.7.29)"] + [[package]] name = "distro" version = "1.9.0" @@ -2125,6 +2188,28 @@ files = [ {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, ] +[[package]] +name = "docker" +version = "7.1.0" +description = "A Python library for the Docker Engine API." +optional = false +python-versions = ">=3.8" +files = [ + {file = "docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0"}, + {file = "docker-7.1.0.tar.gz", hash = "sha256:ad8c70e6e3f8926cb8a92619b832b4ea5299e2831c14284663184e200546fa6c"}, +] + +[package.dependencies] +pywin32 = {version = ">=304", markers = "sys_platform == \"win32\""} +requests = ">=2.26.0" +urllib3 = ">=1.26.0" + +[package.extras] +dev = ["coverage (==7.2.7)", "pytest (==7.4.2)", "pytest-cov (==4.1.0)", "pytest-timeout (==2.1.0)", "ruff (==0.1.8)"] +docs = ["myst-parser (==0.18.0)", "sphinx (==5.1.1)"] +ssh = ["paramiko (>=2.4.3)"] +websockets = ["websocket-client (>=1.3.0)"] + [[package]] name = "docstring-parser" version = "0.16" @@ -2211,13 +2296,13 @@ files = [ [[package]] name = "duckduckgo-search" -version = "6.2.10" +version = "6.2.11" description = "Search for words, documents, images, news, maps and text translation using the DuckDuckGo.com search engine." optional = false python-versions = ">=3.8" files = [ - {file = "duckduckgo_search-6.2.10-py3-none-any.whl", hash = "sha256:266c1528dcbc90931b7c800a2c1041a0cb447c83c485414d77a7e443be717ed6"}, - {file = "duckduckgo_search-6.2.10.tar.gz", hash = "sha256:53057368480ca496fc4e331a34648124711580cf43fbb65336eaa6fd2ee37cec"}, + {file = "duckduckgo_search-6.2.11-py3-none-any.whl", hash = "sha256:6fb7069b79e8928f487001de6859034ade19201bdcd257ec198802430e374bfe"}, + {file = "duckduckgo_search-6.2.11.tar.gz", hash = "sha256:6b6ef1b552c5e67f23e252025d2504caf6f9fc14f70e86c6dd512200f386c673"}, ] [package.dependencies] @@ -2304,6 +2389,19 @@ django = ["dj-database-url", "dj-email-url", "django-cache-url"] lint = ["flake8 (==4.0.1)", "flake8-bugbear (==21.9.2)", "mypy (==0.910)", "pre-commit (>=2.4,<3.0)"] tests = ["dj-database-url", "dj-email-url", "django-cache-url", "pytest"] +[[package]] +name = "esdk-obs-python" +version = "3.24.6.1" +description = "OBS Python SDK" +optional = false +python-versions = "*" +files = [ + {file = "esdk-obs-python-3.24.6.1.tar.gz", hash = "sha256:c45fed143e99d9256c8560c1d78f651eae0d2e809d16e962f8b286b773c33bf0"}, +] + +[package.dependencies] +pycryptodome = ">=3.10.1" + [[package]] name = "et-xmlfile" version = "1.1.0" @@ -2331,13 +2429,13 @@ test = ["pytest (>=6)"] [[package]] name = "fastapi" -version = "0.112.1" +version = "0.113.0" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" optional = false python-versions = ">=3.8" files = [ - {file = "fastapi-0.112.1-py3-none-any.whl", hash = "sha256:bcbd45817fc2a1cd5da09af66815b84ec0d3d634eb173d1ab468ae3103e183e4"}, - {file = "fastapi-0.112.1.tar.gz", hash = "sha256:b2537146f8c23389a7faa8b03d0bd38d4986e6983874557d95eed2acc46448ef"}, + {file = "fastapi-0.113.0-py3-none-any.whl", hash = "sha256:c8d364485b6361fb643d53920a18d58a696e189abcb901ec03b487e35774c476"}, + {file = "fastapi-0.113.0.tar.gz", hash = "sha256:b7cf9684dc154dfc93f8b718e5850577b529889096518df44defa41e73caf50f"}, ] [package.dependencies] @@ -2346,47 +2444,47 @@ starlette = ">=0.37.2,<0.39.0" typing-extensions = ">=4.8.0" [package.extras] -all = ["email_validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] -standard = ["email_validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "jinja2 (>=2.11.2)", "python-multipart (>=0.0.7)", "uvicorn[standard] (>=0.12.0)"] +all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +standard = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "jinja2 (>=2.11.2)", "python-multipart (>=0.0.7)", "uvicorn[standard] (>=0.12.0)"] [[package]] name = "fastavro" -version = "1.9.5" +version = "1.9.7" description = "Fast read/write of AVRO files" optional = false python-versions = ">=3.8" files = [ - {file = "fastavro-1.9.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:61253148e95dd2b6457247b441b7555074a55de17aef85f5165bfd5facf600fc"}, - {file = "fastavro-1.9.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b604935d671ad47d888efc92a106f98e9440874108b444ac10e28d643109c937"}, - {file = "fastavro-1.9.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0adbf4956fd53bd74c41e7855bb45ccce953e0eb0e44f5836d8d54ad843f9944"}, - {file = "fastavro-1.9.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:53d838e31457db8bf44460c244543f75ed307935d5fc1d93bc631cc7caef2082"}, - {file = "fastavro-1.9.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:07b6288e8681eede16ff077632c47395d4925c2f51545cd7a60f194454db2211"}, - {file = "fastavro-1.9.5-cp310-cp310-win_amd64.whl", hash = "sha256:ef08cf247fdfd61286ac0c41854f7194f2ad05088066a756423d7299b688d975"}, - {file = "fastavro-1.9.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c52d7bb69f617c90935a3e56feb2c34d4276819a5c477c466c6c08c224a10409"}, - {file = "fastavro-1.9.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85e05969956003df8fa4491614bc62fe40cec59e94d06e8aaa8d8256ee3aab82"}, - {file = "fastavro-1.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06e6df8527493a9f0d9a8778df82bab8b1aa6d80d1b004e5aec0a31dc4dc501c"}, - {file = "fastavro-1.9.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:27820da3b17bc01cebb6d1687c9d7254b16d149ef458871aaa207ed8950f3ae6"}, - {file = "fastavro-1.9.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:195a5b8e33eb89a1a9b63fa9dce7a77d41b3b0cd785bac6044df619f120361a2"}, - {file = "fastavro-1.9.5-cp311-cp311-win_amd64.whl", hash = "sha256:be612c109efb727bfd36d4d7ed28eb8e0506617b7dbe746463ebbf81e85eaa6b"}, - {file = "fastavro-1.9.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:b133456c8975ec7d2a99e16a7e68e896e45c821b852675eac4ee25364b999c14"}, - {file = "fastavro-1.9.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf586373c3d1748cac849395aad70c198ee39295f92e7c22c75757b5c0300fbe"}, - {file = "fastavro-1.9.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:724ef192bc9c55d5b4c7df007f56a46a21809463499856349d4580a55e2b914c"}, - {file = "fastavro-1.9.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:bfd11fe355a8f9c0416803afac298960eb4c603a23b1c74ff9c1d3e673ea7185"}, - {file = "fastavro-1.9.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:9827d1654d7bcb118ef5efd3e5b2c9ab2a48d44dac5e8c6a2327bc3ac3caa828"}, - {file = "fastavro-1.9.5-cp312-cp312-win_amd64.whl", hash = "sha256:d84b69dca296667e6137ae7c9a96d060123adbc0c00532cc47012b64d38b47e9"}, - {file = "fastavro-1.9.5-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:fb744e9de40fb1dc75354098c8db7da7636cba50a40f7bef3b3fb20f8d189d88"}, - {file = "fastavro-1.9.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:240df8bacd13ff5487f2465604c007d686a566df5cbc01d0550684eaf8ff014a"}, - {file = "fastavro-1.9.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3bb35c25bbc3904e1c02333bc1ae0173e0a44aa37a8e95d07e681601246e1f1"}, - {file = "fastavro-1.9.5-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:b47a54a9700de3eabefd36dabfb237808acae47bc873cada6be6990ef6b165aa"}, - {file = "fastavro-1.9.5-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:48c7b5e6d2f3bf7917af301c275b05c5be3dd40bb04e80979c9e7a2ab31a00d1"}, - {file = "fastavro-1.9.5-cp38-cp38-win_amd64.whl", hash = "sha256:05d13f98d4e325be40387e27da9bd60239968862fe12769258225c62ec906f04"}, - {file = "fastavro-1.9.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5b47948eb196263f6111bf34e1cd08d55529d4ed46eb50c1bc8c7c30a8d18868"}, - {file = "fastavro-1.9.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85b7a66ad521298ad9373dfe1897a6ccfc38feab54a47b97922e213ae5ad8870"}, - {file = "fastavro-1.9.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44cb154f863ad80e41aea72a709b12e1533b8728c89b9b1348af91a6154ab2f5"}, - {file = "fastavro-1.9.5-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:b5f7f2b1fe21231fd01f1a2a90e714ae267fe633cd7ce930c0aea33d1c9f4901"}, - {file = "fastavro-1.9.5-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:88fbbe16c61d90a89d78baeb5a34dc1c63a27b115adccdbd6b1fb6f787deacf2"}, - {file = "fastavro-1.9.5-cp39-cp39-win_amd64.whl", hash = "sha256:753f5eedeb5ca86004e23a9ce9b41c5f25eb64a876f95edcc33558090a7f3e4b"}, - {file = "fastavro-1.9.5.tar.gz", hash = "sha256:6419ebf45f88132a9945c51fe555d4f10bb97c236288ed01894f957c6f914553"}, + {file = "fastavro-1.9.7-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cc811fb4f7b5ae95f969cda910241ceacf82e53014c7c7224df6f6e0ca97f52f"}, + {file = "fastavro-1.9.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb8749e419a85f251bf1ac87d463311874972554d25d4a0b19f6bdc56036d7cf"}, + {file = "fastavro-1.9.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b2f9bafa167cb4d1c3dd17565cb5bf3d8c0759e42620280d1760f1e778e07fc"}, + {file = "fastavro-1.9.7-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e87d04b235b29f7774d226b120da2ca4e60b9e6fdf6747daef7f13f218b3517a"}, + {file = "fastavro-1.9.7-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b525c363e267ed11810aaad8fbdbd1c3bd8837d05f7360977d72a65ab8c6e1fa"}, + {file = "fastavro-1.9.7-cp310-cp310-win_amd64.whl", hash = "sha256:6312fa99deecc319820216b5e1b1bd2d7ebb7d6f221373c74acfddaee64e8e60"}, + {file = "fastavro-1.9.7-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ec8499dc276c2d2ef0a68c0f1ad11782b2b956a921790a36bf4c18df2b8d4020"}, + {file = "fastavro-1.9.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76d9d96f98052615ab465c63ba8b76ed59baf2e3341b7b169058db104cbe2aa0"}, + {file = "fastavro-1.9.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:919f3549e07a8a8645a2146f23905955c35264ac809f6c2ac18142bc5b9b6022"}, + {file = "fastavro-1.9.7-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:9de1fa832a4d9016724cd6facab8034dc90d820b71a5d57c7e9830ffe90f31e4"}, + {file = "fastavro-1.9.7-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1d09227d1f48f13281bd5ceac958650805aef9a4ef4f95810128c1f9be1df736"}, + {file = "fastavro-1.9.7-cp311-cp311-win_amd64.whl", hash = "sha256:2db993ae6cdc63e25eadf9f93c9e8036f9b097a3e61d19dca42536dcc5c4d8b3"}, + {file = "fastavro-1.9.7-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:4e1289b731214a7315884c74b2ec058b6e84380ce9b18b8af5d387e64b18fc44"}, + {file = "fastavro-1.9.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eac69666270a76a3a1d0444f39752061195e79e146271a568777048ffbd91a27"}, + {file = "fastavro-1.9.7-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9be089be8c00f68e343bbc64ca6d9a13e5e5b0ba8aa52bcb231a762484fb270e"}, + {file = "fastavro-1.9.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d576eccfd60a18ffa028259500df67d338b93562c6700e10ef68bbd88e499731"}, + {file = "fastavro-1.9.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ee9bf23c157bd7dcc91ea2c700fa3bd924d9ec198bb428ff0b47fa37fe160659"}, + {file = "fastavro-1.9.7-cp312-cp312-win_amd64.whl", hash = "sha256:b6b2ccdc78f6afc18c52e403ee68c00478da12142815c1bd8a00973138a166d0"}, + {file = "fastavro-1.9.7-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:7313def3aea3dacface0a8b83f6d66e49a311149aa925c89184a06c1ef99785d"}, + {file = "fastavro-1.9.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:536f5644737ad21d18af97d909dba099b9e7118c237be7e4bd087c7abde7e4f0"}, + {file = "fastavro-1.9.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2af559f30383b79cf7d020a6b644c42ffaed3595f775fe8f3d7f80b1c43dfdc5"}, + {file = "fastavro-1.9.7-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:edc28ab305e3c424de5ac5eb87b48d1e07eddb6aa08ef5948fcda33cc4d995ce"}, + {file = "fastavro-1.9.7-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:ec2e96bdabd58427fe683329b3d79f42c7b4f4ff6b3644664a345a655ac2c0a1"}, + {file = "fastavro-1.9.7-cp38-cp38-win_amd64.whl", hash = "sha256:3b683693c8a85ede496ebebe115be5d7870c150986e34a0442a20d88d7771224"}, + {file = "fastavro-1.9.7-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:58f76a5c9a312fbd37b84e49d08eb23094d36e10d43bc5df5187bc04af463feb"}, + {file = "fastavro-1.9.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56304401d2f4f69f5b498bdd1552c13ef9a644d522d5de0dc1d789cf82f47f73"}, + {file = "fastavro-1.9.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fcce036c6aa06269fc6a0428050fcb6255189997f5e1a728fc461e8b9d3e26b"}, + {file = "fastavro-1.9.7-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:17de68aae8c2525f5631d80f2b447a53395cdc49134f51b0329a5497277fc2d2"}, + {file = "fastavro-1.9.7-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:7c911366c625d0a997eafe0aa83ffbc6fd00d8fd4543cb39a97c6f3b8120ea87"}, + {file = "fastavro-1.9.7-cp39-cp39-win_amd64.whl", hash = "sha256:912283ed48578a103f523817fdf0c19b1755cea9b4a6387b73c79ecb8f8f84fc"}, + {file = "fastavro-1.9.7.tar.gz", hash = "sha256:13e11c6cb28626da85290933027cd419ce3f9ab8e45410ef24ce6b89d20a1f6c"}, ] [package.extras] @@ -2491,13 +2589,13 @@ flask = "*" [[package]] name = "flask-cors" -version = "4.0.1" +version = "4.0.2" description = "A Flask extension adding a decorator for CORS support" optional = false python-versions = "*" files = [ - {file = "Flask_Cors-4.0.1-py2.py3-none-any.whl", hash = "sha256:f2a704e4458665580c074b714c4627dd5a306b333deb9074d0b1794dfa2fb677"}, - {file = "flask_cors-4.0.1.tar.gz", hash = "sha256:eeb69b342142fdbf4766ad99357a7f3876a2ceb77689dc10ff912aac06c389e4"}, + {file = "Flask_Cors-4.0.2-py2.py3-none-any.whl", hash = "sha256:38364faf1a7a5d0a55bd1d2e2f83ee9e359039182f5e6a029557e1f56d92c09a"}, + {file = "flask_cors-4.0.2.tar.gz", hash = "sha256:493b98e2d1e2f1a4720a7af25693ef2fe32fbafec09a2f72c59f3e475eda61d2"}, ] [package.dependencies] @@ -2793,13 +2891,13 @@ files = [ [[package]] name = "fsspec" -version = "2024.6.1" +version = "2024.9.0" description = "File-system specification" optional = false python-versions = ">=3.8" files = [ - {file = "fsspec-2024.6.1-py3-none-any.whl", hash = "sha256:3cb443f8bcd2efb31295a5b9fdb02aee81d8452c80d28f97a6d0959e6cee101e"}, - {file = "fsspec-2024.6.1.tar.gz", hash = "sha256:fad7d7e209dd4c1208e3bbfda706620e0da5142bebbd9c384afb95b07e798e49"}, + {file = "fsspec-2024.9.0-py3-none-any.whl", hash = "sha256:a0947d552d8a6efa72cc2c730b12c41d043509156966cca4fb157b0f2a0c574b"}, + {file = "fsspec-2024.9.0.tar.gz", hash = "sha256:4b0afb90c2f21832df142f292649035d80b421f60a9e1c027802e5a0da2b04e8"}, ] [package.extras] @@ -3199,79 +3297,38 @@ protobuf = ["protobuf (<5.0.0dev)"] [[package]] name = "google-crc32c" -version = "1.5.0" +version = "1.6.0" description = "A python wrapper of the C library 'Google CRC32C'" optional = false -python-versions = ">=3.7" +python-versions = ">=3.9" files = [ - {file = "google-crc32c-1.5.0.tar.gz", hash = "sha256:89284716bc6a5a415d4eaa11b1726d2d60a0cd12aadf5439828353662ede9dd7"}, - {file = "google_crc32c-1.5.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:596d1f98fc70232fcb6590c439f43b350cb762fb5d61ce7b0e9db4539654cc13"}, - {file = "google_crc32c-1.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:be82c3c8cfb15b30f36768797a640e800513793d6ae1724aaaafe5bf86f8f346"}, - {file = "google_crc32c-1.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:461665ff58895f508e2866824a47bdee72497b091c730071f2b7575d5762ab65"}, - {file = "google_crc32c-1.5.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2096eddb4e7c7bdae4bd69ad364e55e07b8316653234a56552d9c988bd2d61b"}, - {file = "google_crc32c-1.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:116a7c3c616dd14a3de8c64a965828b197e5f2d121fedd2f8c5585c547e87b02"}, - {file = "google_crc32c-1.5.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:5829b792bf5822fd0a6f6eb34c5f81dd074f01d570ed7f36aa101d6fc7a0a6e4"}, - {file = "google_crc32c-1.5.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:64e52e2b3970bd891309c113b54cf0e4384762c934d5ae56e283f9a0afcd953e"}, - {file = "google_crc32c-1.5.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:02ebb8bf46c13e36998aeaad1de9b48f4caf545e91d14041270d9dca767b780c"}, - {file = "google_crc32c-1.5.0-cp310-cp310-win32.whl", hash = "sha256:2e920d506ec85eb4ba50cd4228c2bec05642894d4c73c59b3a2fe20346bd00ee"}, - {file = "google_crc32c-1.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:07eb3c611ce363c51a933bf6bd7f8e3878a51d124acfc89452a75120bc436289"}, - {file = "google_crc32c-1.5.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:cae0274952c079886567f3f4f685bcaf5708f0a23a5f5216fdab71f81a6c0273"}, - {file = "google_crc32c-1.5.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1034d91442ead5a95b5aaef90dbfaca8633b0247d1e41621d1e9f9db88c36298"}, - {file = "google_crc32c-1.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c42c70cd1d362284289c6273adda4c6af8039a8ae12dc451dcd61cdabb8ab57"}, - {file = "google_crc32c-1.5.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8485b340a6a9e76c62a7dce3c98e5f102c9219f4cfbf896a00cf48caf078d438"}, - {file = "google_crc32c-1.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:77e2fd3057c9d78e225fa0a2160f96b64a824de17840351b26825b0848022906"}, - {file = "google_crc32c-1.5.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f583edb943cf2e09c60441b910d6a20b4d9d626c75a36c8fcac01a6c96c01183"}, - {file = "google_crc32c-1.5.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:a1fd716e7a01f8e717490fbe2e431d2905ab8aa598b9b12f8d10abebb36b04dd"}, - {file = "google_crc32c-1.5.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:72218785ce41b9cfd2fc1d6a017dc1ff7acfc4c17d01053265c41a2c0cc39b8c"}, - {file = "google_crc32c-1.5.0-cp311-cp311-win32.whl", hash = "sha256:66741ef4ee08ea0b2cc3c86916ab66b6aef03768525627fd6a1b34968b4e3709"}, - {file = "google_crc32c-1.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:ba1eb1843304b1e5537e1fca632fa894d6f6deca8d6389636ee5b4797affb968"}, - {file = "google_crc32c-1.5.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:98cb4d057f285bd80d8778ebc4fde6b4d509ac3f331758fb1528b733215443ae"}, - {file = "google_crc32c-1.5.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd8536e902db7e365f49e7d9029283403974ccf29b13fc7028b97e2295b33556"}, - {file = "google_crc32c-1.5.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:19e0a019d2c4dcc5e598cd4a4bc7b008546b0358bd322537c74ad47a5386884f"}, - {file = "google_crc32c-1.5.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02c65b9817512edc6a4ae7c7e987fea799d2e0ee40c53ec573a692bee24de876"}, - {file = "google_crc32c-1.5.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:6ac08d24c1f16bd2bf5eca8eaf8304812f44af5cfe5062006ec676e7e1d50afc"}, - {file = "google_crc32c-1.5.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:3359fc442a743e870f4588fcf5dcbc1bf929df1fad8fb9905cd94e5edb02e84c"}, - {file = "google_crc32c-1.5.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:1e986b206dae4476f41bcec1faa057851f3889503a70e1bdb2378d406223994a"}, - {file = "google_crc32c-1.5.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:de06adc872bcd8c2a4e0dc51250e9e65ef2ca91be023b9d13ebd67c2ba552e1e"}, - {file = "google_crc32c-1.5.0-cp37-cp37m-win32.whl", hash = "sha256:d3515f198eaa2f0ed49f8819d5732d70698c3fa37384146079b3799b97667a94"}, - {file = "google_crc32c-1.5.0-cp37-cp37m-win_amd64.whl", hash = "sha256:67b741654b851abafb7bc625b6d1cdd520a379074e64b6a128e3b688c3c04740"}, - {file = "google_crc32c-1.5.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c02ec1c5856179f171e032a31d6f8bf84e5a75c45c33b2e20a3de353b266ebd8"}, - {file = "google_crc32c-1.5.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:edfedb64740750e1a3b16152620220f51d58ff1b4abceb339ca92e934775c27a"}, - {file = "google_crc32c-1.5.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:84e6e8cd997930fc66d5bb4fde61e2b62ba19d62b7abd7a69920406f9ecca946"}, - {file = "google_crc32c-1.5.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:024894d9d3cfbc5943f8f230e23950cd4906b2fe004c72e29b209420a1e6b05a"}, - {file = "google_crc32c-1.5.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:998679bf62b7fb599d2878aa3ed06b9ce688b8974893e7223c60db155f26bd8d"}, - {file = "google_crc32c-1.5.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:83c681c526a3439b5cf94f7420471705bbf96262f49a6fe546a6db5f687a3d4a"}, - {file = "google_crc32c-1.5.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:4c6fdd4fccbec90cc8a01fc00773fcd5fa28db683c116ee3cb35cd5da9ef6c37"}, - {file = "google_crc32c-1.5.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:5ae44e10a8e3407dbe138984f21e536583f2bba1be9491239f942c2464ac0894"}, - {file = "google_crc32c-1.5.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:37933ec6e693e51a5b07505bd05de57eee12f3e8c32b07da7e73669398e6630a"}, - {file = "google_crc32c-1.5.0-cp38-cp38-win32.whl", hash = "sha256:fe70e325aa68fa4b5edf7d1a4b6f691eb04bbccac0ace68e34820d283b5f80d4"}, - {file = "google_crc32c-1.5.0-cp38-cp38-win_amd64.whl", hash = "sha256:74dea7751d98034887dbd821b7aae3e1d36eda111d6ca36c206c44478035709c"}, - {file = "google_crc32c-1.5.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c6c777a480337ac14f38564ac88ae82d4cd238bf293f0a22295b66eb89ffced7"}, - {file = "google_crc32c-1.5.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:759ce4851a4bb15ecabae28f4d2e18983c244eddd767f560165563bf9aefbc8d"}, - {file = "google_crc32c-1.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f13cae8cc389a440def0c8c52057f37359014ccbc9dc1f0827936bcd367c6100"}, - {file = "google_crc32c-1.5.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e560628513ed34759456a416bf86b54b2476c59144a9138165c9a1575801d0d9"}, - {file = "google_crc32c-1.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1674e4307fa3024fc897ca774e9c7562c957af85df55efe2988ed9056dc4e57"}, - {file = "google_crc32c-1.5.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:278d2ed7c16cfc075c91378c4f47924c0625f5fc84b2d50d921b18b7975bd210"}, - {file = "google_crc32c-1.5.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d5280312b9af0976231f9e317c20e4a61cd2f9629b7bfea6a693d1878a264ebd"}, - {file = "google_crc32c-1.5.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:8b87e1a59c38f275c0e3676fc2ab6d59eccecfd460be267ac360cc31f7bcde96"}, - {file = "google_crc32c-1.5.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:7c074fece789b5034b9b1404a1f8208fc2d4c6ce9decdd16e8220c5a793e6f61"}, - {file = "google_crc32c-1.5.0-cp39-cp39-win32.whl", hash = "sha256:7f57f14606cd1dd0f0de396e1e53824c371e9544a822648cd76c034d209b559c"}, - {file = "google_crc32c-1.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:a2355cba1f4ad8b6988a4ca3feed5bff33f6af2d7f134852cf279c2aebfde541"}, - {file = "google_crc32c-1.5.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f314013e7dcd5cf45ab1945d92e713eec788166262ae8deb2cfacd53def27325"}, - {file = "google_crc32c-1.5.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b747a674c20a67343cb61d43fdd9207ce5da6a99f629c6e2541aa0e89215bcd"}, - {file = "google_crc32c-1.5.0-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8f24ed114432de109aa9fd317278518a5af2d31ac2ea6b952b2f7782b43da091"}, - {file = "google_crc32c-1.5.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8667b48e7a7ef66afba2c81e1094ef526388d35b873966d8a9a447974ed9178"}, - {file = "google_crc32c-1.5.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:1c7abdac90433b09bad6c43a43af253e688c9cfc1c86d332aed13f9a7c7f65e2"}, - {file = "google_crc32c-1.5.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:6f998db4e71b645350b9ac28a2167e6632c239963ca9da411523bb439c5c514d"}, - {file = "google_crc32c-1.5.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c99616c853bb585301df6de07ca2cadad344fd1ada6d62bb30aec05219c45d2"}, - {file = "google_crc32c-1.5.0-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2ad40e31093a4af319dadf503b2467ccdc8f67c72e4bcba97f8c10cb078207b5"}, - {file = "google_crc32c-1.5.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd67cf24a553339d5062eff51013780a00d6f97a39ca062781d06b3a73b15462"}, - {file = "google_crc32c-1.5.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:398af5e3ba9cf768787eef45c803ff9614cc3e22a5b2f7d7ae116df8b11e3314"}, - {file = "google_crc32c-1.5.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:b1f8133c9a275df5613a451e73f36c2aea4fe13c5c8997e22cf355ebd7bd0728"}, - {file = "google_crc32c-1.5.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9ba053c5f50430a3fcfd36f75aff9caeba0440b2d076afdb79a318d6ca245f88"}, - {file = "google_crc32c-1.5.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:272d3892a1e1a2dbc39cc5cde96834c236d5327e2122d3aaa19f6614531bb6eb"}, - {file = "google_crc32c-1.5.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:635f5d4dd18758a1fbd1049a8e8d2fee4ffed124462d837d1a02a0e009c3ab31"}, - {file = "google_crc32c-1.5.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:c672d99a345849301784604bfeaeba4db0c7aae50b95be04dd651fd2a7310b93"}, + {file = "google_crc32c-1.6.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:5bcc90b34df28a4b38653c36bb5ada35671ad105c99cfe915fb5bed7ad6924aa"}, + {file = "google_crc32c-1.6.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:d9e9913f7bd69e093b81da4535ce27af842e7bf371cde42d1ae9e9bd382dc0e9"}, + {file = "google_crc32c-1.6.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a184243544811e4a50d345838a883733461e67578959ac59964e43cca2c791e7"}, + {file = "google_crc32c-1.6.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:236c87a46cdf06384f614e9092b82c05f81bd34b80248021f729396a78e55d7e"}, + {file = "google_crc32c-1.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ebab974b1687509e5c973b5c4b8b146683e101e102e17a86bd196ecaa4d099fc"}, + {file = "google_crc32c-1.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:50cf2a96da226dcbff8671233ecf37bf6e95de98b2a2ebadbfdf455e6d05df42"}, + {file = "google_crc32c-1.6.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:f7a1fc29803712f80879b0806cb83ab24ce62fc8daf0569f2204a0cfd7f68ed4"}, + {file = "google_crc32c-1.6.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:40b05ab32a5067525670880eb5d169529089a26fe35dce8891127aeddc1950e8"}, + {file = "google_crc32c-1.6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a9e4b426c3702f3cd23b933436487eb34e01e00327fac20c9aebb68ccf34117d"}, + {file = "google_crc32c-1.6.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51c4f54dd8c6dfeb58d1df5e4f7f97df8abf17a36626a217f169893d1d7f3e9f"}, + {file = "google_crc32c-1.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:bb8b3c75bd157010459b15222c3fd30577042a7060e29d42dabce449c087f2b3"}, + {file = "google_crc32c-1.6.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:ed767bf4ba90104c1216b68111613f0d5926fb3780660ea1198fc469af410e9d"}, + {file = "google_crc32c-1.6.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:62f6d4a29fea082ac4a3c9be5e415218255cf11684ac6ef5488eea0c9132689b"}, + {file = "google_crc32c-1.6.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c87d98c7c4a69066fd31701c4e10d178a648c2cac3452e62c6b24dc51f9fcc00"}, + {file = "google_crc32c-1.6.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bd5e7d2445d1a958c266bfa5d04c39932dc54093fa391736dbfdb0f1929c1fb3"}, + {file = "google_crc32c-1.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:7aec8e88a3583515f9e0957fe4f5f6d8d4997e36d0f61624e70469771584c760"}, + {file = "google_crc32c-1.6.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:e2806553238cd076f0a55bddab37a532b53580e699ed8e5606d0de1f856b5205"}, + {file = "google_crc32c-1.6.0-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:bb0966e1c50d0ef5bc743312cc730b533491d60585a9a08f897274e57c3f70e0"}, + {file = "google_crc32c-1.6.0-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:386122eeaaa76951a8196310432c5b0ef3b53590ef4c317ec7588ec554fec5d2"}, + {file = "google_crc32c-1.6.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2952396dc604544ea7476b33fe87faedc24d666fb0c2d5ac971a2b9576ab871"}, + {file = "google_crc32c-1.6.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:35834855408429cecf495cac67ccbab802de269e948e27478b1e47dfb6465e57"}, + {file = "google_crc32c-1.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:d8797406499f28b5ef791f339594b0b5fdedf54e203b5066675c406ba69d705c"}, + {file = "google_crc32c-1.6.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48abd62ca76a2cbe034542ed1b6aee851b6f28aaca4e6551b5599b6f3ef175cc"}, + {file = "google_crc32c-1.6.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18e311c64008f1f1379158158bb3f0c8d72635b9eb4f9545f8cf990c5668e59d"}, + {file = "google_crc32c-1.6.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05e2d8c9a2f853ff116db9706b4a27350587f341eda835f46db3c0a8c8ce2f24"}, + {file = "google_crc32c-1.6.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91ca8145b060679ec9176e6de4f89b07363d6805bd4760631ef254905503598d"}, + {file = "google_crc32c-1.6.0.tar.gz", hash = "sha256:6eceb6ad197656a1ff49ebfbbfa870678c75be4344feb35ac1edf694309413dc"}, ] [package.extras] @@ -3300,6 +3357,21 @@ typing-extensions = "*" [package.extras] dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"] +[[package]] +name = "google-pasta" +version = "0.2.0" +description = "pasta is an AST-based Python refactoring library" +optional = false +python-versions = "*" +files = [ + {file = "google-pasta-0.2.0.tar.gz", hash = "sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e"}, + {file = "google_pasta-0.2.0-py2-none-any.whl", hash = "sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954"}, + {file = "google_pasta-0.2.0-py3-none-any.whl", hash = "sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed"}, +] + +[package.dependencies] +six = "*" + [[package]] name = "google-resumable-media" version = "2.7.2" @@ -3425,61 +3497,61 @@ protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4 [[package]] name = "grpcio" -version = "1.63.0" +version = "1.66.1" description = "HTTP/2-based RPC framework" optional = false python-versions = ">=3.8" files = [ - {file = "grpcio-1.63.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:2e93aca840c29d4ab5db93f94ed0a0ca899e241f2e8aec6334ab3575dc46125c"}, - {file = "grpcio-1.63.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:91b73d3f1340fefa1e1716c8c1ec9930c676d6b10a3513ab6c26004cb02d8b3f"}, - {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:b3afbd9d6827fa6f475a4f91db55e441113f6d3eb9b7ebb8fb806e5bb6d6bd0d"}, - {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8f3f6883ce54a7a5f47db43289a0a4c776487912de1a0e2cc83fdaec9685cc9f"}, - {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf8dae9cc0412cb86c8de5a8f3be395c5119a370f3ce2e69c8b7d46bb9872c8d"}, - {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:08e1559fd3b3b4468486b26b0af64a3904a8dbc78d8d936af9c1cf9636eb3e8b"}, - {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5c039ef01516039fa39da8a8a43a95b64e288f79f42a17e6c2904a02a319b357"}, - {file = "grpcio-1.63.0-cp310-cp310-win32.whl", hash = "sha256:ad2ac8903b2eae071055a927ef74121ed52d69468e91d9bcbd028bd0e554be6d"}, - {file = "grpcio-1.63.0-cp310-cp310-win_amd64.whl", hash = "sha256:b2e44f59316716532a993ca2966636df6fbe7be4ab6f099de6815570ebe4383a"}, - {file = "grpcio-1.63.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:f28f8b2db7b86c77916829d64ab21ff49a9d8289ea1564a2b2a3a8ed9ffcccd3"}, - {file = "grpcio-1.63.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:65bf975639a1f93bee63ca60d2e4951f1b543f498d581869922910a476ead2f5"}, - {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:b5194775fec7dc3dbd6a935102bb156cd2c35efe1685b0a46c67b927c74f0cfb"}, - {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4cbb2100ee46d024c45920d16e888ee5d3cf47c66e316210bc236d5bebc42b3"}, - {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ff737cf29b5b801619f10e59b581869e32f400159e8b12d7a97e7e3bdeee6a2"}, - {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cd1e68776262dd44dedd7381b1a0ad09d9930ffb405f737d64f505eb7f77d6c7"}, - {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:93f45f27f516548e23e4ec3fbab21b060416007dbe768a111fc4611464cc773f"}, - {file = "grpcio-1.63.0-cp311-cp311-win32.whl", hash = "sha256:878b1d88d0137df60e6b09b74cdb73db123f9579232c8456f53e9abc4f62eb3c"}, - {file = "grpcio-1.63.0-cp311-cp311-win_amd64.whl", hash = "sha256:756fed02dacd24e8f488f295a913f250b56b98fb793f41d5b2de6c44fb762434"}, - {file = "grpcio-1.63.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:93a46794cc96c3a674cdfb59ef9ce84d46185fe9421baf2268ccb556f8f81f57"}, - {file = "grpcio-1.63.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:a7b19dfc74d0be7032ca1eda0ed545e582ee46cd65c162f9e9fc6b26ef827dc6"}, - {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8064d986d3a64ba21e498b9a376cbc5d6ab2e8ab0e288d39f266f0fca169b90d"}, - {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:219bb1848cd2c90348c79ed0a6b0ea51866bc7e72fa6e205e459fedab5770172"}, - {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2d60cd1d58817bc5985fae6168d8b5655c4981d448d0f5b6194bbcc038090d2"}, - {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:9e350cb096e5c67832e9b6e018cf8a0d2a53b2a958f6251615173165269a91b0"}, - {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:56cdf96ff82e3cc90dbe8bac260352993f23e8e256e063c327b6cf9c88daf7a9"}, - {file = "grpcio-1.63.0-cp312-cp312-win32.whl", hash = "sha256:3a6d1f9ea965e750db7b4ee6f9fdef5fdf135abe8a249e75d84b0a3e0c668a1b"}, - {file = "grpcio-1.63.0-cp312-cp312-win_amd64.whl", hash = "sha256:d2497769895bb03efe3187fb1888fc20e98a5f18b3d14b606167dacda5789434"}, - {file = "grpcio-1.63.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:fdf348ae69c6ff484402cfdb14e18c1b0054ac2420079d575c53a60b9b2853ae"}, - {file = "grpcio-1.63.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a3abfe0b0f6798dedd2e9e92e881d9acd0fdb62ae27dcbbfa7654a57e24060c0"}, - {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:6ef0ad92873672a2a3767cb827b64741c363ebaa27e7f21659e4e31f4d750280"}, - {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b416252ac5588d9dfb8a30a191451adbf534e9ce5f56bb02cd193f12d8845b7f"}, - {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3b77eaefc74d7eb861d3ffbdf91b50a1bb1639514ebe764c47773b833fa2d91"}, - {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b005292369d9c1f80bf70c1db1c17c6c342da7576f1c689e8eee4fb0c256af85"}, - {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cdcda1156dcc41e042d1e899ba1f5c2e9f3cd7625b3d6ebfa619806a4c1aadda"}, - {file = "grpcio-1.63.0-cp38-cp38-win32.whl", hash = "sha256:01799e8649f9e94ba7db1aeb3452188048b0019dc37696b0f5ce212c87c560c3"}, - {file = "grpcio-1.63.0-cp38-cp38-win_amd64.whl", hash = "sha256:6a1a3642d76f887aa4009d92f71eb37809abceb3b7b5a1eec9c554a246f20e3a"}, - {file = "grpcio-1.63.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:75f701ff645858a2b16bc8c9fc68af215a8bb2d5a9b647448129de6e85d52bce"}, - {file = "grpcio-1.63.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cacdef0348a08e475a721967f48206a2254a1b26ee7637638d9e081761a5ba86"}, - {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:0697563d1d84d6985e40ec5ec596ff41b52abb3fd91ec240e8cb44a63b895094"}, - {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6426e1fb92d006e47476d42b8f240c1d916a6d4423c5258ccc5b105e43438f61"}, - {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e48cee31bc5f5a31fb2f3b573764bd563aaa5472342860edcc7039525b53e46a"}, - {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:50344663068041b34a992c19c600236e7abb42d6ec32567916b87b4c8b8833b3"}, - {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:259e11932230d70ef24a21b9fb5bb947eb4703f57865a404054400ee92f42f5d"}, - {file = "grpcio-1.63.0-cp39-cp39-win32.whl", hash = "sha256:a44624aad77bf8ca198c55af811fd28f2b3eaf0a50ec5b57b06c034416ef2d0a"}, - {file = "grpcio-1.63.0-cp39-cp39-win_amd64.whl", hash = "sha256:166e5c460e5d7d4656ff9e63b13e1f6029b122104c1633d5f37eaea348d7356d"}, - {file = "grpcio-1.63.0.tar.gz", hash = "sha256:f3023e14805c61bc439fb40ca545ac3d5740ce66120a678a3c6c2c55b70343d1"}, + {file = "grpcio-1.66.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:4877ba180591acdf127afe21ec1c7ff8a5ecf0fe2600f0d3c50e8c4a1cbc6492"}, + {file = "grpcio-1.66.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:3750c5a00bd644c75f4507f77a804d0189d97a107eb1481945a0cf3af3e7a5ac"}, + {file = "grpcio-1.66.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:a013c5fbb12bfb5f927444b477a26f1080755a931d5d362e6a9a720ca7dbae60"}, + {file = "grpcio-1.66.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b1b24c23d51a1e8790b25514157d43f0a4dce1ac12b3f0b8e9f66a5e2c4c132f"}, + {file = "grpcio-1.66.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b7ffb8ea674d68de4cac6f57d2498fef477cef582f1fa849e9f844863af50083"}, + {file = "grpcio-1.66.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:307b1d538140f19ccbd3aed7a93d8f71103c5d525f3c96f8616111614b14bf2a"}, + {file = "grpcio-1.66.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:1c17ebcec157cfb8dd445890a03e20caf6209a5bd4ac5b040ae9dbc59eef091d"}, + {file = "grpcio-1.66.1-cp310-cp310-win32.whl", hash = "sha256:ef82d361ed5849d34cf09105d00b94b6728d289d6b9235513cb2fcc79f7c432c"}, + {file = "grpcio-1.66.1-cp310-cp310-win_amd64.whl", hash = "sha256:292a846b92cdcd40ecca46e694997dd6b9be6c4c01a94a0dfb3fcb75d20da858"}, + {file = "grpcio-1.66.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:c30aeceeaff11cd5ddbc348f37c58bcb96da8d5aa93fed78ab329de5f37a0d7a"}, + {file = "grpcio-1.66.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8a1e224ce6f740dbb6b24c58f885422deebd7eb724aff0671a847f8951857c26"}, + {file = "grpcio-1.66.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:a66fe4dc35d2330c185cfbb42959f57ad36f257e0cc4557d11d9f0a3f14311df"}, + {file = "grpcio-1.66.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e3ba04659e4fce609de2658fe4dbf7d6ed21987a94460f5f92df7579fd5d0e22"}, + {file = "grpcio-1.66.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4573608e23f7e091acfbe3e84ac2045680b69751d8d67685ffa193a4429fedb1"}, + {file = "grpcio-1.66.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7e06aa1f764ec8265b19d8f00140b8c4b6ca179a6dc67aa9413867c47e1fb04e"}, + {file = "grpcio-1.66.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3885f037eb11f1cacc41f207b705f38a44b69478086f40608959bf5ad85826dd"}, + {file = "grpcio-1.66.1-cp311-cp311-win32.whl", hash = "sha256:97ae7edd3f3f91480e48ede5d3e7d431ad6005bfdbd65c1b56913799ec79e791"}, + {file = "grpcio-1.66.1-cp311-cp311-win_amd64.whl", hash = "sha256:cfd349de4158d797db2bd82d2020554a121674e98fbe6b15328456b3bf2495bb"}, + {file = "grpcio-1.66.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:a92c4f58c01c77205df6ff999faa008540475c39b835277fb8883b11cada127a"}, + {file = "grpcio-1.66.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:fdb14bad0835914f325349ed34a51940bc2ad965142eb3090081593c6e347be9"}, + {file = "grpcio-1.66.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:f03a5884c56256e08fd9e262e11b5cfacf1af96e2ce78dc095d2c41ccae2c80d"}, + {file = "grpcio-1.66.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2ca2559692d8e7e245d456877a85ee41525f3ed425aa97eb7a70fc9a79df91a0"}, + {file = "grpcio-1.66.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:84ca1be089fb4446490dd1135828bd42a7c7f8421e74fa581611f7afdf7ab761"}, + {file = "grpcio-1.66.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:d639c939ad7c440c7b2819a28d559179a4508783f7e5b991166f8d7a34b52815"}, + {file = "grpcio-1.66.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:b9feb4e5ec8dc2d15709f4d5fc367794d69277f5d680baf1910fc9915c633524"}, + {file = "grpcio-1.66.1-cp312-cp312-win32.whl", hash = "sha256:7101db1bd4cd9b880294dec41a93fcdce465bdbb602cd8dc5bd2d6362b618759"}, + {file = "grpcio-1.66.1-cp312-cp312-win_amd64.whl", hash = "sha256:b0aa03d240b5539648d996cc60438f128c7f46050989e35b25f5c18286c86734"}, + {file = "grpcio-1.66.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:ecfe735e7a59e5a98208447293ff8580e9db1e890e232b8b292dc8bd15afc0d2"}, + {file = "grpcio-1.66.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:4825a3aa5648010842e1c9d35a082187746aa0cdbf1b7a2a930595a94fb10fce"}, + {file = "grpcio-1.66.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:f517fd7259fe823ef3bd21e508b653d5492e706e9f0ef82c16ce3347a8a5620c"}, + {file = "grpcio-1.66.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f1fe60d0772831d96d263b53d83fb9a3d050a94b0e94b6d004a5ad111faa5b5b"}, + {file = "grpcio-1.66.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31a049daa428f928f21090403e5d18ea02670e3d5d172581670be006100db9ef"}, + {file = "grpcio-1.66.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:6f914386e52cbdeb5d2a7ce3bf1fdfacbe9d818dd81b6099a05b741aaf3848bb"}, + {file = "grpcio-1.66.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bff2096bdba686019fb32d2dde45b95981f0d1490e054400f70fc9a8af34b49d"}, + {file = "grpcio-1.66.1-cp38-cp38-win32.whl", hash = "sha256:aa8ba945c96e73de29d25331b26f3e416e0c0f621e984a3ebdb2d0d0b596a3b3"}, + {file = "grpcio-1.66.1-cp38-cp38-win_amd64.whl", hash = "sha256:161d5c535c2bdf61b95080e7f0f017a1dfcb812bf54093e71e5562b16225b4ce"}, + {file = "grpcio-1.66.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:d0cd7050397b3609ea51727b1811e663ffda8bda39c6a5bb69525ef12414b503"}, + {file = "grpcio-1.66.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0e6c9b42ded5d02b6b1fea3a25f036a2236eeb75d0579bfd43c0018c88bf0a3e"}, + {file = "grpcio-1.66.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:c9f80f9fad93a8cf71c7f161778ba47fd730d13a343a46258065c4deb4b550c0"}, + {file = "grpcio-1.66.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5dd67ed9da78e5121efc5c510f0122a972216808d6de70953a740560c572eb44"}, + {file = "grpcio-1.66.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:48b0d92d45ce3be2084b92fb5bae2f64c208fea8ceed7fccf6a7b524d3c4942e"}, + {file = "grpcio-1.66.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:4d813316d1a752be6f5c4360c49f55b06d4fe212d7df03253dfdae90c8a402bb"}, + {file = "grpcio-1.66.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9c9bebc6627873ec27a70fc800f6083a13c70b23a5564788754b9ee52c5aef6c"}, + {file = "grpcio-1.66.1-cp39-cp39-win32.whl", hash = "sha256:30a1c2cf9390c894c90bbc70147f2372130ad189cffef161f0432d0157973f45"}, + {file = "grpcio-1.66.1-cp39-cp39-win_amd64.whl", hash = "sha256:17663598aadbedc3cacd7bbde432f541c8e07d2496564e22b214b22c7523dac8"}, + {file = "grpcio-1.66.1.tar.gz", hash = "sha256:35334f9c9745add3e357e3372756fd32d925bd52c41da97f4dfdafbde0bf0ee2"}, ] [package.extras] -protobuf = ["grpcio-tools (>=1.63.0)"] +protobuf = ["grpcio-tools (>=1.66.1)"] [[package]] name = "grpcio-status" @@ -3826,13 +3898,13 @@ test = ["Cython (>=0.29.24,<0.30.0)"] [[package]] name = "httpx" -version = "0.27.0" +version = "0.27.2" description = "The next generation HTTP client." optional = false python-versions = ">=3.8" files = [ - {file = "httpx-0.27.0-py3-none-any.whl", hash = "sha256:71d5465162c13681bff01ad59b2cc68dd838ea1f10e51574bac27103f00c91a5"}, - {file = "httpx-0.27.0.tar.gz", hash = "sha256:a0cb88a46f32dc874e04ee956e4c2764aba2aa228f650b06788ba6bda2962ab5"}, + {file = "httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0"}, + {file = "httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"}, ] [package.dependencies] @@ -3849,6 +3921,7 @@ brotli = ["brotli", "brotlicffi"] cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] +zstd = ["zstandard (>=0.18.0)"] [[package]] name = "huggingface-hub" @@ -3909,33 +3982,33 @@ files = [ [[package]] name = "idna" -version = "3.7" +version = "3.8" description = "Internationalized Domain Names in Applications (IDNA)" optional = false -python-versions = ">=3.5" +python-versions = ">=3.6" files = [ - {file = "idna-3.7-py3-none-any.whl", hash = "sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0"}, - {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"}, + {file = "idna-3.8-py3-none-any.whl", hash = "sha256:050b4e5baadcd44d760cedbd2b8e639f2ff89bbc7a5730fcc662954303377aac"}, + {file = "idna-3.8.tar.gz", hash = "sha256:d838c2c0ed6fced7693d5e8ab8e734d5f8fda53a039c0164afb0b82e771e3603"}, ] [[package]] name = "importlib-metadata" -version = "8.0.0" +version = "6.11.0" description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_metadata-8.0.0-py3-none-any.whl", hash = "sha256:15584cf2b1bf449d98ff8a6ff1abef57bf20f3ac6454f431736cd3e660921b2f"}, - {file = "importlib_metadata-8.0.0.tar.gz", hash = "sha256:188bd24e4c346d3f0a933f275c2fec67050326a856b9a359881d7c2a697e8812"}, + {file = "importlib_metadata-6.11.0-py3-none-any.whl", hash = "sha256:f0afba6205ad8f8947c7d338b5342d5db2afbfd82f9cbef7879a9539cc12eb9b"}, + {file = "importlib_metadata-6.11.0.tar.gz", hash = "sha256:1231cf92d825c9e03cfc4da076a16de6422c863558229ea0b22b675657463443"}, ] [package.dependencies] zipp = ">=0.5" [package.extras] -doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] perf = ["ipython"] -test = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] [[package]] name = "importlib-resources" @@ -4117,115 +4190,125 @@ files = [ [[package]] name = "kiwisolver" -version = "1.4.5" +version = "1.4.7" description = "A fast implementation of the Cassowary constraint solver" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "kiwisolver-1.4.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:05703cf211d585109fcd72207a31bb170a0f22144d68298dc5e61b3c946518af"}, - {file = "kiwisolver-1.4.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:146d14bebb7f1dc4d5fbf74f8a6cb15ac42baadee8912eb84ac0b3b2a3dc6ac3"}, - {file = "kiwisolver-1.4.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6ef7afcd2d281494c0a9101d5c571970708ad911d028137cd558f02b851c08b4"}, - {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:9eaa8b117dc8337728e834b9c6e2611f10c79e38f65157c4c38e9400286f5cb1"}, - {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ec20916e7b4cbfb1f12380e46486ec4bcbaa91a9c448b97023fde0d5bbf9e4ff"}, - {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39b42c68602539407884cf70d6a480a469b93b81b7701378ba5e2328660c847a"}, - {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa12042de0171fad672b6c59df69106d20d5596e4f87b5e8f76df757a7c399aa"}, - {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a40773c71d7ccdd3798f6489aaac9eee213d566850a9533f8d26332d626b82c"}, - {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:19df6e621f6d8b4b9c4d45f40a66839294ff2bb235e64d2178f7522d9170ac5b"}, - {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:83d78376d0d4fd884e2c114d0621624b73d2aba4e2788182d286309ebdeed770"}, - {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e391b1f0a8a5a10ab3b9bb6afcfd74f2175f24f8975fb87ecae700d1503cdee0"}, - {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:852542f9481f4a62dbb5dd99e8ab7aedfeb8fb6342349a181d4036877410f525"}, - {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:59edc41b24031bc25108e210c0def6f6c2191210492a972d585a06ff246bb79b"}, - {file = "kiwisolver-1.4.5-cp310-cp310-win32.whl", hash = "sha256:a6aa6315319a052b4ee378aa171959c898a6183f15c1e541821c5c59beaa0238"}, - {file = "kiwisolver-1.4.5-cp310-cp310-win_amd64.whl", hash = "sha256:d0ef46024e6a3d79c01ff13801cb19d0cad7fd859b15037aec74315540acc276"}, - {file = "kiwisolver-1.4.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:11863aa14a51fd6ec28688d76f1735f8f69ab1fabf388851a595d0721af042f5"}, - {file = "kiwisolver-1.4.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8ab3919a9997ab7ef2fbbed0cc99bb28d3c13e6d4b1ad36e97e482558a91be90"}, - {file = "kiwisolver-1.4.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fcc700eadbbccbf6bc1bcb9dbe0786b4b1cb91ca0dcda336eef5c2beed37b797"}, - {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dfdd7c0b105af050eb3d64997809dc21da247cf44e63dc73ff0fd20b96be55a9"}, - {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76c6a5964640638cdeaa0c359382e5703e9293030fe730018ca06bc2010c4437"}, - {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bbea0db94288e29afcc4c28afbf3a7ccaf2d7e027489c449cf7e8f83c6346eb9"}, - {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ceec1a6bc6cab1d6ff5d06592a91a692f90ec7505d6463a88a52cc0eb58545da"}, - {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:040c1aebeda72197ef477a906782b5ab0d387642e93bda547336b8957c61022e"}, - {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f91de7223d4c7b793867797bacd1ee53bfe7359bd70d27b7b58a04efbb9436c8"}, - {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:faae4860798c31530dd184046a900e652c95513796ef51a12bc086710c2eec4d"}, - {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:b0157420efcb803e71d1b28e2c287518b8808b7cf1ab8af36718fd0a2c453eb0"}, - {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:06f54715b7737c2fecdbf140d1afb11a33d59508a47bf11bb38ecf21dc9ab79f"}, - {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fdb7adb641a0d13bdcd4ef48e062363d8a9ad4a182ac7647ec88f695e719ae9f"}, - {file = "kiwisolver-1.4.5-cp311-cp311-win32.whl", hash = "sha256:bb86433b1cfe686da83ce32a9d3a8dd308e85c76b60896d58f082136f10bffac"}, - {file = "kiwisolver-1.4.5-cp311-cp311-win_amd64.whl", hash = "sha256:6c08e1312a9cf1074d17b17728d3dfce2a5125b2d791527f33ffbe805200a355"}, - {file = "kiwisolver-1.4.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:32d5cf40c4f7c7b3ca500f8985eb3fb3a7dfc023215e876f207956b5ea26632a"}, - {file = "kiwisolver-1.4.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f846c260f483d1fd217fe5ed7c173fb109efa6b1fc8381c8b7552c5781756192"}, - {file = "kiwisolver-1.4.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5ff5cf3571589b6d13bfbfd6bcd7a3f659e42f96b5fd1c4830c4cf21d4f5ef45"}, - {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7269d9e5f1084a653d575c7ec012ff57f0c042258bf5db0954bf551c158466e7"}, - {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da802a19d6e15dffe4b0c24b38b3af68e6c1a68e6e1d8f30148c83864f3881db"}, - {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3aba7311af82e335dd1e36ffff68aaca609ca6290c2cb6d821a39aa075d8e3ff"}, - {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:763773d53f07244148ccac5b084da5adb90bfaee39c197554f01b286cf869228"}, - {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2270953c0d8cdab5d422bee7d2007f043473f9d2999631c86a223c9db56cbd16"}, - {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d099e745a512f7e3bbe7249ca835f4d357c586d78d79ae8f1dcd4d8adeb9bda9"}, - {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:74db36e14a7d1ce0986fa104f7d5637aea5c82ca6326ed0ec5694280942d1162"}, - {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e5bab140c309cb3a6ce373a9e71eb7e4873c70c2dda01df6820474f9889d6d4"}, - {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0f114aa76dc1b8f636d077979c0ac22e7cd8f3493abbab152f20eb8d3cda71f3"}, - {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:88a2df29d4724b9237fc0c6eaf2a1adae0cdc0b3e9f4d8e7dc54b16812d2d81a"}, - {file = "kiwisolver-1.4.5-cp312-cp312-win32.whl", hash = "sha256:72d40b33e834371fd330fb1472ca19d9b8327acb79a5821d4008391db8e29f20"}, - {file = "kiwisolver-1.4.5-cp312-cp312-win_amd64.whl", hash = "sha256:2c5674c4e74d939b9d91dda0fae10597ac7521768fec9e399c70a1f27e2ea2d9"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:3a2b053a0ab7a3960c98725cfb0bf5b48ba82f64ec95fe06f1d06c99b552e130"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3cd32d6c13807e5c66a7cbb79f90b553642f296ae4518a60d8d76243b0ad2898"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:59ec7b7c7e1a61061850d53aaf8e93db63dce0c936db1fda2658b70e4a1be709"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:da4cfb373035def307905d05041c1d06d8936452fe89d464743ae7fb8371078b"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2400873bccc260b6ae184b2b8a4fec0e4082d30648eadb7c3d9a13405d861e89"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:1b04139c4236a0f3aff534479b58f6f849a8b351e1314826c2d230849ed48985"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:4e66e81a5779b65ac21764c295087de82235597a2293d18d943f8e9e32746265"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:7931d8f1f67c4be9ba1dd9c451fb0eeca1a25b89e4d3f89e828fe12a519b782a"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:b3f7e75f3015df442238cca659f8baa5f42ce2a8582727981cbfa15fee0ee205"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:bbf1d63eef84b2e8c89011b7f2235b1e0bf7dacc11cac9431fc6468e99ac77fb"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:4c380469bd3f970ef677bf2bcba2b6b0b4d5c75e7a020fb863ef75084efad66f"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-win32.whl", hash = "sha256:9408acf3270c4b6baad483865191e3e582b638b1654a007c62e3efe96f09a9a3"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-win_amd64.whl", hash = "sha256:5b94529f9b2591b7af5f3e0e730a4e0a41ea174af35a4fd067775f9bdfeee01a"}, - {file = "kiwisolver-1.4.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:11c7de8f692fc99816e8ac50d1d1aef4f75126eefc33ac79aac02c099fd3db71"}, - {file = "kiwisolver-1.4.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:53abb58632235cd154176ced1ae8f0d29a6657aa1aa9decf50b899b755bc2b93"}, - {file = "kiwisolver-1.4.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:88b9f257ca61b838b6f8094a62418421f87ac2a1069f7e896c36a7d86b5d4c29"}, - {file = "kiwisolver-1.4.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3195782b26fc03aa9c6913d5bad5aeb864bdc372924c093b0f1cebad603dd712"}, - {file = "kiwisolver-1.4.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fc579bf0f502e54926519451b920e875f433aceb4624a3646b3252b5caa9e0b6"}, - {file = "kiwisolver-1.4.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5a580c91d686376f0f7c295357595c5a026e6cbc3d77b7c36e290201e7c11ecb"}, - {file = "kiwisolver-1.4.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:cfe6ab8da05c01ba6fbea630377b5da2cd9bcbc6338510116b01c1bc939a2c18"}, - {file = "kiwisolver-1.4.5-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:d2e5a98f0ec99beb3c10e13b387f8db39106d53993f498b295f0c914328b1333"}, - {file = "kiwisolver-1.4.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:a51a263952b1429e429ff236d2f5a21c5125437861baeed77f5e1cc2d2c7c6da"}, - {file = "kiwisolver-1.4.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3edd2fa14e68c9be82c5b16689e8d63d89fe927e56debd6e1dbce7a26a17f81b"}, - {file = "kiwisolver-1.4.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:74d1b44c6cfc897df648cc9fdaa09bc3e7679926e6f96df05775d4fb3946571c"}, - {file = "kiwisolver-1.4.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:76d9289ed3f7501012e05abb8358bbb129149dbd173f1f57a1bf1c22d19ab7cc"}, - {file = "kiwisolver-1.4.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:92dea1ffe3714fa8eb6a314d2b3c773208d865a0e0d35e713ec54eea08a66250"}, - {file = "kiwisolver-1.4.5-cp38-cp38-win32.whl", hash = "sha256:5c90ae8c8d32e472be041e76f9d2f2dbff4d0b0be8bd4041770eddb18cf49a4e"}, - {file = "kiwisolver-1.4.5-cp38-cp38-win_amd64.whl", hash = "sha256:c7940c1dc63eb37a67721b10d703247552416f719c4188c54e04334321351ced"}, - {file = "kiwisolver-1.4.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:9407b6a5f0d675e8a827ad8742e1d6b49d9c1a1da5d952a67d50ef5f4170b18d"}, - {file = "kiwisolver-1.4.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:15568384086b6df3c65353820a4473575dbad192e35010f622c6ce3eebd57af9"}, - {file = "kiwisolver-1.4.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0dc9db8e79f0036e8173c466d21ef18e1befc02de8bf8aa8dc0813a6dc8a7046"}, - {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:cdc8a402aaee9a798b50d8b827d7ecf75edc5fb35ea0f91f213ff927c15f4ff0"}, - {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:6c3bd3cde54cafb87d74d8db50b909705c62b17c2099b8f2e25b461882e544ff"}, - {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:955e8513d07a283056b1396e9a57ceddbd272d9252c14f154d450d227606eb54"}, - {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:346f5343b9e3f00b8db8ba359350eb124b98c99efd0b408728ac6ebf38173958"}, - {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b9098e0049e88c6a24ff64545cdfc50807818ba6c1b739cae221bbbcbc58aad3"}, - {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf"}, - {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7b8b454bac16428b22560d0a1cf0a09875339cab69df61d7805bf48919415901"}, - {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:f1d072c2eb0ad60d4c183f3fb44ac6f73fb7a8f16a2694a91f988275cbf352f9"}, - {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:31a82d498054cac9f6d0b53d02bb85811185bcb477d4b60144f915f3b3126342"}, - {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6512cb89e334e4700febbffaaa52761b65b4f5a3cf33f960213d5656cea36a77"}, - {file = "kiwisolver-1.4.5-cp39-cp39-win32.whl", hash = "sha256:9db8ea4c388fdb0f780fe91346fd438657ea602d58348753d9fb265ce1bca67f"}, - {file = "kiwisolver-1.4.5-cp39-cp39-win_amd64.whl", hash = "sha256:59415f46a37f7f2efeec758353dd2eae1b07640d8ca0f0c42548ec4125492635"}, - {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:5c7b3b3a728dc6faf3fc372ef24f21d1e3cee2ac3e9596691d746e5a536de920"}, - {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:620ced262a86244e2be10a676b646f29c34537d0d9cc8eb26c08f53d98013390"}, - {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:378a214a1e3bbf5ac4a8708304318b4f890da88c9e6a07699c4ae7174c09a68d"}, - {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aaf7be1207676ac608a50cd08f102f6742dbfc70e8d60c4db1c6897f62f71523"}, - {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:ba55dce0a9b8ff59495ddd050a0225d58bd0983d09f87cfe2b6aec4f2c1234e4"}, - {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:fd32ea360bcbb92d28933fc05ed09bffcb1704ba3fc7942e81db0fd4f81a7892"}, - {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5e7139af55d1688f8b960ee9ad5adafc4ac17c1c473fe07133ac092310d76544"}, - {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:dced8146011d2bc2e883f9bd68618b8247387f4bbec46d7392b3c3b032640126"}, - {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c9bf3325c47b11b2e51bca0824ea217c7cd84491d8ac4eefd1e409705ef092bd"}, - {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:5794cf59533bc3f1b1c821f7206a3617999db9fbefc345360aafe2e067514929"}, - {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:e368f200bbc2e4f905b8e71eb38b3c04333bddaa6a2464a6355487b02bb7fb09"}, - {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5d706eba36b4c4d5bc6c6377bb6568098765e990cfc21ee16d13963fab7b3e7"}, - {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85267bd1aa8880a9c88a8cb71e18d3d64d2751a790e6ca6c27b8ccc724bcd5ad"}, - {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:210ef2c3a1f03272649aff1ef992df2e724748918c4bc2d5a90352849eb40bea"}, - {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:11d011a7574eb3b82bcc9c1a1d35c1d7075677fdd15de527d91b46bd35e935ee"}, - {file = "kiwisolver-1.4.5.tar.gz", hash = "sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec"}, + {file = "kiwisolver-1.4.7-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8a9c83f75223d5e48b0bc9cb1bf2776cf01563e00ade8775ffe13b0b6e1af3a6"}, + {file = "kiwisolver-1.4.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:58370b1ffbd35407444d57057b57da5d6549d2d854fa30249771775c63b5fe17"}, + {file = "kiwisolver-1.4.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:aa0abdf853e09aff551db11fce173e2177d00786c688203f52c87ad7fcd91ef9"}, + {file = "kiwisolver-1.4.7-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:8d53103597a252fb3ab8b5845af04c7a26d5e7ea8122303dd7a021176a87e8b9"}, + {file = "kiwisolver-1.4.7-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:88f17c5ffa8e9462fb79f62746428dd57b46eb931698e42e990ad63103f35e6c"}, + {file = "kiwisolver-1.4.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88a9ca9c710d598fd75ee5de59d5bda2684d9db36a9f50b6125eaea3969c2599"}, + {file = "kiwisolver-1.4.7-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f4d742cb7af1c28303a51b7a27aaee540e71bb8e24f68c736f6f2ffc82f2bf05"}, + {file = "kiwisolver-1.4.7-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e28c7fea2196bf4c2f8d46a0415c77a1c480cc0724722f23d7410ffe9842c407"}, + {file = "kiwisolver-1.4.7-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e968b84db54f9d42046cf154e02911e39c0435c9801681e3fc9ce8a3c4130278"}, + {file = "kiwisolver-1.4.7-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:0c18ec74c0472de033e1bebb2911c3c310eef5649133dd0bedf2a169a1b269e5"}, + {file = "kiwisolver-1.4.7-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:8f0ea6da6d393d8b2e187e6a5e3fb81f5862010a40c3945e2c6d12ae45cfb2ad"}, + {file = "kiwisolver-1.4.7-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:f106407dda69ae456dd1227966bf445b157ccc80ba0dff3802bb63f30b74e895"}, + {file = "kiwisolver-1.4.7-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:84ec80df401cfee1457063732d90022f93951944b5b58975d34ab56bb150dfb3"}, + {file = "kiwisolver-1.4.7-cp310-cp310-win32.whl", hash = "sha256:71bb308552200fb2c195e35ef05de12f0c878c07fc91c270eb3d6e41698c3bcc"}, + {file = "kiwisolver-1.4.7-cp310-cp310-win_amd64.whl", hash = "sha256:44756f9fd339de0fb6ee4f8c1696cfd19b2422e0d70b4cefc1cc7f1f64045a8c"}, + {file = "kiwisolver-1.4.7-cp310-cp310-win_arm64.whl", hash = "sha256:78a42513018c41c2ffd262eb676442315cbfe3c44eed82385c2ed043bc63210a"}, + {file = "kiwisolver-1.4.7-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d2b0e12a42fb4e72d509fc994713d099cbb15ebf1103545e8a45f14da2dfca54"}, + {file = "kiwisolver-1.4.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2a8781ac3edc42ea4b90bc23e7d37b665d89423818e26eb6df90698aa2287c95"}, + {file = "kiwisolver-1.4.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:46707a10836894b559e04b0fd143e343945c97fd170d69a2d26d640b4e297935"}, + {file = "kiwisolver-1.4.7-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ef97b8df011141c9b0f6caf23b29379f87dd13183c978a30a3c546d2c47314cb"}, + {file = "kiwisolver-1.4.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3ab58c12a2cd0fc769089e6d38466c46d7f76aced0a1f54c77652446733d2d02"}, + {file = "kiwisolver-1.4.7-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:803b8e1459341c1bb56d1c5c010406d5edec8a0713a0945851290a7930679b51"}, + {file = "kiwisolver-1.4.7-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f9a9e8a507420fe35992ee9ecb302dab68550dedc0da9e2880dd88071c5fb052"}, + {file = "kiwisolver-1.4.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18077b53dc3bb490e330669a99920c5e6a496889ae8c63b58fbc57c3d7f33a18"}, + {file = "kiwisolver-1.4.7-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6af936f79086a89b3680a280c47ea90b4df7047b5bdf3aa5c524bbedddb9e545"}, + {file = "kiwisolver-1.4.7-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:3abc5b19d24af4b77d1598a585b8a719beb8569a71568b66f4ebe1fb0449460b"}, + {file = "kiwisolver-1.4.7-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:933d4de052939d90afbe6e9d5273ae05fb836cc86c15b686edd4b3560cc0ee36"}, + {file = "kiwisolver-1.4.7-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:65e720d2ab2b53f1f72fb5da5fb477455905ce2c88aaa671ff0a447c2c80e8e3"}, + {file = "kiwisolver-1.4.7-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:3bf1ed55088f214ba6427484c59553123fdd9b218a42bbc8c6496d6754b1e523"}, + {file = "kiwisolver-1.4.7-cp311-cp311-win32.whl", hash = "sha256:4c00336b9dd5ad96d0a558fd18a8b6f711b7449acce4c157e7343ba92dd0cf3d"}, + {file = "kiwisolver-1.4.7-cp311-cp311-win_amd64.whl", hash = "sha256:929e294c1ac1e9f615c62a4e4313ca1823ba37326c164ec720a803287c4c499b"}, + {file = "kiwisolver-1.4.7-cp311-cp311-win_arm64.whl", hash = "sha256:e33e8fbd440c917106b237ef1a2f1449dfbb9b6f6e1ce17c94cd6a1e0d438376"}, + {file = "kiwisolver-1.4.7-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:5360cc32706dab3931f738d3079652d20982511f7c0ac5711483e6eab08efff2"}, + {file = "kiwisolver-1.4.7-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:942216596dc64ddb25adb215c3c783215b23626f8d84e8eff8d6d45c3f29f75a"}, + {file = "kiwisolver-1.4.7-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:48b571ecd8bae15702e4f22d3ff6a0f13e54d3d00cd25216d5e7f658242065ee"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ad42ba922c67c5f219097b28fae965e10045ddf145d2928bfac2eb2e17673640"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:612a10bdae23404a72941a0fc8fa2660c6ea1217c4ce0dbcab8a8f6543ea9e7f"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9e838bba3a3bac0fe06d849d29772eb1afb9745a59710762e4ba3f4cb8424483"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:22f499f6157236c19f4bbbd472fa55b063db77a16cd74d49afe28992dff8c258"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:693902d433cf585133699972b6d7c42a8b9f8f826ebcaf0132ff55200afc599e"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4e77f2126c3e0b0d055f44513ed349038ac180371ed9b52fe96a32aa071a5107"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:657a05857bda581c3656bfc3b20e353c232e9193eb167766ad2dc58b56504948"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:4bfa75a048c056a411f9705856abfc872558e33c055d80af6a380e3658766038"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:34ea1de54beef1c104422d210c47c7d2a4999bdecf42c7b5718fbe59a4cac383"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:90da3b5f694b85231cf93586dad5e90e2d71b9428f9aad96952c99055582f520"}, + {file = "kiwisolver-1.4.7-cp312-cp312-win32.whl", hash = "sha256:18e0cca3e008e17fe9b164b55735a325140a5a35faad8de92dd80265cd5eb80b"}, + {file = "kiwisolver-1.4.7-cp312-cp312-win_amd64.whl", hash = "sha256:58cb20602b18f86f83a5c87d3ee1c766a79c0d452f8def86d925e6c60fbf7bfb"}, + {file = "kiwisolver-1.4.7-cp312-cp312-win_arm64.whl", hash = "sha256:f5a8b53bdc0b3961f8b6125e198617c40aeed638b387913bf1ce78afb1b0be2a"}, + {file = "kiwisolver-1.4.7-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:2e6039dcbe79a8e0f044f1c39db1986a1b8071051efba3ee4d74f5b365f5226e"}, + {file = "kiwisolver-1.4.7-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a1ecf0ac1c518487d9d23b1cd7139a6a65bc460cd101ab01f1be82ecf09794b6"}, + {file = "kiwisolver-1.4.7-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7ab9ccab2b5bd5702ab0803676a580fffa2aa178c2badc5557a84cc943fcf750"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f816dd2277f8d63d79f9c8473a79fe54047bc0467754962840782c575522224d"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf8bcc23ceb5a1b624572a1623b9f79d2c3b337c8c455405ef231933a10da379"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dea0bf229319828467d7fca8c7c189780aa9ff679c94539eed7532ebe33ed37c"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c06a4c7cf15ec739ce0e5971b26c93638730090add60e183530d70848ebdd34"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:913983ad2deb14e66d83c28b632fd35ba2b825031f2fa4ca29675e665dfecbe1"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:5337ec7809bcd0f424c6b705ecf97941c46279cf5ed92311782c7c9c2026f07f"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:4c26ed10c4f6fa6ddb329a5120ba3b6db349ca192ae211e882970bfc9d91420b"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:c619b101e6de2222c1fcb0531e1b17bbffbe54294bfba43ea0d411d428618c27"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:073a36c8273647592ea332e816e75ef8da5c303236ec0167196793eb1e34657a"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:3ce6b2b0231bda412463e152fc18335ba32faf4e8c23a754ad50ffa70e4091ee"}, + {file = "kiwisolver-1.4.7-cp313-cp313-win32.whl", hash = "sha256:f4c9aee212bc89d4e13f58be11a56cc8036cabad119259d12ace14b34476fd07"}, + {file = "kiwisolver-1.4.7-cp313-cp313-win_amd64.whl", hash = "sha256:8a3ec5aa8e38fc4c8af308917ce12c536f1c88452ce554027e55b22cbbfbff76"}, + {file = "kiwisolver-1.4.7-cp313-cp313-win_arm64.whl", hash = "sha256:76c8094ac20ec259471ac53e774623eb62e6e1f56cd8690c67ce6ce4fcb05650"}, + {file = "kiwisolver-1.4.7-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5d5abf8f8ec1f4e22882273c423e16cae834c36856cac348cfbfa68e01c40f3a"}, + {file = "kiwisolver-1.4.7-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:aeb3531b196ef6f11776c21674dba836aeea9d5bd1cf630f869e3d90b16cfade"}, + {file = "kiwisolver-1.4.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b7d755065e4e866a8086c9bdada157133ff466476a2ad7861828e17b6026e22c"}, + {file = "kiwisolver-1.4.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08471d4d86cbaec61f86b217dd938a83d85e03785f51121e791a6e6689a3be95"}, + {file = "kiwisolver-1.4.7-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7bbfcb7165ce3d54a3dfbe731e470f65739c4c1f85bb1018ee912bae139e263b"}, + {file = "kiwisolver-1.4.7-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d34eb8494bea691a1a450141ebb5385e4b69d38bb8403b5146ad279f4b30fa3"}, + {file = "kiwisolver-1.4.7-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9242795d174daa40105c1d86aba618e8eab7bf96ba8c3ee614da8302a9f95503"}, + {file = "kiwisolver-1.4.7-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:a0f64a48bb81af7450e641e3fe0b0394d7381e342805479178b3d335d60ca7cf"}, + {file = "kiwisolver-1.4.7-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:8e045731a5416357638d1700927529e2b8ab304811671f665b225f8bf8d8f933"}, + {file = "kiwisolver-1.4.7-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:4322872d5772cae7369f8351da1edf255a604ea7087fe295411397d0cfd9655e"}, + {file = "kiwisolver-1.4.7-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:e1631290ee9271dffe3062d2634c3ecac02c83890ada077d225e081aca8aab89"}, + {file = "kiwisolver-1.4.7-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:edcfc407e4eb17e037bca59be0e85a2031a2ac87e4fed26d3e9df88b4165f92d"}, + {file = "kiwisolver-1.4.7-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:4d05d81ecb47d11e7f8932bd8b61b720bf0b41199358f3f5e36d38e28f0532c5"}, + {file = "kiwisolver-1.4.7-cp38-cp38-win32.whl", hash = "sha256:b38ac83d5f04b15e515fd86f312479d950d05ce2368d5413d46c088dda7de90a"}, + {file = "kiwisolver-1.4.7-cp38-cp38-win_amd64.whl", hash = "sha256:d83db7cde68459fc803052a55ace60bea2bae361fc3b7a6d5da07e11954e4b09"}, + {file = "kiwisolver-1.4.7-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:3f9362ecfca44c863569d3d3c033dbe8ba452ff8eed6f6b5806382741a1334bd"}, + {file = "kiwisolver-1.4.7-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e8df2eb9b2bac43ef8b082e06f750350fbbaf2887534a5be97f6cf07b19d9583"}, + {file = "kiwisolver-1.4.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f32d6edbc638cde7652bd690c3e728b25332acbadd7cad670cc4a02558d9c417"}, + {file = "kiwisolver-1.4.7-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:e2e6c39bd7b9372b0be21456caab138e8e69cc0fc1190a9dfa92bd45a1e6e904"}, + {file = "kiwisolver-1.4.7-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:dda56c24d869b1193fcc763f1284b9126550eaf84b88bbc7256e15028f19188a"}, + {file = "kiwisolver-1.4.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79849239c39b5e1fd906556c474d9b0439ea6792b637511f3fe3a41158d89ca8"}, + {file = "kiwisolver-1.4.7-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5e3bc157fed2a4c02ec468de4ecd12a6e22818d4f09cde2c31ee3226ffbefab2"}, + {file = "kiwisolver-1.4.7-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3da53da805b71e41053dc670f9a820d1157aae77b6b944e08024d17bcd51ef88"}, + {file = "kiwisolver-1.4.7-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:8705f17dfeb43139a692298cb6637ee2e59c0194538153e83e9ee0c75c2eddde"}, + {file = "kiwisolver-1.4.7-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:82a5c2f4b87c26bb1a0ef3d16b5c4753434633b83d365cc0ddf2770c93829e3c"}, + {file = "kiwisolver-1.4.7-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ce8be0466f4c0d585cdb6c1e2ed07232221df101a4c6f28821d2aa754ca2d9e2"}, + {file = "kiwisolver-1.4.7-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:409afdfe1e2e90e6ee7fc896f3df9a7fec8e793e58bfa0d052c8a82f99c37abb"}, + {file = "kiwisolver-1.4.7-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:5b9c3f4ee0b9a439d2415012bd1b1cc2df59e4d6a9939f4d669241d30b414327"}, + {file = "kiwisolver-1.4.7-cp39-cp39-win32.whl", hash = "sha256:a79ae34384df2b615eefca647a2873842ac3b596418032bef9a7283675962644"}, + {file = "kiwisolver-1.4.7-cp39-cp39-win_amd64.whl", hash = "sha256:cf0438b42121a66a3a667de17e779330fc0f20b0d97d59d2f2121e182b0505e4"}, + {file = "kiwisolver-1.4.7-cp39-cp39-win_arm64.whl", hash = "sha256:764202cc7e70f767dab49e8df52c7455e8de0df5d858fa801a11aa0d882ccf3f"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:94252291e3fe68001b1dd747b4c0b3be12582839b95ad4d1b641924d68fd4643"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:5b7dfa3b546da08a9f622bb6becdb14b3e24aaa30adba66749d38f3cc7ea9706"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bd3de6481f4ed8b734da5df134cd5a6a64fe32124fe83dde1e5b5f29fe30b1e6"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a91b5f9f1205845d488c928e8570dcb62b893372f63b8b6e98b863ebd2368ff2"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40fa14dbd66b8b8f470d5fc79c089a66185619d31645f9b0773b88b19f7223c4"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:eb542fe7933aa09d8d8f9d9097ef37532a7df6497819d16efe4359890a2f417a"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:bfa1acfa0c54932d5607e19a2c24646fb4c1ae2694437789129cf099789a3b00"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:eee3ea935c3d227d49b4eb85660ff631556841f6e567f0f7bda972df6c2c9935"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f3160309af4396e0ed04db259c3ccbfdc3621b5559b5453075e5de555e1f3a1b"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a17f6a29cf8935e587cc8a4dbfc8368c55edc645283db0ce9801016f83526c2d"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:10849fb2c1ecbfae45a693c070e0320a91b35dd4bcf58172c023b994283a124d"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:ac542bf38a8a4be2dc6b15248d36315ccc65f0743f7b1a76688ffb6b5129a5c2"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:8b01aac285f91ca889c800042c35ad3b239e704b150cfd3382adfc9dcc780e39"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:48be928f59a1f5c8207154f935334d374e79f2b5d212826307d072595ad76a2e"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f37cfe618a117e50d8c240555331160d73d0411422b59b5ee217843d7b693608"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:599b5c873c63a1f6ed7eead644a8a380cfbdf5db91dcb6f85707aaab213b1674"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:801fa7802e5cfabe3ab0c81a34c323a319b097dfb5004be950482d882f3d7225"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:0c6c43471bc764fad4bc99c5c2d6d16a676b1abf844ca7c8702bdae92df01ee0"}, + {file = "kiwisolver-1.4.7.tar.gz", hash = "sha256:9893ff81bd7107f7b685d3017cc6583daadb4fc26e4a888350df530e41980a60"}, ] [[package]] @@ -4302,13 +4385,13 @@ six = "*" [[package]] name = "langfuse" -version = "2.44.0" +version = "2.46.3" description = "A client library for accessing langfuse" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langfuse-2.44.0-py3-none-any.whl", hash = "sha256:adb73400a6ad6d597cc95c31381c82f81face3d5fb69391181f224a26f7e8562"}, - {file = "langfuse-2.44.0.tar.gz", hash = "sha256:dfa5378ff7022ae9fe5b8b842c0365347c98f9ef2b772dcee6a93a45442de28c"}, + {file = "langfuse-2.46.3-py3-none-any.whl", hash = "sha256:59dcca4b13ea5f5c7f5a9344266116c3b8b998ae63274e4e9d0dabb51a47d361"}, + {file = "langfuse-2.46.3.tar.gz", hash = "sha256:a68c2dba630f53ccd473205164082ac1b29a1cbdb73500004daee72b5b522624"}, ] [package.dependencies] @@ -4316,7 +4399,7 @@ anyio = ">=4.4.0,<5.0.0" backoff = ">=1.10.0" httpx = ">=0.15.4,<1.0" idna = ">=3.7,<4.0" -packaging = ">=23.2,<24.0" +packaging = ">=23.2,<25.0" pydantic = ">=1.10.7,<3.0" wrapt = ">=1.14,<2.0" @@ -4327,13 +4410,13 @@ openai = ["openai (>=0.27.8)"] [[package]] name = "langsmith" -version = "0.1.101" +version = "0.1.115" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.1.101-py3-none-any.whl", hash = "sha256:572e2c90709cda1ad837ac86cedda7295f69933f2124c658a92a35fb890477cc"}, - {file = "langsmith-0.1.101.tar.gz", hash = "sha256:caf4d95f314bb6cd3c4e0632eed821fd5cd5d0f18cb824772fce6d7a9113895b"}, + {file = "langsmith-0.1.115-py3-none-any.whl", hash = "sha256:04e35cfd4c2d4ff1ea10bb577ff43957b05ebb3d9eb4e06e200701f4a2b4ac9f"}, + {file = "langsmith-0.1.115.tar.gz", hash = "sha256:3b775377d858d32354f3ee0dd1ed637068cfe9a1f13e7b3bfa82db1615cdffc9"}, ] [package.dependencies] @@ -4909,6 +4992,22 @@ files = [ [package.extras] test = ["mypy (>=1.0)", "pytest (>=7.0.0)"] +[[package]] +name = "mock" +version = "4.0.3" +description = "Rolling backport of unittest.mock for all Pythons" +optional = false +python-versions = ">=3.6" +files = [ + {file = "mock-4.0.3-py3-none-any.whl", hash = "sha256:122fcb64ee37cfad5b3f48d7a7d51875d7031aaf3d8be7c42e2bee25044eee62"}, + {file = "mock-4.0.3.tar.gz", hash = "sha256:7d3fbbde18228f4ff2f1f119a45cdffa458b4c0dee32eb4d2bb2f82554bac7bc"}, +] + +[package.extras] +build = ["blurb", "twine", "wheel"] +docs = ["sphinx"] +test = ["pytest (<5.4)", "pytest-cov"] + [[package]] name = "monotonic" version = "1.6" @@ -5108,6 +5207,30 @@ files = [ {file = "multidict-6.0.5.tar.gz", hash = "sha256:f7e301075edaf50500f0b341543c41194d8df3ae5caf4702f2095f3ca73dd8da"}, ] +[[package]] +name = "multiprocess" +version = "0.70.16" +description = "better multiprocessing and multithreading in Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "multiprocess-0.70.16-pp310-pypy310_pp73-macosx_10_13_x86_64.whl", hash = "sha256:476887be10e2f59ff183c006af746cb6f1fd0eadcfd4ef49e605cbe2659920ee"}, + {file = "multiprocess-0.70.16-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d951bed82c8f73929ac82c61f01a7b5ce8f3e5ef40f5b52553b4f547ce2b08ec"}, + {file = "multiprocess-0.70.16-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:37b55f71c07e2d741374998c043b9520b626a8dddc8b3129222ca4f1a06ef67a"}, + {file = "multiprocess-0.70.16-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:ba8c31889abf4511c7308a8c52bb4a30b9d590e7f58523302ba00237702ca054"}, + {file = "multiprocess-0.70.16-pp39-pypy39_pp73-macosx_10_13_x86_64.whl", hash = "sha256:0dfd078c306e08d46d7a8d06fb120313d87aa43af60d66da43ffff40b44d2f41"}, + {file = "multiprocess-0.70.16-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e7b9d0f307cd9bd50851afaac0dba2cb6c44449efff697df7c7645f7d3f2be3a"}, + {file = "multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02"}, + {file = "multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a"}, + {file = "multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e"}, + {file = "multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435"}, + {file = "multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3"}, + {file = "multiprocess-0.70.16.tar.gz", hash = "sha256:161af703d4652a0e1410be6abccecde4a7ddffd19341be0a7011b94aeb171ac1"}, +] + +[package.dependencies] +dill = ">=0.3.8" + [[package]] name = "multitasking" version = "0.0.11" @@ -5333,6 +5456,25 @@ rsa = ["cryptography (>=3.0.0)"] signals = ["blinker (>=1.4.0)"] signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] +[[package]] +name = "oci" +version = "2.133.0" +description = "Oracle Cloud Infrastructure Python SDK" +optional = false +python-versions = "*" +files = [ + {file = "oci-2.133.0-py3-none-any.whl", hash = "sha256:9706365481ca538c89b3a15e6b5c246801eccb06be831a7f21c40f2a2ee310a7"}, + {file = "oci-2.133.0.tar.gz", hash = "sha256:800418025bb98f587c65bbf89c6b6d61ef0f2249e0698d73439baf3251640b7f"}, +] + +[package.dependencies] +certifi = "*" +circuitbreaker = {version = ">=1.3.1,<3.0.0", markers = "python_version >= \"3.7\""} +cryptography = ">=3.2.1,<43.0.0" +pyOpenSSL = ">=17.5.0,<25.0.0" +python-dateutil = ">=2.5.3,<3.0.0" +pytz = ">=2016.10" + [[package]] name = "odfpy" version = "1.4.1" @@ -5362,36 +5504,36 @@ tests = ["pytest", "pytest-cov"] [[package]] name = "onnxruntime" -version = "1.19.0" +version = "1.19.2" description = "ONNX Runtime is a runtime accelerator for Machine Learning models" optional = false python-versions = "*" files = [ - {file = "onnxruntime-1.19.0-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:6ce22a98dfec7b646ae305f52d0ce14a189a758b02ea501860ca719f4b0ae04b"}, - {file = "onnxruntime-1.19.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:19019c72873f26927aa322c54cf2bf7312b23451b27451f39b88f57016c94f8b"}, - {file = "onnxruntime-1.19.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8eaa16df99171dc636e30108d15597aed8c4c2dd9dbfdd07cc464d57d73fb275"}, - {file = "onnxruntime-1.19.0-cp310-cp310-win32.whl", hash = "sha256:0eb0f8dbe596fd0f4737fe511fdbb17603853a7d204c5b2ca38d3c7808fc556b"}, - {file = "onnxruntime-1.19.0-cp310-cp310-win_amd64.whl", hash = "sha256:616092d54ba8023b7bc0a5f6d900a07a37cc1cfcc631873c15f8c1d6e9e184d4"}, - {file = "onnxruntime-1.19.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:a2b53b3c287cd933e5eb597273926e899082d8c84ab96e1b34035764a1627e17"}, - {file = "onnxruntime-1.19.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e94984663963e74fbb468bde9ec6f19dcf890b594b35e249c4dc8789d08993c5"}, - {file = "onnxruntime-1.19.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6f379d1f050cfb55ce015d53727b78ee362febc065c38eed81512b22b757da73"}, - {file = "onnxruntime-1.19.0-cp311-cp311-win32.whl", hash = "sha256:4ccb48faea02503275ae7e79e351434fc43c294c4cb5c4d8bcb7479061396614"}, - {file = "onnxruntime-1.19.0-cp311-cp311-win_amd64.whl", hash = "sha256:9cdc8d311289a84e77722de68bd22b8adfb94eea26f4be6f9e017350faac8b18"}, - {file = "onnxruntime-1.19.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:1b59eaec1be9a8613c5fdeaafe67f73a062edce3ac03bbbdc9e2d98b58a30617"}, - {file = "onnxruntime-1.19.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:be4144d014a4b25184e63ce7a463a2e7796e2f3df931fccc6a6aefa6f1365dc5"}, - {file = "onnxruntime-1.19.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10d7e7d4ca7021ce7f29a66dbc6071addf2de5839135339bd855c6d9c2bba371"}, - {file = "onnxruntime-1.19.0-cp312-cp312-win32.whl", hash = "sha256:87f2c58b577a1fb31dc5d92b647ecc588fd5f1ea0c3ad4526f5f80a113357c8d"}, - {file = "onnxruntime-1.19.0-cp312-cp312-win_amd64.whl", hash = "sha256:8a1f50d49676d7b69566536ff039d9e4e95fc482a55673719f46528218ecbb94"}, - {file = "onnxruntime-1.19.0-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:71423c8c4b2d7a58956271534302ec72721c62a41efd0c4896343249b8399ab0"}, - {file = "onnxruntime-1.19.0-cp38-cp38-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9d63630d45e9498f96e75bbeb7fd4a56acb10155de0de4d0e18d1b6cbb0b358a"}, - {file = "onnxruntime-1.19.0-cp38-cp38-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f3bfd15db1e8794d379a86c1a9116889f47f2cca40cc82208fc4f7e8c38e8522"}, - {file = "onnxruntime-1.19.0-cp38-cp38-win32.whl", hash = "sha256:3b098003b6b4cb37cc84942e5f1fe27f945dd857cbd2829c824c26b0ba4a247e"}, - {file = "onnxruntime-1.19.0-cp38-cp38-win_amd64.whl", hash = "sha256:cea067a6541d6787d903ee6843401c5b1332a266585160d9700f9f0939443886"}, - {file = "onnxruntime-1.19.0-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:c4fcff12dc5ca963c5f76b9822bb404578fa4a98c281e8c666b429192799a099"}, - {file = "onnxruntime-1.19.0-cp39-cp39-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f6dcad8a4db908fbe70b98c79cea1c8b6ac3316adf4ce93453136e33a524ac59"}, - {file = "onnxruntime-1.19.0-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4bc449907c6e8d99eee5ae5cc9c8fdef273d801dcd195393d3f9ab8ad3f49522"}, - {file = "onnxruntime-1.19.0-cp39-cp39-win32.whl", hash = "sha256:947febd48405afcf526e45ccff97ff23b15e530434705f734870d22ae7fcf236"}, - {file = "onnxruntime-1.19.0-cp39-cp39-win_amd64.whl", hash = "sha256:f60be47eff5ee77fd28a466b0fd41d7debc42a32179d1ddb21e05d6067d7b48b"}, + {file = "onnxruntime-1.19.2-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:84fa57369c06cadd3c2a538ae2a26d76d583e7c34bdecd5769d71ca5c0fc750e"}, + {file = "onnxruntime-1.19.2-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bdc471a66df0c1cdef774accef69e9f2ca168c851ab5e4f2f3341512c7ef4666"}, + {file = "onnxruntime-1.19.2-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e3a4ce906105d99ebbe817f536d50a91ed8a4d1592553f49b3c23c4be2560ae6"}, + {file = "onnxruntime-1.19.2-cp310-cp310-win32.whl", hash = "sha256:4b3d723cc154c8ddeb9f6d0a8c0d6243774c6b5930847cc83170bfe4678fafb3"}, + {file = "onnxruntime-1.19.2-cp310-cp310-win_amd64.whl", hash = "sha256:17ed7382d2c58d4b7354fb2b301ff30b9bf308a1c7eac9546449cd122d21cae5"}, + {file = "onnxruntime-1.19.2-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:d863e8acdc7232d705d49e41087e10b274c42f09e259016a46f32c34e06dc4fd"}, + {file = "onnxruntime-1.19.2-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c1dfe4f660a71b31caa81fc298a25f9612815215a47b286236e61d540350d7b6"}, + {file = "onnxruntime-1.19.2-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a36511dc07c5c964b916697e42e366fa43c48cdb3d3503578d78cef30417cb84"}, + {file = "onnxruntime-1.19.2-cp311-cp311-win32.whl", hash = "sha256:50cbb8dc69d6befad4746a69760e5b00cc3ff0a59c6c3fb27f8afa20e2cab7e7"}, + {file = "onnxruntime-1.19.2-cp311-cp311-win_amd64.whl", hash = "sha256:1c3e5d415b78337fa0b1b75291e9ea9fb2a4c1f148eb5811e7212fed02cfffa8"}, + {file = "onnxruntime-1.19.2-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:68e7051bef9cfefcbb858d2d2646536829894d72a4130c24019219442b1dd2ed"}, + {file = "onnxruntime-1.19.2-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d2d366fbcc205ce68a8a3bde2185fd15c604d9645888703785b61ef174265168"}, + {file = "onnxruntime-1.19.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:477b93df4db467e9cbf34051662a4b27c18e131fa1836e05974eae0d6e4cf29b"}, + {file = "onnxruntime-1.19.2-cp312-cp312-win32.whl", hash = "sha256:9a174073dc5608fad05f7cf7f320b52e8035e73d80b0a23c80f840e5a97c0147"}, + {file = "onnxruntime-1.19.2-cp312-cp312-win_amd64.whl", hash = "sha256:190103273ea4507638ffc31d66a980594b237874b65379e273125150eb044857"}, + {file = "onnxruntime-1.19.2-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:636bc1d4cc051d40bc52e1f9da87fbb9c57d9d47164695dfb1c41646ea51ea66"}, + {file = "onnxruntime-1.19.2-cp38-cp38-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5bd8b875757ea941cbcfe01582970cc299893d1b65bd56731e326a8333f638a3"}, + {file = "onnxruntime-1.19.2-cp38-cp38-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b2046fc9560f97947bbc1acbe4c6d48585ef0f12742744307d3364b131ac5778"}, + {file = "onnxruntime-1.19.2-cp38-cp38-win32.whl", hash = "sha256:31c12840b1cde4ac1f7d27d540c44e13e34f2345cf3642762d2a3333621abb6a"}, + {file = "onnxruntime-1.19.2-cp38-cp38-win_amd64.whl", hash = "sha256:016229660adea180e9a32ce218b95f8f84860a200f0f13b50070d7d90e92956c"}, + {file = "onnxruntime-1.19.2-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:006c8d326835c017a9e9f74c9c77ebb570a71174a1e89fe078b29a557d9c3848"}, + {file = "onnxruntime-1.19.2-cp39-cp39-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:df2a94179a42d530b936f154615b54748239c2908ee44f0d722cb4df10670f68"}, + {file = "onnxruntime-1.19.2-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fae4b4de45894b9ce7ae418c5484cbf0341db6813effec01bb2216091c52f7fb"}, + {file = "onnxruntime-1.19.2-cp39-cp39-win32.whl", hash = "sha256:dc5430f473e8706fff837ae01323be9dcfddd3ea471c900a91fa7c9b807ec5d3"}, + {file = "onnxruntime-1.19.2-cp39-cp39-win_amd64.whl", hash = "sha256:38475e29a95c5f6c62c2c603d69fc7d4c6ccbf4df602bd567b86ae1138881c49"}, ] [package.dependencies] @@ -5524,42 +5666,42 @@ kerberos = ["requests-kerberos"] [[package]] name = "opentelemetry-api" -version = "1.26.0" +version = "1.27.0" description = "OpenTelemetry Python API" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_api-1.26.0-py3-none-any.whl", hash = "sha256:7d7ea33adf2ceda2dd680b18b1677e4152000b37ca76e679da71ff103b943064"}, - {file = "opentelemetry_api-1.26.0.tar.gz", hash = "sha256:2bd639e4bed5b18486fef0b5a520aaffde5a18fc225e808a1ac4df363f43a1ce"}, + {file = "opentelemetry_api-1.27.0-py3-none-any.whl", hash = "sha256:953d5871815e7c30c81b56d910c707588000fff7a3ca1c73e6531911d53065e7"}, + {file = "opentelemetry_api-1.27.0.tar.gz", hash = "sha256:ed673583eaa5f81b5ce5e86ef7cdaf622f88ef65f0b9aab40b843dcae5bef342"}, ] [package.dependencies] deprecated = ">=1.2.6" -importlib-metadata = ">=6.0,<=8.0.0" +importlib-metadata = ">=6.0,<=8.4.0" [[package]] name = "opentelemetry-exporter-otlp-proto-common" -version = "1.26.0" +version = "1.27.0" description = "OpenTelemetry Protobuf encoding" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_exporter_otlp_proto_common-1.26.0-py3-none-any.whl", hash = "sha256:ee4d8f8891a1b9c372abf8d109409e5b81947cf66423fd998e56880057afbc71"}, - {file = "opentelemetry_exporter_otlp_proto_common-1.26.0.tar.gz", hash = "sha256:bdbe50e2e22a1c71acaa0c8ba6efaadd58882e5a5978737a44a4c4b10d304c92"}, + {file = "opentelemetry_exporter_otlp_proto_common-1.27.0-py3-none-any.whl", hash = "sha256:675db7fffcb60946f3a5c43e17d1168a3307a94a930ecf8d2ea1f286f3d4f79a"}, + {file = "opentelemetry_exporter_otlp_proto_common-1.27.0.tar.gz", hash = "sha256:159d27cf49f359e3798c4c3eb8da6ef4020e292571bd8c5604a2a573231dd5c8"}, ] [package.dependencies] -opentelemetry-proto = "1.26.0" +opentelemetry-proto = "1.27.0" [[package]] name = "opentelemetry-exporter-otlp-proto-grpc" -version = "1.26.0" +version = "1.27.0" description = "OpenTelemetry Collector Protobuf over gRPC Exporter" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_exporter_otlp_proto_grpc-1.26.0-py3-none-any.whl", hash = "sha256:e2be5eff72ebcb010675b818e8d7c2e7d61ec451755b8de67a140bc49b9b0280"}, - {file = "opentelemetry_exporter_otlp_proto_grpc-1.26.0.tar.gz", hash = "sha256:a65b67a9a6b06ba1ec406114568e21afe88c1cdb29c464f2507d529eb906d8ae"}, + {file = "opentelemetry_exporter_otlp_proto_grpc-1.27.0-py3-none-any.whl", hash = "sha256:56b5bbd5d61aab05e300d9d62a6b3c134827bbd28d0b12f2649c2da368006c9e"}, + {file = "opentelemetry_exporter_otlp_proto_grpc-1.27.0.tar.gz", hash = "sha256:af6f72f76bcf425dfb5ad11c1a6d6eca2863b91e63575f89bb7b4b55099d968f"}, ] [package.dependencies] @@ -5567,19 +5709,19 @@ deprecated = ">=1.2.6" googleapis-common-protos = ">=1.52,<2.0" grpcio = ">=1.0.0,<2.0.0" opentelemetry-api = ">=1.15,<2.0" -opentelemetry-exporter-otlp-proto-common = "1.26.0" -opentelemetry-proto = "1.26.0" -opentelemetry-sdk = ">=1.26.0,<1.27.0" +opentelemetry-exporter-otlp-proto-common = "1.27.0" +opentelemetry-proto = "1.27.0" +opentelemetry-sdk = ">=1.27.0,<1.28.0" [[package]] name = "opentelemetry-instrumentation" -version = "0.47b0" +version = "0.48b0" description = "Instrumentation Tools & Auto Instrumentation for OpenTelemetry Python" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_instrumentation-0.47b0-py3-none-any.whl", hash = "sha256:88974ee52b1db08fc298334b51c19d47e53099c33740e48c4f084bd1afd052d5"}, - {file = "opentelemetry_instrumentation-0.47b0.tar.gz", hash = "sha256:96f9885e450c35e3f16a4f33145f2ebf620aea910c9fd74a392bbc0f807a350f"}, + {file = "opentelemetry_instrumentation-0.48b0-py3-none-any.whl", hash = "sha256:a69750dc4ba6a5c3eb67986a337185a25b739966d80479befe37b546fc870b44"}, + {file = "opentelemetry_instrumentation-0.48b0.tar.gz", hash = "sha256:94929685d906380743a71c3970f76b5f07476eea1834abd5dd9d17abfe23cc35"}, ] [package.dependencies] @@ -5589,55 +5731,55 @@ wrapt = ">=1.0.0,<2.0.0" [[package]] name = "opentelemetry-instrumentation-asgi" -version = "0.47b0" +version = "0.48b0" description = "ASGI instrumentation for OpenTelemetry" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_instrumentation_asgi-0.47b0-py3-none-any.whl", hash = "sha256:b798dc4957b3edc9dfecb47a4c05809036a4b762234c5071212fda39ead80ade"}, - {file = "opentelemetry_instrumentation_asgi-0.47b0.tar.gz", hash = "sha256:e78b7822c1bca0511e5e9610ec484b8994a81670375e570c76f06f69af7c506a"}, + {file = "opentelemetry_instrumentation_asgi-0.48b0-py3-none-any.whl", hash = "sha256:ddb1b5fc800ae66e85a4e2eca4d9ecd66367a8c7b556169d9e7b57e10676e44d"}, + {file = "opentelemetry_instrumentation_asgi-0.48b0.tar.gz", hash = "sha256:04c32174b23c7fa72ddfe192dad874954968a6a924608079af9952964ecdf785"}, ] [package.dependencies] asgiref = ">=3.0,<4.0" opentelemetry-api = ">=1.12,<2.0" -opentelemetry-instrumentation = "0.47b0" -opentelemetry-semantic-conventions = "0.47b0" -opentelemetry-util-http = "0.47b0" +opentelemetry-instrumentation = "0.48b0" +opentelemetry-semantic-conventions = "0.48b0" +opentelemetry-util-http = "0.48b0" [package.extras] instruments = ["asgiref (>=3.0,<4.0)"] [[package]] name = "opentelemetry-instrumentation-fastapi" -version = "0.47b0" +version = "0.48b0" description = "OpenTelemetry FastAPI Instrumentation" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_instrumentation_fastapi-0.47b0-py3-none-any.whl", hash = "sha256:5ac28dd401160b02e4f544a85a9e4f61a8cbe5b077ea0379d411615376a2bd21"}, - {file = "opentelemetry_instrumentation_fastapi-0.47b0.tar.gz", hash = "sha256:0c7c10b5d971e99a420678ffd16c5b1ea4f0db3b31b62faf305fbb03b4ebee36"}, + {file = "opentelemetry_instrumentation_fastapi-0.48b0-py3-none-any.whl", hash = "sha256:afeb820a59e139d3e5d96619600f11ce0187658b8ae9e3480857dd790bc024f2"}, + {file = "opentelemetry_instrumentation_fastapi-0.48b0.tar.gz", hash = "sha256:21a72563ea412c0b535815aeed75fc580240f1f02ebc72381cfab672648637a2"}, ] [package.dependencies] opentelemetry-api = ">=1.12,<2.0" -opentelemetry-instrumentation = "0.47b0" -opentelemetry-instrumentation-asgi = "0.47b0" -opentelemetry-semantic-conventions = "0.47b0" -opentelemetry-util-http = "0.47b0" +opentelemetry-instrumentation = "0.48b0" +opentelemetry-instrumentation-asgi = "0.48b0" +opentelemetry-semantic-conventions = "0.48b0" +opentelemetry-util-http = "0.48b0" [package.extras] -instruments = ["fastapi (>=0.58,<1.0)", "fastapi-slim (>=0.111.0,<0.112.0)"] +instruments = ["fastapi (>=0.58,<1.0)"] [[package]] name = "opentelemetry-proto" -version = "1.26.0" +version = "1.27.0" description = "OpenTelemetry Python Proto" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_proto-1.26.0-py3-none-any.whl", hash = "sha256:6c4d7b4d4d9c88543bcf8c28ae3f8f0448a753dc291c18c5390444c90b76a725"}, - {file = "opentelemetry_proto-1.26.0.tar.gz", hash = "sha256:c5c18796c0cab3751fc3b98dee53855835e90c0422924b484432ac852d93dc1e"}, + {file = "opentelemetry_proto-1.27.0-py3-none-any.whl", hash = "sha256:b133873de5581a50063e1e4b29cdcf0c5e253a8c2d8dc1229add20a4c3830ace"}, + {file = "opentelemetry_proto-1.27.0.tar.gz", hash = "sha256:33c9345d91dafd8a74fc3d7576c5a38f18b7fdf8d02983ac67485386132aedd6"}, ] [package.dependencies] @@ -5645,44 +5787,44 @@ protobuf = ">=3.19,<5.0" [[package]] name = "opentelemetry-sdk" -version = "1.26.0" +version = "1.27.0" description = "OpenTelemetry Python SDK" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_sdk-1.26.0-py3-none-any.whl", hash = "sha256:feb5056a84a88670c041ea0ded9921fca559efec03905dddeb3885525e0af897"}, - {file = "opentelemetry_sdk-1.26.0.tar.gz", hash = "sha256:c90d2868f8805619535c05562d699e2f4fb1f00dbd55a86dcefca4da6fa02f85"}, + {file = "opentelemetry_sdk-1.27.0-py3-none-any.whl", hash = "sha256:365f5e32f920faf0fd9e14fdfd92c086e317eaa5f860edba9cdc17a380d9197d"}, + {file = "opentelemetry_sdk-1.27.0.tar.gz", hash = "sha256:d525017dea0ccce9ba4e0245100ec46ecdc043f2d7b8315d56b19aff0904fa6f"}, ] [package.dependencies] -opentelemetry-api = "1.26.0" -opentelemetry-semantic-conventions = "0.47b0" +opentelemetry-api = "1.27.0" +opentelemetry-semantic-conventions = "0.48b0" typing-extensions = ">=3.7.4" [[package]] name = "opentelemetry-semantic-conventions" -version = "0.47b0" +version = "0.48b0" description = "OpenTelemetry Semantic Conventions" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_semantic_conventions-0.47b0-py3-none-any.whl", hash = "sha256:4ff9d595b85a59c1c1413f02bba320ce7ea6bf9e2ead2b0913c4395c7bbc1063"}, - {file = "opentelemetry_semantic_conventions-0.47b0.tar.gz", hash = "sha256:a8d57999bbe3495ffd4d510de26a97dadc1dace53e0275001b2c1b2f67992a7e"}, + {file = "opentelemetry_semantic_conventions-0.48b0-py3-none-any.whl", hash = "sha256:a0de9f45c413a8669788a38569c7e0a11ce6ce97861a628cca785deecdc32a1f"}, + {file = "opentelemetry_semantic_conventions-0.48b0.tar.gz", hash = "sha256:12d74983783b6878162208be57c9effcb89dc88691c64992d70bb89dc00daa1a"}, ] [package.dependencies] deprecated = ">=1.2.6" -opentelemetry-api = "1.26.0" +opentelemetry-api = "1.27.0" [[package]] name = "opentelemetry-util-http" -version = "0.47b0" +version = "0.48b0" description = "Web util for OpenTelemetry" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_util_http-0.47b0-py3-none-any.whl", hash = "sha256:3d3215e09c4a723b12da6d0233a31395aeb2bb33a64d7b15a1500690ba250f19"}, - {file = "opentelemetry_util_http-0.47b0.tar.gz", hash = "sha256:352a07664c18eef827eb8ddcbd64c64a7284a39dd1655e2f16f577eb046ccb32"}, + {file = "opentelemetry_util_http-0.48b0-py3-none-any.whl", hash = "sha256:76f598af93aab50328d2a69c786beaedc8b6a7770f7a818cc307eb353debfffb"}, + {file = "opentelemetry_util_http-0.48b0.tar.gz", hash = "sha256:60312015153580cc20f322e5cdc3d3ecad80a71743235bdb77716e742814623c"}, ] [[package]] @@ -5825,13 +5967,13 @@ files = [ [[package]] name = "packaging" -version = "23.2" +version = "24.1" description = "Core utilities for Python packages" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"}, - {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, + {file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"}, + {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, ] [[package]] @@ -5916,6 +6058,23 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.9.2)"] +[[package]] +name = "pathos" +version = "0.3.2" +description = "parallel graph management and execution in heterogeneous computing" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pathos-0.3.2-py3-none-any.whl", hash = "sha256:d669275e6eb4b3fbcd2846d7a6d1bba315fe23add0c614445ba1408d8b38bafe"}, + {file = "pathos-0.3.2.tar.gz", hash = "sha256:4f2a42bc1e10ccf0fe71961e7145fc1437018b6b21bd93b2446abc3983e49a7a"}, +] + +[package.dependencies] +dill = ">=0.3.8" +multiprocess = ">=0.70.16" +pox = ">=0.3.4" +ppft = ">=1.7.6.8" + [[package]] name = "peewee" version = "3.17.6" @@ -6076,13 +6235,13 @@ type = ["mypy (>=1.8)"] [[package]] name = "plotly" -version = "5.23.0" +version = "5.24.0" description = "An open-source, interactive data visualization library for Python" optional = false python-versions = ">=3.8" files = [ - {file = "plotly-5.23.0-py3-none-any.whl", hash = "sha256:76cbe78f75eddc10c56f5a4ee3e7ccaade7c0a57465546f02098c0caed6c2d1a"}, - {file = "plotly-5.23.0.tar.gz", hash = "sha256:89e57d003a116303a34de6700862391367dd564222ab71f8531df70279fc0193"}, + {file = "plotly-5.24.0-py3-none-any.whl", hash = "sha256:0e54efe52c8cef899f7daa41be9ed97dfb6be622613a2a8f56a86a0634b2b67e"}, + {file = "plotly-5.24.0.tar.gz", hash = "sha256:eae9f4f54448682442c92c1e97148e3ad0c52f0cf86306e1b76daba24add554a"}, ] [package.dependencies] @@ -6136,13 +6295,13 @@ tests = ["pytest (>=5.4.1)", "pytest-cov (>=2.8.1)", "pytest-mypy (>=0.8.0)", "p [[package]] name = "posthog" -version = "3.5.2" +version = "3.6.3" description = "Integrate PostHog into any python application." optional = false python-versions = "*" files = [ - {file = "posthog-3.5.2-py2.py3-none-any.whl", hash = "sha256:605b3d92369971cc99290b1fcc8534cbddac3726ef7972caa993454a5ecfb644"}, - {file = "posthog-3.5.2.tar.gz", hash = "sha256:a383a80c1f47e0243f5ce359e81e06e2e7b37eb39d1d6f8d01c3e64ed29df2ee"}, + {file = "posthog-3.6.3-py2.py3-none-any.whl", hash = "sha256:cdd6c5d8919fd6158bbc4103bccc7129c712d8104dc33828be02bada7b6320a4"}, + {file = "posthog-3.6.3.tar.gz", hash = "sha256:6e1104a20638eab2b5d9cde6b6202a2900d67436237b3ac3521614ec17686701"}, ] [package.dependencies] @@ -6155,7 +6314,32 @@ six = ">=1.5" [package.extras] dev = ["black", "flake8", "flake8-print", "isort", "pre-commit"] sentry = ["django", "sentry-sdk"] -test = ["coverage", "flake8", "freezegun (==0.3.15)", "mock (>=2.0.0)", "pylint", "pytest", "pytest-timeout"] +test = ["coverage", "django", "flake8", "freezegun (==0.3.15)", "mock (>=2.0.0)", "pylint", "pytest", "pytest-timeout"] + +[[package]] +name = "pox" +version = "0.3.4" +description = "utilities for filesystem exploration and automated builds" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pox-0.3.4-py3-none-any.whl", hash = "sha256:651b8ae8a7b341b7bfd267f67f63106daeb9805f1ac11f323d5280d2da93fdb6"}, + {file = "pox-0.3.4.tar.gz", hash = "sha256:16e6eca84f1bec3828210b06b052adf04cf2ab20c22fd6fbef5f78320c9a6fed"}, +] + +[[package]] +name = "ppft" +version = "1.7.6.8" +description = "distributed and parallel Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "ppft-1.7.6.8-py3-none-any.whl", hash = "sha256:de2dd4b1b080923dd9627fbdea52649fd741c752fce4f3cf37e26f785df23d9b"}, + {file = "ppft-1.7.6.8.tar.gz", hash = "sha256:76a429a7d7b74c4d743f6dba8351e58d62b6432ed65df9fe204790160dab996d"}, +] + +[package.extras] +dill = ["dill (>=0.3.8)"] [[package]] name = "primp" @@ -6692,18 +6876,18 @@ tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] [[package]] name = "pymilvus" -version = "2.4.5" +version = "2.4.6" description = "Python Sdk for Milvus" optional = false python-versions = ">=3.8" files = [ - {file = "pymilvus-2.4.5-py3-none-any.whl", hash = "sha256:dc4f2d1eac8db9cf3951de39566a1a244695760bb94d8310fbfc73d6d62bb267"}, - {file = "pymilvus-2.4.5.tar.gz", hash = "sha256:1a497fe9b41d6bf62b1d5e1c412960922dde1598576fcbb8818040c8af11149f"}, + {file = "pymilvus-2.4.6-py3-none-any.whl", hash = "sha256:b4c43472edc313b845d313be50610e19054e6954b2c5c3b515565c596c2d3d97"}, + {file = "pymilvus-2.4.6.tar.gz", hash = "sha256:6ac3eb91c92cc01bbe444fe83f895f02d7b2546d96ac67998630bf31ac074d66"}, ] [package.dependencies] environs = "<=9.5.0" -grpcio = ">=1.49.1,<=1.63.0" +grpcio = ">=1.49.1" milvus-lite = {version = ">=2.4.0,<2.5.0", markers = "sys_platform != \"win32\""} pandas = ">=1.2.4" protobuf = ">=3.20.0" @@ -6730,6 +6914,24 @@ files = [ ed25519 = ["PyNaCl (>=1.4.0)"] rsa = ["cryptography"] +[[package]] +name = "pyopenssl" +version = "24.2.1" +description = "Python wrapper module around the OpenSSL library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pyOpenSSL-24.2.1-py3-none-any.whl", hash = "sha256:967d5719b12b243588573f39b0c677637145c7a1ffedcd495a487e58177fbb8d"}, + {file = "pyopenssl-24.2.1.tar.gz", hash = "sha256:4247f0dbe3748d560dcbb2ff3ea01af0f9a1a001ef5f7c4c647956ed8cbf0e95"}, +] + +[package.dependencies] +cryptography = ">=41.0.5,<44" + +[package.extras] +docs = ["sphinx (!=5.2.0,!=5.2.0.post0,!=7.2.5)", "sphinx-rtd-theme"] +test = ["pretend", "pytest (>=3.0.1)", "pytest-rerunfailures"] + [[package]] name = "pypandoc" version = "1.13" @@ -6743,13 +6945,13 @@ files = [ [[package]] name = "pyparsing" -version = "3.1.2" +version = "3.1.4" description = "pyparsing module - Classes and methods to define and execute parsing grammars" optional = false python-versions = ">=3.6.8" files = [ - {file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"}, - {file = "pyparsing-3.1.2.tar.gz", hash = "sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad"}, + {file = "pyparsing-3.1.4-py3-none-any.whl", hash = "sha256:a6a7ee4235a3f944aa1fa2249307708f893fe5717dc603503c6c7969c070fb7c"}, + {file = "pyparsing-3.1.4.tar.gz", hash = "sha256:f86ec8d1a83f11977c9a6ea7598e8c27fc5cddfa5b07ea2241edbbde1d7bc032"}, ] [package.extras] @@ -7264,119 +7466,119 @@ dev = ["pytest"] [[package]] name = "rapidfuzz" -version = "3.9.6" +version = "3.9.7" description = "rapid fuzzy string matching" optional = false python-versions = ">=3.8" files = [ - {file = "rapidfuzz-3.9.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a7ed0d0b9c85720f0ae33ac5efc8dc3f60c1489dad5c29d735fbdf2f66f0431f"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f3deff6ab7017ed21b9aec5874a07ad13e6b2a688af055837f88b743c7bfd947"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3f9fc060160507b2704f7d1491bd58453d69689b580cbc85289335b14fe8ca"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c4e86c2b3827fa6169ad6e7d4b790ce02a20acefb8b78d92fa4249589bbc7a2c"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f982e1aafb4bd8207a5e073b1efef9e68a984e91330e1bbf364f9ed157ed83f0"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9196a51d0ec5eaaaf5bca54a85b7b1e666fc944c332f68e6427503af9fb8c49e"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb5a514064e02585b1cc09da2fe406a6dc1a7e5f3e92dd4f27c53e5f1465ec81"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e3a4244f65dbc3580b1275480118c3763f9dc29fc3dd96610560cb5e140a4d4a"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:f6ebb910a702e41641e1e1dada3843bc11ba9107a33c98daef6945a885a40a07"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:624fbe96115fb39addafa288d583b5493bc76dab1d34d0ebba9987d6871afdf9"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:1c59f1c1507b7a557cf3c410c76e91f097460da7d97e51c985343798e9df7a3c"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f6f0256cb27b6a0fb2e1918477d1b56473cd04acfa245376a342e7c15806a396"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-win32.whl", hash = "sha256:24d473d00d23a30a85802b502b417a7f5126019c3beec91a6739fe7b95388b24"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-win_amd64.whl", hash = "sha256:248f6d2612e661e2b5f9a22bbd5862a1600e720da7bb6ad8a55bb1548cdfa423"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-win_arm64.whl", hash = "sha256:e03fdf0e74f346ed7e798135df5f2a0fb8d6b96582b00ebef202dcf2171e1d1d"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:52e4675f642fbc85632f691b67115a243cd4d2a47bdcc4a3d9a79e784518ff97"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1f93a2f13038700bd245b927c46a2017db3dcd4d4ff94687d74b5123689b873b"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42b70500bca460264b8141d8040caee22e9cf0418c5388104ff0c73fb69ee28f"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a1e037fb89f714a220f68f902fc6300ab7a33349f3ce8ffae668c3b3a40b0b06"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6792f66d59b86ccfad5e247f2912e255c85c575789acdbad8e7f561412ffed8a"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:68d9cffe710b67f1969cf996983608cee4490521d96ea91d16bd7ea5dc80ea98"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63daaeeea76da17fa0bbe7fb05cba8ed8064bb1a0edf8360636557f8b6511961"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d214e063bffa13e3b771520b74f674b22d309b5720d4df9918ff3e0c0f037720"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:ed443a2062460f44c0346cb9d269b586496b808c2419bbd6057f54061c9b9c75"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:5b0c9b227ee0076fb2d58301c505bb837a290ae99ee628beacdb719f0626d749"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:82c9722b7dfaa71e8b61f8c89fed0482567fb69178e139fe4151fc71ed7df782"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c18897c95c0a288347e29537b63608a8f63a5c3cb6da258ac46fcf89155e723e"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-win32.whl", hash = "sha256:3e910cf08944da381159587709daaad9e59d8ff7bca1f788d15928f3c3d49c2a"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-win_amd64.whl", hash = "sha256:59c4a61fab676d37329fc3a671618a461bfeef53a4d0b8b12e3bc24a14e166f8"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-win_arm64.whl", hash = "sha256:8b4afea244102332973377fddbe54ce844d0916e1c67a5123432291717f32ffa"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:70591b28b218fff351b88cdd7f2359a01a71f9f7f5a2e465ce3715ed4b3c422b"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ee2d8355c7343c631a03e57540ea06e8717c19ecf5ff64ea07e0498f7f161457"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:708fb675de0f47b9635d1cc6fbbf80d52cb710d0a1abbfae5c84c46e3abbddc3"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1d66c247c2d3bb7a9b60567c395a15a929d0ebcc5f4ceedb55bfa202c38c6e0c"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:15146301b32e6e3d2b7e8146db1a26747919d8b13690c7f83a4cb5dc111b3a08"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7a03da59b6c7c97e657dd5cd4bcaab5fe4a2affd8193958d6f4d938bee36679"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d2c2fe19e392dbc22695b6c3b2510527e2b774647e79936bbde49db7742d6f1"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:91aaee4c94cb45930684f583ffc4e7c01a52b46610971cede33586cf8a04a12e"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3f5702828c10768f9281180a7ff8597da1e5002803e1304e9519dd0f06d79a85"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:ccd1763b608fb4629a0b08f00b3c099d6395e67c14e619f6341b2c8429c2f310"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:cc7a0d4b2cb166bc46d02c8c9f7551cde8e2f3c9789df3827309433ee9771163"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7496f53d40560a58964207b52586783633f371683834a8f719d6d965d223a2eb"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-win32.whl", hash = "sha256:5eb1a9272ca71bc72be5415c2fa8448a6302ea4578e181bb7da9db855b367df0"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-win_amd64.whl", hash = "sha256:0d21fc3c0ca507a1180152a6dbd129ebaef48facde3f943db5c1055b6e6be56a"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-win_arm64.whl", hash = "sha256:43bb27a57c29dc5fa754496ba6a1a508480d21ae99ac0d19597646c16407e9f3"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:83a5ac6547a9d6eedaa212975cb8f2ce2aa07e6e30833b40e54a52b9f9999aa4"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:10f06139142ecde67078ebc9a745965446132b998f9feebffd71acdf218acfcc"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:74720c3f24597f76c7c3e2c4abdff55f1664f4766ff5b28aeaa689f8ffba5fab"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce2bce52b5c150878e558a0418c2b637fb3dbb6eb38e4eb27d24aa839920483e"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1611199f178793ca9a060c99b284e11f6d7d124998191f1cace9a0245334d219"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0308b2ad161daf502908a6e21a57c78ded0258eba9a8f5e2545e2dafca312507"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3eda91832201b86e3b70835f91522587725bec329ec68f2f7faf5124091e5ca7"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ece873c093aedd87fc07c2a7e333d52e458dc177016afa1edaf157e82b6914d8"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:d97d3c9d209d5c30172baea5966f2129e8a198fec4a1aeb2f92abb6e82a2edb1"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:6c4550d0db4931f5ebe9f0678916d1b06f06f5a99ba0b8a48b9457fd8959a7d4"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:b6b8dd4af6324fc325d9483bec75ecf9be33e590928c9202d408e4eafff6a0a6"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:16122ae448bc89e2bea9d81ce6cb0f751e4e07da39bd1e70b95cae2493857853"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-win32.whl", hash = "sha256:71cc168c305a4445109cd0d4925406f6e66bcb48fde99a1835387c58af4ecfe9"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-win_amd64.whl", hash = "sha256:59ee78f2ecd53fef8454909cda7400fe2cfcd820f62b8a5d4dfe930102268054"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-win_arm64.whl", hash = "sha256:58b4ce83f223605c358ae37e7a2d19a41b96aa65b1fede99cc664c9053af89ac"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9f469dbc9c4aeaac7dd005992af74b7dff94aa56a3ea063ce64e4b3e6736dd2f"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a9ed7ad9adb68d0fe63a156fe752bbf5f1403ed66961551e749641af2874da92"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39ffe48ffbeedf78d120ddfb9d583f2ca906712159a4e9c3c743c9f33e7b1775"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8502ccdea9084d54b6f737d96a3b60a84e3afed9d016686dc979b49cdac71613"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6a4bec4956e06b170ca896ba055d08d4c457dac745548172443982956a80e118"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2c0488b1c273be39e109ff885ccac0448b2fa74dea4c4dc676bcf756c15f16d6"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0542c036cb6acf24edd2c9e0411a67d7ba71e29e4d3001a082466b86fc34ff30"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:0a96b52c9f26857bf009e270dcd829381e7a634f7ddd585fa29b87d4c82146d9"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:6edd3cd7c4aa8c68c716d349f531bd5011f2ca49ddade216bb4429460151559f"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:50b2fb55d7ed58c66d49c9f954acd8fc4a3f0e9fd0ff708299bd8abb68238d0e"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:32848dfe54391636b84cda1823fd23e5a6b1dbb8be0e9a1d80e4ee9903820994"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:29146cb7a1bf69c87e928b31bffa54f066cb65639d073b36e1425f98cccdebc6"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-win32.whl", hash = "sha256:aed13e5edacb0ecadcc304cc66e93e7e77ff24f059c9792ee602c0381808e10c"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-win_amd64.whl", hash = "sha256:af440e36b828922256d0b4d79443bf2cbe5515fc4b0e9e96017ec789b36bb9fc"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:efa674b407424553024522159296690d99d6e6b1192cafe99ca84592faff16b4"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0b40ff76ee19b03ebf10a0a87938f86814996a822786c41c3312d251b7927849"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16a6c7997cb5927ced6f617122eb116ba514ec6b6f60f4803e7925ef55158891"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f3f42504bdc8d770987fc3d99964766d42b2a03e4d5b0f891decdd256236bae0"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ad9462aa2be9f60b540c19a083471fdf28e7cf6434f068b631525b5e6251b35e"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1629698e68f47609a73bf9e73a6da3a4cac20bc710529215cbdf111ab603665b"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:68bc7621843d8e9a7fd1b1a32729465bf94b47b6fb307d906da168413331f8d6"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:c6254c50f15bc2fcc33cb93a95a81b702d9e6590f432a7f7822b8c7aba9ae288"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:7e535a114fa575bc143e175e4ca386a467ec8c42909eff500f5f0f13dc84e3e0"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:d50acc0e9d67e4ba7a004a14c42d1b1e8b6ca1c515692746f4f8e7948c673167"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:fa742ec60bec53c5a211632cf1d31b9eb5a3c80f1371a46a23ac25a1fa2ab209"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:c256fa95d29cbe5aa717db790b231a9a5b49e5983d50dc9df29d364a1db5e35b"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-win32.whl", hash = "sha256:89acbf728b764421036c173a10ada436ecca22999851cdc01d0aa904c70d362d"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-win_amd64.whl", hash = "sha256:c608fcba8b14d86c04cb56b203fed31a96e8a1ebb4ce99e7b70313c5bf8cf497"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-win_arm64.whl", hash = "sha256:d41c00ded0e22e9dba88ff23ebe0dc9d2a5f21ba2f88e185ea7374461e61daa9"}, - {file = "rapidfuzz-3.9.6-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:a65c2f63218ea2dedd56fc56361035e189ca123bd9c9ce63a9bef6f99540d681"}, - {file = "rapidfuzz-3.9.6-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:680dc78a5f889d3b89f74824b89fe357f49f88ad10d2c121e9c3ad37bac1e4eb"}, - {file = "rapidfuzz-3.9.6-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b8ca862927a0b05bd825e46ddf82d0724ea44b07d898ef639386530bf9b40f15"}, - {file = "rapidfuzz-3.9.6-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2116fa1fbff21fa52cd46f3cfcb1e193ba1d65d81f8b6e123193451cd3d6c15e"}, - {file = "rapidfuzz-3.9.6-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4dcb7d9afd740370a897c15da61d3d57a8d54738d7c764a99cedb5f746d6a003"}, - {file = "rapidfuzz-3.9.6-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:1a5bd6401bb489e14cbb5981c378d53ede850b7cc84b2464cad606149cc4e17d"}, - {file = "rapidfuzz-3.9.6-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:29fda70b9d03e29df6fc45cc27cbcc235534b1b0b2900e0a3ae0b43022aaeef5"}, - {file = "rapidfuzz-3.9.6-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:88144f5f52ae977df9352029488326afadd7a7f42c6779d486d1f82d43b2b1f2"}, - {file = "rapidfuzz-3.9.6-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:715aeaabafba2709b9dd91acb2a44bad59d60b4616ef90c08f4d4402a3bbca60"}, - {file = "rapidfuzz-3.9.6-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:af26ebd3714224fbf9bebbc27bdbac14f334c15f5d7043699cd694635050d6ca"}, - {file = "rapidfuzz-3.9.6-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:101bd2df438861a005ed47c032631b7857dfcdb17b82beeeb410307983aac61d"}, - {file = "rapidfuzz-3.9.6-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:2185e8e29809b97ad22a7f99281d1669a89bdf5fa1ef4ef1feca36924e675367"}, - {file = "rapidfuzz-3.9.6-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:9e53c72d08f0e9c6e4a369e52df5971f311305b4487690c62e8dd0846770260c"}, - {file = "rapidfuzz-3.9.6-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:a0cb157162f0cdd62e538c7bd298ff669847fc43a96422811d5ab933f4c16c3a"}, - {file = "rapidfuzz-3.9.6-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4bb5ff2bd48132ed5e7fbb8f619885facb2e023759f2519a448b2c18afe07e5d"}, - {file = "rapidfuzz-3.9.6-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6dc37f601865e8407e3a8037ffbc3afe0b0f837b2146f7632bd29d087385babe"}, - {file = "rapidfuzz-3.9.6-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a657eee4b94668faf1fa2703bdd803654303f7e468eb9ba10a664d867ed9e779"}, - {file = "rapidfuzz-3.9.6-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:51be6ab5b1d5bb32abd39718f2a5e3835502e026a8272d139ead295c224a6f5e"}, - {file = "rapidfuzz-3.9.6.tar.gz", hash = "sha256:5cf2a7d621e4515fee84722e93563bf77ff2cbe832a77a48b81f88f9e23b9e8d"}, + {file = "rapidfuzz-3.9.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ccf68e30b80e903f2309f90a438dbd640dd98e878eeb5ad361a288051ee5b75c"}, + {file = "rapidfuzz-3.9.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:696a79018ef989bf1c9abd9005841cee18005ccad4748bad8a4c274c47b6241a"}, + {file = "rapidfuzz-3.9.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c4eebf6c93af0ae866c22b403a84747580bb5c10f0d7b51c82a87f25405d4dcb"}, + {file = "rapidfuzz-3.9.7-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0e9125377fa3d21a8abd4fbdbcf1c27be73e8b1850f0b61b5b711364bf3b59db"}, + {file = "rapidfuzz-3.9.7-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c12d180b17a22d107c8747de9c68d0b9c1d15dcda5445ff9bf9f4ccfb67c3e16"}, + {file = "rapidfuzz-3.9.7-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c1318d42610c26dcd68bd3279a1bf9e3605377260867c9a8ed22eafc1bd93a7c"}, + {file = "rapidfuzz-3.9.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd5fa6e3c6e0333051c1f3a49f0807b3366f4131c8d6ac8c3e05fd0d0ce3755c"}, + {file = "rapidfuzz-3.9.7-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:fcf79b686962d7bec458a0babc904cb4fa319808805e036b9d5a531ee6b9b835"}, + {file = "rapidfuzz-3.9.7-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:8b01153c7466d0bad48fba77a303d5a768e66f24b763853469f47220b3de4661"}, + {file = "rapidfuzz-3.9.7-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:94baaeea0b4f8632a6da69348b1e741043eba18d4e3088d674d3f76586b6223d"}, + {file = "rapidfuzz-3.9.7-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:6c5b32875646cb7f60c193ade99b2e4b124f19583492115293cd00f6fb198b17"}, + {file = "rapidfuzz-3.9.7-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:110b6294396bc0a447648627479c9320f095c2034c0537f687592e0f58622638"}, + {file = "rapidfuzz-3.9.7-cp310-cp310-win32.whl", hash = "sha256:3445a35c4c8d288f2b2011eb61bce1227c633ce85a3154e727170f37c0266bb2"}, + {file = "rapidfuzz-3.9.7-cp310-cp310-win_amd64.whl", hash = "sha256:0d1415a732ee75e74a90af12020b77a0b396b36c60afae1bde3208a78cd2c9fc"}, + {file = "rapidfuzz-3.9.7-cp310-cp310-win_arm64.whl", hash = "sha256:836f4d88b8bd0fff2ebe815dcaab8aa6c8d07d1d566a7e21dd137cf6fe11ed5b"}, + {file = "rapidfuzz-3.9.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d098ce6162eb5e48fceb0745455bc950af059df6113eec83e916c129fca11408"}, + {file = "rapidfuzz-3.9.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:048d55d36c02c6685a2b2741688503c3d15149694506655b6169dcfd3b6c2585"}, + {file = "rapidfuzz-3.9.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c33211cfff9aec425bb1bfedaf94afcf337063aa273754f22779d6dadebef4c2"}, + {file = "rapidfuzz-3.9.7-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e6d9db2fa4e9be171e9bb31cf2d2575574774966b43f5b951062bb2e67885852"}, + {file = "rapidfuzz-3.9.7-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d4e049d5ad61448c9a020d1061eba20944c4887d720c4069724beb6ea1692507"}, + {file = "rapidfuzz-3.9.7-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cfa74aac64c85898b93d9c80bb935a96bf64985e28d4ee0f1a3d1f3bf11a5106"}, + {file = "rapidfuzz-3.9.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:965693c2e9efd425b0f059f5be50ef830129f82892fa1858e220e424d9d0160f"}, + {file = "rapidfuzz-3.9.7-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8501000a5eb8037c4b56857724797fe5a8b01853c363de91c8d0d0ad56bef319"}, + {file = "rapidfuzz-3.9.7-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:8d92c552c6b7577402afdd547dcf5d31ea6c8ae31ad03f78226e055cfa37f3c6"}, + {file = "rapidfuzz-3.9.7-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:1ee2086f490cb501d86b7e386c1eb4e3a0ccbb0c99067089efaa8c79012c8952"}, + {file = "rapidfuzz-3.9.7-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:1de91e7fd7f525e10ea79a6e62c559d1b0278ec097ad83d9da378b6fab65a265"}, + {file = "rapidfuzz-3.9.7-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a4da514d13f4433e16960a17f05b67e0af30ac771719c9a9fb877e5004f74477"}, + {file = "rapidfuzz-3.9.7-cp311-cp311-win32.whl", hash = "sha256:a40184c67db8252593ec518e17fb8a6e86d7259dc9f2d6c0bf4ff4db8cf1ad4b"}, + {file = "rapidfuzz-3.9.7-cp311-cp311-win_amd64.whl", hash = "sha256:c4f28f1930b09a2c300357d8465b388cecb7e8b2f454a5d5425561710b7fd07f"}, + {file = "rapidfuzz-3.9.7-cp311-cp311-win_arm64.whl", hash = "sha256:675b75412a943bb83f1f53e2e54fd18c80ef15ed642dc6eb0382d1949419d904"}, + {file = "rapidfuzz-3.9.7-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1ef6a1a8f0b12f8722f595f15c62950c9a02d5abc64742561299ffd49f6c6944"}, + {file = "rapidfuzz-3.9.7-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:32532af1d70c6ec02ea5ac7ee2766dfff7c8ae8c761abfe8da9e527314e634e8"}, + {file = "rapidfuzz-3.9.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ae1a38bade755aa9dd95a81cda949e1bf9cd92b79341ccc5e2189c9e7bdfc5ec"}, + {file = "rapidfuzz-3.9.7-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d73ee2df41224c87336448d279b5b6a3a75f36e41dd3dcf538c0c9cce36360d8"}, + {file = "rapidfuzz-3.9.7-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:be3a1fc3e2ab3bdf93dc0c83c00acca8afd2a80602297d96cf4a0ba028333cdf"}, + {file = "rapidfuzz-3.9.7-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:603f48f621272a448ff58bb556feb4371252a02156593303391f5c3281dfaeac"}, + {file = "rapidfuzz-3.9.7-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:268f8e1ca50fc61c0736f3fe9d47891424adf62d96ed30196f30f4bd8216b41f"}, + {file = "rapidfuzz-3.9.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5f8bf3f0d02935751d8660abda6044821a861f6229f7d359f98bcdcc7e66c39b"}, + {file = "rapidfuzz-3.9.7-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b997ff3b39d4cee9fb025d6c46b0a24bd67595ce5a5b652a97fb3a9d60beb651"}, + {file = "rapidfuzz-3.9.7-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:ca66676c8ef6557f9b81c5b2b519097817a7c776a6599b8d6fcc3e16edd216fe"}, + {file = "rapidfuzz-3.9.7-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:35d3044cb635ca6b1b2b7b67b3597bd19f34f1753b129eb6d2ae04cf98cd3945"}, + {file = "rapidfuzz-3.9.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:5a93c9e60904cb76e7aefef67afffb8b37c4894f81415ed513db090f29d01101"}, + {file = "rapidfuzz-3.9.7-cp312-cp312-win32.whl", hash = "sha256:579d107102c0725f7c79b4e79f16d3cf4d7c9208f29c66b064fa1fd4641d5155"}, + {file = "rapidfuzz-3.9.7-cp312-cp312-win_amd64.whl", hash = "sha256:953b3780765c8846866faf891ee4290f6a41a6dacf4fbcd3926f78c9de412ca6"}, + {file = "rapidfuzz-3.9.7-cp312-cp312-win_arm64.whl", hash = "sha256:7c20c1474b068c4bd45bf2fd0ad548df284f74e9a14a68b06746c56e3aa8eb70"}, + {file = "rapidfuzz-3.9.7-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:fde81b1da9a947f931711febe2e2bee694e891f6d3e6aa6bc02c1884702aea19"}, + {file = "rapidfuzz-3.9.7-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:47e92c155a14f44511ea8ebcc6bc1535a1fe8d0a7d67ad3cc47ba61606df7bcf"}, + {file = "rapidfuzz-3.9.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8772b745668260c5c4d069c678bbaa68812e6c69830f3771eaad521af7bc17f8"}, + {file = "rapidfuzz-3.9.7-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:578302828dd97ee2ba507d2f71d62164e28d2fc7bc73aad0d2d1d2afc021a5d5"}, + {file = "rapidfuzz-3.9.7-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fc3e6081069eea61593f1d6839029da53d00c8c9b205c5534853eaa3f031085c"}, + {file = "rapidfuzz-3.9.7-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0b1c2d504eddf97bc0f2eba422c8915576dbf025062ceaca2d68aecd66324ad9"}, + {file = "rapidfuzz-3.9.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fb76e5a21034f0307c51c5a2fc08856f698c53a4c593b17d291f7d6e9d09ca3"}, + {file = "rapidfuzz-3.9.7-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:d4ba2318ef670ce505f42881a5d2af70f948124646947341a3c6ccb33cd70369"}, + {file = "rapidfuzz-3.9.7-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:057bb03f39e285047d7e9412e01ecf31bb2d42b9466a5409d715d587460dd59b"}, + {file = "rapidfuzz-3.9.7-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:a8feac9006d5c9758438906f093befffc4290de75663dbb2098461df7c7d28dd"}, + {file = "rapidfuzz-3.9.7-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:95b8292383e717e10455f2c917df45032b611141e43d1adf70f71b1566136b11"}, + {file = "rapidfuzz-3.9.7-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e9fbf659537d246086d0297628b3795dc3e4a384101ecc01e5791c827b8d7345"}, + {file = "rapidfuzz-3.9.7-cp313-cp313-win32.whl", hash = "sha256:1dc516ac6d32027be2b0196bedf6d977ac26debd09ca182376322ad620460feb"}, + {file = "rapidfuzz-3.9.7-cp313-cp313-win_amd64.whl", hash = "sha256:b4f86e09d3064dca0b014cd48688964036a904a2d28048f00c8f4640796d06a8"}, + {file = "rapidfuzz-3.9.7-cp313-cp313-win_arm64.whl", hash = "sha256:19c64d8ddb2940b42a4567b23f1681af77f50a5ff6c9b8e85daba079c210716e"}, + {file = "rapidfuzz-3.9.7-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:fbda3dd68d8b28ccb20ffb6f756fefd9b5ba570a772bedd7643ed441f5793308"}, + {file = "rapidfuzz-3.9.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2379e0b2578ad3ac7004f223251550f08bca873ff76c169b09410ec562ad78d8"}, + {file = "rapidfuzz-3.9.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d1eff95362f993b0276fd3839aee48625b09aac8938bb0c23b40d219cba5dc5"}, + {file = "rapidfuzz-3.9.7-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cd9360e30041690912525a210e48a897b49b230768cc8af1c702e5395690464f"}, + {file = "rapidfuzz-3.9.7-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a93cd834b3c315ab437f0565ee3a2f42dd33768dc885ccbabf9710b131cf70d2"}, + {file = "rapidfuzz-3.9.7-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4ff196996240db7075f62c7bc4506f40a3c80cd4ae3ab0e79ac6892283a90859"}, + {file = "rapidfuzz-3.9.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:948dcee7aaa1cd14358b2a7ef08bf0be42bf89049c3a906669874a715fc2c937"}, + {file = "rapidfuzz-3.9.7-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:d95751f505a301af1aaf086c19f34536056d6c8efa91b2240de532a3db57b543"}, + {file = "rapidfuzz-3.9.7-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:90db86fa196eecf96cb6db09f1083912ea945c50c57188039392d810d0b784e1"}, + {file = "rapidfuzz-3.9.7-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:3171653212218a162540a3c8eb8ae7d3dcc8548540b69eaecaf3b47c14d89c90"}, + {file = "rapidfuzz-3.9.7-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:36dd6e820379c37a1ffefc8a52b648758e867cd9d78ee5b5dc0c9a6a10145378"}, + {file = "rapidfuzz-3.9.7-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:7b702de95666a1f7d5c6b47eacadfe2d2794af3742d63d2134767d13e5d1c713"}, + {file = "rapidfuzz-3.9.7-cp38-cp38-win32.whl", hash = "sha256:9030e7238c0df51aed5c9c5ed8eee2bdd47a2ae788e562c1454af2851c3d1906"}, + {file = "rapidfuzz-3.9.7-cp38-cp38-win_amd64.whl", hash = "sha256:f847fb0fbfb72482b1c05c59cbb275c58a55b73708a7f77a83f8035ee3c86497"}, + {file = "rapidfuzz-3.9.7-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:97f2ce529d2a70a60c290f6ab269a2bbf1d3b47b9724dccc84339b85f7afb044"}, + {file = "rapidfuzz-3.9.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e2957fdad10bb83b1982b02deb3604a3f6911a5e545f518b59c741086f92d152"}, + {file = "rapidfuzz-3.9.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d5262383634626eb45c536017204b8163a03bc43bda880cf1bdd7885db9a163"}, + {file = "rapidfuzz-3.9.7-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:364587827d7cbd41afa0782adc2d2d19e3f07d355b0750a02a8e33ad27a9c368"}, + {file = "rapidfuzz-3.9.7-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ecc24af7f905f3d6efb371a01680116ffea8d64e266618fb9ad1602a9b4f7934"}, + {file = "rapidfuzz-3.9.7-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9dc86aa6b29d174713c5f4caac35ffb7f232e3e649113e8d13812b35ab078228"}, + {file = "rapidfuzz-3.9.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3dcfbe7266e74a707173a12a7b355a531f2dcfbdb32f09468e664330da14874"}, + {file = "rapidfuzz-3.9.7-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:b23806fbdd6b510ba9ac93bb72d503066263b0fba44b71b835be9f063a84025f"}, + {file = "rapidfuzz-3.9.7-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:5551d68264c1bb6943f542da83a4dc8940ede52c5847ef158698799cc28d14f5"}, + {file = "rapidfuzz-3.9.7-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:13d8675a1fa7e2b19650ca7ef9a6ec01391d4bb12ab9e0793e8eb024538b4a34"}, + {file = "rapidfuzz-3.9.7-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:9b6a5de507b9be6de688dae40143b656f7a93b10995fb8bd90deb555e7875c60"}, + {file = "rapidfuzz-3.9.7-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:111a20a3c090cf244d9406e60500b6c34b2375ba3a5009e2b38fd806fe38e337"}, + {file = "rapidfuzz-3.9.7-cp39-cp39-win32.whl", hash = "sha256:22589c0b8ccc6c391ce7f776c93a8c92c96ab8d34e1a19f1bd2b12a235332632"}, + {file = "rapidfuzz-3.9.7-cp39-cp39-win_amd64.whl", hash = "sha256:6f83221db5755b8f34222e40607d87f1176a8d5d4dbda4a55a0f0b67d588a69c"}, + {file = "rapidfuzz-3.9.7-cp39-cp39-win_arm64.whl", hash = "sha256:3665b92e788578c3bb334bd5b5fa7ee1a84bafd68be438e3110861d1578c63a0"}, + {file = "rapidfuzz-3.9.7-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:d7df9c2194c7ec930b33c991c55dbd0c10951bd25800c0b7a7b571994ebbced5"}, + {file = "rapidfuzz-3.9.7-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:68bd888eafd07b09585dcc8bc2716c5ecdb7eed62827470664d25588982b2873"}, + {file = "rapidfuzz-3.9.7-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d1230e0f9026851a6a432beaa0ce575dda7b39fe689b576f99a0704fbb81fc9c"}, + {file = "rapidfuzz-3.9.7-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a3b36e1c61b796ae1777f3e9e11fd39898b09d351c9384baf6e3b7e6191d8ced"}, + {file = "rapidfuzz-3.9.7-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9dba13d86806fcf3fe9c9919f58575e0090eadfb89c058bde02bcc7ab24e4548"}, + {file = "rapidfuzz-3.9.7-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:1f1a33e84056b7892c721d84475d3bde49a145126bc4c6efe0d6d0d59cb31c29"}, + {file = "rapidfuzz-3.9.7-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:3492c7a42b7fa9f0051d7fcce9893e95ed91c97c9ec7fb64346f3e070dd318ed"}, + {file = "rapidfuzz-3.9.7-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:ece45eb2af8b00f90d10f7419322e8804bd42fb1129026f9bfe712c37508b514"}, + {file = "rapidfuzz-3.9.7-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9dcd14cf4876f04b488f6e54a7abd3e9b31db5f5a6aba0ce90659917aaa8c088"}, + {file = "rapidfuzz-3.9.7-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:521c58c72ed8a612b25cda378ff10dee17e6deb4ee99a070b723519a345527b9"}, + {file = "rapidfuzz-3.9.7-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18669bb6cdf7d40738526d37e550df09ba065b5a7560f3d802287988b6cb63cf"}, + {file = "rapidfuzz-3.9.7-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:7abe2dbae81120a64bb4f8d3fcafe9122f328c9f86d7f327f174187a5af4ed86"}, + {file = "rapidfuzz-3.9.7-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:a3c0783910911f4f24655826d007c9f4360f08107410952c01ee3df98c713eb2"}, + {file = "rapidfuzz-3.9.7-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:03126f9a040ff21d2a110610bfd6b93b79377ce8b4121edcb791d61b7df6eec5"}, + {file = "rapidfuzz-3.9.7-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:591908240f4085e2ade5b685c6e8346e2ed44932cffeaac2fb32ddac95b55c7f"}, + {file = "rapidfuzz-3.9.7-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e9012d86c6397edbc9da4ac0132de7f8ee9d6ce857f4194d5684c4ddbcdd1c5c"}, + {file = "rapidfuzz-3.9.7-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df596ddd3db38aa513d4c0995611267b3946e7cbe5a8761b50e9306dfec720ee"}, + {file = "rapidfuzz-3.9.7-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:3ed5adb752f4308fcc8f4fb6f8eb7aa4082f9d12676fda0a74fa5564242a8107"}, + {file = "rapidfuzz-3.9.7.tar.gz", hash = "sha256:f1c7296534c1afb6f495aa95871f14ccdc197c6db42965854e483100df313030"}, ] [package.extras] @@ -7629,13 +7831,13 @@ requests = "2.31.0" [[package]] name = "rich" -version = "13.7.1" +version = "13.8.0" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" optional = false python-versions = ">=3.7.0" files = [ - {file = "rich-13.7.1-py3-none-any.whl", hash = "sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222"}, - {file = "rich-13.7.1.tar.gz", hash = "sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432"}, + {file = "rich-13.8.0-py3-none-any.whl", hash = "sha256:2e85306a063b9492dffc86278197a60cbece75bcb766022f3436f567cae11bdc"}, + {file = "rich-13.8.0.tar.gz", hash = "sha256:a5ac1f1cd448ade0d59cc3356f7db7a7ccda2c8cbae9c7a90c28ff463d3e91f4"}, ] [package.dependencies] @@ -7773,29 +7975,29 @@ pyasn1 = ">=0.1.3" [[package]] name = "ruff" -version = "0.6.1" +version = "0.6.4" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.6.1-py3-none-linux_armv6l.whl", hash = "sha256:b4bb7de6a24169dc023f992718a9417380301b0c2da0fe85919f47264fb8add9"}, - {file = "ruff-0.6.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:45efaae53b360c81043e311cdec8a7696420b3d3e8935202c2846e7a97d4edae"}, - {file = "ruff-0.6.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:bc60c7d71b732c8fa73cf995efc0c836a2fd8b9810e115be8babb24ae87e0850"}, - {file = "ruff-0.6.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c7477c3b9da822e2db0b4e0b59e61b8a23e87886e727b327e7dcaf06213c5cf"}, - {file = "ruff-0.6.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3a0af7ab3f86e3dc9f157a928e08e26c4b40707d0612b01cd577cc84b8905cc9"}, - {file = "ruff-0.6.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:392688dbb50fecf1bf7126731c90c11a9df1c3a4cdc3f481b53e851da5634fa5"}, - {file = "ruff-0.6.1-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:5278d3e095ccc8c30430bcc9bc550f778790acc211865520f3041910a28d0024"}, - {file = "ruff-0.6.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fe6d5f65d6f276ee7a0fc50a0cecaccb362d30ef98a110f99cac1c7872df2f18"}, - {file = "ruff-0.6.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b2e0dd11e2ae553ee5c92a81731d88a9883af8db7408db47fc81887c1f8b672e"}, - {file = "ruff-0.6.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d812615525a34ecfc07fd93f906ef5b93656be01dfae9a819e31caa6cfe758a1"}, - {file = "ruff-0.6.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:faaa4060f4064c3b7aaaa27328080c932fa142786f8142aff095b42b6a2eb631"}, - {file = "ruff-0.6.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:99d7ae0df47c62729d58765c593ea54c2546d5de213f2af2a19442d50a10cec9"}, - {file = "ruff-0.6.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9eb18dfd7b613eec000e3738b3f0e4398bf0153cb80bfa3e351b3c1c2f6d7b15"}, - {file = "ruff-0.6.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:c62bc04c6723a81e25e71715aa59489f15034d69bf641df88cb38bdc32fd1dbb"}, - {file = "ruff-0.6.1-py3-none-win32.whl", hash = "sha256:9fb4c4e8b83f19c9477a8745e56d2eeef07a7ff50b68a6998f7d9e2e3887bdc4"}, - {file = "ruff-0.6.1-py3-none-win_amd64.whl", hash = "sha256:c2ebfc8f51ef4aca05dad4552bbcf6fe8d1f75b2f6af546cc47cc1c1ca916b5b"}, - {file = "ruff-0.6.1-py3-none-win_arm64.whl", hash = "sha256:3bc81074971b0ffad1bd0c52284b22411f02a11a012082a76ac6da153536e014"}, - {file = "ruff-0.6.1.tar.gz", hash = "sha256:af3ffd8c6563acb8848d33cd19a69b9bfe943667f0419ca083f8ebe4224a3436"}, + {file = "ruff-0.6.4-py3-none-linux_armv6l.whl", hash = "sha256:c4b153fc152af51855458e79e835fb6b933032921756cec9af7d0ba2aa01a258"}, + {file = "ruff-0.6.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:bedff9e4f004dad5f7f76a9d39c4ca98af526c9b1695068198b3bda8c085ef60"}, + {file = "ruff-0.6.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d02a4127a86de23002e694d7ff19f905c51e338c72d8e09b56bfb60e1681724f"}, + {file = "ruff-0.6.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7862f42fc1a4aca1ea3ffe8a11f67819d183a5693b228f0bb3a531f5e40336fc"}, + {file = "ruff-0.6.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eebe4ff1967c838a1a9618a5a59a3b0a00406f8d7eefee97c70411fefc353617"}, + {file = "ruff-0.6.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:932063a03bac394866683e15710c25b8690ccdca1cf192b9a98260332ca93408"}, + {file = "ruff-0.6.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:50e30b437cebef547bd5c3edf9ce81343e5dd7c737cb36ccb4fe83573f3d392e"}, + {file = "ruff-0.6.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c44536df7b93a587de690e124b89bd47306fddd59398a0fb12afd6133c7b3818"}, + {file = "ruff-0.6.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0ea086601b22dc5e7693a78f3fcfc460cceabfdf3bdc36dc898792aba48fbad6"}, + {file = "ruff-0.6.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b52387d3289ccd227b62102c24714ed75fbba0b16ecc69a923a37e3b5e0aaaa"}, + {file = "ruff-0.6.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:0308610470fcc82969082fc83c76c0d362f562e2f0cdab0586516f03a4e06ec6"}, + {file = "ruff-0.6.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:803b96dea21795a6c9d5bfa9e96127cc9c31a1987802ca68f35e5c95aed3fc0d"}, + {file = "ruff-0.6.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:66dbfea86b663baab8fcae56c59f190caba9398df1488164e2df53e216248baa"}, + {file = "ruff-0.6.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:34d5efad480193c046c86608dbba2bccdc1c5fd11950fb271f8086e0c763a5d1"}, + {file = "ruff-0.6.4-py3-none-win32.whl", hash = "sha256:f0f8968feea5ce3777c0d8365653d5e91c40c31a81d95824ba61d871a11b8523"}, + {file = "ruff-0.6.4-py3-none-win_amd64.whl", hash = "sha256:549daccee5227282289390b0222d0fbee0275d1db6d514550d65420053021a58"}, + {file = "ruff-0.6.4-py3-none-win_arm64.whl", hash = "sha256:ac4b75e898ed189b3708c9ab3fc70b79a433219e1e87193b4f2b77251d058d14"}, + {file = "ruff-0.6.4.tar.gz", hash = "sha256:ac3b5bfbee99973f80aa1b7cbd1c9cbce200883bdd067300c22a6cc1c7fba212"}, ] [[package]] @@ -7817,121 +8019,121 @@ crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"] [[package]] name = "safetensors" -version = "0.4.4" +version = "0.4.5" description = "" optional = false python-versions = ">=3.7" files = [ - {file = "safetensors-0.4.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2adb497ada13097f30e386e88c959c0fda855a5f6f98845710f5bb2c57e14f12"}, - {file = "safetensors-0.4.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7db7fdc2d71fd1444d85ca3f3d682ba2df7d61a637dfc6d80793f439eae264ab"}, - {file = "safetensors-0.4.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8d4f0eed76b430f009fbefca1a0028ddb112891b03cb556d7440d5cd68eb89a9"}, - {file = "safetensors-0.4.4-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:57d216fab0b5c432aabf7170883d7c11671622bde8bd1436c46d633163a703f6"}, - {file = "safetensors-0.4.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7d9b76322e49c056bcc819f8bdca37a2daa5a6d42c07f30927b501088db03309"}, - {file = "safetensors-0.4.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:32f0d1f6243e90ee43bc6ee3e8c30ac5b09ca63f5dd35dbc985a1fc5208c451a"}, - {file = "safetensors-0.4.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44d464bdc384874601a177375028012a5f177f1505279f9456fea84bbc575c7f"}, - {file = "safetensors-0.4.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:63144e36209ad8e4e65384dbf2d52dd5b1866986079c00a72335402a38aacdc5"}, - {file = "safetensors-0.4.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:051d5ecd490af7245258000304b812825974d5e56f14a3ff7e1b8b2ba6dc2ed4"}, - {file = "safetensors-0.4.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:51bc8429d9376224cd3cf7e8ce4f208b4c930cd10e515b6ac6a72cbc3370f0d9"}, - {file = "safetensors-0.4.4-cp310-none-win32.whl", hash = "sha256:fb7b54830cee8cf9923d969e2df87ce20e625b1af2fd194222ab902d3adcc29c"}, - {file = "safetensors-0.4.4-cp310-none-win_amd64.whl", hash = "sha256:4b3e8aa8226d6560de8c2b9d5ff8555ea482599c670610758afdc97f3e021e9c"}, - {file = "safetensors-0.4.4-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:bbaa31f2cb49013818bde319232ccd72da62ee40f7d2aa532083eda5664e85ff"}, - {file = "safetensors-0.4.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9fdcb80f4e9fbb33b58e9bf95e7dbbedff505d1bcd1c05f7c7ce883632710006"}, - {file = "safetensors-0.4.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55c14c20be247b8a1aeaf3ab4476265e3ca83096bb8e09bb1a7aa806088def4f"}, - {file = "safetensors-0.4.4-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:949aaa1118660f992dbf0968487b3e3cfdad67f948658ab08c6b5762e90cc8b6"}, - {file = "safetensors-0.4.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c11a4ab7debc456326a2bac67f35ee0ac792bcf812c7562a4a28559a5c795e27"}, - {file = "safetensors-0.4.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c0cea44bba5c5601b297bc8307e4075535b95163402e4906b2e9b82788a2a6df"}, - {file = "safetensors-0.4.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9d752c97f6bbe327352f76e5b86442d776abc789249fc5e72eacb49e6916482"}, - {file = "safetensors-0.4.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:03f2bb92e61b055ef6cc22883ad1ae898010a95730fa988c60a23800eb742c2c"}, - {file = "safetensors-0.4.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:87bf3f91a9328a941acc44eceffd4e1f5f89b030985b2966637e582157173b98"}, - {file = "safetensors-0.4.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:20d218ec2b6899d29d6895419a58b6e44cc5ff8f0cc29fac8d236a8978ab702e"}, - {file = "safetensors-0.4.4-cp311-none-win32.whl", hash = "sha256:8079486118919f600c603536e2490ca37b3dbd3280e3ad6eaacfe6264605ac8a"}, - {file = "safetensors-0.4.4-cp311-none-win_amd64.whl", hash = "sha256:2f8c2eb0615e2e64ee27d478c7c13f51e5329d7972d9e15528d3e4cfc4a08f0d"}, - {file = "safetensors-0.4.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:baec5675944b4a47749c93c01c73d826ef7d42d36ba8d0dba36336fa80c76426"}, - {file = "safetensors-0.4.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f15117b96866401825f3e94543145028a2947d19974429246ce59403f49e77c6"}, - {file = "safetensors-0.4.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a13a9caea485df164c51be4eb0c87f97f790b7c3213d635eba2314d959fe929"}, - {file = "safetensors-0.4.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6b54bc4ca5f9b9bba8cd4fb91c24b2446a86b5ae7f8975cf3b7a277353c3127c"}, - {file = "safetensors-0.4.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:08332c22e03b651c8eb7bf5fc2de90044f3672f43403b3d9ac7e7e0f4f76495e"}, - {file = "safetensors-0.4.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bb62841e839ee992c37bb75e75891c7f4904e772db3691c59daaca5b4ab960e1"}, - {file = "safetensors-0.4.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e5b927acc5f2f59547270b0309a46d983edc44be64e1ca27a7fcb0474d6cd67"}, - {file = "safetensors-0.4.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2a69c71b1ae98a8021a09a0b43363b0143b0ce74e7c0e83cacba691b62655fb8"}, - {file = "safetensors-0.4.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:23654ad162c02a5636f0cd520a0310902c4421aab1d91a0b667722a4937cc445"}, - {file = "safetensors-0.4.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:0677c109d949cf53756859160b955b2e75b0eefe952189c184d7be30ecf7e858"}, - {file = "safetensors-0.4.4-cp312-none-win32.whl", hash = "sha256:a51d0ddd4deb8871c6de15a772ef40b3dbd26a3c0451bb9e66bc76fc5a784e5b"}, - {file = "safetensors-0.4.4-cp312-none-win_amd64.whl", hash = "sha256:2d065059e75a798bc1933c293b68d04d79b586bb7f8c921e0ca1e82759d0dbb1"}, - {file = "safetensors-0.4.4-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:9d625692578dd40a112df30c02a1adf068027566abd8e6a74893bb13d441c150"}, - {file = "safetensors-0.4.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7cabcf39c81e5b988d0adefdaea2eb9b4fd9bd62d5ed6559988c62f36bfa9a89"}, - {file = "safetensors-0.4.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8359bef65f49d51476e9811d59c015f0ddae618ee0e44144f5595278c9f8268c"}, - {file = "safetensors-0.4.4-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1a32c662e7df9226fd850f054a3ead0e4213a96a70b5ce37b2d26ba27004e013"}, - {file = "safetensors-0.4.4-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c329a4dcc395364a1c0d2d1574d725fe81a840783dda64c31c5a60fc7d41472c"}, - {file = "safetensors-0.4.4-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:239ee093b1db877c9f8fe2d71331a97f3b9c7c0d3ab9f09c4851004a11f44b65"}, - {file = "safetensors-0.4.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bd574145d930cf9405a64f9923600879a5ce51d9f315443a5f706374841327b6"}, - {file = "safetensors-0.4.4-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f6784eed29f9e036acb0b7769d9e78a0dc2c72c2d8ba7903005350d817e287a4"}, - {file = "safetensors-0.4.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:65a4a6072436bf0a4825b1c295d248cc17e5f4651e60ee62427a5bcaa8622a7a"}, - {file = "safetensors-0.4.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:df81e3407630de060ae8313da49509c3caa33b1a9415562284eaf3d0c7705f9f"}, - {file = "safetensors-0.4.4-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:e4a0f374200e8443d9746e947ebb346c40f83a3970e75a685ade0adbba5c48d9"}, - {file = "safetensors-0.4.4-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:181fb5f3dee78dae7fd7ec57d02e58f7936498d587c6b7c1c8049ef448c8d285"}, - {file = "safetensors-0.4.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2cb4ac1d8f6b65ec84ddfacd275079e89d9df7c92f95675ba96c4f790a64df6e"}, - {file = "safetensors-0.4.4-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:76897944cd9239e8a70955679b531b9a0619f76e25476e57ed373322d9c2075d"}, - {file = "safetensors-0.4.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2a9e9d1a27e51a0f69e761a3d581c3af46729ec1c988fa1f839e04743026ae35"}, - {file = "safetensors-0.4.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:005ef9fc0f47cb9821c40793eb029f712e97278dae84de91cb2b4809b856685d"}, - {file = "safetensors-0.4.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26987dac3752688c696c77c3576f951dbbdb8c57f0957a41fb6f933cf84c0b62"}, - {file = "safetensors-0.4.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c05270b290acd8d249739f40d272a64dd597d5a4b90f27d830e538bc2549303c"}, - {file = "safetensors-0.4.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:068d3a33711fc4d93659c825a04480ff5a3854e1d78632cdc8f37fee917e8a60"}, - {file = "safetensors-0.4.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:063421ef08ca1021feea8b46951251b90ae91f899234dd78297cbe7c1db73b99"}, - {file = "safetensors-0.4.4-cp37-none-win32.whl", hash = "sha256:d52f5d0615ea83fd853d4e1d8acf93cc2e0223ad4568ba1e1f6ca72e94ea7b9d"}, - {file = "safetensors-0.4.4-cp37-none-win_amd64.whl", hash = "sha256:88a5ac3280232d4ed8e994cbc03b46a1807ce0aa123867b40c4a41f226c61f94"}, - {file = "safetensors-0.4.4-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:3467ab511bfe3360967d7dc53b49f272d59309e57a067dd2405b4d35e7dcf9dc"}, - {file = "safetensors-0.4.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2ab4c96d922e53670ce25fbb9b63d5ea972e244de4fa1dd97b590d9fd66aacef"}, - {file = "safetensors-0.4.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:87df18fce4440477c3ef1fd7ae17c704a69a74a77e705a12be135ee0651a0c2d"}, - {file = "safetensors-0.4.4-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0e5fe345b2bc7d88587149ac11def1f629d2671c4c34f5df38aed0ba59dc37f8"}, - {file = "safetensors-0.4.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9f1a3e01dce3cd54060791e7e24588417c98b941baa5974700eeb0b8eb65b0a0"}, - {file = "safetensors-0.4.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1c6bf35e9a8998d8339fd9a05ac4ce465a4d2a2956cc0d837b67c4642ed9e947"}, - {file = "safetensors-0.4.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:166c0c52f6488b8538b2a9f3fbc6aad61a7261e170698779b371e81b45f0440d"}, - {file = "safetensors-0.4.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:87e9903b8668a16ef02c08ba4ebc91e57a49c481e9b5866e31d798632805014b"}, - {file = "safetensors-0.4.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:a9c421153aa23c323bd8483d4155b4eee82c9a50ac11cccd83539104a8279c64"}, - {file = "safetensors-0.4.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a4b8617499b2371c7353302c5116a7e0a3a12da66389ce53140e607d3bf7b3d3"}, - {file = "safetensors-0.4.4-cp38-none-win32.whl", hash = "sha256:c6280f5aeafa1731f0a3709463ab33d8e0624321593951aefada5472f0b313fd"}, - {file = "safetensors-0.4.4-cp38-none-win_amd64.whl", hash = "sha256:6ceed6247fc2d33b2a7b7d25d8a0fe645b68798856e0bc7a9800c5fd945eb80f"}, - {file = "safetensors-0.4.4-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:5cf6c6f6193797372adf50c91d0171743d16299491c75acad8650107dffa9269"}, - {file = "safetensors-0.4.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:419010156b914a3e5da4e4adf992bee050924d0fe423c4b329e523e2c14c3547"}, - {file = "safetensors-0.4.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88f6fd5a5c1302ce79993cc5feeadcc795a70f953c762544d01fb02b2db4ea33"}, - {file = "safetensors-0.4.4-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d468cffb82d90789696d5b4d8b6ab8843052cba58a15296691a7a3df55143cd2"}, - {file = "safetensors-0.4.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9353c2af2dd467333d4850a16edb66855e795561cd170685178f706c80d2c71e"}, - {file = "safetensors-0.4.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:83c155b4a33368d9b9c2543e78f2452090fb030c52401ca608ef16fa58c98353"}, - {file = "safetensors-0.4.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9850754c434e636ce3dc586f534bb23bcbd78940c304775bee9005bf610e98f1"}, - {file = "safetensors-0.4.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:275f500b4d26f67b6ec05629a4600645231bd75e4ed42087a7c1801bff04f4b3"}, - {file = "safetensors-0.4.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5c2308de665b7130cd0e40a2329278226e4cf083f7400c51ca7e19ccfb3886f3"}, - {file = "safetensors-0.4.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e06a9ebc8656e030ccfe44634f2a541b4b1801cd52e390a53ad8bacbd65f8518"}, - {file = "safetensors-0.4.4-cp39-none-win32.whl", hash = "sha256:ef73df487b7c14b477016947c92708c2d929e1dee2bacdd6fff5a82ed4539537"}, - {file = "safetensors-0.4.4-cp39-none-win_amd64.whl", hash = "sha256:83d054818a8d1198d8bd8bc3ea2aac112a2c19def2bf73758321976788706398"}, - {file = "safetensors-0.4.4-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:1d1f34c71371f0e034004a0b583284b45d233dd0b5f64a9125e16b8a01d15067"}, - {file = "safetensors-0.4.4-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1a8043a33d58bc9b30dfac90f75712134ca34733ec3d8267b1bd682afe7194f5"}, - {file = "safetensors-0.4.4-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8db8f0c59c84792c12661f8efa85de160f80efe16b87a9d5de91b93f9e0bce3c"}, - {file = "safetensors-0.4.4-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cfc1fc38e37630dd12d519bdec9dcd4b345aec9930bb9ce0ed04461f49e58b52"}, - {file = "safetensors-0.4.4-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e5c9d86d9b13b18aafa88303e2cd21e677f5da2a14c828d2c460fe513af2e9a5"}, - {file = "safetensors-0.4.4-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:43251d7f29a59120a26f5a0d9583b9e112999e500afabcfdcb91606d3c5c89e3"}, - {file = "safetensors-0.4.4-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:2c42e9b277513b81cf507e6121c7b432b3235f980cac04f39f435b7902857f91"}, - {file = "safetensors-0.4.4-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3daacc9a4e3f428a84dd56bf31f20b768eb0b204af891ed68e1f06db9edf546f"}, - {file = "safetensors-0.4.4-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:218bbb9b883596715fc9997bb42470bf9f21bb832c3b34c2bf744d6fa8f2bbba"}, - {file = "safetensors-0.4.4-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7bd5efc26b39f7fc82d4ab1d86a7f0644c8e34f3699c33f85bfa9a717a030e1b"}, - {file = "safetensors-0.4.4-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:56ad9776b65d8743f86698a1973292c966cf3abff627efc44ed60e66cc538ddd"}, - {file = "safetensors-0.4.4-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:30f23e6253c5f43a809dea02dc28a9f5fa747735dc819f10c073fe1b605e97d4"}, - {file = "safetensors-0.4.4-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:5512078d00263de6cb04e9d26c9ae17611098f52357fea856213e38dc462f81f"}, - {file = "safetensors-0.4.4-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b96c3d9266439d17f35fc2173111d93afc1162f168e95aed122c1ca517b1f8f1"}, - {file = "safetensors-0.4.4-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:08d464aa72a9a13826946b4fb9094bb4b16554bbea2e069e20bd903289b6ced9"}, - {file = "safetensors-0.4.4-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:210160816d5a36cf41f48f38473b6f70d7bcb4b0527bedf0889cc0b4c3bb07db"}, - {file = "safetensors-0.4.4-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eb276a53717f2bcfb6df0bcf284d8a12069002508d4c1ca715799226024ccd45"}, - {file = "safetensors-0.4.4-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a2c28c6487f17d8db0089e8b2cdc13de859366b94cc6cdc50e1b0a4147b56551"}, - {file = "safetensors-0.4.4-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:7915f0c60e4e6e65d90f136d85dd3b429ae9191c36b380e626064694563dbd9f"}, - {file = "safetensors-0.4.4-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:00eea99ae422fbfa0b46065acbc58b46bfafadfcec179d4b4a32d5c45006af6c"}, - {file = "safetensors-0.4.4-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:bb1ed4fcb0b3c2f3ea2c5767434622fe5d660e5752f21ac2e8d737b1e5e480bb"}, - {file = "safetensors-0.4.4-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:73fc9a0a4343188bdb421783e600bfaf81d0793cd4cce6bafb3c2ed567a74cd5"}, - {file = "safetensors-0.4.4-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c37e6b714200824c73ca6eaf007382de76f39466a46e97558b8dc4cf643cfbf"}, - {file = "safetensors-0.4.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f75698c5c5c542417ac4956acfc420f7d4a2396adca63a015fd66641ea751759"}, - {file = "safetensors-0.4.4-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ca1a209157f242eb183e209040097118472e169f2e069bfbd40c303e24866543"}, - {file = "safetensors-0.4.4-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:177f2b60a058f92a3cec7a1786c9106c29eca8987ecdfb79ee88126e5f47fa31"}, - {file = "safetensors-0.4.4-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:ee9622e84fe6e4cd4f020e5fda70d6206feff3157731df7151d457fdae18e541"}, - {file = "safetensors-0.4.4.tar.gz", hash = "sha256:5fe3e9b705250d0172ed4e100a811543108653fb2b66b9e702a088ad03772a07"}, + {file = "safetensors-0.4.5-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:a63eaccd22243c67e4f2b1c3e258b257effc4acd78f3b9d397edc8cf8f1298a7"}, + {file = "safetensors-0.4.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:23fc9b4ec7b602915cbb4ec1a7c1ad96d2743c322f20ab709e2c35d1b66dad27"}, + {file = "safetensors-0.4.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6885016f34bef80ea1085b7e99b3c1f92cb1be78a49839203060f67b40aee761"}, + {file = "safetensors-0.4.5-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:133620f443450429322f238fda74d512c4008621227fccf2f8cf4a76206fea7c"}, + {file = "safetensors-0.4.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4fb3e0609ec12d2a77e882f07cced530b8262027f64b75d399f1504ffec0ba56"}, + {file = "safetensors-0.4.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d0f1dd769f064adc33831f5e97ad07babbd728427f98e3e1db6902e369122737"}, + {file = "safetensors-0.4.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6d156bdb26732feada84f9388a9f135528c1ef5b05fae153da365ad4319c4c5"}, + {file = "safetensors-0.4.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9e347d77e2c77eb7624400ccd09bed69d35c0332f417ce8c048d404a096c593b"}, + {file = "safetensors-0.4.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9f556eea3aec1d3d955403159fe2123ddd68e880f83954ee9b4a3f2e15e716b6"}, + {file = "safetensors-0.4.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9483f42be3b6bc8ff77dd67302de8ae411c4db39f7224dec66b0eb95822e4163"}, + {file = "safetensors-0.4.5-cp310-none-win32.whl", hash = "sha256:7389129c03fadd1ccc37fd1ebbc773f2b031483b04700923c3511d2a939252cc"}, + {file = "safetensors-0.4.5-cp310-none-win_amd64.whl", hash = "sha256:e98ef5524f8b6620c8cdef97220c0b6a5c1cef69852fcd2f174bb96c2bb316b1"}, + {file = "safetensors-0.4.5-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:21f848d7aebd5954f92538552d6d75f7c1b4500f51664078b5b49720d180e47c"}, + {file = "safetensors-0.4.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bb07000b19d41e35eecef9a454f31a8b4718a185293f0d0b1c4b61d6e4487971"}, + {file = "safetensors-0.4.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09dedf7c2fda934ee68143202acff6e9e8eb0ddeeb4cfc24182bef999efa9f42"}, + {file = "safetensors-0.4.5-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:59b77e4b7a708988d84f26de3ebead61ef1659c73dcbc9946c18f3b1786d2688"}, + {file = "safetensors-0.4.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5d3bc83e14d67adc2e9387e511097f254bd1b43c3020440e708858c684cbac68"}, + {file = "safetensors-0.4.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:39371fc551c1072976073ab258c3119395294cf49cdc1f8476794627de3130df"}, + {file = "safetensors-0.4.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a6c19feda32b931cae0acd42748a670bdf56bee6476a046af20181ad3fee4090"}, + {file = "safetensors-0.4.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a659467495de201e2f282063808a41170448c78bada1e62707b07a27b05e6943"}, + {file = "safetensors-0.4.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:bad5e4b2476949bcd638a89f71b6916fa9a5cae5c1ae7eede337aca2100435c0"}, + {file = "safetensors-0.4.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a3a315a6d0054bc6889a17f5668a73f94f7fe55121ff59e0a199e3519c08565f"}, + {file = "safetensors-0.4.5-cp311-none-win32.whl", hash = "sha256:a01e232e6d3d5cf8b1667bc3b657a77bdab73f0743c26c1d3c5dd7ce86bd3a92"}, + {file = "safetensors-0.4.5-cp311-none-win_amd64.whl", hash = "sha256:cbd39cae1ad3e3ef6f63a6f07296b080c951f24cec60188378e43d3713000c04"}, + {file = "safetensors-0.4.5-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:473300314e026bd1043cef391bb16a8689453363381561b8a3e443870937cc1e"}, + {file = "safetensors-0.4.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:801183a0f76dc647f51a2d9141ad341f9665602a7899a693207a82fb102cc53e"}, + {file = "safetensors-0.4.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1524b54246e422ad6fb6aea1ac71edeeb77666efa67230e1faf6999df9b2e27f"}, + {file = "safetensors-0.4.5-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b3139098e3e8b2ad7afbca96d30ad29157b50c90861084e69fcb80dec7430461"}, + {file = "safetensors-0.4.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65573dc35be9059770808e276b017256fa30058802c29e1038eb1c00028502ea"}, + {file = "safetensors-0.4.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fd33da8e9407559f8779c82a0448e2133737f922d71f884da27184549416bfed"}, + {file = "safetensors-0.4.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3685ce7ed036f916316b567152482b7e959dc754fcc4a8342333d222e05f407c"}, + {file = "safetensors-0.4.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:dde2bf390d25f67908278d6f5d59e46211ef98e44108727084d4637ee70ab4f1"}, + {file = "safetensors-0.4.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7469d70d3de970b1698d47c11ebbf296a308702cbaae7fcb993944751cf985f4"}, + {file = "safetensors-0.4.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3a6ba28118636a130ccbb968bc33d4684c48678695dba2590169d5ab03a45646"}, + {file = "safetensors-0.4.5-cp312-none-win32.whl", hash = "sha256:c859c7ed90b0047f58ee27751c8e56951452ed36a67afee1b0a87847d065eec6"}, + {file = "safetensors-0.4.5-cp312-none-win_amd64.whl", hash = "sha256:b5a8810ad6a6f933fff6c276eae92c1da217b39b4d8b1bc1c0b8af2d270dc532"}, + {file = "safetensors-0.4.5-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:25e5f8e2e92a74f05b4ca55686234c32aac19927903792b30ee6d7bd5653d54e"}, + {file = "safetensors-0.4.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:81efb124b58af39fcd684254c645e35692fea81c51627259cdf6d67ff4458916"}, + {file = "safetensors-0.4.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:585f1703a518b437f5103aa9cf70e9bd437cb78eea9c51024329e4fb8a3e3679"}, + {file = "safetensors-0.4.5-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4b99fbf72e3faf0b2f5f16e5e3458b93b7d0a83984fe8d5364c60aa169f2da89"}, + {file = "safetensors-0.4.5-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b17b299ca9966ca983ecda1c0791a3f07f9ca6ab5ded8ef3d283fff45f6bcd5f"}, + {file = "safetensors-0.4.5-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:76ded72f69209c9780fdb23ea89e56d35c54ae6abcdec67ccb22af8e696e449a"}, + {file = "safetensors-0.4.5-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2783956926303dcfeb1de91a4d1204cd4089ab441e622e7caee0642281109db3"}, + {file = "safetensors-0.4.5-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d94581aab8c6b204def4d7320f07534d6ee34cd4855688004a4354e63b639a35"}, + {file = "safetensors-0.4.5-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:67e1e7cb8678bb1b37ac48ec0df04faf689e2f4e9e81e566b5c63d9f23748523"}, + {file = "safetensors-0.4.5-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:dbd280b07e6054ea68b0cb4b16ad9703e7d63cd6890f577cb98acc5354780142"}, + {file = "safetensors-0.4.5-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:77d9b228da8374c7262046a36c1f656ba32a93df6cc51cd4453af932011e77f1"}, + {file = "safetensors-0.4.5-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:500cac01d50b301ab7bb192353317035011c5ceeef0fca652f9f43c000bb7f8d"}, + {file = "safetensors-0.4.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:75331c0c746f03158ded32465b7d0b0e24c5a22121743662a2393439c43a45cf"}, + {file = "safetensors-0.4.5-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:670e95fe34e0d591d0529e5e59fd9d3d72bc77b1444fcaa14dccda4f36b5a38b"}, + {file = "safetensors-0.4.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:098923e2574ff237c517d6e840acada8e5b311cb1fa226019105ed82e9c3b62f"}, + {file = "safetensors-0.4.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:13ca0902d2648775089fa6a0c8fc9e6390c5f8ee576517d33f9261656f851e3f"}, + {file = "safetensors-0.4.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f0032bedc869c56f8d26259fe39cd21c5199cd57f2228d817a0e23e8370af25"}, + {file = "safetensors-0.4.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f4b15f51b4f8f2a512341d9ce3475cacc19c5fdfc5db1f0e19449e75f95c7dc8"}, + {file = "safetensors-0.4.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:f6594d130d0ad933d885c6a7b75c5183cb0e8450f799b80a39eae2b8508955eb"}, + {file = "safetensors-0.4.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:60c828a27e852ded2c85fc0f87bf1ec20e464c5cd4d56ff0e0711855cc2e17f8"}, + {file = "safetensors-0.4.5-cp37-none-win32.whl", hash = "sha256:6d3de65718b86c3eeaa8b73a9c3d123f9307a96bbd7be9698e21e76a56443af5"}, + {file = "safetensors-0.4.5-cp37-none-win_amd64.whl", hash = "sha256:5a2d68a523a4cefd791156a4174189a4114cf0bf9c50ceb89f261600f3b2b81a"}, + {file = "safetensors-0.4.5-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:e7a97058f96340850da0601a3309f3d29d6191b0702b2da201e54c6e3e44ccf0"}, + {file = "safetensors-0.4.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:63bfd425e25f5c733f572e2246e08a1c38bd6f2e027d3f7c87e2e43f228d1345"}, + {file = "safetensors-0.4.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3664ac565d0e809b0b929dae7ccd74e4d3273cd0c6d1220c6430035befb678e"}, + {file = "safetensors-0.4.5-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:313514b0b9b73ff4ddfb4edd71860696dbe3c1c9dc4d5cc13dbd74da283d2cbf"}, + {file = "safetensors-0.4.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:31fa33ee326f750a2f2134a6174773c281d9a266ccd000bd4686d8021f1f3dac"}, + {file = "safetensors-0.4.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:09566792588d77b68abe53754c9f1308fadd35c9f87be939e22c623eaacbed6b"}, + {file = "safetensors-0.4.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:309aaec9b66cbf07ad3a2e5cb8a03205663324fea024ba391594423d0f00d9fe"}, + {file = "safetensors-0.4.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:53946c5813b8f9e26103c5efff4a931cc45d874f45229edd68557ffb35ffb9f8"}, + {file = "safetensors-0.4.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:868f9df9e99ad1e7f38c52194063a982bc88fedc7d05096f4f8160403aaf4bd6"}, + {file = "safetensors-0.4.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:9cc9449bd0b0bc538bd5e268221f0c5590bc5c14c1934a6ae359d44410dc68c4"}, + {file = "safetensors-0.4.5-cp38-none-win32.whl", hash = "sha256:83c4f13a9e687335c3928f615cd63a37e3f8ef072a3f2a0599fa09f863fb06a2"}, + {file = "safetensors-0.4.5-cp38-none-win_amd64.whl", hash = "sha256:b98d40a2ffa560653f6274e15b27b3544e8e3713a44627ce268f419f35c49478"}, + {file = "safetensors-0.4.5-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:cf727bb1281d66699bef5683b04d98c894a2803442c490a8d45cd365abfbdeb2"}, + {file = "safetensors-0.4.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:96f1d038c827cdc552d97e71f522e1049fef0542be575421f7684756a748e457"}, + {file = "safetensors-0.4.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:139fbee92570ecea774e6344fee908907db79646d00b12c535f66bc78bd5ea2c"}, + {file = "safetensors-0.4.5-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c36302c1c69eebb383775a89645a32b9d266878fab619819ce660309d6176c9b"}, + {file = "safetensors-0.4.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d641f5b8149ea98deb5ffcf604d764aad1de38a8285f86771ce1abf8e74c4891"}, + {file = "safetensors-0.4.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b4db6a61d968de73722b858038c616a1bebd4a86abe2688e46ca0cc2d17558f2"}, + {file = "safetensors-0.4.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b75a616e02f21b6f1d5785b20cecbab5e2bd3f6358a90e8925b813d557666ec1"}, + {file = "safetensors-0.4.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:788ee7d04cc0e0e7f944c52ff05f52a4415b312f5efd2ee66389fb7685ee030c"}, + {file = "safetensors-0.4.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:87bc42bd04fd9ca31396d3ca0433db0be1411b6b53ac5a32b7845a85d01ffc2e"}, + {file = "safetensors-0.4.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4037676c86365a721a8c9510323a51861d703b399b78a6b4486a54a65a975fca"}, + {file = "safetensors-0.4.5-cp39-none-win32.whl", hash = "sha256:1500418454529d0ed5c1564bda376c4ddff43f30fce9517d9bee7bcce5a8ef50"}, + {file = "safetensors-0.4.5-cp39-none-win_amd64.whl", hash = "sha256:9d1a94b9d793ed8fe35ab6d5cea28d540a46559bafc6aae98f30ee0867000cab"}, + {file = "safetensors-0.4.5-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:fdadf66b5a22ceb645d5435a0be7a0292ce59648ca1d46b352f13cff3ea80410"}, + {file = "safetensors-0.4.5-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d42ffd4c2259f31832cb17ff866c111684c87bd930892a1ba53fed28370c918c"}, + {file = "safetensors-0.4.5-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd8a1f6d2063a92cd04145c7fd9e31a1c7d85fbec20113a14b487563fdbc0597"}, + {file = "safetensors-0.4.5-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:951d2fcf1817f4fb0ef0b48f6696688a4e852a95922a042b3f96aaa67eedc920"}, + {file = "safetensors-0.4.5-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6ac85d9a8c1af0e3132371d9f2d134695a06a96993c2e2f0bbe25debb9e3f67a"}, + {file = "safetensors-0.4.5-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:e3cec4a29eb7fe8da0b1c7988bc3828183080439dd559f720414450de076fcab"}, + {file = "safetensors-0.4.5-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:21742b391b859e67b26c0b2ac37f52c9c0944a879a25ad2f9f9f3cd61e7fda8f"}, + {file = "safetensors-0.4.5-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:c7db3006a4915151ce1913652e907cdede299b974641a83fbc092102ac41b644"}, + {file = "safetensors-0.4.5-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f68bf99ea970960a237f416ea394e266e0361895753df06e3e06e6ea7907d98b"}, + {file = "safetensors-0.4.5-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8158938cf3324172df024da511839d373c40fbfaa83e9abf467174b2910d7b4c"}, + {file = "safetensors-0.4.5-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:540ce6c4bf6b58cb0fd93fa5f143bc0ee341c93bb4f9287ccd92cf898cc1b0dd"}, + {file = "safetensors-0.4.5-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:bfeaa1a699c6b9ed514bd15e6a91e74738b71125a9292159e3d6b7f0a53d2cde"}, + {file = "safetensors-0.4.5-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:01c8f00da537af711979e1b42a69a8ec9e1d7112f208e0e9b8a35d2c381085ef"}, + {file = "safetensors-0.4.5-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:a0dd565f83b30f2ca79b5d35748d0d99dd4b3454f80e03dfb41f0038e3bdf180"}, + {file = "safetensors-0.4.5-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:023b6e5facda76989f4cba95a861b7e656b87e225f61811065d5c501f78cdb3f"}, + {file = "safetensors-0.4.5-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9633b663393d5796f0b60249549371e392b75a0b955c07e9c6f8708a87fc841f"}, + {file = "safetensors-0.4.5-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78dd8adfb48716233c45f676d6e48534d34b4bceb50162c13d1f0bdf6f78590a"}, + {file = "safetensors-0.4.5-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8e8deb16c4321d61ae72533b8451ec4a9af8656d1c61ff81aa49f966406e4b68"}, + {file = "safetensors-0.4.5-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:52452fa5999dc50c4decaf0c53aa28371f7f1e0fe5c2dd9129059fbe1e1599c7"}, + {file = "safetensors-0.4.5-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:d5f23198821e227cfc52d50fa989813513db381255c6d100927b012f0cfec63d"}, + {file = "safetensors-0.4.5-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f4beb84b6073b1247a773141a6331117e35d07134b3bb0383003f39971d414bb"}, + {file = "safetensors-0.4.5-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:68814d599d25ed2fdd045ed54d370d1d03cf35e02dce56de44c651f828fb9b7b"}, + {file = "safetensors-0.4.5-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f0b6453c54c57c1781292c46593f8a37254b8b99004c68d6c3ce229688931a22"}, + {file = "safetensors-0.4.5-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:adaa9c6dead67e2dd90d634f89131e43162012479d86e25618e821a03d1eb1dc"}, + {file = "safetensors-0.4.5-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:73e7d408e9012cd17511b382b43547850969c7979efc2bc353f317abaf23c84c"}, + {file = "safetensors-0.4.5-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:775409ce0fcc58b10773fdb4221ed1eb007de10fe7adbdf8f5e8a56096b6f0bc"}, + {file = "safetensors-0.4.5-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:834001bed193e4440c4a3950a31059523ee5090605c907c66808664c932b549c"}, + {file = "safetensors-0.4.5.tar.gz", hash = "sha256:d73de19682deabb02524b3d5d1f8b3aaba94c72f1bbfc7911b9b9d5d391c0310"}, ] [package.extras] @@ -7947,6 +8149,84 @@ tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"] testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools-rust (>=1.5.2)"] torch = ["safetensors[numpy]", "torch (>=1.10)"] +[[package]] +name = "sagemaker" +version = "2.231.0" +description = "Open source library for training and deploying models on Amazon SageMaker." +optional = false +python-versions = ">=3.8" +files = [ + {file = "sagemaker-2.231.0-py3-none-any.whl", hash = "sha256:5b6d84484a58c6ac8b22af42c6c5e0ea3c5f42d719345fe6aafba42f93635000"}, + {file = "sagemaker-2.231.0.tar.gz", hash = "sha256:d49ee9c35725832dd9810708938af723201b831e82924a3a6ac1c4260a3d8239"}, +] + +[package.dependencies] +attrs = ">=23.1.0,<24" +boto3 = ">=1.34.142,<2.0" +cloudpickle = "2.2.1" +docker = "*" +google-pasta = "*" +importlib-metadata = ">=1.4.0,<7.0" +jsonschema = "*" +numpy = ">=1.9.0,<2.0" +packaging = ">=20.0" +pandas = "*" +pathos = "*" +platformdirs = "*" +protobuf = ">=3.12,<5.0" +psutil = "*" +pyyaml = ">=6.0,<7.0" +requests = "*" +sagemaker-core = ">=1.0.0,<2.0.0" +schema = "*" +smdebug-rulesconfig = "1.0.1" +tblib = ">=1.7.0,<4" +tqdm = "*" +urllib3 = ">=1.26.8,<3.0.0" + +[package.extras] +all = ["accelerate (>=0.24.1,<=0.27.0)", "docker (>=5.0.2,<8.0.0)", "fastapi (>=0.111.0)", "nest-asyncio", "pyspark (==3.3.1)", "pyyaml (>=5.4.1,<7)", "sagemaker-feature-store-pyspark-3-3", "sagemaker-schema-inference-artifacts (>=0.0.5)", "scipy (==1.10.1)", "urllib3 (>=1.26.8,<3.0.0)", "uvicorn (>=0.30.1)"] +feature-processor = ["pyspark (==3.3.1)", "sagemaker-feature-store-pyspark-3-3"] +huggingface = ["accelerate (>=0.24.1,<=0.27.0)", "fastapi (>=0.111.0)", "nest-asyncio", "sagemaker-schema-inference-artifacts (>=0.0.5)", "uvicorn (>=0.30.1)"] +local = ["docker (>=5.0.2,<8.0.0)", "pyyaml (>=5.4.1,<7)", "urllib3 (>=1.26.8,<3.0.0)"] +scipy = ["scipy (==1.10.1)"] +test = ["accelerate (>=0.24.1,<=0.27.0)", "apache-airflow (==2.9.3)", "apache-airflow-providers-amazon (==7.2.1)", "attrs (>=23.1.0,<24)", "awslogs (==0.14.0)", "black (==24.3.0)", "build[virtualenv] (==1.2.1)", "cloudpickle (==2.2.1)", "contextlib2 (==21.6.0)", "coverage (>=5.2,<6.2)", "docker (>=5.0.2,<8.0.0)", "fabric (==2.6.0)", "fastapi (>=0.111.0)", "flake8 (==4.0.1)", "huggingface-hub (>=0.23.4)", "jinja2 (==3.1.4)", "mlflow (>=2.12.2,<2.13)", "mock (==4.0.3)", "nbformat (>=5.9,<6)", "nest-asyncio", "numpy (>=1.24.0)", "onnx (>=1.15.0)", "pandas (>=1.3.5,<1.5)", "pillow (>=10.0.1,<=11)", "pyspark (==3.3.1)", "pytest (==6.2.5)", "pytest-cov (==3.0.0)", "pytest-rerunfailures (==10.2)", "pytest-timeout (==2.1.0)", "pytest-xdist (==2.4.0)", "pyvis (==0.2.1)", "pyyaml (==6.0)", "pyyaml (>=5.4.1,<7)", "requests (==2.32.2)", "sagemaker-experiments (==0.1.35)", "sagemaker-feature-store-pyspark-3-3", "sagemaker-schema-inference-artifacts (>=0.0.5)", "schema (==0.7.5)", "scikit-learn (==1.3.0)", "scipy (==1.10.1)", "stopit (==1.1.2)", "tensorflow (>=2.1,<=2.16)", "tox (==3.24.5)", "tritonclient[http] (<2.37.0)", "urllib3 (>=1.26.8,<3.0.0)", "uvicorn (>=0.30.1)", "xgboost (>=1.6.2,<=1.7.6)"] + +[[package]] +name = "sagemaker-core" +version = "1.0.2" +description = "An python package for sagemaker core functionalities" +optional = false +python-versions = ">=3.8" +files = [ + {file = "sagemaker_core-1.0.2-py3-none-any.whl", hash = "sha256:ce8d38a4a32efa83e4bc037a8befc7e29f87cd3eaf99acc4472b607f75a0f45a"}, + {file = "sagemaker_core-1.0.2.tar.gz", hash = "sha256:8fb942aac5e7ed928dab512ffe6facf8c6bdd4595df63c59c0bd0795ea434f8d"}, +] + +[package.dependencies] +boto3 = ">=1.34.0,<2.0.0" +importlib-metadata = ">=1.4.0,<7.0" +jsonschema = "<5.0.0" +mock = ">4.0,<5.0" +platformdirs = ">=4.0.0,<5.0.0" +pydantic = ">=1.7.0,<3.0.0" +PyYAML = ">=6.0,<7.0" +rich = ">=13.0.0,<14.0.0" + +[package.extras] +codegen = ["black (>=24.3.0,<25.0.0)", "pandas (>=2.0.0,<3.0.0)", "pylint (>=3.0.0,<4.0.0)", "pytest (>=8.0.0,<9.0.0)"] + +[[package]] +name = "schema" +version = "0.7.7" +description = "Simple data validation library" +optional = false +python-versions = "*" +files = [ + {file = "schema-0.7.7-py2.py3-none-any.whl", hash = "sha256:5d976a5b50f36e74e2157b47097b60002bd4d42e65425fcc9c9befadb4255dde"}, + {file = "schema-0.7.7.tar.gz", hash = "sha256:7da553abd2958a19dc2547c388cde53398b39196175a9be59ea1caf5ab0a1807"}, +] + [[package]] name = "scikit-learn" version = "1.5.1" @@ -8094,19 +8374,23 @@ tornado = ["tornado (>=5)"] [[package]] name = "setuptools" -version = "73.0.1" +version = "74.1.2" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-73.0.1-py3-none-any.whl", hash = "sha256:b208925fcb9f7af924ed2dc04708ea89791e24bde0d3020b27df0e116088b34e"}, - {file = "setuptools-73.0.1.tar.gz", hash = "sha256:d59a3e788ab7e012ab2c4baed1b376da6366883ee20d7a5fc426816e3d7b1193"}, + {file = "setuptools-74.1.2-py3-none-any.whl", hash = "sha256:5f4c08aa4d3ebcb57a50c33b1b07e94315d7fc7230f7115e47fc99776c8ce308"}, + {file = "setuptools-74.1.2.tar.gz", hash = "sha256:95b40ed940a1c67eb70fc099094bd6e99c6ee7c23aa2306f4d2697ba7916f9c6"}, ] [package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.5.2)"] core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.text (>=3.7)", "more-itertools (>=8.8)", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] +cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] -test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "mypy (==1.11.*)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (<0.4)", "pytest-ruff (>=0.2.1)", "pytest-ruff (>=0.3.2)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] +type = ["importlib-metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.11.*)", "pytest-mypy"] [[package]] name = "sgmllib3k" @@ -8215,6 +8499,17 @@ files = [ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] +[[package]] +name = "smdebug-rulesconfig" +version = "1.0.1" +description = "SMDebug RulesConfig" +optional = false +python-versions = ">=2.7" +files = [ + {file = "smdebug_rulesconfig-1.0.1-py2.py3-none-any.whl", hash = "sha256:104da3e6931ecf879dfc687ca4bbb3bee5ea2bc27f4478e9dbb3ee3655f1ae61"}, + {file = "smdebug_rulesconfig-1.0.1.tar.gz", hash = "sha256:7a19e6eb2e6bcfefbc07e4a86ef7a88f32495001a038bf28c7d8e77ab793fcd6"}, +] + [[package]] name = "sniffio" version = "1.3.1" @@ -8250,60 +8545,60 @@ files = [ [[package]] name = "sqlalchemy" -version = "2.0.32" +version = "2.0.34" description = "Database Abstraction Library" optional = false python-versions = ">=3.7" files = [ - {file = "SQLAlchemy-2.0.32-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0c9045ecc2e4db59bfc97b20516dfdf8e41d910ac6fb667ebd3a79ea54084619"}, - {file = "SQLAlchemy-2.0.32-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1467940318e4a860afd546ef61fefb98a14d935cd6817ed07a228c7f7c62f389"}, - {file = "SQLAlchemy-2.0.32-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5954463675cb15db8d4b521f3566a017c8789222b8316b1e6934c811018ee08b"}, - {file = "SQLAlchemy-2.0.32-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:167e7497035c303ae50651b351c28dc22a40bb98fbdb8468cdc971821b1ae533"}, - {file = "SQLAlchemy-2.0.32-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:b27dfb676ac02529fb6e343b3a482303f16e6bc3a4d868b73935b8792edb52d0"}, - {file = "SQLAlchemy-2.0.32-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:bf2360a5e0f7bd75fa80431bf8ebcfb920c9f885e7956c7efde89031695cafb8"}, - {file = "SQLAlchemy-2.0.32-cp310-cp310-win32.whl", hash = "sha256:306fe44e754a91cd9d600a6b070c1f2fadbb4a1a257b8781ccf33c7067fd3e4d"}, - {file = "SQLAlchemy-2.0.32-cp310-cp310-win_amd64.whl", hash = "sha256:99db65e6f3ab42e06c318f15c98f59a436f1c78179e6a6f40f529c8cc7100b22"}, - {file = "SQLAlchemy-2.0.32-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:21b053be28a8a414f2ddd401f1be8361e41032d2ef5884b2f31d31cb723e559f"}, - {file = "SQLAlchemy-2.0.32-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b178e875a7a25b5938b53b006598ee7645172fccafe1c291a706e93f48499ff5"}, - {file = "SQLAlchemy-2.0.32-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:723a40ee2cc7ea653645bd4cf024326dea2076673fc9d3d33f20f6c81db83e1d"}, - {file = "SQLAlchemy-2.0.32-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:295ff8689544f7ee7e819529633d058bd458c1fd7f7e3eebd0f9268ebc56c2a0"}, - {file = "SQLAlchemy-2.0.32-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:49496b68cd190a147118af585173ee624114dfb2e0297558c460ad7495f9dfe2"}, - {file = "SQLAlchemy-2.0.32-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:acd9b73c5c15f0ec5ce18128b1fe9157ddd0044abc373e6ecd5ba376a7e5d961"}, - {file = "SQLAlchemy-2.0.32-cp311-cp311-win32.whl", hash = "sha256:9365a3da32dabd3e69e06b972b1ffb0c89668994c7e8e75ce21d3e5e69ddef28"}, - {file = "SQLAlchemy-2.0.32-cp311-cp311-win_amd64.whl", hash = "sha256:8bd63d051f4f313b102a2af1cbc8b80f061bf78f3d5bd0843ff70b5859e27924"}, - {file = "SQLAlchemy-2.0.32-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6bab3db192a0c35e3c9d1560eb8332463e29e5507dbd822e29a0a3c48c0a8d92"}, - {file = "SQLAlchemy-2.0.32-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:19d98f4f58b13900d8dec4ed09dd09ef292208ee44cc9c2fe01c1f0a2fe440e9"}, - {file = "SQLAlchemy-2.0.32-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3cd33c61513cb1b7371fd40cf221256456d26a56284e7d19d1f0b9f1eb7dd7e8"}, - {file = "SQLAlchemy-2.0.32-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d6ba0497c1d066dd004e0f02a92426ca2df20fac08728d03f67f6960271feec"}, - {file = "SQLAlchemy-2.0.32-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2b6be53e4fde0065524f1a0a7929b10e9280987b320716c1509478b712a7688c"}, - {file = "SQLAlchemy-2.0.32-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:916a798f62f410c0b80b63683c8061f5ebe237b0f4ad778739304253353bc1cb"}, - {file = "SQLAlchemy-2.0.32-cp312-cp312-win32.whl", hash = "sha256:31983018b74908ebc6c996a16ad3690301a23befb643093fcfe85efd292e384d"}, - {file = "SQLAlchemy-2.0.32-cp312-cp312-win_amd64.whl", hash = "sha256:4363ed245a6231f2e2957cccdda3c776265a75851f4753c60f3004b90e69bfeb"}, - {file = "SQLAlchemy-2.0.32-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b8afd5b26570bf41c35c0121801479958b4446751a3971fb9a480c1afd85558e"}, - {file = "SQLAlchemy-2.0.32-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c750987fc876813f27b60d619b987b057eb4896b81117f73bb8d9918c14f1cad"}, - {file = "SQLAlchemy-2.0.32-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ada0102afff4890f651ed91120c1120065663506b760da4e7823913ebd3258be"}, - {file = "SQLAlchemy-2.0.32-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:78c03d0f8a5ab4f3034c0e8482cfcc415a3ec6193491cfa1c643ed707d476f16"}, - {file = "SQLAlchemy-2.0.32-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:3bd1cae7519283ff525e64645ebd7a3e0283f3c038f461ecc1c7b040a0c932a1"}, - {file = "SQLAlchemy-2.0.32-cp37-cp37m-win32.whl", hash = "sha256:01438ebcdc566d58c93af0171c74ec28efe6a29184b773e378a385e6215389da"}, - {file = "SQLAlchemy-2.0.32-cp37-cp37m-win_amd64.whl", hash = "sha256:4979dc80fbbc9d2ef569e71e0896990bc94df2b9fdbd878290bd129b65ab579c"}, - {file = "SQLAlchemy-2.0.32-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6c742be912f57586ac43af38b3848f7688863a403dfb220193a882ea60e1ec3a"}, - {file = "SQLAlchemy-2.0.32-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:62e23d0ac103bcf1c5555b6c88c114089587bc64d048fef5bbdb58dfd26f96da"}, - {file = "SQLAlchemy-2.0.32-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:251f0d1108aab8ea7b9aadbd07fb47fb8e3a5838dde34aa95a3349876b5a1f1d"}, - {file = "SQLAlchemy-2.0.32-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ef18a84e5116340e38eca3e7f9eeaaef62738891422e7c2a0b80feab165905f"}, - {file = "SQLAlchemy-2.0.32-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:3eb6a97a1d39976f360b10ff208c73afb6a4de86dd2a6212ddf65c4a6a2347d5"}, - {file = "SQLAlchemy-2.0.32-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:0c1c9b673d21477cec17ab10bc4decb1322843ba35b481585facd88203754fc5"}, - {file = "SQLAlchemy-2.0.32-cp38-cp38-win32.whl", hash = "sha256:c41a2b9ca80ee555decc605bd3c4520cc6fef9abde8fd66b1cf65126a6922d65"}, - {file = "SQLAlchemy-2.0.32-cp38-cp38-win_amd64.whl", hash = "sha256:8a37e4d265033c897892279e8adf505c8b6b4075f2b40d77afb31f7185cd6ecd"}, - {file = "SQLAlchemy-2.0.32-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:52fec964fba2ef46476312a03ec8c425956b05c20220a1a03703537824b5e8e1"}, - {file = "SQLAlchemy-2.0.32-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:328429aecaba2aee3d71e11f2477c14eec5990fb6d0e884107935f7fb6001632"}, - {file = "SQLAlchemy-2.0.32-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85a01b5599e790e76ac3fe3aa2f26e1feba56270023d6afd5550ed63c68552b3"}, - {file = "SQLAlchemy-2.0.32-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aaf04784797dcdf4c0aa952c8d234fa01974c4729db55c45732520ce12dd95b4"}, - {file = "SQLAlchemy-2.0.32-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:4488120becf9b71b3ac718f4138269a6be99a42fe023ec457896ba4f80749525"}, - {file = "SQLAlchemy-2.0.32-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:14e09e083a5796d513918a66f3d6aedbc131e39e80875afe81d98a03312889e6"}, - {file = "SQLAlchemy-2.0.32-cp39-cp39-win32.whl", hash = "sha256:0d322cc9c9b2154ba7e82f7bf25ecc7c36fbe2d82e2933b3642fc095a52cfc78"}, - {file = "SQLAlchemy-2.0.32-cp39-cp39-win_amd64.whl", hash = "sha256:7dd8583df2f98dea28b5cd53a1beac963f4f9d087888d75f22fcc93a07cf8d84"}, - {file = "SQLAlchemy-2.0.32-py3-none-any.whl", hash = "sha256:e567a8793a692451f706b363ccf3c45e056b67d90ead58c3bc9471af5d212202"}, - {file = "SQLAlchemy-2.0.32.tar.gz", hash = "sha256:c1b88cc8b02b6a5f0efb0345a03672d4c897dc7d92585176f88c67346f565ea8"}, + {file = "SQLAlchemy-2.0.34-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:95d0b2cf8791ab5fb9e3aa3d9a79a0d5d51f55b6357eecf532a120ba3b5524db"}, + {file = "SQLAlchemy-2.0.34-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:243f92596f4fd4c8bd30ab8e8dd5965afe226363d75cab2468f2c707f64cd83b"}, + {file = "SQLAlchemy-2.0.34-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9ea54f7300553af0a2a7235e9b85f4204e1fc21848f917a3213b0e0818de9a24"}, + {file = "SQLAlchemy-2.0.34-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:173f5f122d2e1bff8fbd9f7811b7942bead1f5e9f371cdf9e670b327e6703ebd"}, + {file = "SQLAlchemy-2.0.34-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:196958cde924a00488e3e83ff917be3b73cd4ed8352bbc0f2989333176d1c54d"}, + {file = "SQLAlchemy-2.0.34-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:bd90c221ed4e60ac9d476db967f436cfcecbd4ef744537c0f2d5291439848768"}, + {file = "SQLAlchemy-2.0.34-cp310-cp310-win32.whl", hash = "sha256:3166dfff2d16fe9be3241ee60ece6fcb01cf8e74dd7c5e0b64f8e19fab44911b"}, + {file = "SQLAlchemy-2.0.34-cp310-cp310-win_amd64.whl", hash = "sha256:6831a78bbd3c40f909b3e5233f87341f12d0b34a58f14115c9e94b4cdaf726d3"}, + {file = "SQLAlchemy-2.0.34-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c7db3db284a0edaebe87f8f6642c2b2c27ed85c3e70064b84d1c9e4ec06d5d84"}, + {file = "SQLAlchemy-2.0.34-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:430093fce0efc7941d911d34f75a70084f12f6ca5c15d19595c18753edb7c33b"}, + {file = "SQLAlchemy-2.0.34-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79cb400c360c7c210097b147c16a9e4c14688a6402445ac848f296ade6283bbc"}, + {file = "SQLAlchemy-2.0.34-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb1b30f31a36c7f3fee848391ff77eebdd3af5750bf95fbf9b8b5323edfdb4ec"}, + {file = "SQLAlchemy-2.0.34-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8fddde2368e777ea2a4891a3fb4341e910a056be0bb15303bf1b92f073b80c02"}, + {file = "SQLAlchemy-2.0.34-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:80bd73ea335203b125cf1d8e50fef06be709619eb6ab9e7b891ea34b5baa2287"}, + {file = "SQLAlchemy-2.0.34-cp311-cp311-win32.whl", hash = "sha256:6daeb8382d0df526372abd9cb795c992e18eed25ef2c43afe518c73f8cccb721"}, + {file = "SQLAlchemy-2.0.34-cp311-cp311-win_amd64.whl", hash = "sha256:5bc08e75ed11693ecb648b7a0a4ed80da6d10845e44be0c98c03f2f880b68ff4"}, + {file = "SQLAlchemy-2.0.34-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:53e68b091492c8ed2bd0141e00ad3089bcc6bf0e6ec4142ad6505b4afe64163e"}, + {file = "SQLAlchemy-2.0.34-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bcd18441a49499bf5528deaa9dee1f5c01ca491fc2791b13604e8f972877f812"}, + {file = "SQLAlchemy-2.0.34-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:165bbe0b376541092bf49542bd9827b048357f4623486096fc9aaa6d4e7c59a2"}, + {file = "SQLAlchemy-2.0.34-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3330415cd387d2b88600e8e26b510d0370db9b7eaf984354a43e19c40df2e2b"}, + {file = "SQLAlchemy-2.0.34-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:97b850f73f8abbffb66ccbab6e55a195a0eb655e5dc74624d15cff4bfb35bd74"}, + {file = "SQLAlchemy-2.0.34-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7cee4c6917857fd6121ed84f56d1dc78eb1d0e87f845ab5a568aba73e78adf83"}, + {file = "SQLAlchemy-2.0.34-cp312-cp312-win32.whl", hash = "sha256:fbb034f565ecbe6c530dff948239377ba859420d146d5f62f0271407ffb8c580"}, + {file = "SQLAlchemy-2.0.34-cp312-cp312-win_amd64.whl", hash = "sha256:707c8f44931a4facd4149b52b75b80544a8d824162602b8cd2fe788207307f9a"}, + {file = "SQLAlchemy-2.0.34-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:24af3dc43568f3780b7e1e57c49b41d98b2d940c1fd2e62d65d3928b6f95f021"}, + {file = "SQLAlchemy-2.0.34-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e60ed6ef0a35c6b76b7640fe452d0e47acc832ccbb8475de549a5cc5f90c2c06"}, + {file = "SQLAlchemy-2.0.34-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:413c85cd0177c23e32dee6898c67a5f49296640041d98fddb2c40888fe4daa2e"}, + {file = "SQLAlchemy-2.0.34-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:25691f4adfb9d5e796fd48bf1432272f95f4bbe5f89c475a788f31232ea6afba"}, + {file = "SQLAlchemy-2.0.34-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:526ce723265643dbc4c7efb54f56648cc30e7abe20f387d763364b3ce7506c82"}, + {file = "SQLAlchemy-2.0.34-cp37-cp37m-win32.whl", hash = "sha256:13be2cc683b76977a700948411a94c67ad8faf542fa7da2a4b167f2244781cf3"}, + {file = "SQLAlchemy-2.0.34-cp37-cp37m-win_amd64.whl", hash = "sha256:e54ef33ea80d464c3dcfe881eb00ad5921b60f8115ea1a30d781653edc2fd6a2"}, + {file = "SQLAlchemy-2.0.34-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:43f28005141165edd11fbbf1541c920bd29e167b8bbc1fb410d4fe2269c1667a"}, + {file = "SQLAlchemy-2.0.34-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b68094b165a9e930aedef90725a8fcfafe9ef95370cbb54abc0464062dbf808f"}, + {file = "SQLAlchemy-2.0.34-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a1e03db964e9d32f112bae36f0cc1dcd1988d096cfd75d6a588a3c3def9ab2b"}, + {file = "SQLAlchemy-2.0.34-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:203d46bddeaa7982f9c3cc693e5bc93db476ab5de9d4b4640d5c99ff219bee8c"}, + {file = "SQLAlchemy-2.0.34-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:ae92bebca3b1e6bd203494e5ef919a60fb6dfe4d9a47ed2453211d3bd451b9f5"}, + {file = "SQLAlchemy-2.0.34-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:9661268415f450c95f72f0ac1217cc6f10256f860eed85c2ae32e75b60278ad8"}, + {file = "SQLAlchemy-2.0.34-cp38-cp38-win32.whl", hash = "sha256:895184dfef8708e15f7516bd930bda7e50ead069280d2ce09ba11781b630a434"}, + {file = "SQLAlchemy-2.0.34-cp38-cp38-win_amd64.whl", hash = "sha256:6e7cde3a2221aa89247944cafb1b26616380e30c63e37ed19ff0bba5e968688d"}, + {file = "SQLAlchemy-2.0.34-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:dbcdf987f3aceef9763b6d7b1fd3e4ee210ddd26cac421d78b3c206d07b2700b"}, + {file = "SQLAlchemy-2.0.34-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ce119fc4ce0d64124d37f66a6f2a584fddc3c5001755f8a49f1ca0a177ef9796"}, + {file = "SQLAlchemy-2.0.34-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a17d8fac6df9835d8e2b4c5523666e7051d0897a93756518a1fe101c7f47f2f0"}, + {file = "SQLAlchemy-2.0.34-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ebc11c54c6ecdd07bb4efbfa1554538982f5432dfb8456958b6d46b9f834bb7"}, + {file = "SQLAlchemy-2.0.34-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2e6965346fc1491a566e019a4a1d3dfc081ce7ac1a736536367ca305da6472a8"}, + {file = "SQLAlchemy-2.0.34-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:220574e78ad986aea8e81ac68821e47ea9202b7e44f251b7ed8c66d9ae3f4278"}, + {file = "SQLAlchemy-2.0.34-cp39-cp39-win32.whl", hash = "sha256:b75b00083e7fe6621ce13cfce9d4469c4774e55e8e9d38c305b37f13cf1e874c"}, + {file = "SQLAlchemy-2.0.34-cp39-cp39-win_amd64.whl", hash = "sha256:c29d03e0adf3cc1a8c3ec62d176824972ae29b67a66cbb18daff3062acc6faa8"}, + {file = "SQLAlchemy-2.0.34-py3-none-any.whl", hash = "sha256:7286c353ee6475613d8beff83167374006c6b3e3f0e6491bfe8ca610eb1dec0f"}, + {file = "sqlalchemy-2.0.34.tar.gz", hash = "sha256:10d8f36990dd929690666679b0f42235c159a7051534adb135728ee52828dd22"}, ] [package.dependencies] @@ -8352,13 +8647,13 @@ doc = ["sphinx"] [[package]] name = "starlette" -version = "0.38.2" +version = "0.38.4" description = "The little ASGI library that shines." optional = false python-versions = ">=3.8" files = [ - {file = "starlette-0.38.2-py3-none-any.whl", hash = "sha256:4ec6a59df6bbafdab5f567754481657f7ed90dc9d69b0c9ff017907dd54faeff"}, - {file = "starlette-0.38.2.tar.gz", hash = "sha256:c7c0441065252160993a1a37cf2a73bb64d271b17303e0b0c1eb7191cfb12d75"}, + {file = "starlette-0.38.4-py3-none-any.whl", hash = "sha256:526f53a77f0e43b85f583438aee1a940fd84f8fd610353e8b0c1a77ad8a87e76"}, + {file = "starlette-0.38.4.tar.gz", hash = "sha256:53a7439060304a208fea17ed407e998f46da5e5d9b1addfea3040094512a6379"}, ] [package.dependencies] @@ -8412,6 +8707,17 @@ files = [ [package.extras] widechars = ["wcwidth"] +[[package]] +name = "tblib" +version = "3.0.0" +description = "Traceback serialization library." +optional = false +python-versions = ">=3.8" +files = [ + {file = "tblib-3.0.0-py3-none-any.whl", hash = "sha256:80a6c77e59b55e83911e1e607c649836a69c103963c5f28a46cbeef44acf8129"}, + {file = "tblib-3.0.0.tar.gz", hash = "sha256:93622790a0a29e04f0346458face1e144dc4d32f493714c6c3dff82a4adb77e6"}, +] + [[package]] name = "tcvectordb" version = "1.3.2" @@ -8444,13 +8750,13 @@ test = ["pytest", "tornado (>=4.5)", "typeguard"] [[package]] name = "tencentcloud-sdk-python-common" -version = "3.0.1216" +version = "3.0.1226" description = "Tencent Cloud Common SDK for Python" optional = false python-versions = "*" files = [ - {file = "tencentcloud-sdk-python-common-3.0.1216.tar.gz", hash = "sha256:7ad83b100574068fe25439fe47fd27253ff1c730348a309567a7ff88eda63cf8"}, - {file = "tencentcloud_sdk_python_common-3.0.1216-py2.py3-none-any.whl", hash = "sha256:5e1cf9b685923d567d379f96a7008084006ad68793cfa0a0524e65dc59fd09d7"}, + {file = "tencentcloud-sdk-python-common-3.0.1226.tar.gz", hash = "sha256:8e126cdce6adffce6fa5a3b464f0a6e483af7c7f78939883823393c2c5e8fc62"}, + {file = "tencentcloud_sdk_python_common-3.0.1226-py2.py3-none-any.whl", hash = "sha256:6165481280147afa226c6bb91df4cd0c43c5230f566be3d3f9c45a826b1105c5"}, ] [package.dependencies] @@ -8458,17 +8764,17 @@ requests = ">=2.16.0" [[package]] name = "tencentcloud-sdk-python-hunyuan" -version = "3.0.1216" +version = "3.0.1226" description = "Tencent Cloud Hunyuan SDK for Python" optional = false python-versions = "*" files = [ - {file = "tencentcloud-sdk-python-hunyuan-3.0.1216.tar.gz", hash = "sha256:b295d67f97dba52ed358a1d9e061f94b1a4a87e45714efbf0987edab12642206"}, - {file = "tencentcloud_sdk_python_hunyuan-3.0.1216-py2.py3-none-any.whl", hash = "sha256:62d925b41424017929b532389061a076dca72dde455e85ec089947645010e691"}, + {file = "tencentcloud-sdk-python-hunyuan-3.0.1226.tar.gz", hash = "sha256:c9b9c3a373d967b691444bd590e3be1424aaab9f1ab30c57d98777113e2b7882"}, + {file = "tencentcloud_sdk_python_hunyuan-3.0.1226-py2.py3-none-any.whl", hash = "sha256:87a1d63f85c25b5ec6c07f16d813091411ea6f296a1bf7fb608a529852b38bbe"}, ] [package.dependencies] -tencentcloud-sdk-python-common = "3.0.1216" +tencentcloud-sdk-python-common = "3.0.1226" [[package]] name = "threadpoolctl" @@ -8730,6 +9036,23 @@ files = [ {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, ] +[[package]] +name = "tos" +version = "2.7.1" +description = "Volc TOS (Tinder Object Storage) SDK" +optional = false +python-versions = "*" +files = [ + {file = "tos-2.7.1.tar.gz", hash = "sha256:4bccdbff3cfd63eb44648bb44862903708c4b3e790f0dd55c96305baaeece805"}, +] + +[package.dependencies] +crcmod = ">=1.7" +Deprecated = ">=1.2.13,<2.0.0" +pytz = "*" +requests = ">=2.19.1,<3.dev0" +six = "*" + [[package]] name = "tqdm" version = "4.66.5" @@ -8837,13 +9160,13 @@ requests = ">=2.0.0" [[package]] name = "typer" -version = "0.12.4" +version = "0.12.5" description = "Typer, build great CLIs. Easy to code. Based on Python type hints." optional = false python-versions = ">=3.7" files = [ - {file = "typer-0.12.4-py3-none-any.whl", hash = "sha256:819aa03699f438397e876aa12b0d63766864ecba1b579092cc9fe35d886e34b6"}, - {file = "typer-0.12.4.tar.gz", hash = "sha256:c9c1613ed6a166162705b3347b8d10b661ccc5d95692654d0fb628118f2c34e6"}, + {file = "typer-0.12.5-py3-none-any.whl", hash = "sha256:62fe4e471711b147e3365034133904df3e235698399bc4de2b36c8579298d52b"}, + {file = "typer-0.12.5.tar.gz", hash = "sha256:f592f089bedcc8ec1b974125d64851029c3b1af145f04aca64d69410f0c9b722"}, ] [package.dependencies] @@ -8854,13 +9177,13 @@ typing-extensions = ">=3.7.4.3" [[package]] name = "types-requests" -version = "2.32.0.20240712" +version = "2.32.0.20240905" description = "Typing stubs for requests" optional = false python-versions = ">=3.8" files = [ - {file = "types-requests-2.32.0.20240712.tar.gz", hash = "sha256:90c079ff05e549f6bf50e02e910210b98b8ff1ebdd18e19c873cd237737c1358"}, - {file = "types_requests-2.32.0.20240712-py3-none-any.whl", hash = "sha256:f754283e152c752e46e70942fa2a146b5bc70393522257bb85bd1ef7e019dcc3"}, + {file = "types-requests-2.32.0.20240905.tar.gz", hash = "sha256:e97fd015a5ed982c9ddcd14cc4afba9d111e0e06b797c8f776d14602735e9bd6"}, + {file = "types_requests-2.32.0.20240905-py3-none-any.whl", hash = "sha256:f46ecb55f5e1a37a58be684cf3f013f166da27552732ef2469a0cc8e62a72881"}, ] [package.dependencies] @@ -9263,12 +9586,12 @@ files = [ [[package]] name = "volcengine-python-sdk" -version = "1.0.98" +version = "1.0.100" description = "Volcengine SDK for Python" optional = false python-versions = "*" files = [ - {file = "volcengine-python-sdk-1.0.98.tar.gz", hash = "sha256:1515e8d46cdcda387f9b45abbcaf0b04b982f7be68068de83f1e388281441784"}, + {file = "volcengine-python-sdk-1.0.100.tar.gz", hash = "sha256:cdc194fe3ce51adda6892d2ca1c43edba3300699321dc6c69119c59fc3b28932"}, ] [package.dependencies] @@ -9285,98 +9608,94 @@ ark = ["anyio (>=3.5.0,<5)", "cached-property", "httpx (>=0.23.0,<1)", "pydantic [[package]] name = "watchfiles" -version = "0.23.0" +version = "0.24.0" description = "Simple, modern and high performance file watching and code reload in python." optional = false python-versions = ">=3.8" files = [ - {file = "watchfiles-0.23.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:bee8ce357a05c20db04f46c22be2d1a2c6a8ed365b325d08af94358e0688eeb4"}, - {file = "watchfiles-0.23.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4ccd3011cc7ee2f789af9ebe04745436371d36afe610028921cab9f24bb2987b"}, - {file = "watchfiles-0.23.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb02d41c33be667e6135e6686f1bb76104c88a312a18faa0ef0262b5bf7f1a0f"}, - {file = "watchfiles-0.23.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7cf12ac34c444362f3261fb3ff548f0037ddd4c5bb85f66c4be30d2936beb3c5"}, - {file = "watchfiles-0.23.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a0b2c25040a3c0ce0e66c7779cc045fdfbbb8d59e5aabfe033000b42fe44b53e"}, - {file = "watchfiles-0.23.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ecf2be4b9eece4f3da8ba5f244b9e51932ebc441c0867bd6af46a3d97eb068d6"}, - {file = "watchfiles-0.23.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:40cb8fa00028908211eb9f8d47744dca21a4be6766672e1ff3280bee320436f1"}, - {file = "watchfiles-0.23.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f48c917ffd36ff9a5212614c2d0d585fa8b064ca7e66206fb5c095015bc8207"}, - {file = "watchfiles-0.23.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9d183e3888ada88185ab17064079c0db8c17e32023f5c278d7bf8014713b1b5b"}, - {file = "watchfiles-0.23.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9837edf328b2805346f91209b7e660f65fb0e9ca18b7459d075d58db082bf981"}, - {file = "watchfiles-0.23.0-cp310-none-win32.whl", hash = "sha256:296e0b29ab0276ca59d82d2da22cbbdb39a23eed94cca69aed274595fb3dfe42"}, - {file = "watchfiles-0.23.0-cp310-none-win_amd64.whl", hash = "sha256:4ea756e425ab2dfc8ef2a0cb87af8aa7ef7dfc6fc46c6f89bcf382121d4fff75"}, - {file = "watchfiles-0.23.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:e397b64f7aaf26915bf2ad0f1190f75c855d11eb111cc00f12f97430153c2eab"}, - {file = "watchfiles-0.23.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b4ac73b02ca1824ec0a7351588241fd3953748d3774694aa7ddb5e8e46aef3e3"}, - {file = "watchfiles-0.23.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:130a896d53b48a1cecccfa903f37a1d87dbb74295305f865a3e816452f6e49e4"}, - {file = "watchfiles-0.23.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c5e7803a65eb2d563c73230e9d693c6539e3c975ccfe62526cadde69f3fda0cf"}, - {file = "watchfiles-0.23.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1aa4cc85202956d1a65c88d18c7b687b8319dbe6b1aec8969784ef7a10e7d1a"}, - {file = "watchfiles-0.23.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:87f889f6e58849ddb7c5d2cb19e2e074917ed1c6e3ceca50405775166492cca8"}, - {file = "watchfiles-0.23.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:37fd826dac84c6441615aa3f04077adcc5cac7194a021c9f0d69af20fb9fa788"}, - {file = "watchfiles-0.23.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ee7db6e36e7a2c15923072e41ea24d9a0cf39658cb0637ecc9307b09d28827e1"}, - {file = "watchfiles-0.23.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:2368c5371c17fdcb5a2ea71c5c9d49f9b128821bfee69503cc38eae00feb3220"}, - {file = "watchfiles-0.23.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:857af85d445b9ba9178db95658c219dbd77b71b8264e66836a6eba4fbf49c320"}, - {file = "watchfiles-0.23.0-cp311-none-win32.whl", hash = "sha256:1d636c8aeb28cdd04a4aa89030c4b48f8b2954d8483e5f989774fa441c0ed57b"}, - {file = "watchfiles-0.23.0-cp311-none-win_amd64.whl", hash = "sha256:46f1d8069a95885ca529645cdbb05aea5837d799965676e1b2b1f95a4206313e"}, - {file = "watchfiles-0.23.0-cp311-none-win_arm64.whl", hash = "sha256:e495ed2a7943503766c5d1ff05ae9212dc2ce1c0e30a80d4f0d84889298fa304"}, - {file = "watchfiles-0.23.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:1db691bad0243aed27c8354b12d60e8e266b75216ae99d33e927ff5238d270b5"}, - {file = "watchfiles-0.23.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:62d2b18cb1edaba311fbbfe83fb5e53a858ba37cacb01e69bc20553bb70911b8"}, - {file = "watchfiles-0.23.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e087e8fdf1270d000913c12e6eca44edd02aad3559b3e6b8ef00f0ce76e0636f"}, - {file = "watchfiles-0.23.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:dd41d5c72417b87c00b1b635738f3c283e737d75c5fa5c3e1c60cd03eac3af77"}, - {file = "watchfiles-0.23.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e5f3ca0ff47940ce0a389457b35d6df601c317c1e1a9615981c474452f98de1"}, - {file = "watchfiles-0.23.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6991e3a78f642368b8b1b669327eb6751439f9f7eaaa625fae67dd6070ecfa0b"}, - {file = "watchfiles-0.23.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7f7252f52a09f8fa5435dc82b6af79483118ce6bd51eb74e6269f05ee22a7b9f"}, - {file = "watchfiles-0.23.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e01bcb8d767c58865207a6c2f2792ad763a0fe1119fb0a430f444f5b02a5ea0"}, - {file = "watchfiles-0.23.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:8e56fbcdd27fce061854ddec99e015dd779cae186eb36b14471fc9ae713b118c"}, - {file = "watchfiles-0.23.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:bd3e2d64500a6cad28bcd710ee6269fbeb2e5320525acd0cfab5f269ade68581"}, - {file = "watchfiles-0.23.0-cp312-none-win32.whl", hash = "sha256:eb99c954291b2fad0eff98b490aa641e128fbc4a03b11c8a0086de8b7077fb75"}, - {file = "watchfiles-0.23.0-cp312-none-win_amd64.whl", hash = "sha256:dccc858372a56080332ea89b78cfb18efb945da858fabeb67f5a44fa0bcb4ebb"}, - {file = "watchfiles-0.23.0-cp312-none-win_arm64.whl", hash = "sha256:6c21a5467f35c61eafb4e394303720893066897fca937bade5b4f5877d350ff8"}, - {file = "watchfiles-0.23.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:ba31c32f6b4dceeb2be04f717811565159617e28d61a60bb616b6442027fd4b9"}, - {file = "watchfiles-0.23.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:85042ab91814fca99cec4678fc063fb46df4cbb57b4835a1cc2cb7a51e10250e"}, - {file = "watchfiles-0.23.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:24655e8c1c9c114005c3868a3d432c8aa595a786b8493500071e6a52f3d09217"}, - {file = "watchfiles-0.23.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6b1a950ab299a4a78fd6369a97b8763732bfb154fdb433356ec55a5bce9515c1"}, - {file = "watchfiles-0.23.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8d3c5cd327dd6ce0edfc94374fb5883d254fe78a5e9d9dfc237a1897dc73cd1"}, - {file = "watchfiles-0.23.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ff785af8bacdf0be863ec0c428e3288b817e82f3d0c1d652cd9c6d509020dd0"}, - {file = "watchfiles-0.23.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:02b7ba9d4557149410747353e7325010d48edcfe9d609a85cb450f17fd50dc3d"}, - {file = "watchfiles-0.23.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:48a1b05c0afb2cd2f48c1ed2ae5487b116e34b93b13074ed3c22ad5c743109f0"}, - {file = "watchfiles-0.23.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:109a61763e7318d9f821b878589e71229f97366fa6a5c7720687d367f3ab9eef"}, - {file = "watchfiles-0.23.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:9f8e6bb5ac007d4a4027b25f09827ed78cbbd5b9700fd6c54429278dacce05d1"}, - {file = "watchfiles-0.23.0-cp313-none-win32.whl", hash = "sha256:f46c6f0aec8d02a52d97a583782d9af38c19a29900747eb048af358a9c1d8e5b"}, - {file = "watchfiles-0.23.0-cp313-none-win_amd64.whl", hash = "sha256:f449afbb971df5c6faeb0a27bca0427d7b600dd8f4a068492faec18023f0dcff"}, - {file = "watchfiles-0.23.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:2dddc2487d33e92f8b6222b5fb74ae2cfde5e8e6c44e0248d24ec23befdc5366"}, - {file = "watchfiles-0.23.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e75695cc952e825fa3e0684a7f4a302f9128721f13eedd8dbd3af2ba450932b8"}, - {file = "watchfiles-0.23.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2537ef60596511df79b91613a5bb499b63f46f01a11a81b0a2b0dedf645d0a9c"}, - {file = "watchfiles-0.23.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:20b423b58f5fdde704a226b598a2d78165fe29eb5621358fe57ea63f16f165c4"}, - {file = "watchfiles-0.23.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b98732ec893975455708d6fc9a6daab527fc8bbe65be354a3861f8c450a632a4"}, - {file = "watchfiles-0.23.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ee1f5fcbf5bc33acc0be9dd31130bcba35d6d2302e4eceafafd7d9018c7755ab"}, - {file = "watchfiles-0.23.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a8f195338a5a7b50a058522b39517c50238358d9ad8284fd92943643144c0c03"}, - {file = "watchfiles-0.23.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:524fcb8d59b0dbee2c9b32207084b67b2420f6431ed02c18bd191e6c575f5c48"}, - {file = "watchfiles-0.23.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:0eff099a4df36afaa0eea7a913aa64dcf2cbd4e7a4f319a73012210af4d23810"}, - {file = "watchfiles-0.23.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a8323daae27ea290ba3350c70c836c0d2b0fb47897fa3b0ca6a5375b952b90d3"}, - {file = "watchfiles-0.23.0-cp38-none-win32.whl", hash = "sha256:aafea64a3ae698695975251f4254df2225e2624185a69534e7fe70581066bc1b"}, - {file = "watchfiles-0.23.0-cp38-none-win_amd64.whl", hash = "sha256:c846884b2e690ba62a51048a097acb6b5cd263d8bd91062cd6137e2880578472"}, - {file = "watchfiles-0.23.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:a753993635eccf1ecb185dedcc69d220dab41804272f45e4aef0a67e790c3eb3"}, - {file = "watchfiles-0.23.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6bb91fa4d0b392f0f7e27c40981e46dda9eb0fbc84162c7fb478fe115944f491"}, - {file = "watchfiles-0.23.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1f67312efa3902a8e8496bfa9824d3bec096ff83c4669ea555c6bdd213aa516"}, - {file = "watchfiles-0.23.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7ca6b71dcc50d320c88fb2d88ecd63924934a8abc1673683a242a7ca7d39e781"}, - {file = "watchfiles-0.23.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2aec5c29915caf08771d2507da3ac08e8de24a50f746eb1ed295584ba1820330"}, - {file = "watchfiles-0.23.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1733b9bc2c8098c6bdb0ff7a3d7cb211753fecb7bd99bdd6df995621ee1a574b"}, - {file = "watchfiles-0.23.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:02ff5d7bd066c6a7673b17c8879cd8ee903078d184802a7ee851449c43521bdd"}, - {file = "watchfiles-0.23.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18e2de19801b0eaa4c5292a223effb7cfb43904cb742c5317a0ac686ed604765"}, - {file = "watchfiles-0.23.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:8ada449e22198c31fb013ae7e9add887e8d2bd2335401abd3cbc55f8c5083647"}, - {file = "watchfiles-0.23.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3af1b05361e1cc497bf1be654a664750ae61f5739e4bb094a2be86ec8c6db9b6"}, - {file = "watchfiles-0.23.0-cp39-none-win32.whl", hash = "sha256:486bda18be5d25ab5d932699ceed918f68eb91f45d018b0343e3502e52866e5e"}, - {file = "watchfiles-0.23.0-cp39-none-win_amd64.whl", hash = "sha256:d2d42254b189a346249424fb9bb39182a19289a2409051ee432fb2926bad966a"}, - {file = "watchfiles-0.23.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:6a9265cf87a5b70147bfb2fec14770ed5b11a5bb83353f0eee1c25a81af5abfe"}, - {file = "watchfiles-0.23.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:9f02a259fcbbb5fcfe7a0805b1097ead5ba7a043e318eef1db59f93067f0b49b"}, - {file = "watchfiles-0.23.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1ebaebb53b34690da0936c256c1cdb0914f24fb0e03da76d185806df9328abed"}, - {file = "watchfiles-0.23.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd257f98cff9c6cb39eee1a83c7c3183970d8a8d23e8cf4f47d9a21329285cee"}, - {file = "watchfiles-0.23.0-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:aba037c1310dd108411d27b3d5815998ef0e83573e47d4219f45753c710f969f"}, - {file = "watchfiles-0.23.0-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:a96ac14e184aa86dc43b8a22bb53854760a58b2966c2b41580de938e9bf26ed0"}, - {file = "watchfiles-0.23.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:11698bb2ea5e991d10f1f4f83a39a02f91e44e4bd05f01b5c1ec04c9342bf63c"}, - {file = "watchfiles-0.23.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:efadd40fca3a04063d40c4448c9303ce24dd6151dc162cfae4a2a060232ebdcb"}, - {file = "watchfiles-0.23.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:556347b0abb4224c5ec688fc58214162e92a500323f50182f994f3ad33385dcb"}, - {file = "watchfiles-0.23.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:1cf7f486169986c4b9d34087f08ce56a35126600b6fef3028f19ca16d5889071"}, - {file = "watchfiles-0.23.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f18de0f82c62c4197bea5ecf4389288ac755896aac734bd2cc44004c56e4ac47"}, - {file = "watchfiles-0.23.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:532e1f2c491274d1333a814e4c5c2e8b92345d41b12dc806cf07aaff786beb66"}, - {file = "watchfiles-0.23.0.tar.gz", hash = "sha256:9338ade39ff24f8086bb005d16c29f8e9f19e55b18dcb04dfa26fcbc09da497b"}, + {file = "watchfiles-0.24.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:083dc77dbdeef09fa44bb0f4d1df571d2e12d8a8f985dccde71ac3ac9ac067a0"}, + {file = "watchfiles-0.24.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e94e98c7cb94cfa6e071d401ea3342767f28eb5a06a58fafdc0d2a4974f4f35c"}, + {file = "watchfiles-0.24.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82ae557a8c037c42a6ef26c494d0631cacca040934b101d001100ed93d43f361"}, + {file = "watchfiles-0.24.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:acbfa31e315a8f14fe33e3542cbcafc55703b8f5dcbb7c1eecd30f141df50db3"}, + {file = "watchfiles-0.24.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b74fdffce9dfcf2dc296dec8743e5b0332d15df19ae464f0e249aa871fc1c571"}, + {file = "watchfiles-0.24.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:449f43f49c8ddca87c6b3980c9284cab6bd1f5c9d9a2b00012adaaccd5e7decd"}, + {file = "watchfiles-0.24.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4abf4ad269856618f82dee296ac66b0cd1d71450fc3c98532d93798e73399b7a"}, + {file = "watchfiles-0.24.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f895d785eb6164678ff4bb5cc60c5996b3ee6df3edb28dcdeba86a13ea0465e"}, + {file = "watchfiles-0.24.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:7ae3e208b31be8ce7f4c2c0034f33406dd24fbce3467f77223d10cd86778471c"}, + {file = "watchfiles-0.24.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2efec17819b0046dde35d13fb8ac7a3ad877af41ae4640f4109d9154ed30a188"}, + {file = "watchfiles-0.24.0-cp310-none-win32.whl", hash = "sha256:6bdcfa3cd6fdbdd1a068a52820f46a815401cbc2cb187dd006cb076675e7b735"}, + {file = "watchfiles-0.24.0-cp310-none-win_amd64.whl", hash = "sha256:54ca90a9ae6597ae6dc00e7ed0a040ef723f84ec517d3e7ce13e63e4bc82fa04"}, + {file = "watchfiles-0.24.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:bdcd5538e27f188dd3c804b4a8d5f52a7fc7f87e7fd6b374b8e36a4ca03db428"}, + {file = "watchfiles-0.24.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2dadf8a8014fde6addfd3c379e6ed1a981c8f0a48292d662e27cabfe4239c83c"}, + {file = "watchfiles-0.24.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6509ed3f467b79d95fc62a98229f79b1a60d1b93f101e1c61d10c95a46a84f43"}, + {file = "watchfiles-0.24.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8360f7314a070c30e4c976b183d1d8d1585a4a50c5cb603f431cebcbb4f66327"}, + {file = "watchfiles-0.24.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:316449aefacf40147a9efaf3bd7c9bdd35aaba9ac5d708bd1eb5763c9a02bef5"}, + {file = "watchfiles-0.24.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:73bde715f940bea845a95247ea3e5eb17769ba1010efdc938ffcb967c634fa61"}, + {file = "watchfiles-0.24.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3770e260b18e7f4e576edca4c0a639f704088602e0bc921c5c2e721e3acb8d15"}, + {file = "watchfiles-0.24.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa0fd7248cf533c259e59dc593a60973a73e881162b1a2f73360547132742823"}, + {file = "watchfiles-0.24.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d7a2e3b7f5703ffbd500dabdefcbc9eafeff4b9444bbdd5d83d79eedf8428fab"}, + {file = "watchfiles-0.24.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d831ee0a50946d24a53821819b2327d5751b0c938b12c0653ea5be7dea9c82ec"}, + {file = "watchfiles-0.24.0-cp311-none-win32.whl", hash = "sha256:49d617df841a63b4445790a254013aea2120357ccacbed00253f9c2b5dc24e2d"}, + {file = "watchfiles-0.24.0-cp311-none-win_amd64.whl", hash = "sha256:d3dcb774e3568477275cc76554b5a565024b8ba3a0322f77c246bc7111c5bb9c"}, + {file = "watchfiles-0.24.0-cp311-none-win_arm64.whl", hash = "sha256:9301c689051a4857d5b10777da23fafb8e8e921bcf3abe6448a058d27fb67633"}, + {file = "watchfiles-0.24.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:7211b463695d1e995ca3feb38b69227e46dbd03947172585ecb0588f19b0d87a"}, + {file = "watchfiles-0.24.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4b8693502d1967b00f2fb82fc1e744df128ba22f530e15b763c8d82baee15370"}, + {file = "watchfiles-0.24.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cdab9555053399318b953a1fe1f586e945bc8d635ce9d05e617fd9fe3a4687d6"}, + {file = "watchfiles-0.24.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:34e19e56d68b0dad5cff62273107cf5d9fbaf9d75c46277aa5d803b3ef8a9e9b"}, + {file = "watchfiles-0.24.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:41face41f036fee09eba33a5b53a73e9a43d5cb2c53dad8e61fa6c9f91b5a51e"}, + {file = "watchfiles-0.24.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5148c2f1ea043db13ce9b0c28456e18ecc8f14f41325aa624314095b6aa2e9ea"}, + {file = "watchfiles-0.24.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7e4bd963a935aaf40b625c2499f3f4f6bbd0c3776f6d3bc7c853d04824ff1c9f"}, + {file = "watchfiles-0.24.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c79d7719d027b7a42817c5d96461a99b6a49979c143839fc37aa5748c322f234"}, + {file = "watchfiles-0.24.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:32aa53a9a63b7f01ed32e316e354e81e9da0e6267435c7243bf8ae0f10b428ef"}, + {file = "watchfiles-0.24.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ce72dba6a20e39a0c628258b5c308779b8697f7676c254a845715e2a1039b968"}, + {file = "watchfiles-0.24.0-cp312-none-win32.whl", hash = "sha256:d9018153cf57fc302a2a34cb7564870b859ed9a732d16b41a9b5cb2ebed2d444"}, + {file = "watchfiles-0.24.0-cp312-none-win_amd64.whl", hash = "sha256:551ec3ee2a3ac9cbcf48a4ec76e42c2ef938a7e905a35b42a1267fa4b1645896"}, + {file = "watchfiles-0.24.0-cp312-none-win_arm64.whl", hash = "sha256:b52a65e4ea43c6d149c5f8ddb0bef8d4a1e779b77591a458a893eb416624a418"}, + {file = "watchfiles-0.24.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:3d2e3ab79a1771c530233cadfd277fcc762656d50836c77abb2e5e72b88e3a48"}, + {file = "watchfiles-0.24.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:327763da824817b38ad125dcd97595f942d720d32d879f6c4ddf843e3da3fe90"}, + {file = "watchfiles-0.24.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd82010f8ab451dabe36054a1622870166a67cf3fce894f68895db6f74bbdc94"}, + {file = "watchfiles-0.24.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d64ba08db72e5dfd5c33be1e1e687d5e4fcce09219e8aee893a4862034081d4e"}, + {file = "watchfiles-0.24.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1cf1f6dd7825053f3d98f6d33f6464ebdd9ee95acd74ba2c34e183086900a827"}, + {file = "watchfiles-0.24.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:43e3e37c15a8b6fe00c1bce2473cfa8eb3484bbeecf3aefbf259227e487a03df"}, + {file = "watchfiles-0.24.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:88bcd4d0fe1d8ff43675360a72def210ebad3f3f72cabfeac08d825d2639b4ab"}, + {file = "watchfiles-0.24.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:999928c6434372fde16c8f27143d3e97201160b48a614071261701615a2a156f"}, + {file = "watchfiles-0.24.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:30bbd525c3262fd9f4b1865cb8d88e21161366561cd7c9e1194819e0a33ea86b"}, + {file = "watchfiles-0.24.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:edf71b01dec9f766fb285b73930f95f730bb0943500ba0566ae234b5c1618c18"}, + {file = "watchfiles-0.24.0-cp313-none-win32.whl", hash = "sha256:f4c96283fca3ee09fb044f02156d9570d156698bc3734252175a38f0e8975f07"}, + {file = "watchfiles-0.24.0-cp313-none-win_amd64.whl", hash = "sha256:a974231b4fdd1bb7f62064a0565a6b107d27d21d9acb50c484d2cdba515b9366"}, + {file = "watchfiles-0.24.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:ee82c98bed9d97cd2f53bdb035e619309a098ea53ce525833e26b93f673bc318"}, + {file = "watchfiles-0.24.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:fd92bbaa2ecdb7864b7600dcdb6f2f1db6e0346ed425fbd01085be04c63f0b05"}, + {file = "watchfiles-0.24.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f83df90191d67af5a831da3a33dd7628b02a95450e168785586ed51e6d28943c"}, + {file = "watchfiles-0.24.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fca9433a45f18b7c779d2bae7beeec4f740d28b788b117a48368d95a3233ed83"}, + {file = "watchfiles-0.24.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b995bfa6bf01a9e09b884077a6d37070464b529d8682d7691c2d3b540d357a0c"}, + {file = "watchfiles-0.24.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ed9aba6e01ff6f2e8285e5aa4154e2970068fe0fc0998c4380d0e6278222269b"}, + {file = "watchfiles-0.24.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e5171ef898299c657685306d8e1478a45e9303ddcd8ac5fed5bd52ad4ae0b69b"}, + {file = "watchfiles-0.24.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4933a508d2f78099162da473841c652ad0de892719043d3f07cc83b33dfd9d91"}, + {file = "watchfiles-0.24.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:95cf3b95ea665ab03f5a54765fa41abf0529dbaf372c3b83d91ad2cfa695779b"}, + {file = "watchfiles-0.24.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:01def80eb62bd5db99a798d5e1f5f940ca0a05986dcfae21d833af7a46f7ee22"}, + {file = "watchfiles-0.24.0-cp38-none-win32.whl", hash = "sha256:4d28cea3c976499475f5b7a2fec6b3a36208656963c1a856d328aeae056fc5c1"}, + {file = "watchfiles-0.24.0-cp38-none-win_amd64.whl", hash = "sha256:21ab23fdc1208086d99ad3f69c231ba265628014d4aed31d4e8746bd59e88cd1"}, + {file = "watchfiles-0.24.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:b665caeeda58625c3946ad7308fbd88a086ee51ccb706307e5b1fa91556ac886"}, + {file = "watchfiles-0.24.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5c51749f3e4e269231510da426ce4a44beb98db2dce9097225c338f815b05d4f"}, + {file = "watchfiles-0.24.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82b2509f08761f29a0fdad35f7e1638b8ab1adfa2666d41b794090361fb8b855"}, + {file = "watchfiles-0.24.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9a60e2bf9dc6afe7f743e7c9b149d1fdd6dbf35153c78fe3a14ae1a9aee3d98b"}, + {file = "watchfiles-0.24.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f7d9b87c4c55e3ea8881dfcbf6d61ea6775fffed1fedffaa60bd047d3c08c430"}, + {file = "watchfiles-0.24.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:78470906a6be5199524641f538bd2c56bb809cd4bf29a566a75051610bc982c3"}, + {file = "watchfiles-0.24.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:07cdef0c84c03375f4e24642ef8d8178e533596b229d32d2bbd69e5128ede02a"}, + {file = "watchfiles-0.24.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d337193bbf3e45171c8025e291530fb7548a93c45253897cd764a6a71c937ed9"}, + {file = "watchfiles-0.24.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ec39698c45b11d9694a1b635a70946a5bad066b593af863460a8e600f0dff1ca"}, + {file = "watchfiles-0.24.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2e28d91ef48eab0afb939fa446d8ebe77e2f7593f5f463fd2bb2b14132f95b6e"}, + {file = "watchfiles-0.24.0-cp39-none-win32.whl", hash = "sha256:7138eff8baa883aeaa074359daabb8b6c1e73ffe69d5accdc907d62e50b1c0da"}, + {file = "watchfiles-0.24.0-cp39-none-win_amd64.whl", hash = "sha256:b3ef2c69c655db63deb96b3c3e587084612f9b1fa983df5e0c3379d41307467f"}, + {file = "watchfiles-0.24.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:632676574429bee8c26be8af52af20e0c718cc7f5f67f3fb658c71928ccd4f7f"}, + {file = "watchfiles-0.24.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:a2a9891723a735d3e2540651184be6fd5b96880c08ffe1a98bae5017e65b544b"}, + {file = "watchfiles-0.24.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a7fa2bc0efef3e209a8199fd111b8969fe9db9c711acc46636686331eda7dd4"}, + {file = "watchfiles-0.24.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01550ccf1d0aed6ea375ef259706af76ad009ef5b0203a3a4cce0f6024f9b68a"}, + {file = "watchfiles-0.24.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:96619302d4374de5e2345b2b622dc481257a99431277662c30f606f3e22f42be"}, + {file = "watchfiles-0.24.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:85d5f0c7771dcc7a26c7a27145059b6bb0ce06e4e751ed76cdf123d7039b60b5"}, + {file = "watchfiles-0.24.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:951088d12d339690a92cef2ec5d3cfd957692834c72ffd570ea76a6790222777"}, + {file = "watchfiles-0.24.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49fb58bcaa343fedc6a9e91f90195b20ccb3135447dc9e4e2570c3a39565853e"}, + {file = "watchfiles-0.24.0.tar.gz", hash = "sha256:afb72325b74fa7a428c009c1b8be4b4d7c2afedafb2982827ef2156646df2fe1"}, ] [package.dependencies] @@ -9442,97 +9761,97 @@ test = ["websockets"] [[package]] name = "websockets" -version = "13.0" +version = "13.0.1" description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" optional = false python-versions = ">=3.8" files = [ - {file = "websockets-13.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ad4fa707ff9e2ffee019e946257b5300a45137a58f41fbd9a4db8e684ab61528"}, - {file = "websockets-13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6fd757f313c13c34dae9f126d3ba4cf97175859c719e57c6a614b781c86b617e"}, - {file = "websockets-13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cbac2eb7ce0fac755fb983c9247c4a60c4019bcde4c0e4d167aeb17520cc7ef1"}, - {file = "websockets-13.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4b83cf7354cbbc058e97b3e545dceb75b8d9cf17fd5a19db419c319ddbaaf7a"}, - {file = "websockets-13.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9202c0010c78fad1041e1c5285232b6508d3633f92825687549540a70e9e5901"}, - {file = "websockets-13.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e6566e79c8c7cbea75ec450f6e1828945fc5c9a4769ceb1c7b6e22470539712"}, - {file = "websockets-13.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e7fcad070dcd9ad37a09d89a4cbc2a5e3e45080b88977c0da87b3090f9f55ead"}, - {file = "websockets-13.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:0a8f7d65358a25172db00c69bcc7df834155ee24229f560d035758fd6613111a"}, - {file = "websockets-13.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:63b702fb31e3f058f946ccdfa551f4d57a06f7729c369e8815eb18643099db37"}, - {file = "websockets-13.0-cp310-cp310-win32.whl", hash = "sha256:3a20cf14ba7b482c4a1924b5e061729afb89c890ca9ed44ac4127c6c5986e424"}, - {file = "websockets-13.0-cp310-cp310-win_amd64.whl", hash = "sha256:587245f0704d0bb675f919898d7473e8827a6d578e5a122a21756ca44b811ec8"}, - {file = "websockets-13.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:06df8306c241c235075d2ae77367038e701e53bc8c1bb4f6644f4f53aa6dedd0"}, - {file = "websockets-13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:85a1f92a02f0b8c1bf02699731a70a8a74402bb3f82bee36e7768b19a8ed9709"}, - {file = "websockets-13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9ed02c604349068d46d87ef4c2012c112c791f2bec08671903a6bb2bd9c06784"}, - {file = "websockets-13.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b89849171b590107f6724a7b0790736daead40926ddf47eadf998b4ff51d6414"}, - {file = "websockets-13.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:939a16849d71203628157a5e4a495da63967c744e1e32018e9b9e2689aca64d4"}, - {file = "websockets-13.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad818cdac37c0ad4c58e51cb4964eae4f18b43c4a83cb37170b0d90c31bd80cf"}, - {file = "websockets-13.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:cbfe82a07596a044de78bb7a62519e71690c5812c26c5f1d4b877e64e4f46309"}, - {file = "websockets-13.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e07e76c49f39c5b45cbd7362b94f001ae209a3ea4905ae9a09cfd53b3c76373d"}, - {file = "websockets-13.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:372f46a0096cfda23c88f7e42349a33f8375e10912f712e6b496d3a9a557290f"}, - {file = "websockets-13.0-cp311-cp311-win32.whl", hash = "sha256:376a43a4fd96725f13450d3d2e98f4f36c3525c562ab53d9a98dd2950dca9a8a"}, - {file = "websockets-13.0-cp311-cp311-win_amd64.whl", hash = "sha256:2be1382a4daa61e2f3e2be3b3c86932a8db9d1f85297feb6e9df22f391f94452"}, - {file = "websockets-13.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:b5407c34776b9b77bd89a5f95eb0a34aaf91889e3f911c63f13035220eb50107"}, - {file = "websockets-13.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:4782ec789f059f888c1e8fdf94383d0e64b531cffebbf26dd55afd53ab487ca4"}, - {file = "websockets-13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c8feb8e19ef65c9994e652c5b0324abd657bedd0abeb946fb4f5163012c1e730"}, - {file = "websockets-13.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3f3d2e20c442b58dbac593cb1e02bc02d149a86056cc4126d977ad902472e3b"}, - {file = "websockets-13.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e39d393e0ab5b8bd01717cc26f2922026050188947ff54fe6a49dc489f7750b7"}, - {file = "websockets-13.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f661a4205741bdc88ac9c2b2ec003c72cee97e4acd156eb733662ff004ba429"}, - {file = "websockets-13.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:384129ad0490e06bab2b98c1da9b488acb35bb11e2464c728376c6f55f0d45f3"}, - {file = "websockets-13.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:df5c0eff91f61b8205a6c9f7b255ff390cdb77b61c7b41f79ca10afcbb22b6cb"}, - {file = "websockets-13.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:02cc9bb1a887dac0e08bf657c5d00aa3fac0d03215d35a599130c2034ae6663a"}, - {file = "websockets-13.0-cp312-cp312-win32.whl", hash = "sha256:d9726d2c9bd6aed8cb994d89b3910ca0079406edce3670886ec828a73e7bdd53"}, - {file = "websockets-13.0-cp312-cp312-win_amd64.whl", hash = "sha256:fa0839f35322f7b038d8adcf679e2698c3a483688cc92e3bd15ee4fb06669e9a"}, - {file = "websockets-13.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:da7e501e59857e8e3e9d10586139dc196b80445a591451ca9998aafba1af5278"}, - {file = "websockets-13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a00e1e587c655749afb5b135d8d3edcfe84ec6db864201e40a882e64168610b3"}, - {file = "websockets-13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a7fbf2a8fe7556a8f4e68cb3e736884af7bf93653e79f6219f17ebb75e97d8f0"}, - {file = "websockets-13.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ea9c9c7443a97ea4d84d3e4d42d0e8c4235834edae652993abcd2aff94affd7"}, - {file = "websockets-13.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:35c2221b539b360203f3f9ad168e527bf16d903e385068ae842c186efb13d0ea"}, - {file = "websockets-13.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:358d37c5c431dd050ffb06b4b075505aae3f4f795d7fff9794e5ed96ce99b998"}, - {file = "websockets-13.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:038e7a0f1bfafc7bf52915ab3506b7a03d1e06381e9f60440c856e8918138151"}, - {file = "websockets-13.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:fd038bc9e2c134847f1e0ce3191797fad110756e690c2fdd9702ed34e7a43abb"}, - {file = "websockets-13.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:93b8c2008f372379fb6e5d2b3f7c9ec32f7b80316543fd3a5ace6610c5cde1b0"}, - {file = "websockets-13.0-cp313-cp313-win32.whl", hash = "sha256:851fd0afb3bc0b73f7c5b5858975d42769a5fdde5314f4ef2c106aec63100687"}, - {file = "websockets-13.0-cp313-cp313-win_amd64.whl", hash = "sha256:7d14901fdcf212804970c30ab9ee8f3f0212e620c7ea93079d6534863444fb4e"}, - {file = "websockets-13.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ae7a519a56a714f64c3445cabde9fc2fc927e7eae44f413eae187cddd9e54178"}, - {file = "websockets-13.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5575031472ca87302aeb2ce2c2349f4c6ea978c86a9d1289bc5d16058ad4c10a"}, - {file = "websockets-13.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9895df6cd0bfe79d09bcd1dbdc03862846f26fbd93797153de954306620c1d00"}, - {file = "websockets-13.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a4de299c947a54fca9ce1c5fd4a08eb92ffce91961becb13bd9195f7c6e71b47"}, - {file = "websockets-13.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:05c25f7b849702950b6fd0e233989bb73a0d2bc83faa3b7233313ca395205f6d"}, - {file = "websockets-13.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ede95125a30602b1691a4b1da88946bf27dae283cf30f22cd2cb8ca4b2e0d119"}, - {file = "websockets-13.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:addf0a16e4983280efed272d8cb3b2e05f0051755372461e7d966b80a6554e16"}, - {file = "websockets-13.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:06b3186e97bf9a33921fa60734d5ed90f2a9b407cce8d23c7333a0984049ef61"}, - {file = "websockets-13.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:eae368cac85adc4c7dc3b0d5f84ffcca609d658db6447387300478e44db70796"}, - {file = "websockets-13.0-cp38-cp38-win32.whl", hash = "sha256:337837ac788d955728b1ab01876d72b73da59819a3388e1c5e8e05c3999f1afa"}, - {file = "websockets-13.0-cp38-cp38-win_amd64.whl", hash = "sha256:f66e00e42f25ca7e91076366303e11c82572ca87cc5aae51e6e9c094f315ab41"}, - {file = "websockets-13.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:94c1c02721139fe9940b38d28fb15b4b782981d800d5f40f9966264fbf23dcc8"}, - {file = "websockets-13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bd4ba86513430513e2aa25a441bb538f6f83734dc368a2c5d18afdd39097aa33"}, - {file = "websockets-13.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a1ab8f0e0cadc5be5f3f9fa11a663957fecbf483d434762c8dfb8aa44948944a"}, - {file = "websockets-13.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3670def5d3dfd5af6f6e2b3b243ea8f1f72d8da1ef927322f0703f85c90d9603"}, - {file = "websockets-13.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6058b6be92743358885ad6dcdecb378fde4a4c74d4dd16a089d07580c75a0e80"}, - {file = "websockets-13.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:516062a0a8ef5ecbfa4acbaec14b199fc070577834f9fe3d40800a99f92523ca"}, - {file = "websockets-13.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:da7e918d82e7bdfc6f66d31febe1b2e28a1ca3387315f918de26f5e367f61572"}, - {file = "websockets-13.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:9cc7f35dcb49a4e32db82a849fcc0714c4d4acc9d2273aded2d61f87d7f660b7"}, - {file = "websockets-13.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:f5737c53eb2c8ed8f64b50d3dafd3c1dae739f78aa495a288421ac1b3de82717"}, - {file = "websockets-13.0-cp39-cp39-win32.whl", hash = "sha256:265e1f0d3f788ce8ef99dca591a1aec5263b26083ca0934467ad9a1d1181067c"}, - {file = "websockets-13.0-cp39-cp39-win_amd64.whl", hash = "sha256:4d70c89e3d3b347a7c4d3c33f8d323f0584c9ceb69b82c2ef8a174ca84ea3d4a"}, - {file = "websockets-13.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:602cbd010d8c21c8475f1798b705bb18567eb189c533ab5ef568bc3033fdf417"}, - {file = "websockets-13.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:bf8eb5dca4f484a60f5327b044e842e0d7f7cdbf02ea6dc4a4f811259f1f1f0b"}, - {file = "websockets-13.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89d795c1802d99a643bf689b277e8604c14b5af1bc0a31dade2cd7a678087212"}, - {file = "websockets-13.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:788bc841d250beccff67a20a5a53a15657a60111ef9c0c0a97fbdd614fae0fe2"}, - {file = "websockets-13.0-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7334752052532c156d28b8eaf3558137e115c7871ea82adff69b6d94a7bee273"}, - {file = "websockets-13.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:e7a1963302947332c3039e3f66209ec73b1626f8a0191649e0713c391e9f5b0d"}, - {file = "websockets-13.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:2e1cf4e1eb84b4fd74a47688e8b0940c89a04ad9f6937afa43d468e71128cd68"}, - {file = "websockets-13.0-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:c026ee729c4ce55708a14b839ba35086dfae265fc12813b62d34ce33f4980c1c"}, - {file = "websockets-13.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f5f9d23fbbf96eefde836d9692670bfc89e2d159f456d499c5efcf6a6281c1af"}, - {file = "websockets-13.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ad684cb7efce227d756bae3e8484f2e56aa128398753b54245efdfbd1108f2c"}, - {file = "websockets-13.0-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1e10b3fbed7be4a59831d3a939900e50fcd34d93716e433d4193a4d0d1d335d"}, - {file = "websockets-13.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:d42a818e634f789350cd8fb413a3f5eec1cf0400a53d02062534c41519f5125c"}, - {file = "websockets-13.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:e5ba5e9b332267d0f2c33ede390061850f1ac3ee6cd1bdcf4c5ea33ead971966"}, - {file = "websockets-13.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:f9af457ed593e35f467140d8b61d425495b127744a9d65d45a366f8678449a23"}, - {file = "websockets-13.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bcea3eb58c09c3a31cc83b45c06d5907f02ddaf10920aaa6443975310f699b95"}, - {file = "websockets-13.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c210d1460dc8d326ffdef9703c2f83269b7539a1690ad11ae04162bc1878d33d"}, - {file = "websockets-13.0-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b32f38bc81170fd56d0482d505b556e52bf9078b36819a8ba52624bd6667e39e"}, - {file = "websockets-13.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:81a11a1ddd5320429db47c04d35119c3e674d215173d87aaeb06ae80f6e9031f"}, - {file = "websockets-13.0-py3-none-any.whl", hash = "sha256:dbbac01e80aee253d44c4f098ab3cc17c822518519e869b284cfbb8cd16cc9de"}, - {file = "websockets-13.0.tar.gz", hash = "sha256:b7bf950234a482b7461afdb2ec99eee3548ec4d53f418c7990bb79c620476602"}, + {file = "websockets-13.0.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:1841c9082a3ba4a05ea824cf6d99570a6a2d8849ef0db16e9c826acb28089e8f"}, + {file = "websockets-13.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c5870b4a11b77e4caa3937142b650fbbc0914a3e07a0cf3131f35c0587489c1c"}, + {file = "websockets-13.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f1d3d1f2eb79fe7b0fb02e599b2bf76a7619c79300fc55f0b5e2d382881d4f7f"}, + {file = "websockets-13.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:15c7d62ee071fa94a2fc52c2b472fed4af258d43f9030479d9c4a2de885fd543"}, + {file = "websockets-13.0.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6724b554b70d6195ba19650fef5759ef11346f946c07dbbe390e039bcaa7cc3d"}, + {file = "websockets-13.0.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56a952fa2ae57a42ba7951e6b2605e08a24801a4931b5644dfc68939e041bc7f"}, + {file = "websockets-13.0.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:17118647c0ea14796364299e942c330d72acc4b248e07e639d34b75067b3cdd8"}, + {file = "websockets-13.0.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:64a11aae1de4c178fa653b07d90f2fb1a2ed31919a5ea2361a38760192e1858b"}, + {file = "websockets-13.0.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:0617fd0b1d14309c7eab6ba5deae8a7179959861846cbc5cb528a7531c249448"}, + {file = "websockets-13.0.1-cp310-cp310-win32.whl", hash = "sha256:11f9976ecbc530248cf162e359a92f37b7b282de88d1d194f2167b5e7ad80ce3"}, + {file = "websockets-13.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:c3c493d0e5141ec055a7d6809a28ac2b88d5b878bb22df8c621ebe79a61123d0"}, + {file = "websockets-13.0.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:699ba9dd6a926f82a277063603fc8d586b89f4cb128efc353b749b641fcddda7"}, + {file = "websockets-13.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cf2fae6d85e5dc384bf846f8243ddaa9197f3a1a70044f59399af001fd1f51d4"}, + {file = "websockets-13.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:52aed6ef21a0f1a2a5e310fb5c42d7555e9c5855476bbd7173c3aa3d8a0302f2"}, + {file = "websockets-13.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8eb2b9a318542153674c6e377eb8cb9ca0fc011c04475110d3477862f15d29f0"}, + {file = "websockets-13.0.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5df891c86fe68b2c38da55b7aea7095beca105933c697d719f3f45f4220a5e0e"}, + {file = "websockets-13.0.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fac2d146ff30d9dd2fcf917e5d147db037a5c573f0446c564f16f1f94cf87462"}, + {file = "websockets-13.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b8ac5b46fd798bbbf2ac6620e0437c36a202b08e1f827832c4bf050da081b501"}, + {file = "websockets-13.0.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:46af561eba6f9b0848b2c9d2427086cabadf14e0abdd9fde9d72d447df268418"}, + {file = "websockets-13.0.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b5a06d7f60bc2fc378a333978470dfc4e1415ee52f5f0fce4f7853eb10c1e9df"}, + {file = "websockets-13.0.1-cp311-cp311-win32.whl", hash = "sha256:556e70e4f69be1082e6ef26dcb70efcd08d1850f5d6c5f4f2bcb4e397e68f01f"}, + {file = "websockets-13.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:67494e95d6565bf395476e9d040037ff69c8b3fa356a886b21d8422ad86ae075"}, + {file = "websockets-13.0.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f9c9e258e3d5efe199ec23903f5da0eeaad58cf6fccb3547b74fd4750e5ac47a"}, + {file = "websockets-13.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6b41a1b3b561f1cba8321fb32987552a024a8f67f0d05f06fcf29f0090a1b956"}, + {file = "websockets-13.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f73e676a46b0fe9426612ce8caeca54c9073191a77c3e9d5c94697aef99296af"}, + {file = "websockets-13.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f613289f4a94142f914aafad6c6c87903de78eae1e140fa769a7385fb232fdf"}, + {file = "websockets-13.0.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0f52504023b1480d458adf496dc1c9e9811df4ba4752f0bc1f89ae92f4f07d0c"}, + {file = "websockets-13.0.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:139add0f98206cb74109faf3611b7783ceafc928529c62b389917a037d4cfdf4"}, + {file = "websockets-13.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:47236c13be337ef36546004ce8c5580f4b1150d9538b27bf8a5ad8edf23ccfab"}, + {file = "websockets-13.0.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c44ca9ade59b2e376612df34e837013e2b273e6c92d7ed6636d0556b6f4db93d"}, + {file = "websockets-13.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:9bbc525f4be3e51b89b2a700f5746c2a6907d2e2ef4513a8daafc98198b92237"}, + {file = "websockets-13.0.1-cp312-cp312-win32.whl", hash = "sha256:3624fd8664f2577cf8de996db3250662e259bfbc870dd8ebdcf5d7c6ac0b5185"}, + {file = "websockets-13.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0513c727fb8adffa6d9bf4a4463b2bade0186cbd8c3604ae5540fae18a90cb99"}, + {file = "websockets-13.0.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:1ee4cc030a4bdab482a37462dbf3ffb7e09334d01dd37d1063be1136a0d825fa"}, + {file = "websockets-13.0.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:dbb0b697cc0655719522406c059eae233abaa3243821cfdfab1215d02ac10231"}, + {file = "websockets-13.0.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:acbebec8cb3d4df6e2488fbf34702cbc37fc39ac7abf9449392cefb3305562e9"}, + {file = "websockets-13.0.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63848cdb6fcc0bf09d4a155464c46c64ffdb5807ede4fb251da2c2692559ce75"}, + {file = "websockets-13.0.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:872afa52a9f4c414d6955c365b6588bc4401272c629ff8321a55f44e3f62b553"}, + {file = "websockets-13.0.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:05e70fec7c54aad4d71eae8e8cab50525e899791fc389ec6f77b95312e4e9920"}, + {file = "websockets-13.0.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:e82db3756ccb66266504f5a3de05ac6b32f287faacff72462612120074103329"}, + {file = "websockets-13.0.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:4e85f46ce287f5c52438bb3703d86162263afccf034a5ef13dbe4318e98d86e7"}, + {file = "websockets-13.0.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f3fea72e4e6edb983908f0db373ae0732b275628901d909c382aae3b592589f2"}, + {file = "websockets-13.0.1-cp313-cp313-win32.whl", hash = "sha256:254ecf35572fca01a9f789a1d0f543898e222f7b69ecd7d5381d8d8047627bdb"}, + {file = "websockets-13.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:ca48914cdd9f2ccd94deab5bcb5ac98025a5ddce98881e5cce762854a5de330b"}, + {file = "websockets-13.0.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:b74593e9acf18ea5469c3edaa6b27fa7ecf97b30e9dabd5a94c4c940637ab96e"}, + {file = "websockets-13.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:132511bfd42e77d152c919147078460c88a795af16b50e42a0bd14f0ad71ddd2"}, + {file = "websockets-13.0.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:165bedf13556f985a2aa064309baa01462aa79bf6112fbd068ae38993a0e1f1b"}, + {file = "websockets-13.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e801ca2f448850685417d723ec70298feff3ce4ff687c6f20922c7474b4746ae"}, + {file = "websockets-13.0.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:30d3a1f041360f029765d8704eae606781e673e8918e6b2c792e0775de51352f"}, + {file = "websockets-13.0.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67648f5e50231b5a7f6d83b32f9c525e319f0ddc841be0de64f24928cd75a603"}, + {file = "websockets-13.0.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:4f0426d51c8f0926a4879390f53c7f5a855e42d68df95fff6032c82c888b5f36"}, + {file = "websockets-13.0.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:ef48e4137e8799998a343706531e656fdec6797b80efd029117edacb74b0a10a"}, + {file = "websockets-13.0.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:249aab278810bee585cd0d4de2f08cfd67eed4fc75bde623be163798ed4db2eb"}, + {file = "websockets-13.0.1-cp38-cp38-win32.whl", hash = "sha256:06c0a667e466fcb56a0886d924b5f29a7f0886199102f0a0e1c60a02a3751cb4"}, + {file = "websockets-13.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1f3cf6d6ec1142412d4535adabc6bd72a63f5f148c43fe559f06298bc21953c9"}, + {file = "websockets-13.0.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1fa082ea38d5de51dd409434edc27c0dcbd5fed2b09b9be982deb6f0508d25bc"}, + {file = "websockets-13.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4a365bcb7be554e6e1f9f3ed64016e67e2fa03d7b027a33e436aecf194febb63"}, + {file = "websockets-13.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:10a0dc7242215d794fb1918f69c6bb235f1f627aaf19e77f05336d147fce7c37"}, + {file = "websockets-13.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:59197afd478545b1f73367620407b0083303569c5f2d043afe5363676f2697c9"}, + {file = "websockets-13.0.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d20516990d8ad557b5abeb48127b8b779b0b7e6771a265fa3e91767596d7d97"}, + {file = "websockets-13.0.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a1a2e272d067030048e1fe41aa1ec8cfbbaabce733b3d634304fa2b19e5c897f"}, + {file = "websockets-13.0.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:ad327ac80ba7ee61da85383ca8822ff808ab5ada0e4a030d66703cc025b021c4"}, + {file = "websockets-13.0.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:518f90e6dd089d34eaade01101fd8a990921c3ba18ebbe9b0165b46ebff947f0"}, + {file = "websockets-13.0.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:68264802399aed6fe9652e89761031acc734fc4c653137a5911c2bfa995d6d6d"}, + {file = "websockets-13.0.1-cp39-cp39-win32.whl", hash = "sha256:a5dc0c42ded1557cc7c3f0240b24129aefbad88af4f09346164349391dea8e58"}, + {file = "websockets-13.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:b448a0690ef43db5ef31b3a0d9aea79043882b4632cfc3eaab20105edecf6097"}, + {file = "websockets-13.0.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:faef9ec6354fe4f9a2c0bbb52fb1ff852effc897e2a4501e25eb3a47cb0a4f89"}, + {file = "websockets-13.0.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:03d3f9ba172e0a53e37fa4e636b86cc60c3ab2cfee4935e66ed1d7acaa4625ad"}, + {file = "websockets-13.0.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d450f5a7a35662a9b91a64aefa852f0c0308ee256122f5218a42f1d13577d71e"}, + {file = "websockets-13.0.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3f55b36d17ac50aa8a171b771e15fbe1561217510c8768af3d546f56c7576cdc"}, + {file = "websockets-13.0.1-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14b9c006cac63772b31abbcd3e3abb6228233eec966bf062e89e7fa7ae0b7333"}, + {file = "websockets-13.0.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:b79915a1179a91f6c5f04ece1e592e2e8a6bd245a0e45d12fd56b2b59e559a32"}, + {file = "websockets-13.0.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f40de079779acbcdbb6ed4c65af9f018f8b77c5ec4e17a4b737c05c2db554491"}, + {file = "websockets-13.0.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:80e4ba642fc87fa532bac07e5ed7e19d56940b6af6a8c61d4429be48718a380f"}, + {file = "websockets-13.0.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a02b0161c43cc9e0232711eff846569fad6ec836a7acab16b3cf97b2344c060"}, + {file = "websockets-13.0.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6aa74a45d4cdc028561a7d6ab3272c8b3018e23723100b12e58be9dfa5a24491"}, + {file = "websockets-13.0.1-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:00fd961943b6c10ee6f0b1130753e50ac5dcd906130dcd77b0003c3ab797d026"}, + {file = "websockets-13.0.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:d93572720d781331fb10d3da9ca1067817d84ad1e7c31466e9f5e59965618096"}, + {file = "websockets-13.0.1-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:71e6e5a3a3728886caee9ab8752e8113670936a193284be9d6ad2176a137f376"}, + {file = "websockets-13.0.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:c4a6343e3b0714e80da0b0893543bf9a5b5fa71b846ae640e56e9abc6fbc4c83"}, + {file = "websockets-13.0.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a678532018e435396e37422a95e3ab87f75028ac79570ad11f5bf23cd2a7d8c"}, + {file = "websockets-13.0.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6716c087e4aa0b9260c4e579bb82e068f84faddb9bfba9906cb87726fa2e870"}, + {file = "websockets-13.0.1-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e33505534f3f673270dd67f81e73550b11de5b538c56fe04435d63c02c3f26b5"}, + {file = "websockets-13.0.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:acab3539a027a85d568c2573291e864333ec9d912675107d6efceb7e2be5d980"}, + {file = "websockets-13.0.1-py3-none-any.whl", hash = "sha256:b80f0c51681c517604152eb6a572f5a9378f877763231fddb883ba2f968e8817"}, + {file = "websockets-13.0.1.tar.gz", hash = "sha256:4d6ece65099411cfd9a48d13701d7438d9c34f479046b34c50ff60bb8834e43e"}, ] [[package]] @@ -9718,101 +10037,103 @@ files = [ [[package]] name = "yarl" -version = "1.9.4" +version = "1.9.11" description = "Yet another URL library" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a8c1df72eb746f4136fe9a2e72b0c9dc1da1cbd23b5372f94b5820ff8ae30e0e"}, - {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a3a6ed1d525bfb91b3fc9b690c5a21bb52de28c018530ad85093cc488bee2dd2"}, - {file = "yarl-1.9.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c38c9ddb6103ceae4e4498f9c08fac9b590c5c71b0370f98714768e22ac6fa66"}, - {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9e09c9d74f4566e905a0b8fa668c58109f7624db96a2171f21747abc7524234"}, - {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8477c1ee4bd47c57d49621a062121c3023609f7a13b8a46953eb6c9716ca392"}, - {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5ff2c858f5f6a42c2a8e751100f237c5e869cbde669a724f2062d4c4ef93551"}, - {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:357495293086c5b6d34ca9616a43d329317feab7917518bc97a08f9e55648455"}, - {file = "yarl-1.9.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54525ae423d7b7a8ee81ba189f131054defdb122cde31ff17477951464c1691c"}, - {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:801e9264d19643548651b9db361ce3287176671fb0117f96b5ac0ee1c3530d53"}, - {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e516dc8baf7b380e6c1c26792610230f37147bb754d6426462ab115a02944385"}, - {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:7d5aaac37d19b2904bb9dfe12cdb08c8443e7ba7d2852894ad448d4b8f442863"}, - {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:54beabb809ffcacbd9d28ac57b0db46e42a6e341a030293fb3185c409e626b8b"}, - {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bac8d525a8dbc2a1507ec731d2867025d11ceadcb4dd421423a5d42c56818541"}, - {file = "yarl-1.9.4-cp310-cp310-win32.whl", hash = "sha256:7855426dfbddac81896b6e533ebefc0af2f132d4a47340cee6d22cac7190022d"}, - {file = "yarl-1.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:848cd2a1df56ddbffeb375535fb62c9d1645dde33ca4d51341378b3f5954429b"}, - {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:35a2b9396879ce32754bd457d31a51ff0a9d426fd9e0e3c33394bf4b9036b099"}, - {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c7d56b293cc071e82532f70adcbd8b61909eec973ae9d2d1f9b233f3d943f2c"}, - {file = "yarl-1.9.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d8a1c6c0be645c745a081c192e747c5de06e944a0d21245f4cf7c05e457c36e0"}, - {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b3c1ffe10069f655ea2d731808e76e0f452fc6c749bea04781daf18e6039525"}, - {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:549d19c84c55d11687ddbd47eeb348a89df9cb30e1993f1b128f4685cd0ebbf8"}, - {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7409f968456111140c1c95301cadf071bd30a81cbd7ab829169fb9e3d72eae9"}, - {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e23a6d84d9d1738dbc6e38167776107e63307dfc8ad108e580548d1f2c587f42"}, - {file = "yarl-1.9.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8b889777de69897406c9fb0b76cdf2fd0f31267861ae7501d93003d55f54fbe"}, - {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:03caa9507d3d3c83bca08650678e25364e1843b484f19986a527630ca376ecce"}, - {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4e9035df8d0880b2f1c7f5031f33f69e071dfe72ee9310cfc76f7b605958ceb9"}, - {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:c0ec0ed476f77db9fb29bca17f0a8fcc7bc97ad4c6c1d8959c507decb22e8572"}, - {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:ee04010f26d5102399bd17f8df8bc38dc7ccd7701dc77f4a68c5b8d733406958"}, - {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49a180c2e0743d5d6e0b4d1a9e5f633c62eca3f8a86ba5dd3c471060e352ca98"}, - {file = "yarl-1.9.4-cp311-cp311-win32.whl", hash = "sha256:81eb57278deb6098a5b62e88ad8281b2ba09f2f1147c4767522353eaa6260b31"}, - {file = "yarl-1.9.4-cp311-cp311-win_amd64.whl", hash = "sha256:d1d2532b340b692880261c15aee4dc94dd22ca5d61b9db9a8a361953d36410b1"}, - {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0d2454f0aef65ea81037759be5ca9947539667eecebca092733b2eb43c965a81"}, - {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:44d8ffbb9c06e5a7f529f38f53eda23e50d1ed33c6c869e01481d3fafa6b8142"}, - {file = "yarl-1.9.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aaaea1e536f98754a6e5c56091baa1b6ce2f2700cc4a00b0d49eca8dea471074"}, - {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3777ce5536d17989c91696db1d459574e9a9bd37660ea7ee4d3344579bb6f129"}, - {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fc5fc1eeb029757349ad26bbc5880557389a03fa6ada41703db5e068881e5f2"}, - {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea65804b5dc88dacd4a40279af0cdadcfe74b3e5b4c897aa0d81cf86927fee78"}, - {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa102d6d280a5455ad6a0f9e6d769989638718e938a6a0a2ff3f4a7ff8c62cc4"}, - {file = "yarl-1.9.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09efe4615ada057ba2d30df871d2f668af661e971dfeedf0c159927d48bbeff0"}, - {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:008d3e808d03ef28542372d01057fd09168419cdc8f848efe2804f894ae03e51"}, - {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:6f5cb257bc2ec58f437da2b37a8cd48f666db96d47b8a3115c29f316313654ff"}, - {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:992f18e0ea248ee03b5a6e8b3b4738850ae7dbb172cc41c966462801cbf62cf7"}, - {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0e9d124c191d5b881060a9e5060627694c3bdd1fe24c5eecc8d5d7d0eb6faabc"}, - {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3986b6f41ad22988e53d5778f91855dc0399b043fc8946d4f2e68af22ee9ff10"}, - {file = "yarl-1.9.4-cp312-cp312-win32.whl", hash = "sha256:4b21516d181cd77ebd06ce160ef8cc2a5e9ad35fb1c5930882baff5ac865eee7"}, - {file = "yarl-1.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:a9bd00dc3bc395a662900f33f74feb3e757429e545d831eef5bb280252631984"}, - {file = "yarl-1.9.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:63b20738b5aac74e239622d2fe30df4fca4942a86e31bf47a81a0e94c14df94f"}, - {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7d7f7de27b8944f1fee2c26a88b4dabc2409d2fea7a9ed3df79b67277644e17"}, - {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c74018551e31269d56fab81a728f683667e7c28c04e807ba08f8c9e3bba32f14"}, - {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ca06675212f94e7a610e85ca36948bb8fc023e458dd6c63ef71abfd482481aa5"}, - {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5aef935237d60a51a62b86249839b51345f47564208c6ee615ed2a40878dccdd"}, - {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2b134fd795e2322b7684155b7855cc99409d10b2e408056db2b93b51a52accc7"}, - {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d25039a474c4c72a5ad4b52495056f843a7ff07b632c1b92ea9043a3d9950f6e"}, - {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:f7d6b36dd2e029b6bcb8a13cf19664c7b8e19ab3a58e0fefbb5b8461447ed5ec"}, - {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:957b4774373cf6f709359e5c8c4a0af9f6d7875db657adb0feaf8d6cb3c3964c"}, - {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d7eeb6d22331e2fd42fce928a81c697c9ee2d51400bd1a28803965883e13cead"}, - {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:6a962e04b8f91f8c4e5917e518d17958e3bdee71fd1d8b88cdce74dd0ebbf434"}, - {file = "yarl-1.9.4-cp37-cp37m-win32.whl", hash = "sha256:f3bc6af6e2b8f92eced34ef6a96ffb248e863af20ef4fde9448cc8c9b858b749"}, - {file = "yarl-1.9.4-cp37-cp37m-win_amd64.whl", hash = "sha256:ad4d7a90a92e528aadf4965d685c17dacff3df282db1121136c382dc0b6014d2"}, - {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ec61d826d80fc293ed46c9dd26995921e3a82146feacd952ef0757236fc137be"}, - {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8be9e837ea9113676e5754b43b940b50cce76d9ed7d2461df1af39a8ee674d9f"}, - {file = "yarl-1.9.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:bef596fdaa8f26e3d66af846bbe77057237cb6e8efff8cd7cc8dff9a62278bbf"}, - {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d47552b6e52c3319fede1b60b3de120fe83bde9b7bddad11a69fb0af7db32f1"}, - {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84fc30f71689d7fc9168b92788abc977dc8cefa806909565fc2951d02f6b7d57"}, - {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4aa9741085f635934f3a2583e16fcf62ba835719a8b2b28fb2917bb0537c1dfa"}, - {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:206a55215e6d05dbc6c98ce598a59e6fbd0c493e2de4ea6cc2f4934d5a18d130"}, - {file = "yarl-1.9.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07574b007ee20e5c375a8fe4a0789fad26db905f9813be0f9fef5a68080de559"}, - {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5a2e2433eb9344a163aced6a5f6c9222c0786e5a9e9cac2c89f0b28433f56e23"}, - {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:6ad6d10ed9b67a382b45f29ea028f92d25bc0bc1daf6c5b801b90b5aa70fb9ec"}, - {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:6fe79f998a4052d79e1c30eeb7d6c1c1056ad33300f682465e1b4e9b5a188b78"}, - {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a825ec844298c791fd28ed14ed1bffc56a98d15b8c58a20e0e08c1f5f2bea1be"}, - {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8619d6915b3b0b34420cf9b2bb6d81ef59d984cb0fde7544e9ece32b4b3043c3"}, - {file = "yarl-1.9.4-cp38-cp38-win32.whl", hash = "sha256:686a0c2f85f83463272ddffd4deb5e591c98aac1897d65e92319f729c320eece"}, - {file = "yarl-1.9.4-cp38-cp38-win_amd64.whl", hash = "sha256:a00862fb23195b6b8322f7d781b0dc1d82cb3bcac346d1e38689370cc1cc398b"}, - {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:604f31d97fa493083ea21bd9b92c419012531c4e17ea6da0f65cacdcf5d0bd27"}, - {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8a854227cf581330ffa2c4824d96e52ee621dd571078a252c25e3a3b3d94a1b1"}, - {file = "yarl-1.9.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ba6f52cbc7809cd8d74604cce9c14868306ae4aa0282016b641c661f981a6e91"}, - {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a6327976c7c2f4ee6816eff196e25385ccc02cb81427952414a64811037bbc8b"}, - {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8397a3817d7dcdd14bb266283cd1d6fc7264a48c186b986f32e86d86d35fbac5"}, - {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e0381b4ce23ff92f8170080c97678040fc5b08da85e9e292292aba67fdac6c34"}, - {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23d32a2594cb5d565d358a92e151315d1b2268bc10f4610d098f96b147370136"}, - {file = "yarl-1.9.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ddb2a5c08a4eaaba605340fdee8fc08e406c56617566d9643ad8bf6852778fc7"}, - {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:26a1dc6285e03f3cc9e839a2da83bcbf31dcb0d004c72d0730e755b33466c30e"}, - {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:18580f672e44ce1238b82f7fb87d727c4a131f3a9d33a5e0e82b793362bf18b4"}, - {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:29e0f83f37610f173eb7e7b5562dd71467993495e568e708d99e9d1944f561ec"}, - {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:1f23e4fe1e8794f74b6027d7cf19dc25f8b63af1483d91d595d4a07eca1fb26c"}, - {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:db8e58b9d79200c76956cefd14d5c90af54416ff5353c5bfd7cbe58818e26ef0"}, - {file = "yarl-1.9.4-cp39-cp39-win32.whl", hash = "sha256:c7224cab95645c7ab53791022ae77a4509472613e839dab722a72abe5a684575"}, - {file = "yarl-1.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:824d6c50492add5da9374875ce72db7a0733b29c2394890aef23d533106e2b15"}, - {file = "yarl-1.9.4-py3-none-any.whl", hash = "sha256:928cecb0ef9d5a7946eb6ff58417ad2fe9375762382f1bf5c55e61645f2c43ad"}, - {file = "yarl-1.9.4.tar.gz", hash = "sha256:566db86717cf8080b99b58b083b773a908ae40f06681e87e589a976faf8246bf"}, + {file = "yarl-1.9.11-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:79e08c691deae6fcac2fdde2e0515ac561dd3630d7c8adf7b1e786e22f1e193b"}, + {file = "yarl-1.9.11-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:752f4b5cf93268dc73c2ae994cc6d684b0dad5118bc87fbd965fd5d6dca20f45"}, + {file = "yarl-1.9.11-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:441049d3a449fb8756b0535be72c6a1a532938a33e1cf03523076700a5f87a01"}, + {file = "yarl-1.9.11-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b3dfe17b4aed832c627319da22a33f27f282bd32633d6b145c726d519c89fbaf"}, + {file = "yarl-1.9.11-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:67abcb7df27952864440c9c85f1c549a4ad94afe44e2655f77d74b0d25895454"}, + {file = "yarl-1.9.11-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6de3fa29e76fd1518a80e6af4902c44f3b1b4d7fed28eb06913bba4727443de3"}, + {file = "yarl-1.9.11-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fee45b3bd4d8d5786472e056aa1359cc4dc9da68aded95a10cd7929a0ec661fe"}, + {file = "yarl-1.9.11-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c59b23886234abeba62087fd97d10fb6b905d9e36e2f3465d1886ce5c0ca30df"}, + {file = "yarl-1.9.11-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d93c612b2024ac25a3dc01341fd98fdd19c8c5e2011f3dcd084b3743cba8d756"}, + {file = "yarl-1.9.11-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:4d368e3b9ecd50fa22017a20c49e356471af6ae91c4d788c6e9297e25ddf5a62"}, + {file = "yarl-1.9.11-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:5b593acd45cdd4cf6664d342ceacedf25cd95263b83b964fddd6c78930ea5211"}, + {file = "yarl-1.9.11-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:224f8186c220ff00079e64bf193909829144d4e5174bb58665ef0da8bf6955c4"}, + {file = "yarl-1.9.11-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:91c478741d7563a12162f7a2db96c0d23d93b0521563f1f1f0ece46ea1702d33"}, + {file = "yarl-1.9.11-cp310-cp310-win32.whl", hash = "sha256:1cdb8f5bb0534986776a43df84031da7ff04ac0cf87cb22ae8a6368231949c40"}, + {file = "yarl-1.9.11-cp310-cp310-win_amd64.whl", hash = "sha256:498439af143b43a2b2314451ffd0295410aa0dcbdac5ee18fc8633da4670b605"}, + {file = "yarl-1.9.11-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:9e290de5db4fd4859b4ed57cddfe793fcb218504e65781854a8ac283ab8d5518"}, + {file = "yarl-1.9.11-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e5f50a2e26cc2b89186f04c97e0ec0ba107ae41f1262ad16832d46849864f914"}, + {file = "yarl-1.9.11-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b4a0e724a28d7447e4d549c8f40779f90e20147e94bf949d490402eee09845c6"}, + {file = "yarl-1.9.11-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85333d38a4fa5997fa2ff6fd169be66626d814b34fa35ec669e8c914ca50a097"}, + {file = "yarl-1.9.11-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6ff184002ee72e4b247240e35d5dce4c2d9a0e81fdbef715dde79ab4718aa541"}, + {file = "yarl-1.9.11-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:675004040f847c0284827f44a1fa92d8baf425632cc93e7e0aa38408774b07c1"}, + {file = "yarl-1.9.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b30703a7ade2b53f02e09a30685b70cd54f65ed314a8d9af08670c9a5391af1b"}, + {file = "yarl-1.9.11-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7230007ab67d43cf19200ec15bc6b654e6b85c402f545a6fc565d254d34ff754"}, + {file = "yarl-1.9.11-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8c2cf0c7ad745e1c6530fe6521dfb19ca43338239dfcc7da165d0ef2332c0882"}, + {file = "yarl-1.9.11-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4567cc08f479ad80fb07ed0c9e1bcb363a4f6e3483a490a39d57d1419bf1c4c7"}, + {file = "yarl-1.9.11-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:95adc179a02949c4560ef40f8f650a008380766eb253d74232eb9c024747c111"}, + {file = "yarl-1.9.11-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:755ae9cff06c429632d750aa8206f08df2e3d422ca67be79567aadbe74ae64cc"}, + {file = "yarl-1.9.11-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:94f71d54c5faf715e92c8434b4a0b968c4d1043469954d228fc031d51086f143"}, + {file = "yarl-1.9.11-cp311-cp311-win32.whl", hash = "sha256:4ae079573efeaa54e5978ce86b77f4175cd32f42afcaf9bfb8a0677e91f84e4e"}, + {file = "yarl-1.9.11-cp311-cp311-win_amd64.whl", hash = "sha256:9fae7ec5c9a4fe22abb995804e6ce87067dfaf7e940272b79328ce37c8f22097"}, + {file = "yarl-1.9.11-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:614fa50fd0db41b79f426939a413d216cdc7bab8d8c8a25844798d286a999c5a"}, + {file = "yarl-1.9.11-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ff64f575d71eacb5a4d6f0696bfe991993d979423ea2241f23ab19ff63f0f9d1"}, + {file = "yarl-1.9.11-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5c23f6dc3d7126b4c64b80aa186ac2bb65ab104a8372c4454e462fb074197bc6"}, + {file = "yarl-1.9.11-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b8f847cc092c2b85d22e527f91ea83a6cf51533e727e2461557a47a859f96734"}, + {file = "yarl-1.9.11-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:63a5dc2866791236779d99d7a422611d22bb3a3d50935bafa4e017ea13e51469"}, + {file = "yarl-1.9.11-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c335342d482e66254ae94b1231b1532790afb754f89e2e0c646f7f19d09740aa"}, + {file = "yarl-1.9.11-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4a8c3dedd081cca134a21179aebe58b6e426e8d1e0202da9d1cafa56e01af3c"}, + {file = "yarl-1.9.11-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:504d19320c92532cabc3495fb7ed6bb599f3c2bfb45fed432049bf4693dbd6d0"}, + {file = "yarl-1.9.11-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0b2a8e5eb18181060197e3d5db7e78f818432725c0759bc1e5a9d603d9246389"}, + {file = "yarl-1.9.11-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:f568d70b7187f4002b6b500c0996c37674a25ce44b20716faebe5fdb8bd356e7"}, + {file = "yarl-1.9.11-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:735b285ea46ca7e86ad261a462a071d0968aade44e1a3ea2b7d4f3d63b5aab12"}, + {file = "yarl-1.9.11-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:2d1c81c3b92bef0c1c180048e43a5a85754a61b4f69d6f84df8e4bd615bef25d"}, + {file = "yarl-1.9.11-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:8d6e1c1562b53bd26efd38e886fc13863b8d904d559426777990171020c478a9"}, + {file = "yarl-1.9.11-cp312-cp312-win32.whl", hash = "sha256:aeba4aaa59cb709edb824fa88a27cbbff4e0095aaf77212b652989276c493c00"}, + {file = "yarl-1.9.11-cp312-cp312-win_amd64.whl", hash = "sha256:569309a3efb8369ff5d32edb2a0520ebaf810c3059f11d34477418c90aa878fd"}, + {file = "yarl-1.9.11-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:4915818ac850c3b0413e953af34398775b7a337babe1e4d15f68c8f5c4872553"}, + {file = "yarl-1.9.11-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ef9610b2f5a73707d4d8bac040f0115ca848e510e3b1f45ca53e97f609b54130"}, + {file = "yarl-1.9.11-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:47c0a3dc8076a8dd159de10628dea04215bc7ddaa46c5775bf96066a0a18f82b"}, + {file = "yarl-1.9.11-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:545f2fbfa0c723b446e9298b5beba0999ff82ce2c126110759e8dac29b5deaf4"}, + {file = "yarl-1.9.11-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9137975a4ccc163ad5d7a75aad966e6e4e95dedee08d7995eab896a639a0bce2"}, + {file = "yarl-1.9.11-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0b0c70c451d2a86f8408abced5b7498423e2487543acf6fcf618b03f6e669b0a"}, + {file = "yarl-1.9.11-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce2bd986b1e44528677c237b74d59f215c8bfcdf2d69442aa10f62fd6ab2951c"}, + {file = "yarl-1.9.11-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8d7b717f77846a9631046899c6cc730ea469c0e2fb252ccff1cc119950dbc296"}, + {file = "yarl-1.9.11-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:3a26a24bbd19241283d601173cea1e5b93dec361a223394e18a1e8e5b0ef20bd"}, + {file = "yarl-1.9.11-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:c189bf01af155ac9882e128d9f3b3ad68a1f2c2f51404afad7201305df4e12b1"}, + {file = "yarl-1.9.11-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:0cbcc2c54084b2bda4109415631db017cf2960f74f9e8fd1698e1400e4f8aae2"}, + {file = "yarl-1.9.11-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:30f201bc65941a4aa59c1236783efe89049ec5549dafc8cd2b63cc179d3767b0"}, + {file = "yarl-1.9.11-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:922ba3b74f0958a0b5b9c14ff1ef12714a381760c08018f2b9827632783a590c"}, + {file = "yarl-1.9.11-cp313-cp313-win32.whl", hash = "sha256:17107b4b8c43e66befdcbe543fff2f9c93f7a3a9f8e3a9c9ac42bffeba0e8828"}, + {file = "yarl-1.9.11-cp313-cp313-win_amd64.whl", hash = "sha256:0324506afab4f2e176a93cb08b8abcb8b009e1f324e6cbced999a8f5dd9ddb76"}, + {file = "yarl-1.9.11-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:4e4f820fde9437bb47297194f43d29086433e6467fa28fe9876366ad357bd7bb"}, + {file = "yarl-1.9.11-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:dfa9b9d5c9c0dbe69670f5695264452f5e40947590ec3a38cfddc9640ae8ff89"}, + {file = "yarl-1.9.11-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e700eb26635ce665c018c8cfea058baff9b843ed0cc77aa61849d807bb82a64c"}, + {file = "yarl-1.9.11-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c305c1bdf10869b5e51facf50bd5b15892884aeae81962ae4ba061fc11217103"}, + {file = "yarl-1.9.11-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c5b7b307140231ea4f7aad5b69355aba2a67f2d7bc34271cffa3c9c324d35b27"}, + {file = "yarl-1.9.11-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a744bdeda6c86cf3025c94eb0e01ccabe949cf385cd75b6576a3ac9669404b68"}, + {file = "yarl-1.9.11-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e8ed183c7a8f75e40068333fc185566472a8f6c77a750cf7541e11810576ea5"}, + {file = "yarl-1.9.11-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c1db9a4384694b5d20bdd9cb53f033b0831ac816416ab176c8d0997835015d22"}, + {file = "yarl-1.9.11-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:70194da6e99713250aa3f335a7fa246b36adf53672a2bcd0ddaa375d04e53dc0"}, + {file = "yarl-1.9.11-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:ddad5cfcda729e22422bb1c85520bdf2770ce6d975600573ac9017fe882f4b7e"}, + {file = "yarl-1.9.11-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:ca35996e0a4bed28fa0640d9512d37952f6b50dea583bcc167d4f0b1e112ac7f"}, + {file = "yarl-1.9.11-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:61ec0e80970b21a8f3c4b97fa6c6d181c6c6a135dbc7b4a601a78add3feeb209"}, + {file = "yarl-1.9.11-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:9636e4519f6c7558fdccf8f91e6e3b98df2340dc505c4cc3286986d33f2096c2"}, + {file = "yarl-1.9.11-cp38-cp38-win32.whl", hash = "sha256:58081cea14b8feda57c7ce447520e9d0a96c4d010cce54373d789c13242d7083"}, + {file = "yarl-1.9.11-cp38-cp38-win_amd64.whl", hash = "sha256:7d2dee7d6485807c0f64dd5eab9262b7c0b34f760e502243dd83ec09d647d5e1"}, + {file = "yarl-1.9.11-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:d65ad67f981e93ea11f87815f67d086c4f33da4800cf2106d650dd8a0b79dda4"}, + {file = "yarl-1.9.11-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:752c0d33b4aacdb147871d0754b88f53922c6dc2aff033096516b3d5f0c02a0f"}, + {file = "yarl-1.9.11-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:54cc24be98d7f4ff355ca2e725a577e19909788c0db6beead67a0dda70bd3f82"}, + {file = "yarl-1.9.11-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c82126817492bb2ebc946e74af1ffa10aacaca81bee360858477f96124be39a"}, + {file = "yarl-1.9.11-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8503989860d7ac10c85cb5b607fec003a45049cf7a5b4b72451e87893c6bb990"}, + {file = "yarl-1.9.11-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:475e09a67f8b09720192a170ad9021b7abf7827ffd4f3a83826317a705be06b7"}, + {file = "yarl-1.9.11-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afcac5bda602b74ff701e1f683feccd8cce0d5a21dbc68db81bf9bd8fd93ba56"}, + {file = "yarl-1.9.11-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aaeffcb84faceb2923a94a8a9aaa972745d3c728ab54dd011530cc30a3d5d0c1"}, + {file = "yarl-1.9.11-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:51a6f770ac86477cd5c553f88a77a06fe1f6f3b643b053fcc7902ab55d6cbe14"}, + {file = "yarl-1.9.11-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:3fcd056cb7dff3aea5b1ee1b425b0fbaa2fbf6a1c6003e88caf524f01de5f395"}, + {file = "yarl-1.9.11-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:21e56c30e39a1833e4e3fd0112dde98c2abcbc4c39b077e6105c76bb63d2aa04"}, + {file = "yarl-1.9.11-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:0a205ec6349879f5e75dddfb63e069a24f726df5330b92ce76c4752a436aac01"}, + {file = "yarl-1.9.11-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:a5706821e1cf3c70dfea223e4e0958ea354f4e2af9420a1bd45c6b547297fb97"}, + {file = "yarl-1.9.11-cp39-cp39-win32.whl", hash = "sha256:cc295969f8c2172b5d013c0871dccfec7a0e1186cf961e7ea575d47b4d5cbd32"}, + {file = "yarl-1.9.11-cp39-cp39-win_amd64.whl", hash = "sha256:55a67dd29367ce7c08a0541bb602ec0a2c10d46c86b94830a1a665f7fd093dfa"}, + {file = "yarl-1.9.11-py3-none-any.whl", hash = "sha256:c6f6c87665a9e18a635f0545ea541d9640617832af2317d4f5ad389686b4ed3d"}, + {file = "yarl-1.9.11.tar.gz", hash = "sha256:c7548a90cb72b67652e2cd6ae80e2683ee08fde663104528ac7df12d8ef271d2"}, ] [package.dependencies] @@ -9821,13 +10142,13 @@ multidict = ">=4.0" [[package]] name = "yfinance" -version = "0.2.41" +version = "0.2.43" description = "Download market data from Yahoo! Finance API" optional = false python-versions = "*" files = [ - {file = "yfinance-0.2.41-py2.py3-none-any.whl", hash = "sha256:2ed7b453cb8568773eb2dbb4d87cc37ff02e5d133f7723ec3e219ab0b86b56d8"}, - {file = "yfinance-0.2.41.tar.gz", hash = "sha256:f94409a1ed4d596b9da8d2dbb498faaabfcf593d5870e1412e17669a212bb345"}, + {file = "yfinance-0.2.43-py2.py3-none-any.whl", hash = "sha256:11b4f5515b17450bd3bdcdc26b299aeeaea7ff9cb63d0fa0a865f460c0c7618f"}, + {file = "yfinance-0.2.43.tar.gz", hash = "sha256:32404597f325a2a2c2708aceb8d552088dd26891ac0e6018f6c5f3f2f61055f0"}, ] [package.dependencies] @@ -9866,18 +10187,22 @@ requests = "*" [[package]] name = "zipp" -version = "3.20.0" +version = "3.20.1" description = "Backport of pathlib-compatible object wrapper for zip files" optional = false python-versions = ">=3.8" files = [ - {file = "zipp-3.20.0-py3-none-any.whl", hash = "sha256:58da6168be89f0be59beb194da1250516fdaa062ccebd30127ac65d30045e10d"}, - {file = "zipp-3.20.0.tar.gz", hash = "sha256:0145e43d89664cfe1a2e533adc75adafed82fe2da404b4bbb6b026c0157bdb31"}, + {file = "zipp-3.20.1-py3-none-any.whl", hash = "sha256:9960cd8967c8f85a56f920d5d507274e74f9ff813a0ab8889a5b5be2daf44064"}, + {file = "zipp-3.20.1.tar.gz", hash = "sha256:c22b14cc4763c5a5b04134207736c107db42e9d3ef2d9779d465f5f1bcba572b"}, ] [package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"] +type = ["pytest-mypy"] [[package]] name = "zope-event" @@ -9899,45 +10224,45 @@ test = ["zope.testrunner"] [[package]] name = "zope-interface" -version = "7.0.1" +version = "7.0.3" description = "Interfaces for Python" optional = false python-versions = ">=3.8" files = [ - {file = "zope.interface-7.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ec4e87e6fdc511a535254daa122c20e11959ce043b4e3425494b237692a34f1c"}, - {file = "zope.interface-7.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:51d5713e8e38f2d3ec26e0dfdca398ed0c20abda2eb49ffc15a15a23eb8e5f6d"}, - {file = "zope.interface-7.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea8d51e5eb29e57d34744369cd08267637aa5a0fefc9b5d33775ab7ff2ebf2e3"}, - {file = "zope.interface-7.0.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:55bbcc74dc0c7ab489c315c28b61d7a1d03cf938cc99cc58092eb065f120c3a5"}, - {file = "zope.interface-7.0.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10ebac566dd0cec66f942dc759d46a994a2b3ba7179420f0e2130f88f8a5f400"}, - {file = "zope.interface-7.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:7039e624bcb820f77cc2ff3d1adcce531932990eee16121077eb51d9c76b6c14"}, - {file = "zope.interface-7.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:03bd5c0db82237bbc47833a8b25f1cc090646e212f86b601903d79d7e6b37031"}, - {file = "zope.interface-7.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3f52050c6a10d4a039ec6f2c58e5b3ade5cc570d16cf9d102711e6b8413c90e6"}, - {file = "zope.interface-7.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af0b33f04677b57843d529b9257a475d2865403300b48c67654c40abac2f9f24"}, - {file = "zope.interface-7.0.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:696c2a381fc7876b3056711717dba5eddd07c2c9e5ccd50da54029a1293b6e43"}, - {file = "zope.interface-7.0.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f89a420cf5a6f2aa7849dd59e1ff0e477f562d97cf8d6a1ee03461e1eec39887"}, - {file = "zope.interface-7.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:b59deb0ddc7b431e41d720c00f99d68b52cb9bd1d5605a085dc18f502fe9c47f"}, - {file = "zope.interface-7.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:52f5253cca1b35eaeefa51abd366b87f48f8714097c99b131ba61f3fdbbb58e7"}, - {file = "zope.interface-7.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:88d108d004e0df25224de77ce349a7e73494ea2cb194031f7c9687e68a88ec9b"}, - {file = "zope.interface-7.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c203d82069ba31e1f3bc7ba530b2461ec86366cd4bfc9b95ec6ce58b1b559c34"}, - {file = "zope.interface-7.0.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3f3495462bc0438b76536a0e10d765b168ae636092082531b88340dc40dcd118"}, - {file = "zope.interface-7.0.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:192b7a792e3145ed880ff6b1a206fdb783697cfdb4915083bfca7065ec845e60"}, - {file = "zope.interface-7.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:400d06c9ec8dbcc96f56e79376297e7be07a315605c9a2208720da263d44d76f"}, - {file = "zope.interface-7.0.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c1dff87b30fd150c61367d0e2cdc49bb55f8b9fd2a303560bbc24b951573ae1"}, - {file = "zope.interface-7.0.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f749ca804648d00eda62fe1098f229b082dfca930d8bad8386e572a6eafa7525"}, - {file = "zope.interface-7.0.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4ec212037becf6d2f705b7ed4538d56980b1e7bba237df0d8995cbbed29961dc"}, - {file = "zope.interface-7.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d33cb526efdc235a2531433fc1287fcb80d807d5b401f9b801b78bf22df560dd"}, - {file = "zope.interface-7.0.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b419f2144e1762ab845f20316f1df36b15431f2622ebae8a6d5f7e8e712b413c"}, - {file = "zope.interface-7.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03f1452d5d1f279184d5bdb663a3dc39902d9320eceb63276240791e849054b6"}, - {file = "zope.interface-7.0.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ba4b3638d014918b918aa90a9c8370bd74a03abf8fcf9deb353b3a461a59a84"}, - {file = "zope.interface-7.0.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc0615351221926a36a0fbcb2520fb52e0b23e8c22a43754d9cb8f21358c33c0"}, - {file = "zope.interface-7.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:ce6cbb852fb8f2f9bb7b9cdca44e2e37bce783b5f4c167ff82cb5f5128163c8f"}, - {file = "zope.interface-7.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5566fd9271c89ad03d81b0831c37d46ae5e2ed211122c998637130159a120cf1"}, - {file = "zope.interface-7.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:da0cef4d7e3f19c3bd1d71658d6900321af0492fee36ec01b550a10924cffb9c"}, - {file = "zope.interface-7.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f32ca483e6ade23c7caaee9d5ee5d550cf4146e9b68d2fb6c68bac183aa41c37"}, - {file = "zope.interface-7.0.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:da21e7eec49252df34d426c2ee9cf0361c923026d37c24728b0fa4cc0599fd03"}, - {file = "zope.interface-7.0.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a8195b99e650e6f329ce4e5eb22d448bdfef0406404080812bc96e2a05674cb"}, - {file = "zope.interface-7.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:19c829d52e921b9fe0b2c0c6a8f9a2508c49678ee1be598f87d143335b6a35dc"}, - {file = "zope.interface-7.0.1.tar.gz", hash = "sha256:f0f5fda7cbf890371a59ab1d06512da4f2c89a6ea194e595808123c863c38eff"}, + {file = "zope.interface-7.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9b9369671a20b8d039b8e5a1a33abd12e089e319a3383b4cc0bf5c67bd05fe7b"}, + {file = "zope.interface-7.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:db6237e8fa91ea4f34d7e2d16d74741187e9105a63bbb5686c61fea04cdbacca"}, + {file = "zope.interface-7.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:53d678bb1c3b784edbfb0adeebfeea6bf479f54da082854406a8f295d36f8386"}, + {file = "zope.interface-7.0.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3aa8fcbb0d3c2be1bfd013a0f0acd636f6ed570c287743ae2bbd467ee967154d"}, + {file = "zope.interface-7.0.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6195c3c03fef9f87c0dbee0b3b6451df6e056322463cf35bca9a088e564a3c58"}, + {file = "zope.interface-7.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:11fa1382c3efb34abf16becff8cb214b0b2e3144057c90611621f2d186b7e1b7"}, + {file = "zope.interface-7.0.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:af94e429f9d57b36e71ef4e6865182090648aada0cb2d397ae2b3f7fc478493a"}, + {file = "zope.interface-7.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6dd647fcd765030638577fe6984284e0ebba1a1008244c8a38824be096e37fe3"}, + {file = "zope.interface-7.0.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1bee1b722077d08721005e8da493ef3adf0b7908e0cd85cc7dc836ac117d6f32"}, + {file = "zope.interface-7.0.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2545d6d7aac425d528cd9bf0d9e55fcd47ab7fd15f41a64b1c4bf4c6b24946dc"}, + {file = "zope.interface-7.0.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d04b11ea47c9c369d66340dbe51e9031df2a0de97d68f442305ed7625ad6493"}, + {file = "zope.interface-7.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:064ade95cb54c840647205987c7b557f75d2b2f7d1a84bfab4cf81822ef6e7d1"}, + {file = "zope.interface-7.0.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3fcdc76d0cde1c09c37b7c6b0f8beba2d857d8417b055d4f47df9c34ec518bdd"}, + {file = "zope.interface-7.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3d4b91821305c8d8f6e6207639abcbdaf186db682e521af7855d0bea3047c8ca"}, + {file = "zope.interface-7.0.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35062d93bc49bd9b191331c897a96155ffdad10744ab812485b6bad5b588d7e4"}, + {file = "zope.interface-7.0.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c96b3e6b0d4f6ddfec4e947130ec30bd2c7b19db6aa633777e46c8eecf1d6afd"}, + {file = "zope.interface-7.0.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e0c151a6c204f3830237c59ee4770cc346868a7a1af6925e5e38650141a7f05"}, + {file = "zope.interface-7.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:3de1d553ce72868b77a7e9d598c9bff6d3816ad2b4cc81c04f9d8914603814f3"}, + {file = "zope.interface-7.0.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ab985c566a99cc5f73bc2741d93f1ed24a2cc9da3890144d37b9582965aff996"}, + {file = "zope.interface-7.0.3-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d976fa7b5faf5396eb18ce6c132c98e05504b52b60784e3401f4ef0b2e66709b"}, + {file = "zope.interface-7.0.3-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21a207c6b2c58def5011768140861a73f5240f4f39800625072ba84e76c9da0b"}, + {file = "zope.interface-7.0.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:382d31d1e68877061daaa6499468e9eb38eb7625d4369b1615ac08d3860fe896"}, + {file = "zope.interface-7.0.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2c4316a30e216f51acbd9fb318aa5af2e362b716596d82cbb92f9101c8f8d2e7"}, + {file = "zope.interface-7.0.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:01e6e58078ad2799130c14a1d34ec89044ada0e1495329d72ee0407b9ae5100d"}, + {file = "zope.interface-7.0.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:799ef7a444aebbad5a145c3b34bff012b54453cddbde3332d47ca07225792ea4"}, + {file = "zope.interface-7.0.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d3b7ce6d46fb0e60897d62d1ff370790ce50a57d40a651db91a3dde74f73b738"}, + {file = "zope.interface-7.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:f418c88f09c3ba159b95a9d1cfcdbe58f208443abb1f3109f4b9b12fd60b187c"}, + {file = "zope.interface-7.0.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:84f8794bd59ca7d09d8fce43ae1b571be22f52748169d01a13d3ece8394d8b5b"}, + {file = "zope.interface-7.0.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7d92920416f31786bc1b2f34cc4fc4263a35a407425319572cbf96b51e835cd3"}, + {file = "zope.interface-7.0.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:95e5913ec718010dc0e7c215d79a9683b4990e7026828eedfda5268e74e73e11"}, + {file = "zope.interface-7.0.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1eeeb92cb7d95c45e726e3c1afe7707919370addae7ed14f614e22217a536958"}, + {file = "zope.interface-7.0.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ecd32f30f40bfd8511b17666895831a51b532e93fc106bfa97f366589d3e4e0e"}, + {file = "zope.interface-7.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:5112c530fa8aa2108a3196b9c2f078f5738c1c37cfc716970edc0df0414acda8"}, + {file = "zope.interface-7.0.3.tar.gz", hash = "sha256:cd2690d4b08ec9eaf47a85914fe513062b20da78d10d6d789a792c0b20307fb1"}, ] [package.dependencies] @@ -10063,4 +10388,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "e4c00268514d26bd07c6b72925e0e3b4558ec972895d252e60e9571e3ac38895" +content-hash = "2dbff415c3c9ca95c8dcfb59fc088ce2c0d00037c44f386a34c87c98e1d8b942" diff --git a/api/pyproject.toml b/api/pyproject.toml index f1d5e213ae..23e2b5c549 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -6,8 +6,6 @@ requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.ruff] -exclude = [ -] line-length = 120 [tool.ruff.lint] @@ -15,17 +13,11 @@ preview = true select = [ "B", # flake8-bugbear rules "C4", # flake8-comprehensions + "E", # pycodestyle E rules "F", # pyflakes rules "I", # isort rules + "N", # pep8-naming "UP", # pyupgrade rules - "B035", # static-key-dict-comprehension - "E101", # mixed-spaces-and-tabs - "E111", # indentation-with-invalid-multiple - "E112", # no-indented-block - "E113", # unexpected-indentation - "E115", # no-indented-block-comment - "E116", # unexpected-indentation-comment - "E117", # over-indented "RUF019", # unnecessary-key-check "RUF100", # unused-noqa "RUF101", # redirected-noqa @@ -35,10 +27,15 @@ select = [ "SIM910", # dict-get-with-none-default "W191", # tab-indentation "W605", # invalid-escape-sequence - "F601", # multi-value-repeated-key-literal - "F602", # multi-value-repeated-key-variable ] ignore = [ + "E501", # line-too-long + "E402", # module-import-not-at-top-of-file + "E711", # none-comparison + "E712", # true-false-comparison + "E721", # type-comparison + "E722", # bare-except + "E731", # lambda-assignment "F403", # undefined-local-with-import-star "F405", # undefined-local-with-import-star-usage "F821", # undefined-name @@ -49,9 +46,10 @@ ignore = [ "B006", # mutable-argument-default "B007", # unused-loop-control-variable "B026", # star-arg-unpacking-after-keyword-arg -# "B901", # return-in-generator "B904", # raise-without-from-inside-except "B905", # zip-without-explicit-strict + "N806", # non-lowercase-variable-in-function + "N815", # mixed-case-variable-in-class-scope ] [tool.ruff.lint.per-file-ignores] @@ -67,11 +65,15 @@ ignore = [ "F401", # unused-import "F811", # redefined-while-unused ] +"configs/*" = [ + "N802", # invalid-function-name +] +"libs/gmpy2_pkcs10aep_cipher.py" = [ + "N803", # invalid-argument-name +] [tool.ruff.format] exclude = [ - "core/**/*.py", - "models/**/*.py", "migrations/**/*", ] @@ -115,12 +117,14 @@ azure-identity = "1.16.1" azure-storage-blob = "12.13.0" beautifulsoup4 = "4.12.2" boto3 = "1.34.148" +sagemaker = "2.231.0" bs4 = "~0.0.1" cachetools = "~5.3.0" celery = "~5.3.6" chardet = "~5.1.0" cohere = "~5.2.4" cos-python-sdk-v5 = "1.9.30" +esdk-obs-python = "3.24.6.1" dashscope = { version = "~1.17.0", extras = ["tokenizer"] } flask = "~3.0.1" flask-compress = "~1.14" @@ -191,6 +195,8 @@ zhipuai = "1.0.7" azure-ai-ml = "^1.19.0" azure-ai-inference = "^1.0.0b3" volcengine-python-sdk = {extras = ["ark"], version = "^1.0.98"} +oci = "^2.133.0" +tos = "^2.7.1" [tool.poetry.group.indriect.dependencies] kaleido = "0.2.1" rank-bm25 = "~0.2.2" diff --git a/api/services/account_service.py b/api/services/account_service.py index d291e4a6c7..db9e2ed1ad 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -35,7 +35,7 @@ from services.errors.account import ( NoPermissionError, RateLimitExceededError, RoleAlreadyAssignedError, - TenantNotFound, + TenantNotFoundError, ) from services.errors.workspace import WorkSpaceNotAllowedCreateError from tasks.mail_email_code_login import send_email_code_login_mail_task @@ -345,7 +345,10 @@ class TenantService: if available_ta: return - tenant = TenantService.create_tenant(f"{account.name}'s Workspace") + if name: + tenant = TenantService.create_tenant(name) + else: + tenant = TenantService.create_tenant(f"{account.name}'s Workspace") TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant db.session.commit() @@ -379,13 +382,13 @@ class TenantService: """Get tenant by account and add the role""" tenant = account.current_tenant if not tenant: - raise TenantNotFound("Tenant not found.") + raise TenantNotFoundError("Tenant not found.") ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() if ta: tenant.role = ta.role else: - raise TenantNotFound("Tenant not found for the account.") + raise TenantNotFoundError("Tenant not found for the account.") return tenant @staticmethod @@ -686,8 +689,8 @@ class RegisterService: "email": account.email, "workspace_id": tenant.id, } - expiryHours = dify_config.INVITE_EXPIRY_HOURS - redis_client.setex(cls._get_invitation_token_key(token), expiryHours * 60 * 60, json.dumps(invitation_data)) + expiry_hours = dify_config.INVITE_EXPIRY_HOURS + redis_client.setex(cls._get_invitation_token_key(token), expiry_hours * 60 * 60, json.dumps(invitation_data)) return token @classmethod diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index a2aa15ed4b..73c446b83b 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -13,8 +13,9 @@ from services.workflow_service import WorkflowService logger = logging.getLogger(__name__) -current_dsl_version = "0.1.1" +current_dsl_version = "0.1.2" dsl_to_dify_version_mapping: dict[str, str] = { + "0.1.2": "0.8.0", "0.1.1": "0.6.0", # dsl version -> from dify version } @@ -87,6 +88,7 @@ class AppDslService: icon_background = ( args.get("icon_background") if args.get("icon_background") else app_data.get("icon_background") ) + use_icon_as_answer_icon = app_data.get("use_icon_as_answer_icon", False) # import dsl and create app app_mode = AppMode.value_of(app_data.get("mode")) @@ -101,6 +103,7 @@ class AppDslService: icon_type=icon_type, icon=icon, icon_background=icon_background, + use_icon_as_answer_icon=use_icon_as_answer_icon, ) elif app_mode in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]: app = cls._import_and_create_new_model_config_based_app( @@ -113,6 +116,7 @@ class AppDslService: icon_type=icon_type, icon=icon, icon_background=icon_background, + use_icon_as_answer_icon=use_icon_as_answer_icon, ) else: raise ValueError("Invalid app mode") @@ -171,6 +175,7 @@ class AppDslService: "icon": "🤖" if app_model.icon_type == "image" else app_model.icon, "icon_background": "#FFEAD5" if app_model.icon_type == "image" else app_model.icon_background, "description": app_model.description, + "use_icon_as_answer_icon": app_model.use_icon_as_answer_icon, }, } @@ -218,6 +223,7 @@ class AppDslService: icon_type: str, icon: str, icon_background: str, + use_icon_as_answer_icon: bool, ) -> App: """ Import app dsl and create new workflow based app @@ -231,6 +237,7 @@ class AppDslService: :param icon_type: app icon type, "emoji" or "image" :param icon: app icon :param icon_background: app icon background + :param use_icon_as_answer_icon: use app icon as answer icon """ if not workflow_data: raise ValueError("Missing workflow in data argument " "when app mode is advanced-chat or workflow") @@ -244,6 +251,7 @@ class AppDslService: icon_type=icon_type, icon=icon, icon_background=icon_background, + use_icon_as_answer_icon=use_icon_as_answer_icon, ) # init draft workflow @@ -316,6 +324,7 @@ class AppDslService: icon_type: str, icon: str, icon_background: str, + use_icon_as_answer_icon: bool, ) -> App: """ Import app dsl and create new model config based app @@ -341,6 +350,7 @@ class AppDslService: icon_type=icon_type, icon=icon, icon_background=icon_background, + use_icon_as_answer_icon=use_icon_as_answer_icon, ) app_model_config = AppModelConfig() @@ -369,6 +379,7 @@ class AppDslService: icon_type: str, icon: str, icon_background: str, + use_icon_as_answer_icon: bool, ) -> App: """ Create new app @@ -381,6 +392,7 @@ class AppDslService: :param icon_type: app icon type, "emoji" or "image" :param icon: app icon :param icon_background: app icon background + :param use_icon_as_answer_icon: use app icon as answer icon """ app = App( tenant_id=tenant_id, @@ -392,6 +404,7 @@ class AppDslService: icon_background=icon_background, enable_site=True, enable_api=True, + use_icon_as_answer_icon=use_icon_as_answer_icon, created_by=account.id, updated_by=account.id, ) diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 747505977f..26517a05fb 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -12,6 +12,7 @@ from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom from core.app.features.rate_limiting import RateLimit from models.model import Account, App, AppMode, EndUser +from models.workflow import Workflow from services.errors.llm import InvokeRateLimitError from services.workflow_service import WorkflowService @@ -103,9 +104,7 @@ class AppGenerateService: return max_active_requests @classmethod - def generate_single_iteration( - cls, app_model: App, user: Union[Account, EndUser], node_id: str, args: Any, streaming: bool = True - ): + def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True): if app_model.mode == AppMode.ADVANCED_CHAT.value: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator().single_iteration_generate( @@ -142,7 +141,7 @@ class AppGenerateService: ) @classmethod - def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Any: + def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Workflow: """ Get workflow :param app_model: app model diff --git a/api/services/app_service.py b/api/services/app_service.py index 462613fb7d..1dacfea246 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -221,6 +221,7 @@ class AppService: app.icon_type = args.get("icon_type", "emoji") app.icon = args.get("icon") app.icon_background = args.get("icon_background") + app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False) app.updated_by = current_user.id app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 8649d0fea5..cce0874cf4 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -15,7 +15,7 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.models.document import Document as RAGDocument -from core.rag.retrieval.retrival_methods import RetrievalMethod +from core.rag.retrieval.retrieval_methods import RetrievalMethod from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted from extensions.ext_database import db @@ -136,7 +136,9 @@ class DatasetService: return datasets.items, datasets.total @staticmethod - def create_empty_dataset(tenant_id: str, name: str, indexing_technique: Optional[str], account: Account): + def create_empty_dataset( + tenant_id: str, name: str, indexing_technique: Optional[str], account: Account, permission: Optional[str] = None + ): # check if dataset name already exists if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first(): raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") @@ -153,6 +155,7 @@ class DatasetService: dataset.tenant_id = tenant_id dataset.embedding_model_provider = embedding_model.provider if embedding_model else None dataset.embedding_model = embedding_model.model if embedding_model else None + dataset.permission = permission if permission else DatasetPermissionEnum.ONLY_ME db.session.add(dataset) db.session.commit() return dataset @@ -1051,16 +1054,11 @@ class DocumentService: DocumentService.check_documents_upload_quota(count, features) - embedding_model = None dataset_collection_binding_id = None retrieval_model = None if document_data["indexing_technique"] == "high_quality": - model_manager = ModelManager() - embedding_model = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING - ) dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, embedding_model.model + document_data["embedding_model_provider"], document_data["embedding_model"] ) dataset_collection_binding_id = dataset_collection_binding.id if document_data.get("retrieval_model"): @@ -1079,10 +1077,10 @@ class DocumentService: tenant_id=tenant_id, name="", data_source_type=document_data["data_source"]["type"], - indexing_technique=document_data["indexing_technique"], + indexing_technique=document_data.get("indexing_technique", "high_quality"), created_by=account.id, - embedding_model=embedding_model.model if embedding_model else None, - embedding_model_provider=embedding_model.provider if embedding_model else None, + embedding_model=document_data.get("embedding_model"), + embedding_model_provider=document_data.get("embedding_model_provider"), collection_binding_id=dataset_collection_binding_id, retrieval_model=retrieval_model, ) diff --git a/api/services/errors/account.py b/api/services/errors/account.py index ac1551716d..5aca12ffeb 100644 --- a/api/services/errors/account.py +++ b/api/services/errors/account.py @@ -1,7 +1,7 @@ from services.errors.base import BaseServiceError -class AccountNotFound(BaseServiceError): +class AccountNotFoundError(BaseServiceError): pass @@ -29,7 +29,7 @@ class LinkAccountIntegrateError(BaseServiceError): pass -class TenantNotFound(BaseServiceError): +class TenantNotFoundError(BaseServiceError): pass diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index db99064814..2f911f5036 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -3,7 +3,7 @@ import time from core.rag.datasource.retrieval_service import RetrievalService from core.rag.models.document import Document -from core.rag.retrieval.retrival_methods import RetrievalMethod +from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from models.account import Account from models.dataset import Dataset, DatasetQuery, DocumentSegment @@ -36,7 +36,7 @@ class HitTestingService: retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model all_documents = RetrievalService.retrieve( - retrival_method=retrieval_model.get("search_method", "semantic_search"), + retrieval_method=retrieval_model.get("search_method", "semantic_search"), dataset_id=dataset.id, query=cls.escape_query_for_search(query), top_k=retrieval_model.get("top_k", 2), diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 35aa6817e1..1e7935d299 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -32,7 +32,15 @@ class OpsService: "project_key" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_key") ): project_key = OpsTraceManager.get_trace_config_project_key(decrypt_tracing_config, tracing_provider) - new_decrypt_tracing_config.update({"project_key": project_key}) + new_decrypt_tracing_config.update( + {"project_url": "{host}/project/{key}".format(host=decrypt_tracing_config.get("host"), key=project_key)} + ) + + if tracing_provider == "langsmith" and ( + "project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url") + ): + project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider) + new_decrypt_tracing_config.update({"project_url": project_url}) trace_config_data.tracing_config = new_decrypt_tracing_config return trace_config_data.to_dict() @@ -62,8 +70,14 @@ class OpsService: if not OpsTraceManager.check_trace_config_is_effective(tracing_config, tracing_provider): return {"error": "Invalid Credentials"} - # get project key - project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider) + # get project url + if tracing_provider == "langfuse": + project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider) + project_url = "{host}/project/{key}".format(host=tracing_config.get("host"), key=project_key) + elif tracing_provider == "langsmith": + project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider) + else: + project_url = None # check if trace config already exists trace_config_data: TraceAppConfig = ( @@ -78,8 +92,8 @@ class OpsService: # get tenant id tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id tracing_config = OpsTraceManager.encrypt_tracing_config(tenant_id, tracing_provider, tracing_config) - if tracing_provider == "langfuse" and project_key: - tracing_config["project_key"] = project_key + if project_url: + tracing_config["project_url"] = project_url trace_config_data = TraceAppConfig( app_id=app_id, tracing_provider=tracing_provider, diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 4c3ded14ad..357ffd41c1 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -8,9 +8,11 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.segments import Variable from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.workflow_engine_manager import WorkflowEngineManager +from core.workflow.nodes.event import RunCompletedEvent +from core.workflow.nodes.node_mapping import node_classes +from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db from models.account import Account @@ -172,8 +174,13 @@ class WorkflowService: Get default block configs """ # return default block config - workflow_engine_manager = WorkflowEngineManager() - return workflow_engine_manager.get_default_configs() + default_block_configs = [] + for node_type, node_class in node_classes.items(): + default_config = node_class.get_default_config() + if default_config: + default_block_configs.append(default_config) + + return default_block_configs def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]: """ @@ -182,11 +189,18 @@ class WorkflowService: :param filters: filter by node config parameters. :return: """ - node_type = NodeType.value_of(node_type) + node_type_enum: NodeType = NodeType.value_of(node_type) # return default block config - workflow_engine_manager = WorkflowEngineManager() - return workflow_engine_manager.get_default_config(node_type, filters) + node_class = node_classes.get(node_type_enum) + if not node_class: + return None + + default_config = node_class.get_default_config(filters=filters) + if not default_config: + return None + + return default_config def run_draft_workflow_node( self, app_model: App, node_id: str, user_inputs: dict, account: Account @@ -200,82 +214,68 @@ class WorkflowService: raise ValueError("Workflow not initialized") # run draft workflow node - workflow_engine_manager = WorkflowEngineManager() start_at = time.perf_counter() try: - node_instance, node_run_result = workflow_engine_manager.single_step_run_workflow_node( + node_instance, generator = WorkflowEntry.single_step_run( workflow=draft_workflow, node_id=node_id, user_inputs=user_inputs, user_id=account.id, ) + + node_run_result: NodeRunResult | None = None + for event in generator: + if isinstance(event, RunCompletedEvent): + node_run_result = event.run_result + + # sign output files + node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) + break + + if not node_run_result: + raise ValueError("Node run failed with no run result") + + run_succeeded = True if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED else False + error = node_run_result.error if not run_succeeded else None except WorkflowNodeRunFailedError as e: - workflow_node_execution = WorkflowNodeExecution( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - workflow_id=draft_workflow.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, - index=1, - node_id=e.node_id, - node_type=e.node_type.value, - title=e.node_title, - status=WorkflowNodeExecutionStatus.FAILED.value, - error=e.error, - elapsed_time=time.perf_counter() - start_at, - created_by_role=CreatedByRole.ACCOUNT.value, - created_by=account.id, - created_at=datetime.now(timezone.utc).replace(tzinfo=None), - finished_at=datetime.now(timezone.utc).replace(tzinfo=None), - ) - db.session.add(workflow_node_execution) - db.session.commit() + node_instance = e.node_instance + run_succeeded = False + node_run_result = None + error = e.error - return workflow_node_execution + workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.tenant_id = app_model.tenant_id + workflow_node_execution.app_id = app_model.id + workflow_node_execution.workflow_id = draft_workflow.id + workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value + workflow_node_execution.index = 1 + workflow_node_execution.node_id = node_id + workflow_node_execution.node_type = node_instance.node_type.value + workflow_node_execution.title = node_instance.node_data.title + workflow_node_execution.elapsed_time = time.perf_counter() - start_at + workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value + workflow_node_execution.created_by = account.id + workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) - if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + if run_succeeded and node_run_result: # create workflow node execution - workflow_node_execution = WorkflowNodeExecution( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - workflow_id=draft_workflow.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, - index=1, - node_id=node_id, - node_type=node_instance.node_type.value, - title=node_instance.node_data.title, - inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None, - process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None, - outputs=json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None, - execution_metadata=( - json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None - ), - status=WorkflowNodeExecutionStatus.SUCCEEDED.value, - elapsed_time=time.perf_counter() - start_at, - created_by_role=CreatedByRole.ACCOUNT.value, - created_by=account.id, - created_at=datetime.now(timezone.utc).replace(tzinfo=None), - finished_at=datetime.now(timezone.utc).replace(tzinfo=None), + workflow_node_execution.inputs = json.dumps(node_run_result.inputs) if node_run_result.inputs else None + workflow_node_execution.process_data = ( + json.dumps(node_run_result.process_data) if node_run_result.process_data else None ) + workflow_node_execution.outputs = ( + json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None + ) + workflow_node_execution.execution_metadata = ( + json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None + ) + workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value else: # create workflow node execution - workflow_node_execution = WorkflowNodeExecution( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - workflow_id=draft_workflow.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, - index=1, - node_id=node_id, - node_type=node_instance.node_type.value, - title=node_instance.node_data.title, - status=node_run_result.status.value, - error=node_run_result.error, - elapsed_time=time.perf_counter() - start_at, - created_by_role=CreatedByRole.ACCOUNT.value, - created_by=account.id, - created_at=datetime.now(timezone.utc).replace(tzinfo=None), - finished_at=datetime.now(timezone.utc).replace(tzinfo=None), - ) + workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value + workflow_node_execution.error = error db.session.add(workflow_node_execution) db.session.commit() @@ -321,25 +321,3 @@ class WorkflowService: ) else: raise ValueError(f"Invalid app mode: {app_model.mode}") - - @classmethod - def get_elapsed_time(cls, workflow_run_id: str) -> float: - """ - Get elapsed time - """ - elapsed_time = 0.0 - - # fetch workflow node execution by workflow_run_id - workflow_nodes = ( - db.session.query(WorkflowNodeExecution) - .filter(WorkflowNodeExecution.workflow_run_id == workflow_run_id) - .order_by(WorkflowNodeExecution.created_at.asc()) - .all() - ) - if not workflow_nodes: - return elapsed_time - - for node in workflow_nodes: - elapsed_time += node.elapsed_time - - return elapsed_time diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 9ea4c99649..6dd755ab03 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -6,7 +6,7 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.extractor.notion_extractor import NotionExtractor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db @@ -106,7 +106,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): logging.info( click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green") ) - except DocumentIsPausedException as ex: + except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index e0da5f9ed0..72c4674e0f 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -6,7 +6,7 @@ import click from celery import shared_task from configs import dify_config -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from extensions.ext_database import db from models.dataset import Dataset, Document from services.feature_service import FeatureService @@ -72,7 +72,7 @@ def document_indexing_task(dataset_id: str, document_ids: list): indexing_runner.run(documents) end_at = time.perf_counter() logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) - except DocumentIsPausedException as ex: + except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 6e681bcf4f..cb38bc668d 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -6,7 +6,7 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment @@ -69,7 +69,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): indexing_runner.run([document]) end_at = time.perf_counter() logging.info(click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green")) - except DocumentIsPausedException as ex: + except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index 0a7568c385..f4c3dbd2e2 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -6,7 +6,7 @@ import click from celery import shared_task from configs import dify_config -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment @@ -88,7 +88,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): indexing_runner.run(documents) end_at = time.perf_counter() logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) - except DocumentIsPausedException as ex: + except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/mail_invite_member_task.py b/api/tasks/mail_invite_member_task.py index 4ef6e29994..c7dfb9bf60 100644 --- a/api/tasks/mail_invite_member_task.py +++ b/api/tasks/mail_invite_member_task.py @@ -19,7 +19,7 @@ def send_invite_member_mail_task(language: str, to: str, token: str, inviter_nam :param inviter_name :param workspace_name - Usage: send_invite_member_mail_task.delay(langauge, to, token, inviter_name, workspace_name) + Usage: send_invite_member_mail_task.delay(language, to, token, inviter_name, workspace_name) """ if not mail.is_inited(): return diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index 18bae14ffa..21ea11d4dd 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -5,7 +5,7 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from extensions.ext_database import db from models.dataset import Document @@ -39,7 +39,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): logging.info( click.style("Processed document: {} latency: {}".format(document.id, end_at - start_at), fg="green") ) - except DocumentIsPausedException as ex: + except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index 99fb66e1f3..1d2a338c83 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -20,7 +20,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): :param dataset_id: :param document_id: - Usage: sunc_website_document_indexing_task.delay(dataset_id, document_id) + Usage: sync_website_document_indexing_task.delay(dataset_id, document_id) """ start_at = time.perf_counter() diff --git a/api/tests/integration_tests/model_runtime/__mock/fishaudio.py b/api/tests/integration_tests/model_runtime/__mock/fishaudio.py new file mode 100644 index 0000000000..bec3babeaf --- /dev/null +++ b/api/tests/integration_tests/model_runtime/__mock/fishaudio.py @@ -0,0 +1,82 @@ +import os +from collections.abc import Callable +from typing import Literal + +import httpx +import pytest +from _pytest.monkeypatch import MonkeyPatch + + +def mock_get(*args, **kwargs): + if kwargs.get("headers", {}).get("Authorization") != "Bearer test": + raise httpx.HTTPStatusError( + "Invalid API key", + request=httpx.Request("GET", ""), + response=httpx.Response(401), + ) + + return httpx.Response( + 200, + json={ + "items": [ + {"title": "Model 1", "_id": "model1"}, + {"title": "Model 2", "_id": "model2"}, + ] + }, + request=httpx.Request("GET", ""), + ) + + +def mock_stream(*args, **kwargs): + class MockStreamResponse: + def __init__(self): + self.status_code = 200 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def iter_bytes(self): + yield b"Mocked audio data" + + return MockStreamResponse() + + +def mock_fishaudio( + monkeypatch: MonkeyPatch, + methods: list[Literal["list-models", "tts"]], +) -> Callable[[], None]: + """ + mock fishaudio module + + :param monkeypatch: pytest monkeypatch fixture + :return: unpatch function + """ + + def unpatch() -> None: + monkeypatch.undo() + + if "list-models" in methods: + monkeypatch.setattr(httpx, "get", mock_get) + + if "tts" in methods: + monkeypatch.setattr(httpx, "stream", mock_stream) + + return unpatch + + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_fishaudio_mock(request, monkeypatch): + methods = request.param if hasattr(request, "param") else [] + if MOCK: + unpatch = mock_fishaudio(monkeypatch, methods=methods) + + yield + + if MOCK: + unpatch() diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py b/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py index b37b109eba..83317e59de 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py @@ -70,6 +70,7 @@ class MockTEIClass: }, } + @staticmethod def invoke_rerank(server_url: str, query: str, texts: list[str]) -> list[dict]: # Example response: # [ diff --git a/api/tests/integration_tests/model_runtime/baichuan/test_llm.py b/api/tests/integration_tests/model_runtime/baichuan/test_llm.py index ad58610287..fe7fe96891 100644 --- a/api/tests/integration_tests/model_runtime/baichuan/test_llm.py +++ b/api/tests/integration_tests/model_runtime/baichuan/test_llm.py @@ -8,11 +8,11 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.baichuan.llm.llm import BaichuanLarguageModel +from core.model_runtime.model_providers.baichuan.llm.llm import BaichuanLanguageModel def test_predefined_models(): - model = BaichuanLarguageModel() + model = BaichuanLanguageModel() model_schemas = model.predefined_models() assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) @@ -20,7 +20,7 @@ def test_predefined_models(): def test_validate_credentials_for_chat_model(): sleep(3) - model = BaichuanLarguageModel() + model = BaichuanLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( @@ -38,7 +38,7 @@ def test_validate_credentials_for_chat_model(): def test_invoke_model(): sleep(3) - model = BaichuanLarguageModel() + model = BaichuanLanguageModel() response = model.invoke( model="baichuan2-turbo", @@ -64,7 +64,7 @@ def test_invoke_model(): def test_invoke_model_with_system_message(): sleep(3) - model = BaichuanLarguageModel() + model = BaichuanLanguageModel() response = model.invoke( model="baichuan2-turbo", @@ -93,7 +93,7 @@ def test_invoke_model_with_system_message(): def test_invoke_stream_model(): sleep(3) - model = BaichuanLarguageModel() + model = BaichuanLanguageModel() response = model.invoke( model="baichuan2-turbo", @@ -122,7 +122,7 @@ def test_invoke_stream_model(): def test_invoke_with_search(): sleep(3) - model = BaichuanLarguageModel() + model = BaichuanLanguageModel() response = model.invoke( model="baichuan2-turbo", @@ -156,7 +156,7 @@ def test_invoke_with_search(): def test_get_num_tokens(): sleep(3) - model = BaichuanLarguageModel() + model = BaichuanLanguageModel() response = model.get_num_tokens( model="baichuan2-turbo", diff --git a/api/tests/integration_tests/model_runtime/fishaudio/__init__.py b/api/tests/integration_tests/model_runtime/fishaudio/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/fishaudio/test_provider.py b/api/tests/integration_tests/model_runtime/fishaudio/test_provider.py new file mode 100644 index 0000000000..3526574b61 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/fishaudio/test_provider.py @@ -0,0 +1,33 @@ +import os + +import httpx +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.fishaudio.fishaudio import FishAudioProvider +from tests.integration_tests.model_runtime.__mock.fishaudio import setup_fishaudio_mock + + +@pytest.mark.parametrize("setup_fishaudio_mock", [["list-models"]], indirect=True) +def test_validate_provider_credentials(setup_fishaudio_mock): + print("-----", httpx.get) + provider = FishAudioProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials( + credentials={ + "api_key": "bad_api_key", + "api_base": os.environ.get("FISH_AUDIO_API_BASE", "https://api.fish.audio"), + "use_public_models": "false", + "latency": "normal", + } + ) + + provider.validate_provider_credentials( + credentials={ + "api_key": os.environ.get("FISH_AUDIO_API_KEY", "test"), + "api_base": os.environ.get("FISH_AUDIO_API_BASE", "https://api.fish.audio"), + "use_public_models": "false", + "latency": "normal", + } + ) diff --git a/api/tests/integration_tests/model_runtime/fishaudio/test_tts.py b/api/tests/integration_tests/model_runtime/fishaudio/test_tts.py new file mode 100644 index 0000000000..f61fee28b9 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/fishaudio/test_tts.py @@ -0,0 +1,32 @@ +import os + +import pytest + +from core.model_runtime.model_providers.fishaudio.tts.tts import ( + FishAudioText2SpeechModel, +) +from tests.integration_tests.model_runtime.__mock.fishaudio import setup_fishaudio_mock + + +@pytest.mark.parametrize("setup_fishaudio_mock", [["tts"]], indirect=True) +def test_invoke_model(setup_fishaudio_mock): + model = FishAudioText2SpeechModel() + + result = model.invoke( + model="tts-default", + tenant_id="test", + credentials={ + "api_key": os.environ.get("FISH_AUDIO_API_KEY", "test"), + "api_base": os.environ.get("FISH_AUDIO_API_BASE", "https://api.fish.audio"), + "use_public_models": "false", + "latency": "normal", + }, + content_text="Hello, world!", + voice="03397b4c4be74759b72533b663fbd001", + ) + + content = b"" + for chunk in result: + content += chunk + + assert content != b"" diff --git a/api/tests/integration_tests/model_runtime/google/test_llm.py b/api/tests/integration_tests/model_runtime/google/test_llm.py index 4d9d490a87..34d08f270a 100644 --- a/api/tests/integration_tests/model_runtime/google/test_llm.py +++ b/api/tests/integration_tests/model_runtime/google/test_llm.py @@ -155,7 +155,7 @@ def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock): user="abc-123", ) - print(f"resultz: {result.message.content}") + print(f"result: {result.message.content}") assert isinstance(result, LLMResult) assert len(result.message.content) > 0 diff --git a/api/tests/integration_tests/model_runtime/oci/__init__.py b/api/tests/integration_tests/model_runtime/oci/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/oci/test_llm.py b/api/tests/integration_tests/model_runtime/oci/test_llm.py new file mode 100644 index 0000000000..531f26a32e --- /dev/null +++ b/api/tests/integration_tests/model_runtime/oci/test_llm.py @@ -0,0 +1,130 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.oci.llm.llm import OCILargeLanguageModel + + +def test_validate_credentials(): + model = OCILargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="cohere.command-r-plus", + credentials={"oci_config_content": "invalid_key", "oci_key_content": "invalid_key"}, + ) + + model.validate_credentials( + model="cohere.command-r-plus", + credentials={ + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), + }, + ) + + +def test_invoke_model(): + model = OCILargeLanguageModel() + + response = model.invoke( + model="cohere.command-r-plus", + credentials={ + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), + }, + prompt_messages=[UserPromptMessage(content="Hi")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_model(): + model = OCILargeLanguageModel() + + response = model.invoke( + model="meta.llama-3-70b-instruct", + credentials={ + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), + }, + prompt_messages=[UserPromptMessage(content="Hi")], + model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234}, + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_invoke_model_with_function(): + model = OCILargeLanguageModel() + + response = model.invoke( + model="cohere.command-r-plus", + credentials={ + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), + }, + prompt_messages=[UserPromptMessage(content="Hi")], + model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234}, + stream=False, + user="abc-123", + tools=[ + PromptMessageTool( + name="get_current_weather", + description="Get the current weather in a given location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + ) + ], + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_get_num_tokens(): + model = OCILargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="cohere.command-r-plus", + credentials={ + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert num_tokens == 18 diff --git a/api/tests/integration_tests/model_runtime/oci/test_provider.py b/api/tests/integration_tests/model_runtime/oci/test_provider.py new file mode 100644 index 0000000000..2c7107c7cc --- /dev/null +++ b/api/tests/integration_tests/model_runtime/oci/test_provider.py @@ -0,0 +1,20 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.oci.oci import OCIGENAIProvider + + +def test_validate_provider_credentials(): + provider = OCIGENAIProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={}) + + provider.validate_provider_credentials( + credentials={ + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), + } + ) diff --git a/api/tests/integration_tests/model_runtime/oci/test_text_embedding.py b/api/tests/integration_tests/model_runtime/oci/test_text_embedding.py new file mode 100644 index 0000000000..032c5c681a --- /dev/null +++ b/api/tests/integration_tests/model_runtime/oci/test_text_embedding.py @@ -0,0 +1,58 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.oci.text_embedding.text_embedding import OCITextEmbeddingModel + + +def test_validate_credentials(): + model = OCITextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="cohere.embed-multilingual-v3.0", + credentials={"oci_config_content": "invalid_key", "oci_key_content": "invalid_key"}, + ) + + model.validate_credentials( + model="cohere.embed-multilingual-v3.0", + credentials={ + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), + }, + ) + + +def test_invoke_model(): + model = OCITextEmbeddingModel() + + result = model.invoke( + model="cohere.embed-multilingual-v3.0", + credentials={ + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), + }, + texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 4 + # assert result.usage.total_tokens == 811 + + +def test_get_num_tokens(): + model = OCITextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="cohere.embed-multilingual-v3.0", + credentials={ + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), + }, + texts=["hello", "world"], + ) + + assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/xinference/test_llm.py b/api/tests/integration_tests/model_runtime/xinference/test_llm.py index 48d1ae323d..7db59fddef 100644 --- a/api/tests/integration_tests/model_runtime/xinference/test_llm.py +++ b/api/tests/integration_tests/model_runtime/xinference/test_llm.py @@ -109,7 +109,7 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock): """ - Funtion calling of xinference does not support stream mode currently + Function calling of xinference does not support stream mode currently """ # def test_invoke_stream_chat_model_with_functions(): # model = XinferenceAILargeLanguageModel() diff --git a/api/tests/integration_tests/tools/__mock/http.py b/api/tests/integration_tests/tools/__mock/http.py index 4dfc530010..d3c1f3101c 100644 --- a/api/tests/integration_tests/tools/__mock/http.py +++ b/api/tests/integration_tests/tools/__mock/http.py @@ -7,6 +7,7 @@ from _pytest.monkeypatch import MonkeyPatch class MockedHttp: + @staticmethod def httpx_request( method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs ) -> httpx.Response: diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py index 571c1e3d44..53c9b3cae3 100644 --- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -13,7 +13,7 @@ from xinference_client.types import Embedding class MockTcvectordbClass: - def VectorDBClient( + def mock_vector_db_client( self, url=None, username="", @@ -110,7 +110,7 @@ MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" @pytest.fixture def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch): if MOCK: - monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.VectorDBClient) + monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.mock_vector_db_client) monkeypatch.setattr(VectorDBClient, "list_databases", MockTcvectordbClass.list_databases) monkeypatch.setattr(Database, "collection", MockTcvectordbClass.describe_collection) monkeypatch.setattr(Database, "list_collections", MockTcvectordbClass.list_collections) diff --git a/api/tests/integration_tests/vdb/milvus/test_milvus.py b/api/tests/integration_tests/vdb/milvus/test_milvus.py index 7b5f19ea62..c99739a863 100644 --- a/api/tests/integration_tests/vdb/milvus/test_milvus.py +++ b/api/tests/integration_tests/vdb/milvus/test_milvus.py @@ -12,8 +12,7 @@ class MilvusVectorTest(AbstractVectorTest): self.vector = MilvusVector( collection_name=self.collection_name, config=MilvusConfig( - host="localhost", - port=19530, + uri="http://localhost:19530", user="root", password="Milvus", ), diff --git a/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py b/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py index 6b33217d15..6497f47deb 100644 --- a/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py +++ b/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py @@ -31,5 +31,5 @@ class PGVectoRSVectorTest(AbstractVectorTest): assert len(ids) == 1 -def test_pgvecot_rs(setup_mock_redis): +def test_pgvecto_rs(setup_mock_redis): PGVectoRSVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/workflow/nodes/__mock/http.py b/api/tests/integration_tests/workflow/nodes/__mock/http.py index cfc47bcad4..f1ab23b002 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/http.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/http.py @@ -10,6 +10,7 @@ MOCK = os.getenv("MOCK_SWITCH", "false") == "true" class MockedHttp: + @staticmethod def httpx_request( method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs ) -> httpx.Response: diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py index 44dcf9a10f..487178ff58 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py @@ -1,11 +1,11 @@ import pytest -from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor CODE_LANGUAGE = "unsupported_language" def test_unsupported_with_code_template(): - with pytest.raises(CodeExecutionException) as e: + with pytest.raises(CodeExecutionError) as e: CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={}) assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}" diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 6f5421e108..952c90674d 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -1,17 +1,72 @@ +import time +import uuid from os import getenv +from typing import cast import pytest from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import NodeRunResult, UserFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import UserFrom +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.code.code_node import CodeNode -from models.workflow import WorkflowNodeExecutionStatus +from core.workflow.nodes.code.entities import CodeNodeData +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000")) +def init_code_node(code_config: dict): + graph_config = { + "edges": [ + { + "id": "start-source-code-target", + "source": "start", + "target": "code", + }, + ], + "nodes": [{"data": {"type": "start"}, "id": "start"}, code_config], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["code", "123", "args1"], 1) + variable_pool.add(["code", "123", "args2"], 2) + + node = CodeNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config=code_config, + ) + + return node + + @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) def test_execute_code(setup_code_executor_mock): code = """ @@ -22,44 +77,36 @@ def test_execute_code(setup_code_executor_mock): """ # trim first 4 spaces at the beginning of each line code = "\n".join([line[4:] for line in code.split("\n")]) - node = CodeNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.WEB_APP, - config={ - "id": "1", - "data": { - "outputs": { - "result": { - "type": "number", - }, - }, - "title": "123", - "variables": [ - { - "variable": "args1", - "value_selector": ["1", "123", "args1"], - }, - {"variable": "args2", "value_selector": ["1", "123", "args2"]}, - ], - "answer": "123", - "code_language": "python3", - "code": code, - }, - }, - ) - # construct variable pool - pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(["1", "123", "args1"], 1) - pool.add(["1", "123", "args2"], 2) + code_config = { + "id": "code", + "data": { + "outputs": { + "result": { + "type": "number", + }, + }, + "title": "123", + "variables": [ + { + "variable": "args1", + "value_selector": ["1", "123", "args1"], + }, + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, + ], + "answer": "123", + "code_language": "python3", + "code": code, + }, + } + + node = init_code_node(code_config) # execute node - result = node.run(pool) + result = node._run() + assert isinstance(result, NodeRunResult) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert result.outputs["result"] == 3 assert result.error is None @@ -74,44 +121,34 @@ def test_execute_code_output_validator(setup_code_executor_mock): """ # trim first 4 spaces at the beginning of each line code = "\n".join([line[4:] for line in code.split("\n")]) - node = CodeNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.WEB_APP, - config={ - "id": "1", - "data": { - "outputs": { - "result": { - "type": "string", - }, - }, - "title": "123", - "variables": [ - { - "variable": "args1", - "value_selector": ["1", "123", "args1"], - }, - {"variable": "args2", "value_selector": ["1", "123", "args2"]}, - ], - "answer": "123", - "code_language": "python3", - "code": code, - }, - }, - ) - # construct variable pool - pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(["1", "123", "args1"], 1) - pool.add(["1", "123", "args2"], 2) + code_config = { + "id": "code", + "data": { + "outputs": { + "result": { + "type": "string", + }, + }, + "title": "123", + "variables": [ + { + "variable": "args1", + "value_selector": ["1", "123", "args1"], + }, + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, + ], + "answer": "123", + "code_language": "python3", + "code": code, + }, + } + + node = init_code_node(code_config) # execute node - result = node.run(pool) - + result = node._run() + assert isinstance(result, NodeRunResult) assert result.status == WorkflowNodeExecutionStatus.FAILED assert result.error == "Output variable `result` must be a string" @@ -127,65 +164,60 @@ def test_execute_code_output_validator_depth(): """ # trim first 4 spaces at the beginning of each line code = "\n".join([line[4:] for line in code.split("\n")]) - node = CodeNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.WEB_APP, - config={ - "id": "1", - "data": { - "outputs": { - "string_validator": { - "type": "string", - }, - "number_validator": { - "type": "number", - }, - "number_array_validator": { - "type": "array[number]", - }, - "string_array_validator": { - "type": "array[string]", - }, - "object_validator": { - "type": "object", - "children": { - "result": { - "type": "number", - }, - "depth": { - "type": "object", - "children": { - "depth": { - "type": "object", - "children": { - "depth": { - "type": "number", - } - }, - } - }, + + code_config = { + "id": "code", + "data": { + "outputs": { + "string_validator": { + "type": "string", + }, + "number_validator": { + "type": "number", + }, + "number_array_validator": { + "type": "array[number]", + }, + "string_array_validator": { + "type": "array[string]", + }, + "object_validator": { + "type": "object", + "children": { + "result": { + "type": "number", + }, + "depth": { + "type": "object", + "children": { + "depth": { + "type": "object", + "children": { + "depth": { + "type": "number", + } + }, + } }, }, }, }, - "title": "123", - "variables": [ - { - "variable": "args1", - "value_selector": ["1", "123", "args1"], - }, - {"variable": "args2", "value_selector": ["1", "123", "args2"]}, - ], - "answer": "123", - "code_language": "python3", - "code": code, }, + "title": "123", + "variables": [ + { + "variable": "args1", + "value_selector": ["1", "123", "args1"], + }, + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, + ], + "answer": "123", + "code_language": "python3", + "code": code, }, - ) + } + + node = init_code_node(code_config) # construct result result = { @@ -196,6 +228,8 @@ def test_execute_code_output_validator_depth(): "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, } + node.node_data = cast(CodeNodeData, node.node_data) + # validate node._transform_result(result, node.node_data.outputs) @@ -250,35 +284,30 @@ def test_execute_code_output_object_list(): """ # trim first 4 spaces at the beginning of each line code = "\n".join([line[4:] for line in code.split("\n")]) - node = CodeNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, - config={ - "id": "1", - "data": { - "outputs": { - "object_list": { - "type": "array[object]", - }, + + code_config = { + "id": "code", + "data": { + "outputs": { + "object_list": { + "type": "array[object]", }, - "title": "123", - "variables": [ - { - "variable": "args1", - "value_selector": ["1", "123", "args1"], - }, - {"variable": "args2", "value_selector": ["1", "123", "args2"]}, - ], - "answer": "123", - "code_language": "python3", - "code": code, }, + "title": "123", + "variables": [ + { + "variable": "args1", + "value_selector": ["1", "123", "args1"], + }, + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, + ], + "answer": "123", + "code_language": "python3", + "code": code, }, - ) + } + + node = init_code_node(code_config) # construct result result = { @@ -295,6 +324,8 @@ def test_execute_code_output_object_list(): ] } + node.node_data = cast(CodeNodeData, node.node_data) + # validate node._transform_result(result, node.node_data.outputs) diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index acb616b325..65aaa0bddd 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -1,31 +1,69 @@ +import time +import uuid from urllib.parse import urlencode import pytest from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import UserFrom +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.http_request.http_request_node import HttpRequestNode +from models.workflow import WorkflowType from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock -BASIC_NODE_DATA = { - "tenant_id": "1", - "app_id": "1", - "workflow_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.WEB_APP, -} -# construct variable pool -pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) -pool.add(["a", "b123", "args1"], 1) -pool.add(["a", "b123", "args2"], 2) +def init_http_node(config: dict): + graph_config = { + "edges": [ + { + "id": "start-source-next-target", + "source": "start", + "target": "1", + }, + ], + "nodes": [{"data": {"type": "start"}, "id": "start"}, config], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["a", "b123", "args1"], 1) + variable_pool.add(["a", "b123", "args2"], 2) + + return HttpRequestNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config=config, + ) @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_get(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -45,12 +83,11 @@ def test_get(setup_http_mock): "params": "A:b", "body": None, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) - + result = node._run() + assert result.process_data is not None data = result.process_data.get("request", "") assert "?A=b" in data @@ -59,7 +96,7 @@ def test_get(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_no_auth(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -75,12 +112,11 @@ def test_no_auth(setup_http_mock): "params": "A:b", "body": None, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) - + result = node._run() + assert result.process_data is not None data = result.process_data.get("request", "") assert "?A=b" in data @@ -89,7 +125,7 @@ def test_no_auth(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_authorization_header(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -109,12 +145,11 @@ def test_custom_authorization_header(setup_http_mock): "params": "A:b", "body": None, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) - + result = node._run() + assert result.process_data is not None data = result.process_data.get("request", "") assert "?A=b" in data @@ -123,7 +158,7 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_template(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -143,11 +178,11 @@ def test_template(setup_http_mock): "params": "A:b\nTemplate:{{#a.b123.args2#}}", "body": None, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) + result = node._run() + assert result.process_data is not None data = result.process_data.get("request", "") assert "?A=b" in data @@ -158,7 +193,7 @@ def test_template(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_json(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -178,11 +213,11 @@ def test_json(setup_http_mock): "params": "A:b", "body": {"type": "json", "data": '{"a": "{{#a.b123.args1#}}"}'}, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) + result = node._run() + assert result.process_data is not None data = result.process_data.get("request", "") assert '{"a": "1"}' in data @@ -190,7 +225,7 @@ def test_json(setup_http_mock): def test_x_www_form_urlencoded(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -210,11 +245,11 @@ def test_x_www_form_urlencoded(setup_http_mock): "params": "A:b", "body": {"type": "x-www-form-urlencoded", "data": "a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}"}, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) + result = node._run() + assert result.process_data is not None data = result.process_data.get("request", "") assert "a=1&b=2" in data @@ -222,7 +257,7 @@ def test_x_www_form_urlencoded(setup_http_mock): def test_form_data(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -242,11 +277,11 @@ def test_form_data(setup_http_mock): "params": "A:b", "body": {"type": "form-data", "data": "a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}"}, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) + result = node._run() + assert result.process_data is not None data = result.process_data.get("request", "") assert 'form-data; name="a"' in data @@ -257,7 +292,7 @@ def test_form_data(setup_http_mock): def test_none_data(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -277,11 +312,11 @@ def test_none_data(setup_http_mock): "params": "A:b", "body": {"type": "none", "data": "123123123"}, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) + result = node._run() + assert result.process_data is not None data = result.process_data.get("request", "") assert "X-Header: 123" in data @@ -289,7 +324,7 @@ def test_none_data(setup_http_mock): def test_mock_404(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -305,19 +340,19 @@ def test_mock_404(setup_http_mock): "params": "", "headers": "X-Header:123", }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) + result = node._run() + assert result.outputs is not None resp = result.outputs assert 404 == resp.get("status_code") - assert "Not Found" in resp.get("body") + assert "Not Found" in resp.get("body", "") def test_multi_colons_parse(setup_http_mock): - node = HttpRequestNode( + node = init_http_node( config={ "id": "1", "data": { @@ -333,13 +368,14 @@ def test_multi_colons_parse(setup_http_mock): "headers": "Referer:http://example3.com\nRedirect:http://example4.com", "body": {"type": "form-data", "data": "Referer:http://example5.com\nRedirect:http://example6.com"}, }, - }, - **BASIC_NODE_DATA, + } ) - result = node.run(pool) + result = node._run() + assert result.process_data is not None + assert result.outputs is not None resp = result.outputs - assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request") - assert 'form-data; name="Redirect"\n\nhttp://example6.com' in result.process_data.get("request") - assert "http://example3.com" == resp.get("headers").get("referer") + assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request", "") + assert 'form-data; name="Redirect"\n\nhttp://example6.com' in result.process_data.get("request", "") + assert "http://example3.com" == resp.get("headers", {}).get("referer") diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 6bab83a019..dfb43650d2 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -1,5 +1,8 @@ import json import os +import time +import uuid +from collections.abc import Generator from unittest.mock import MagicMock import pytest @@ -10,28 +13,77 @@ from core.entities.provider_entities import CustomConfiguration, CustomProviderC from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers import ModelProviderFactory +from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.base_node import UserFrom +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.llm.llm_node import LLMNode from extensions.ext_database import db from models.provider import ProviderType -from models.workflow import WorkflowNodeExecutionStatus +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType """FOR MOCK FIXTURES, DO NOT REMOVE""" from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock -@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) -def test_execute_llm(setup_openai_mock): - node = LLMNode( +def init_llm_node(config: dict) -> LLMNode: + graph_config = { + "edges": [ + { + "id": "start-source-next-target", + "source": "start", + "target": "llm", + }, + ], + "nodes": [{"data": {"type": "start"}, "id": "start"}, config], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( tenant_id="1", app_id="1", + workflow_type=WorkflowType.WORKFLOW, workflow_id="1", + graph_config=graph_config, user_id="1", - invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather today?", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["abc", "output"], "sunny") + + node = LLMNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config=config, + ) + + return node + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_execute_llm(setup_openai_mock): + node = init_llm_node( config={ "id": "llm", "data": { @@ -49,19 +101,6 @@ def test_execute_llm(setup_openai_mock): }, ) - # construct variable pool - pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "what's the weather today?", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, - user_inputs={}, - environment_variables=[], - ) - pool.add(["abc", "output"], "sunny") - credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} provider_instance = ModelProviderFactory().get_provider_instance("openai") @@ -80,13 +119,15 @@ def test_execute_llm(setup_openai_mock): model_type_instance=model_type_instance, ) model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model="gpt-3.5-turbo") + model_schema = model_type_instance.get_model_schema("gpt-3.5-turbo") + assert model_schema is not None model_config = ModelConfigWithCredentialsEntity( model="gpt-3.5-turbo", provider="openai", mode="chat", credentials=credentials, parameters={}, - model_schema=model_type_instance.get_model_schema("gpt-3.5-turbo"), + model_schema=model_schema, provider_model_bundle=provider_model_bundle, ) @@ -96,11 +137,16 @@ def test_execute_llm(setup_openai_mock): node._fetch_model_config = MagicMock(return_value=(model_instance, model_config)) # execute node - result = node.run(pool) + result = node._run() + assert isinstance(result, Generator) - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs["text"] is not None - assert result.outputs["usage"]["total_tokens"] > 0 + for item in result: + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.process_data is not None + assert item.run_result.outputs is not None + assert item.run_result.outputs.get("text") is not None + assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) @@ -109,13 +155,7 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): """ Test execute LLM node with jinja2 """ - node = LLMNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, + node = init_llm_node( config={ "id": "llm", "data": { @@ -149,19 +189,6 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): }, ) - # construct variable pool - pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "what's the weather today?", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, - user_inputs={}, - environment_variables=[], - ) - pool.add(["abc", "output"], "sunny") - credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} provider_instance = ModelProviderFactory().get_provider_instance("openai") @@ -181,14 +208,15 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): ) model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model="gpt-3.5-turbo") - + model_schema = model_type_instance.get_model_schema("gpt-3.5-turbo") + assert model_schema is not None model_config = ModelConfigWithCredentialsEntity( model="gpt-3.5-turbo", provider="openai", mode="chat", credentials=credentials, parameters={}, - model_schema=model_type_instance.get_model_schema("gpt-3.5-turbo"), + model_schema=model_schema, provider_model_bundle=provider_model_bundle, ) @@ -198,8 +226,11 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): node._fetch_model_config = MagicMock(return_value=(model_instance, model_config)) # execute node - result = node.run(pool) + result = node._run() - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "sunny" in json.dumps(result.process_data) - assert "what's the weather today?" in json.dumps(result.process_data) + for item in result: + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.process_data is not None + assert "sunny" in json.dumps(item.run_result.process_data) + assert "what's the weather today?" in json.dumps(item.run_result.process_data) diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index ca2bae5c53..cbe9c5914f 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -1,5 +1,7 @@ import json import os +import time +import uuid from typing import Optional from unittest.mock import MagicMock @@ -8,19 +10,21 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration -from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.base_node import UserFrom +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from extensions.ext_database import db from models.provider import ProviderType """FOR MOCK FIXTURES, DO NOT REMOVE""" -from models.workflow import WorkflowNodeExecutionStatus +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock @@ -47,13 +51,15 @@ def get_mocked_fetch_model_config( model_type_instance=model_type_instance, ) model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model=model) + model_schema = model_type_instance.get_model_schema(model) + assert model_schema is not None model_config = ModelConfigWithCredentialsEntity( model=model, provider=provider, mode=mode, credentials=credentials, parameters={}, - model_schema=model_type_instance.get_model_schema(model), + model_schema=model_schema, provider_model_bundle=provider_model_bundle, ) @@ -74,18 +80,62 @@ def get_mocked_fetch_memory(memory_text: str): return MagicMock(return_value=MemoryMock()) +def init_parameter_extractor_node(config: dict): + graph_config = { + "edges": [ + { + "id": "start-source-next-target", + "source": "start", + "target": "llm", + }, + ], + "nodes": [{"data": {"type": "start"}, "id": "start"}, config], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["a", "b123", "args1"], 1) + variable_pool.add(["a", "b123", "args2"], 2) + + return ParameterExtractorNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config=config, + ) + + @pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_function_calling_parameter_extractor(setup_openai_mock): """ Test function calling for parameter extractor. """ - node = ParameterExtractorNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, + node = init_parameter_extractor_node( config={ "id": "llm", "data": { @@ -98,7 +148,7 @@ def test_function_calling_parameter_extractor(setup_openai_mock): "reasoning_mode": "function_call", "memory": None, }, - }, + } ) node._fetch_model_config = get_mocked_fetch_model_config( @@ -121,9 +171,10 @@ def test_function_calling_parameter_extractor(setup_openai_mock): environment_variables=[], ) - result = node.run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert result.outputs.get("location") == "kawaii" assert result.outputs.get("__reason") == None @@ -133,13 +184,7 @@ def test_instructions(setup_openai_mock): """ Test chat parameter extractor. """ - node = ParameterExtractorNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, + node = init_parameter_extractor_node( config={ "id": "llm", "data": { @@ -163,29 +208,19 @@ def test_instructions(setup_openai_mock): ) db.session.close = MagicMock() - # construct variable pool - pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "what's the weather in SF", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, - user_inputs={}, - environment_variables=[], - ) - - result = node.run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert result.outputs.get("location") == "kawaii" assert result.outputs.get("__reason") == None process_data = result.process_data + assert process_data is not None process_data.get("prompts") - for prompt in process_data.get("prompts"): + for prompt in process_data.get("prompts", []): if prompt.get("role") == "system": assert "what's the weather in SF" in prompt.get("text") @@ -195,13 +230,7 @@ def test_chat_parameter_extractor(setup_anthropic_mock): """ Test chat parameter extractor. """ - node = ParameterExtractorNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, + node = init_parameter_extractor_node( config={ "id": "llm", "data": { @@ -225,27 +254,17 @@ def test_chat_parameter_extractor(setup_anthropic_mock): ) db.session.close = MagicMock() - # construct variable pool - pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "what's the weather in SF", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, - user_inputs={}, - environment_variables=[], - ) - - result = node.run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert result.outputs.get("location") == "" assert ( result.outputs.get("__reason") == "Failed to extract result from function call or text response, using empty result." ) - prompts = result.process_data.get("prompts") + assert result.process_data is not None + prompts = result.process_data.get("prompts", []) for prompt in prompts: if prompt.get("role") == "user": @@ -258,13 +277,7 @@ def test_completion_parameter_extractor(setup_openai_mock): """ Test completion parameter extractor. """ - node = ParameterExtractorNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, + node = init_parameter_extractor_node( config={ "id": "llm", "data": { @@ -293,28 +306,18 @@ def test_completion_parameter_extractor(setup_openai_mock): ) db.session.close = MagicMock() - # construct variable pool - pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "what's the weather in SF", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, - user_inputs={}, - environment_variables=[], - ) - - result = node.run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert result.outputs.get("location") == "" assert ( result.outputs.get("__reason") == "Failed to extract result from function call or text response, using empty result." ) - assert len(result.process_data.get("prompts")) == 1 - assert "SF" in result.process_data.get("prompts")[0].get("text") + assert result.process_data is not None + assert len(result.process_data.get("prompts", [])) == 1 + assert "SF" in result.process_data.get("prompts", [])[0].get("text") def test_extract_json_response(): @@ -322,13 +325,7 @@ def test_extract_json_response(): Test extract json response. """ - node = ParameterExtractorNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, + node = init_parameter_extractor_node( config={ "id": "llm", "data": { @@ -357,6 +354,7 @@ def test_extract_json_response(): hello world. """) + assert result is not None assert result["location"] == "kawaii" @@ -365,13 +363,7 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock): """ Test chat parameter extractor with memory. """ - node = ParameterExtractorNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, + node = init_parameter_extractor_node( config={ "id": "llm", "data": { @@ -396,27 +388,17 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock): node._fetch_memory = get_mocked_fetch_memory("customized memory") db.session.close = MagicMock() - # construct variable pool - pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "what's the weather in SF", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, - user_inputs={}, - environment_variables=[], - ) - - result = node.run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert result.outputs.get("location") == "" assert ( result.outputs.get("__reason") == "Failed to extract result from function call or text response, using empty result." ) - prompts = result.process_data.get("prompts") + assert result.process_data is not None + prompts = result.process_data.get("prompts", []) latest_role = None for prompt in prompts: diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 617b6370c9..073c4bb799 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -1,46 +1,84 @@ +import time +import uuid + import pytest from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import UserFrom +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode -from models.workflow import WorkflowNodeExecutionStatus +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) def test_execute_code(setup_code_executor_mock): code = """{{args2}}""" - node = TemplateTransformNode( + config = { + "id": "1", + "data": { + "title": "123", + "variables": [ + { + "variable": "args1", + "value_selector": ["1", "123", "args1"], + }, + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, + ], + "template": code, + }, + } + + graph_config = { + "edges": [ + { + "id": "start-source-next-target", + "source": "start", + "target": "1", + }, + ], + "nodes": [{"data": {"type": "start"}, "id": "start"}, config], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( tenant_id="1", app_id="1", + workflow_type=WorkflowType.WORKFLOW, workflow_id="1", + graph_config=graph_config, user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.END_USER, - config={ - "id": "1", - "data": { - "title": "123", - "variables": [ - { - "variable": "args1", - "value_selector": ["1", "123", "args1"], - }, - {"variable": "args2", "value_selector": ["1", "123", "args2"]}, - ], - "template": code, - }, - }, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, ) # construct variable pool - pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(["1", "123", "args1"], 1) - pool.add(["1", "123", "args2"], 3) + variable_pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["1", "123", "args1"], 1) + variable_pool.add(["1", "123", "args2"], 3) + + node = TemplateTransformNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config=config, + ) # execute node - result = node.run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert result.outputs["output"] == "3" diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 29c1efa8e7..4d94cdb28a 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -1,21 +1,62 @@ +import time +import uuid + from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import NodeRunResult, UserFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import UserFrom +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.tool.tool_node import ToolNode -from models.workflow import WorkflowNodeExecutionStatus +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType + + +def init_tool_node(config: dict): + graph_config = { + "edges": [ + { + "id": "start-source-next-target", + "source": "start", + "target": "1", + }, + ], + "nodes": [{"data": {"type": "start"}, "id": "start"}, config], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + + return ToolNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config=config, + ) def test_tool_variable_invoke(): - pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(["1", "123", "args1"], "1+1") - - node = ToolNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, + node = init_tool_node( config={ "id": "1", "data": { @@ -34,28 +75,22 @@ def test_tool_variable_invoke(): } }, }, - }, + } ) - # execute node - result = node.run(pool) + node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], "1+1") + # execute node + result = node._run() + assert isinstance(result, NodeRunResult) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert "2" in result.outputs["text"] assert result.outputs["files"] == [] def test_tool_mixed_invoke(): - pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(["1", "args1"], "1+1") - - node = ToolNode( - tenant_id="1", - app_id="1", - workflow_id="1", - user_id="1", - invoke_from=InvokeFrom.WEB_APP, - user_from=UserFrom.ACCOUNT, + node = init_tool_node( config={ "id": "1", "data": { @@ -74,12 +109,15 @@ def test_tool_mixed_invoke(): } }, }, - }, + } ) - # execute node - result = node.run(pool) + node.graph_runtime_state.variable_pool.add(["1", "args1"], "1+1") + # execute node + result = node._run() + assert isinstance(result, NodeRunResult) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None assert "2" in result.outputs["text"] assert result.outputs["files"] == [] diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index fb415483dd..3f639ccacc 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -19,6 +19,7 @@ def example_env_file(tmp_path, monkeypatch) -> str: """ CONSOLE_API_URL=https://example.com CONSOLE_WEB_URL=https://example.com + HTTP_REQUEST_MAX_WRITE_TIMEOUT=30 """ ) ) @@ -48,6 +49,12 @@ def test_dify_config(example_env_file): assert config.API_COMPRESSION_ENABLED is False assert config.SENTRY_TRACES_SAMPLE_RATE == 1.0 + # annotated field with default value + assert config.HTTP_REQUEST_MAX_READ_TIMEOUT == 60 + + # annotated field with configured value + assert config.HTTP_REQUEST_MAX_WRITE_TIMEOUT == 30 + # NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected. # This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`. diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index afc9802cf1..ca3082953a 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -1,7 +1,24 @@ import os +import pytest +from flask import Flask + # Getting the absolute path of the current file's directory ABS_PATH = os.path.dirname(os.path.abspath(__file__)) # Getting the absolute path of the project's root directory PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir)) + +CACHED_APP = Flask(__name__) +CACHED_APP.config.update({"TESTING": True}) + + +@pytest.fixture() +def app() -> Flask: + return CACHED_APP + + +@pytest.fixture(autouse=True) +def _provide_app_context(app: Flask): + with app.app_context(): + yield diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py index 8d735cae86..bd414c88f4 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py @@ -5,7 +5,7 @@ from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig def test_default_value(): - valid_config = {"host": "localhost", "port": 19530, "user": "root", "password": "Milvus"} + valid_config = {"uri": "http://localhost:19530", "user": "root", "password": "Milvus"} for key in valid_config: config = valid_config.copy() @@ -15,5 +15,4 @@ def test_default_value(): assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required" config = MilvusConfig(**valid_config) - assert config.secure is False assert config.database == "default" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/__init__.py b/api/tests/unit_tests/core/workflow/graph_engine/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py new file mode 100644 index 0000000000..13ba11016a --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py @@ -0,0 +1,791 @@ +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.run_condition import RunCondition +from core.workflow.utils.condition.entities import Condition + + +def test_init(): + graph_config = { + "edges": [ + { + "id": "llm-source-answer-target", + "source": "llm", + "target": "answer", + }, + { + "id": "start-source-qc-target", + "source": "start", + "target": "qc", + }, + { + "id": "qc-1-llm-target", + "source": "qc", + "sourceHandle": "1", + "target": "llm", + }, + { + "id": "qc-2-http-target", + "source": "qc", + "sourceHandle": "2", + "target": "http", + }, + { + "id": "http-source-answer2-target", + "source": "http", + "target": "answer2", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + { + "data": {"type": "question-classifier"}, + "id": "qc", + }, + { + "data": { + "type": "http-request", + }, + "id": "http", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer2", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + start_node_id = "start" + + assert graph.root_node_id == start_node_id + assert graph.edge_mapping.get(start_node_id)[0].target_node_id == "qc" + assert {"llm", "http"} == {node.target_node_id for node in graph.edge_mapping.get("qc")} + + +def test__init_iteration_graph(): + graph_config = { + "edges": [ + { + "id": "llm-answer", + "source": "llm", + "sourceHandle": "source", + "target": "answer", + }, + { + "id": "iteration-source-llm-target", + "source": "iteration", + "sourceHandle": "source", + "target": "llm", + }, + { + "id": "template-transform-in-iteration-source-llm-in-iteration-target", + "source": "template-transform-in-iteration", + "sourceHandle": "source", + "target": "llm-in-iteration", + }, + { + "id": "llm-in-iteration-source-answer-in-iteration-target", + "source": "llm-in-iteration", + "sourceHandle": "source", + "target": "answer-in-iteration", + }, + { + "id": "start-source-code-target", + "source": "start", + "sourceHandle": "source", + "target": "code", + }, + { + "id": "code-source-iteration-target", + "source": "code", + "sourceHandle": "source", + "target": "iteration", + }, + ], + "nodes": [ + { + "data": { + "type": "start", + }, + "id": "start", + }, + { + "data": { + "type": "llm", + }, + "id": "llm", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + { + "data": {"type": "iteration"}, + "id": "iteration", + }, + { + "data": { + "type": "template-transform", + }, + "id": "template-transform-in-iteration", + "parentId": "iteration", + }, + { + "data": { + "type": "llm", + }, + "id": "llm-in-iteration", + "parentId": "iteration", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer-in-iteration", + "parentId": "iteration", + }, + { + "data": { + "type": "code", + }, + "id": "code", + }, + ], + } + + graph = Graph.init(graph_config=graph_config, root_node_id="template-transform-in-iteration") + graph.add_extra_edge( + source_node_id="answer-in-iteration", + target_node_id="template-transform-in-iteration", + run_condition=RunCondition( + type="condition", + conditions=[Condition(variable_selector=["iteration", "index"], comparison_operator="≤", value="5")], + ), + ) + + # iteration: + # [template-transform-in-iteration -> llm-in-iteration -> answer-in-iteration] + + assert graph.root_node_id == "template-transform-in-iteration" + assert graph.edge_mapping.get("template-transform-in-iteration")[0].target_node_id == "llm-in-iteration" + assert graph.edge_mapping.get("llm-in-iteration")[0].target_node_id == "answer-in-iteration" + assert graph.edge_mapping.get("answer-in-iteration")[0].target_node_id == "template-transform-in-iteration" + + +def test_parallels_graph(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm1-source-answer-target", + "source": "llm1", + "target": "answer", + }, + { + "id": "llm2-source-answer-target", + "source": "llm2", + "target": "answer", + }, + { + "id": "llm3-source-answer-target", + "source": "llm3", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + assert graph.root_node_id == "start" + for i in range(3): + start_edges = graph.edge_mapping.get("start") + assert start_edges is not None + assert start_edges[i].target_node_id == f"llm{i + 1}" + + llm_edges = graph.edge_mapping.get(f"llm{i + 1}") + assert llm_edges is not None + assert llm_edges[0].target_node_id == "answer" + + assert len(graph.parallel_mapping) == 1 + assert len(graph.node_parallel_mapping) == 3 + + for node_id in ["llm1", "llm2", "llm3"]: + assert node_id in graph.node_parallel_mapping + + +def test_parallels_graph2(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm1-source-answer-target", + "source": "llm1", + "target": "answer", + }, + { + "id": "llm2-source-answer-target", + "source": "llm2", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + assert graph.root_node_id == "start" + for i in range(3): + assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" + + if i < 2: + assert graph.edge_mapping.get(f"llm{i + 1}") is not None + assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == "answer" + + assert len(graph.parallel_mapping) == 1 + assert len(graph.node_parallel_mapping) == 3 + + for node_id in ["llm1", "llm2", "llm3"]: + assert node_id in graph.node_parallel_mapping + + +def test_parallels_graph3(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + assert graph.root_node_id == "start" + for i in range(3): + assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" + + assert len(graph.parallel_mapping) == 1 + assert len(graph.node_parallel_mapping) == 3 + + for node_id in ["llm1", "llm2", "llm3"]: + assert node_id in graph.node_parallel_mapping + + +def test_parallels_graph4(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm1-source-answer-target", + "source": "llm1", + "target": "code1", + }, + { + "id": "llm2-source-answer-target", + "source": "llm2", + "target": "code2", + }, + { + "id": "llm3-source-code3-target", + "source": "llm3", + "target": "code3", + }, + { + "id": "code1-source-answer-target", + "source": "code1", + "target": "answer", + }, + { + "id": "code2-source-answer-target", + "source": "code2", + "target": "answer", + }, + { + "id": "code3-source-answer-target", + "source": "code3", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "code", + }, + "id": "code1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "code", + }, + "id": "code2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": { + "type": "code", + }, + "id": "code3", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + assert graph.root_node_id == "start" + for i in range(3): + assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" + assert graph.edge_mapping.get(f"llm{i + 1}") is not None + assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == f"code{i + 1}" + assert graph.edge_mapping.get(f"code{i + 1}") is not None + assert graph.edge_mapping.get(f"code{i + 1}")[0].target_node_id == "answer" + + assert len(graph.parallel_mapping) == 1 + assert len(graph.node_parallel_mapping) == 6 + + for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]: + assert node_id in graph.node_parallel_mapping + + +def test_parallels_graph5(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm4", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm5", + }, + { + "id": "llm1-source-code1-target", + "source": "llm1", + "target": "code1", + }, + { + "id": "llm2-source-code1-target", + "source": "llm2", + "target": "code1", + }, + { + "id": "llm3-source-code2-target", + "source": "llm3", + "target": "code2", + }, + { + "id": "llm4-source-code2-target", + "source": "llm4", + "target": "code2", + }, + { + "id": "llm5-source-code3-target", + "source": "llm5", + "target": "code3", + }, + { + "id": "code1-source-answer-target", + "source": "code1", + "target": "answer", + }, + { + "id": "code2-source-answer-target", + "source": "code2", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "code", + }, + "id": "code1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "code", + }, + "id": "code2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": { + "type": "code", + }, + "id": "code3", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + { + "data": { + "type": "llm", + }, + "id": "llm4", + }, + { + "data": { + "type": "llm", + }, + "id": "llm5", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + assert graph.root_node_id == "start" + for i in range(5): + assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" + + assert graph.edge_mapping.get("llm1") is not None + assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1" + assert graph.edge_mapping.get("llm2") is not None + assert graph.edge_mapping.get("llm2")[0].target_node_id == "code1" + assert graph.edge_mapping.get("llm3") is not None + assert graph.edge_mapping.get("llm3")[0].target_node_id == "code2" + assert graph.edge_mapping.get("llm4") is not None + assert graph.edge_mapping.get("llm4")[0].target_node_id == "code2" + assert graph.edge_mapping.get("llm5") is not None + assert graph.edge_mapping.get("llm5")[0].target_node_id == "code3" + assert graph.edge_mapping.get("code1") is not None + assert graph.edge_mapping.get("code1")[0].target_node_id == "answer" + assert graph.edge_mapping.get("code2") is not None + assert graph.edge_mapping.get("code2")[0].target_node_id == "answer" + + assert len(graph.parallel_mapping) == 1 + assert len(graph.node_parallel_mapping) == 8 + + for node_id in ["llm1", "llm2", "llm3", "llm4", "llm5", "code1", "code2", "code3"]: + assert node_id in graph.node_parallel_mapping + + +def test_parallels_graph6(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm1-source-code1-target", + "source": "llm1", + "target": "code1", + }, + { + "id": "llm1-source-code2-target", + "source": "llm1", + "target": "code2", + }, + { + "id": "llm2-source-code3-target", + "source": "llm2", + "target": "code3", + }, + { + "id": "code1-source-answer-target", + "source": "code1", + "target": "answer", + }, + { + "id": "code2-source-answer-target", + "source": "code2", + "target": "answer", + }, + { + "id": "code3-source-answer-target", + "source": "code3", + "target": "answer", + }, + { + "id": "llm3-source-answer-target", + "source": "llm3", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "code", + }, + "id": "code1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "code", + }, + "id": "code2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": { + "type": "code", + }, + "id": "code3", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + assert graph.root_node_id == "start" + for i in range(3): + assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" + + assert graph.edge_mapping.get("llm1") is not None + assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1" + assert graph.edge_mapping.get("llm1") is not None + assert graph.edge_mapping.get("llm1")[1].target_node_id == "code2" + assert graph.edge_mapping.get("llm2") is not None + assert graph.edge_mapping.get("llm2")[0].target_node_id == "code3" + assert graph.edge_mapping.get("code1") is not None + assert graph.edge_mapping.get("code1")[0].target_node_id == "answer" + assert graph.edge_mapping.get("code2") is not None + assert graph.edge_mapping.get("code2")[0].target_node_id == "answer" + assert graph.edge_mapping.get("code3") is not None + assert graph.edge_mapping.get("code3")[0].target_node_id == "answer" + + assert len(graph.parallel_mapping) == 2 + assert len(graph.node_parallel_mapping) == 6 + + for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]: + assert node_id in graph.node_parallel_mapping + + parent_parallel = None + child_parallel = None + for p_id, parallel in graph.parallel_mapping.items(): + if parallel.parent_parallel_id is None: + parent_parallel = parallel + else: + child_parallel = parallel + + for node_id in ["llm1", "llm2", "llm3", "code3"]: + assert graph.node_parallel_mapping[node_id] == parent_parallel.id + + for node_id in ["code1", "code2"]: + assert graph.node_parallel_mapping[node_id] == child_parallel.id diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py new file mode 100644 index 0000000000..a2d71d61fc --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -0,0 +1,505 @@ +from unittest.mock import patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, UserFrom +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.event import ( + BaseNodeEvent, + GraphRunFailedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunFailedEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState +from core.workflow.graph_engine.graph_engine import GraphEngine +from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent +from core.workflow.nodes.llm.llm_node import LLMNode +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType + + +@patch("extensions.ext_database.db.session.remove") +@patch("extensions.ext_database.db.session.close") +def test_run_parallel_in_workflow(mock_close, mock_remove): + graph_config = { + "edges": [ + { + "id": "1", + "source": "start", + "target": "llm1", + }, + { + "id": "2", + "source": "llm1", + "target": "llm2", + }, + { + "id": "3", + "source": "llm1", + "target": "llm3", + }, + { + "id": "4", + "source": "llm2", + "target": "end1", + }, + { + "id": "5", + "source": "llm3", + "target": "end2", + }, + ], + "nodes": [ + { + "data": { + "type": "start", + "title": "start", + "variables": [ + { + "label": "query", + "max_length": 48, + "options": [], + "required": True, + "type": "text-input", + "variable": "query", + } + ], + }, + "id": "start", + }, + { + "data": { + "type": "llm", + "title": "llm1", + "context": {"enabled": False, "variable_selector": []}, + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "prompt_template": [ + {"role": "system", "text": "say hi"}, + {"role": "user", "text": "{{#start.query#}}"}, + ], + "vision": {"configs": {"detail": "high"}, "enabled": False}, + }, + "id": "llm1", + }, + { + "data": { + "type": "llm", + "title": "llm2", + "context": {"enabled": False, "variable_selector": []}, + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "prompt_template": [ + {"role": "system", "text": "say bye"}, + {"role": "user", "text": "{{#start.query#}}"}, + ], + "vision": {"configs": {"detail": "high"}, "enabled": False}, + }, + "id": "llm2", + }, + { + "data": { + "type": "llm", + "title": "llm3", + "context": {"enabled": False, "variable_selector": []}, + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "prompt_template": [ + {"role": "system", "text": "say good morning"}, + {"role": "user", "text": "{{#start.query#}}"}, + ], + "vision": {"configs": {"detail": "high"}, "enabled": False}, + }, + "id": "llm3", + }, + { + "data": { + "type": "end", + "title": "end1", + "outputs": [ + {"value_selector": ["llm2", "text"], "variable": "result2"}, + {"value_selector": ["start", "query"], "variable": "query"}, + ], + }, + "id": "end1", + }, + { + "data": { + "type": "end", + "title": "end2", + "outputs": [ + {"value_selector": ["llm1", "text"], "variable": "result1"}, + {"value_selector": ["llm3", "text"], "variable": "result3"}, + ], + }, + "id": "end2", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + variable_pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"} + ) + + graph_engine = GraphEngine( + tenant_id="111", + app_id="222", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="333", + graph_config=graph_config, + user_id="444", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, + call_depth=0, + graph=graph, + variable_pool=variable_pool, + max_execution_steps=500, + max_execution_time=1200, + ) + + def llm_generator(self): + contents = ["hi", "bye", "good morning"] + + yield RunStreamChunkEvent( + chunk_content=contents[int(self.node_id[-1]) - 1], from_variable_selector=[self.node_id, "text"] + ) + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={}, + process_data={}, + outputs={}, + metadata={ + NodeRunMetadataKey.TOTAL_TOKENS: 1, + NodeRunMetadataKey.TOTAL_PRICE: 1, + NodeRunMetadataKey.CURRENCY: "USD", + }, + ) + ) + + # print("") + + with patch.object(LLMNode, "_run", new=llm_generator): + items = [] + generator = graph_engine.run() + for item in generator: + # print(type(item), item) + items.append(item) + if isinstance(item, NodeRunSucceededEvent): + assert item.route_node_state.status == RouteNodeState.Status.SUCCESS + + assert not isinstance(item, NodeRunFailedEvent) + assert not isinstance(item, GraphRunFailedEvent) + + if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in ["llm2", "llm3", "end1", "end2"]: + assert item.parallel_id is not None + + assert len(items) == 18 + assert isinstance(items[0], GraphRunStartedEvent) + assert isinstance(items[1], NodeRunStartedEvent) + assert items[1].route_node_state.node_id == "start" + assert isinstance(items[2], NodeRunSucceededEvent) + assert items[2].route_node_state.node_id == "start" + + +@patch("extensions.ext_database.db.session.remove") +@patch("extensions.ext_database.db.session.close") +def test_run_parallel_in_chatflow(mock_close, mock_remove): + graph_config = { + "edges": [ + { + "id": "1", + "source": "start", + "target": "answer1", + }, + { + "id": "2", + "source": "answer1", + "target": "answer2", + }, + { + "id": "3", + "source": "answer1", + "target": "answer3", + }, + { + "id": "4", + "source": "answer2", + "target": "answer4", + }, + { + "id": "5", + "source": "answer3", + "target": "answer5", + }, + ], + "nodes": [ + {"data": {"type": "start", "title": "start"}, "id": "start"}, + {"data": {"type": "answer", "title": "answer1", "answer": "1"}, "id": "answer1"}, + { + "data": {"type": "answer", "title": "answer2", "answer": "2"}, + "id": "answer2", + }, + { + "data": {"type": "answer", "title": "answer3", "answer": "3"}, + "id": "answer3", + }, + { + "data": {"type": "answer", "title": "answer4", "answer": "4"}, + "id": "answer4", + }, + { + "data": {"type": "answer", "title": "answer5", "answer": "5"}, + "id": "answer5", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + variable_pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + ) + + graph_engine = GraphEngine( + tenant_id="111", + app_id="222", + workflow_type=WorkflowType.CHAT, + workflow_id="333", + graph_config=graph_config, + user_id="444", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, + call_depth=0, + graph=graph, + variable_pool=variable_pool, + max_execution_steps=500, + max_execution_time=1200, + ) + + # print("") + + items = [] + generator = graph_engine.run() + for item in generator: + # print(type(item), item) + items.append(item) + if isinstance(item, NodeRunSucceededEvent): + assert item.route_node_state.status == RouteNodeState.Status.SUCCESS + + assert not isinstance(item, NodeRunFailedEvent) + assert not isinstance(item, GraphRunFailedEvent) + + if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in [ + "answer2", + "answer3", + "answer4", + "answer5", + ]: + assert item.parallel_id is not None + + assert len(items) == 23 + assert isinstance(items[0], GraphRunStartedEvent) + assert isinstance(items[1], NodeRunStartedEvent) + assert items[1].route_node_state.node_id == "start" + assert isinstance(items[2], NodeRunSucceededEvent) + assert items[2].route_node_state.node_id == "start" + + +@patch("extensions.ext_database.db.session.remove") +@patch("extensions.ext_database.db.session.close") +def test_run_branch(mock_close, mock_remove): + graph_config = { + "edges": [ + { + "id": "1", + "source": "start", + "target": "if-else-1", + }, + { + "id": "2", + "source": "if-else-1", + "sourceHandle": "true", + "target": "answer-1", + }, + { + "id": "3", + "source": "if-else-1", + "sourceHandle": "false", + "target": "if-else-2", + }, + { + "id": "4", + "source": "if-else-2", + "sourceHandle": "true", + "target": "answer-2", + }, + { + "id": "5", + "source": "if-else-2", + "sourceHandle": "false", + "target": "answer-3", + }, + ], + "nodes": [ + { + "data": { + "title": "Start", + "type": "start", + "variables": [ + { + "label": "uid", + "max_length": 48, + "options": [], + "required": True, + "type": "text-input", + "variable": "uid", + } + ], + }, + "id": "start", + }, + { + "data": {"answer": "1 {{#start.uid#}}", "title": "Answer", "type": "answer", "variables": []}, + "id": "answer-1", + }, + { + "data": { + "cases": [ + { + "case_id": "true", + "conditions": [ + { + "comparison_operator": "contains", + "id": "b0f02473-08b6-4a81-af91-15345dcb2ec8", + "value": "hi", + "varType": "string", + "variable_selector": ["sys", "query"], + } + ], + "id": "true", + "logical_operator": "and", + } + ], + "desc": "", + "title": "IF/ELSE", + "type": "if-else", + }, + "id": "if-else-1", + }, + { + "data": { + "cases": [ + { + "case_id": "true", + "conditions": [ + { + "comparison_operator": "contains", + "id": "ae895199-5608-433b-b5f0-0997ae1431e4", + "value": "takatost", + "varType": "string", + "variable_selector": ["sys", "query"], + } + ], + "id": "true", + "logical_operator": "and", + } + ], + "title": "IF/ELSE 2", + "type": "if-else", + }, + "id": "if-else-2", + }, + { + "data": { + "answer": "2", + "title": "Answer 2", + "type": "answer", + }, + "id": "answer-2", + }, + { + "data": { + "answer": "3", + "title": "Answer 3", + "type": "answer", + }, + "id": "answer-3", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + variable_pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "hi", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={"uid": "takato"}, + ) + + graph_engine = GraphEngine( + tenant_id="111", + app_id="222", + workflow_type=WorkflowType.CHAT, + workflow_id="333", + graph_config=graph_config, + user_id="444", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, + call_depth=0, + graph=graph, + variable_pool=variable_pool, + max_execution_steps=500, + max_execution_time=1200, + ) + + # print("") + + items = [] + generator = graph_engine.run() + for item in generator: + # print(type(item), item) + items.append(item) + + assert len(items) == 10 + assert items[3].route_node_state.node_id == "if-else-1" + assert items[4].route_node_state.node_id == "if-else-1" + assert isinstance(items[5], NodeRunStreamChunkEvent) + assert items[5].chunk_content == "1 " + assert isinstance(items[6], NodeRunStreamChunkEvent) + assert items[6].chunk_content == "takato" + assert items[7].route_node_state.node_id == "answer-1" + assert items[8].route_node_state.node_id == "answer-1" + assert items[8].route_node_state.node_run_result.outputs["answer"] == "1 takato" + assert isinstance(items[9], GraphRunSucceededEvent) + + # print(graph_engine.graph_runtime_state.model_dump_json(indent=2)) diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/__init__.py b/api/tests/unit_tests/core/workflow/nodes/answer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py new file mode 100644 index 0000000000..fe4ede6335 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -0,0 +1,82 @@ +import time +import uuid +from unittest.mock import MagicMock + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import UserFrom +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.answer.answer_node import AnswerNode +from extensions.ext_database import db +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType + + +def test_execute_answer(): + graph_config = { + "edges": [ + { + "id": "start-source-llm-target", + "source": "start", + "target": "llm", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + ) + pool.add(["start", "weather"], "sunny") + pool.add(["llm", "text"], "You are a helpful AI.") + + node = AnswerNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "id": "answer", + "data": { + "title": "123", + "type": "answer", + "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + }, + }, + ) + + # Mock db.session.close() + db.session.close = MagicMock() + + # execute node + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs["answer"] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin." diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_generate_router.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_generate_router.py new file mode 100644 index 0000000000..bce87536d8 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_generate_router.py @@ -0,0 +1,109 @@ +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter + + +def test_init(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm3-source-llm4-target", + "source": "llm3", + "target": "llm4", + }, + { + "id": "llm3-source-llm5-target", + "source": "llm3", + "target": "llm5", + }, + { + "id": "llm4-source-answer2-target", + "source": "llm4", + "target": "answer2", + }, + { + "id": "llm5-source-answer-target", + "source": "llm5", + "target": "answer", + }, + { + "id": "answer2-source-answer-target", + "source": "answer2", + "target": "answer", + }, + { + "id": "llm2-source-answer-target", + "source": "llm2", + "target": "answer", + }, + { + "id": "llm1-source-answer-target", + "source": "llm1", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": { + "type": "llm", + }, + "id": "llm4", + }, + { + "data": { + "type": "llm", + }, + "id": "llm5", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1{{#llm2.text#}}2"}, + "id": "answer", + }, + { + "data": {"type": "answer", "title": "answer2", "answer": "1{{#llm3.text#}}2"}, + "id": "answer2", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + answer_stream_generate_route = AnswerStreamGeneratorRouter.init( + node_id_config_mapping=graph.node_id_config_mapping, reverse_edge_mapping=graph.reverse_edge_mapping + ) + + assert answer_stream_generate_route.answer_dependencies["answer"] == ["answer2"] + assert answer_stream_generate_route.answer_dependencies["answer2"] == [] diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py new file mode 100644 index 0000000000..6b1d1e9070 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py @@ -0,0 +1,216 @@ +import uuid +from collections.abc import Generator +from datetime import datetime, timezone + +from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState +from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor +from core.workflow.nodes.start.entities import StartNodeData + + +def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]: + if next_node_id == "start": + yield from _publish_events(graph, next_node_id) + + for edge in graph.edge_mapping.get(next_node_id, []): + yield from _publish_events(graph, edge.target_node_id) + + for edge in graph.edge_mapping.get(next_node_id, []): + yield from _recursive_process(graph, edge.target_node_id) + + +def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]: + route_node_state = RouteNodeState(node_id=next_node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None)) + + parallel_id = graph.node_parallel_mapping.get(next_node_id) + parallel_start_node_id = None + if parallel_id: + parallel = graph.parallel_mapping.get(parallel_id) + parallel_start_node_id = parallel.start_from_node_id if parallel else None + + node_execution_id = str(uuid.uuid4()) + node_config = graph.node_id_config_mapping[next_node_id] + node_type = NodeType.value_of(node_config.get("data", {}).get("type")) + mock_node_data = StartNodeData(**{"title": "demo", "variables": []}) + + yield NodeRunStartedEvent( + id=node_execution_id, + node_id=next_node_id, + node_type=node_type, + node_data=mock_node_data, + route_node_state=route_node_state, + parallel_id=graph.node_parallel_mapping.get(next_node_id), + parallel_start_node_id=parallel_start_node_id, + ) + + if "llm" in next_node_id: + length = int(next_node_id[-1]) + for i in range(0, length): + yield NodeRunStreamChunkEvent( + id=node_execution_id, + node_id=next_node_id, + node_type=node_type, + node_data=mock_node_data, + chunk_content=str(i), + route_node_state=route_node_state, + from_variable_selector=[next_node_id, "text"], + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + ) + + route_node_state.status = RouteNodeState.Status.SUCCESS + route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + yield NodeRunSucceededEvent( + id=node_execution_id, + node_id=next_node_id, + node_type=node_type, + node_data=mock_node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + ) + + +def test_process(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm3-source-llm4-target", + "source": "llm3", + "target": "llm4", + }, + { + "id": "llm3-source-llm5-target", + "source": "llm3", + "target": "llm5", + }, + { + "id": "llm4-source-answer2-target", + "source": "llm4", + "target": "answer2", + }, + { + "id": "llm5-source-answer-target", + "source": "llm5", + "target": "answer", + }, + { + "id": "answer2-source-answer-target", + "source": "answer2", + "target": "answer", + }, + { + "id": "llm2-source-answer-target", + "source": "llm2", + "target": "answer", + }, + { + "id": "llm1-source-answer-target", + "source": "llm1", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": { + "type": "llm", + }, + "id": "llm4", + }, + { + "data": { + "type": "llm", + }, + "id": "llm5", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "a{{#llm2.text#}}b"}, + "id": "answer", + }, + { + "data": {"type": "answer", "title": "answer2", "answer": "c{{#llm3.text#}}d"}, + "id": "answer2", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + variable_pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + ) + + answer_stream_processor = AnswerStreamProcessor(graph=graph, variable_pool=variable_pool) + + def graph_generator() -> Generator[GraphEngineEvent, None, None]: + # print("") + for event in _recursive_process(graph, "start"): + # print("[ORIGIN]", event.__class__.__name__ + ":", event.route_node_state.node_id, + # " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else "")) + if isinstance(event, NodeRunSucceededEvent): + if "llm" in event.route_node_state.node_id: + variable_pool.add( + [event.route_node_state.node_id, "text"], + "".join(str(i) for i in range(0, int(event.route_node_state.node_id[-1]))), + ) + yield event + + result_generator = answer_stream_processor.process(graph_generator()) + stream_contents = "" + for event in result_generator: + # print("[ANSWER]", event.__class__.__name__ + ":", event.route_node_state.node_id, + # " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else "")) + if isinstance(event, NodeRunStreamChunkEvent): + stream_contents += event.chunk_content + pass + + assert stream_contents == "c012da01b" diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/__init__.py b/api/tests/unit_tests/core/workflow/nodes/iteration/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py new file mode 100644 index 0000000000..b3a89061b2 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py @@ -0,0 +1,420 @@ +import time +import uuid +from unittest.mock import patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import NodeRunResult, UserFrom +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.event import RunCompletedEvent +from core.workflow.nodes.iteration.iteration_node import IterationNode +from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType + + +def test_run(): + graph_config = { + "edges": [ + { + "id": "start-source-pe-target", + "source": "start", + "target": "pe", + }, + { + "id": "iteration-1-source-answer-3-target", + "source": "iteration-1", + "target": "answer-3", + }, + { + "id": "tt-source-if-else-target", + "source": "tt", + "target": "if-else", + }, + { + "id": "if-else-true-answer-2-target", + "source": "if-else", + "sourceHandle": "true", + "target": "answer-2", + }, + { + "id": "if-else-false-answer-4-target", + "source": "if-else", + "sourceHandle": "false", + "target": "answer-4", + }, + { + "id": "pe-source-iteration-1-target", + "source": "pe", + "target": "iteration-1", + }, + ], + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "tt", + "title": "iteration", + "type": "iteration", + }, + "id": "iteration-1", + }, + { + "data": { + "answer": "{{#tt.output#}}", + "iteration_id": "iteration-1", + "title": "answer 2", + "type": "answer", + }, + "id": "answer-2", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }} 123", + "title": "template transform", + "type": "template-transform", + "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], + }, + "id": "tt", + }, + { + "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, + "id": "answer-3", + }, + { + "data": { + "conditions": [ + { + "comparison_operator": "is", + "id": "1721916275284", + "value": "hi", + "variable_selector": ["sys", "query"], + } + ], + "iteration_id": "iteration-1", + "logical_operator": "and", + "title": "if", + "type": "if-else", + }, + "id": "if-else", + }, + { + "data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"}, + "id": "answer-4", + }, + { + "data": { + "instruction": "test1", + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "parameters": [ + {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} + ], + "query": ["sys", "query"], + "reasoning_mode": "prompt", + "title": "pe", + "type": "parameter-extractor", + }, + "id": "pe", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.CHAT, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "dify", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "1", + }, + user_inputs={}, + environment_variables=[], + ) + pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) + + iteration_node = IterationNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "tt", + "title": "迭代", + "type": "iteration", + }, + "id": "iteration-1", + }, + ) + + def tt_generator(self): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"iterator_selector": "dify"}, + outputs={"output": "dify 123"}, + ) + + # print("") + + with patch.object(TemplateTransformNode, "_run", new=tt_generator): + # execute node + result = iteration_node._run() + + count = 0 + for item in result: + # print(type(item), item) + count += 1 + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + + assert count == 20 + + +def test_run_parallel(): + graph_config = { + "edges": [ + { + "id": "start-source-pe-target", + "source": "start", + "target": "pe", + }, + { + "id": "iteration-1-source-answer-3-target", + "source": "iteration-1", + "target": "answer-3", + }, + { + "id": "iteration-start-source-tt-target", + "source": "iteration-start", + "target": "tt", + }, + { + "id": "iteration-start-source-tt-2-target", + "source": "iteration-start", + "target": "tt-2", + }, + { + "id": "tt-source-if-else-target", + "source": "tt", + "target": "if-else", + }, + { + "id": "tt-2-source-if-else-target", + "source": "tt-2", + "target": "if-else", + }, + { + "id": "if-else-true-answer-2-target", + "source": "if-else", + "sourceHandle": "true", + "target": "answer-2", + }, + { + "id": "if-else-false-answer-4-target", + "source": "if-else", + "sourceHandle": "false", + "target": "answer-4", + }, + { + "id": "pe-source-iteration-1-target", + "source": "pe", + "target": "iteration-1", + }, + ], + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "iteration", + "type": "iteration", + }, + "id": "iteration-1", + }, + { + "data": { + "answer": "{{#tt.output#}}", + "iteration_id": "iteration-1", + "title": "answer 2", + "type": "answer", + }, + "id": "answer-2", + }, + { + "data": { + "iteration_id": "iteration-1", + "title": "iteration-start", + "type": "iteration-start", + }, + "id": "iteration-start", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }} 123", + "title": "template transform", + "type": "template-transform", + "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], + }, + "id": "tt", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }} 321", + "title": "template transform", + "type": "template-transform", + "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], + }, + "id": "tt-2", + }, + { + "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, + "id": "answer-3", + }, + { + "data": { + "conditions": [ + { + "comparison_operator": "is", + "id": "1721916275284", + "value": "hi", + "variable_selector": ["sys", "query"], + } + ], + "iteration_id": "iteration-1", + "logical_operator": "and", + "title": "if", + "type": "if-else", + }, + "id": "if-else", + }, + { + "data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"}, + "id": "answer-4", + }, + { + "data": { + "instruction": "test1", + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "parameters": [ + {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} + ], + "query": ["sys", "query"], + "reasoning_mode": "prompt", + "title": "pe", + "type": "parameter-extractor", + }, + "id": "pe", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.CHAT, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "dify", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "1", + }, + user_inputs={}, + environment_variables=[], + ) + pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) + + iteration_node = IterationNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "迭代", + "type": "iteration", + }, + "id": "iteration-1", + }, + ) + + def tt_generator(self): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"iterator_selector": "dify"}, + outputs={"output": "dify 123"}, + ) + + # print("") + + with patch.object(TemplateTransformNode, "_run", new=tt_generator): + # execute node + result = iteration_node._run() + + count = 0 + for item in result: + # print(type(item), item) + count += 1 + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + + assert count == 32 diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py index 8020674ee6..cb2e99a854 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -1,22 +1,70 @@ +import time +import uuid from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.base_node import UserFrom from extensions.ext_database import db -from models.workflow import WorkflowNodeExecutionStatus +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType def test_execute_answer(): - node = AnswerNode( + graph_config = { + "edges": [ + { + "id": "start-source-answer-target", + "source": "start", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "title": "123", + "type": "answer", + "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + }, + "id": "answer", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( tenant_id="1", app_id="1", + workflow_type=WorkflowType.WORKFLOW, workflow_id="1", + graph_config=graph_config, user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["start", "weather"], "sunny") + variable_pool.add(["llm", "text"], "You are a helpful AI.") + + node = AnswerNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), config={ "id": "answer", "data": { @@ -27,20 +75,11 @@ def test_execute_answer(): }, ) - # construct variable pool - pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, - user_inputs={}, - environment_variables=[], - ) - pool.add(["start", "weather"], "sunny") - pool.add(["llm", "text"], "You are a helpful AI.") - # Mock db.session.close() db.session.close = MagicMock() # execute node - result = node._run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["answer"] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin." diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 9535bc2186..0795f134d0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -1,22 +1,63 @@ +import time +import uuid from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.base_node import UserFrom +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.if_else.if_else_node import IfElseNode from extensions.ext_database import db -from models.workflow import WorkflowNodeExecutionStatus +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType def test_execute_if_else_result_true(): - node = IfElseNode( + graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]} + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( tenant_id="1", app_id="1", + workflow_type=WorkflowType.WORKFLOW, workflow_id="1", + graph_config=graph_config, user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={} + ) + pool.add(["start", "array_contains"], ["ab", "def"]) + pool.add(["start", "array_not_contains"], ["ac", "def"]) + pool.add(["start", "contains"], "cabcde") + pool.add(["start", "not_contains"], "zacde") + pool.add(["start", "start_with"], "abc") + pool.add(["start", "end_with"], "zzab") + pool.add(["start", "is"], "ab") + pool.add(["start", "is_not"], "aab") + pool.add(["start", "empty"], "") + pool.add(["start", "not_empty"], "aaa") + pool.add(["start", "equals"], 22) + pool.add(["start", "not_equals"], 23) + pool.add(["start", "greater_than"], 23) + pool.add(["start", "less_than"], 21) + pool.add(["start", "greater_than_or_equal"], 22) + pool.add(["start", "less_than_or_equal"], 21) + pool.add(["start", "not_null"], "1212") + + node = IfElseNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), config={ "id": "if-else", "data": { @@ -63,48 +104,64 @@ def test_execute_if_else_result_true(): }, ) - # construct variable pool - pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, - user_inputs={}, - environment_variables=[], - ) - pool.add(["start", "array_contains"], ["ab", "def"]) - pool.add(["start", "array_not_contains"], ["ac", "def"]) - pool.add(["start", "contains"], "cabcde") - pool.add(["start", "not_contains"], "zacde") - pool.add(["start", "start_with"], "abc") - pool.add(["start", "end_with"], "zzab") - pool.add(["start", "is"], "ab") - pool.add(["start", "is_not"], "aab") - pool.add(["start", "empty"], "") - pool.add(["start", "not_empty"], "aaa") - pool.add(["start", "equals"], 22) - pool.add(["start", "not_equals"], 23) - pool.add(["start", "greater_than"], 23) - pool.add(["start", "less_than"], 21) - pool.add(["start", "greater_than_or_equal"], 22) - pool.add(["start", "less_than_or_equal"], 21) - pool.add(["start", "not_null"], "1212") - # Mock db.session.close() db.session.close = MagicMock() # execute node - result = node._run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"] is True def test_execute_if_else_result_false(): - node = IfElseNode( + graph_config = { + "edges": [ + { + "id": "start-source-llm-target", + "source": "start", + "target": "llm", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( tenant_id="1", app_id="1", + workflow_type=WorkflowType.WORKFLOW, workflow_id="1", + graph_config=graph_config, user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + ) + pool.add(["start", "array_contains"], ["1ab", "def"]) + pool.add(["start", "array_not_contains"], ["ab", "def"]) + + node = IfElseNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), config={ "id": "if-else", "data": { @@ -127,20 +184,11 @@ def test_execute_if_else_result_false(): }, ) - # construct variable pool - pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, - user_inputs={}, - environment_variables=[], - ) - pool.add(["start", "array_contains"], ["1ab", "def"]) - pool.add(["start", "array_not_contains"], ["ab", "def"]) - # Mock db.session.close() db.session.close = MagicMock() # execute node - result = node._run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"] is False diff --git a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py index e26c7df642..f45a93f1be 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py @@ -1,17 +1,56 @@ +import time +import uuid from unittest import mock from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom from core.app.segments import ArrayStringVariable, StringVariable +from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.base_node import UserFrom +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode +from models.workflow import WorkflowType DEFAULT_NODE_ID = "node_id" def test_overwrite_string_variable(): + graph_config = { + "edges": [ + { + "id": "start-source-assigner-target", + "source": "start", + "target": "assigner", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "assigner", + }, + "id": "assigner", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + conversation_variable = StringVariable( id=str(uuid4()), name="test_conversation_variable", @@ -24,13 +63,24 @@ def test_overwrite_string_variable(): value="the second value", ) + # construct variable pool + variable_pool = VariablePool( + system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[conversation_variable], + ) + + variable_pool.add( + [DEFAULT_NODE_ID, input_variable.name], + input_variable, + ) + node = VariableAssignerNode( - tenant_id="tenant_id", - app_id="app_id", - workflow_id="workflow_id", - user_id="user_id", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), config={ "id": "node_id", "data": { @@ -41,19 +91,8 @@ def test_overwrite_string_variable(): }, ) - variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) - variable_pool.add( - [DEFAULT_NODE_ID, input_variable.name], - input_variable, - ) - with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run: - node.run(variable_pool) + list(node.run()) mock_run.assert_called_once() got = variable_pool.get(["conversation", conversation_variable.name]) @@ -63,6 +102,39 @@ def test_overwrite_string_variable(): def test_append_variable_to_array(): + graph_config = { + "edges": [ + { + "id": "start-source-assigner-target", + "source": "start", + "target": "assigner", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "assigner", + }, + "id": "assigner", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + conversation_variable = ArrayStringVariable( id=str(uuid4()), name="test_conversation_variable", @@ -75,23 +147,6 @@ def test_append_variable_to_array(): value="the second value", ) - node = VariableAssignerNode( - tenant_id="tenant_id", - app_id="app_id", - workflow_id="workflow_id", - user_id="user_id", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - config={ - "id": "node_id", - "data": { - "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.APPEND.value, - "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], - }, - }, - ) - variable_pool = VariablePool( system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, user_inputs={}, @@ -103,8 +158,23 @@ def test_append_variable_to_array(): input_variable, ) + node = VariableAssignerNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config={ + "id": "node_id", + "data": { + "assigned_variable_selector": ["conversation", conversation_variable.name], + "write_mode": WriteMode.APPEND.value, + "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], + }, + }, + ) + with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run: - node.run(variable_pool) + list(node.run()) mock_run.assert_called_once() got = variable_pool.get(["conversation", conversation_variable.name]) @@ -113,19 +183,57 @@ def test_append_variable_to_array(): def test_clear_array(): + graph_config = { + "edges": [ + { + "id": "start-source-assigner-target", + "source": "start", + "target": "assigner", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "assigner", + }, + "id": "assigner", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + conversation_variable = ArrayStringVariable( id=str(uuid4()), name="test_conversation_variable", value=["the first value"], ) + variable_pool = VariablePool( + system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[conversation_variable], + ) + node = VariableAssignerNode( - tenant_id="tenant_id", - app_id="app_id", - workflow_id="workflow_id", - user_id="user_id", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), config={ "id": "node_id", "data": { @@ -136,14 +244,9 @@ def test_clear_array(): }, ) - variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) - - node.run(variable_pool) + with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run: + list(node.run()) + mock_run.assert_called_once() got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None diff --git a/docker-legacy/docker-compose.yaml b/docker-legacy/docker-compose.yaml index 1f23dc84e0..7075a31f2b 100644 --- a/docker-legacy/docker-compose.yaml +++ b/docker-legacy/docker-compose.yaml @@ -2,7 +2,7 @@ version: '3' services: # API service api: - image: langgenius/dify-api:0.7.2 + image: langgenius/dify-api:0.8.0 restart: always environment: # Startup mode, 'api' starts the API server. @@ -128,16 +128,14 @@ services: # The Qdrant server gRPC mode PORT. QDRANT_GRPC_PORT: 6334 # Milvus configuration Only available when VECTOR_STORE is `milvus`. - # The milvus host. - MILVUS_HOST: 127.0.0.1 - # The milvus host. - MILVUS_PORT: 19530 + # The milvus uri. + MILVUS_URI: http://127.0.0.1:19530 + # The milvus token. + MILVUS_TOKEN: '' # The milvus username. MILVUS_USER: root # The milvus password. MILVUS_PASSWORD: Milvus - # The milvus tls switch. - MILVUS_SECURE: 'false' # relyt configurations RELYT_HOST: db RELYT_PORT: 5432 @@ -229,7 +227,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.7.2 + image: langgenius/dify-api:0.8.0 restart: always environment: CONSOLE_WEB_URL: '' @@ -308,16 +306,14 @@ services: # The Qdrant server gRPC mode PORT. QDRANT_GRPC_PORT: 6334 # Milvus configuration Only available when VECTOR_STORE is `milvus`. - # The milvus host. - MILVUS_HOST: 127.0.0.1 - # The milvus host. - MILVUS_PORT: 19530 + # The milvus uri. + MILVUS_URI: http://127.0.0.1:19530 + # The milvus token. + MILVUS_PORT: '' # The milvus username. MILVUS_USER: root # The milvus password. MILVUS_PASSWORD: Milvus - # The milvus tls switch. - MILVUS_SECURE: 'false' # Mail configuration, support: resend MAIL_TYPE: '' # default send from email address, if not specified @@ -400,7 +396,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.7.2 + image: langgenius/dify-web:0.8.0 restart: always environment: # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is diff --git a/docker/.env.example b/docker/.env.example index 9138360d8b..c5c3cca9a6 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -75,7 +75,7 @@ INIT_PASSWORD= DEPLOY_ENV=PRODUCTION # Whether to enable the version check policy. -# If set to empty, https://updates.dify.ai will not be called for version check. +# If set to empty, https://updates.dify.ai will be called for version check. CHECK_UPDATE_URL=https://updates.dify.ai # Used to change the OpenAI base address, default is https://api.openai.com/v1. @@ -214,6 +214,18 @@ REDIS_USERNAME= REDIS_PASSWORD=difyai123456 REDIS_USE_SSL=false +# Whether to use Redis Sentinel mode. +# If set to true, the application will automatically discover and connect to the master node through Sentinel. +REDIS_USE_SENTINEL=false + +# List of Redis Sentinel nodes. If Sentinel mode is enabled, provide at least one Sentinel IP and port. +# Format: `:,:,:` +REDIS_SENTINELS= +REDIS_SENTINEL_SERVICE_NAME= +REDIS_SENTINEL_USERNAME= +REDIS_SENTINEL_PASSWORD= +REDIS_SENTINEL_SOCKET_TIMEOUT=0.1 + # ------------------------------ # Celery Configuration # ------------------------------ @@ -221,9 +233,16 @@ REDIS_USE_SSL=false # Use redis as the broker, and redis db 1 for celery broker. # Format as follows: `redis://:@:/` # Example: redis://:difyai123456@redis:6379/1 +# If use Redis Sentinel, format as follows: `sentinel://:@:/` +# Example: sentinel://localhost:26379/1;sentinel://localhost:26380/1;sentinel://localhost:26381/1 CELERY_BROKER_URL=redis://:difyai123456@redis:6379/1 BROKER_USE_SSL=false +# If you are using Redis Sentinel for high availability, configure the following settings. +CELERY_USE_SENTINEL=false +CELERY_SENTINEL_MASTER_NAME= +CELERY_SENTINEL_SOCKET_TIMEOUT=0.1 + # ------------------------------ # CORS Configuration # Used to set the front-end cross-domain access policy. @@ -242,7 +261,7 @@ CONSOLE_CORS_ALLOW_ORIGINS=* # ------------------------------ # The type of storage to use for storing user files. -# Supported values are `local` and `s3` and `azure-blob` and `google-storage` and `tencent-cos`, +# Supported values are `local` and `s3` and `azure-blob` and `google-storage` and `tencent-cos` and `huawei-obs` # Default: `local` STORAGE_TYPE=local @@ -285,6 +304,8 @@ ALIYUN_OSS_SECRET_KEY=your-secret-key ALIYUN_OSS_ENDPOINT=https://oss-ap-southeast-1-internal.aliyuncs.com ALIYUN_OSS_REGION=ap-southeast-1 ALIYUN_OSS_AUTH_VERSION=v4 +# Don't start with '/'. OSS doesn't support leading slash in object names. +ALIYUN_OSS_PATH=your-path # Tencent COS Configuration # The name of the Tencent COS bucket to use for storing files. @@ -298,6 +319,28 @@ TENCENT_COS_REGION=your-region # The scheme of the Tencent COS service. TENCENT_COS_SCHEME=your-scheme +# Huawei OBS Configuration +# The name of the Huawei OBS bucket to use for storing files. +HUAWEI_OBS_BUCKET_NAME=your-bucket-name +# The secret key to use for authenticating with the Huawei OBS service. +HUAWEI_OBS_SECRET_KEY=your-secret-key +# The access key to use for authenticating with the Huawei OBS service. +HUAWEI_OBS_ACCESS_KEY=your-access-key +# The server url of the HUAWEI OBS service. +HUAWEI_OBS_SERVER=your-server-url + +# Volcengine TOS Configuration +# The name of the Volcengine TOS bucket to use for storing files. +VOLCENGINE_TOS_BUCKET_NAME=your-bucket-name +# The secret key to use for authenticating with the Volcengine TOS service. +VOLCENGINE_TOS_SECRET_KEY=your-secret-key +# The access key to use for authenticating with the Volcengine TOS service. +VOLCENGINE_TOS_ACCESS_KEY=your-access-key +# The endpoint of the Volcengine TOS service. +VOLCENGINE_TOS_ENDPOINT=your-server-url +# The region of the Volcengine TOS service. +VOLCENGINE_TOS_REGION=your-region + # ------------------------------ # Vector Database Configuration # ------------------------------ @@ -323,16 +366,14 @@ QDRANT_GRPC_ENABLED=false QDRANT_GRPC_PORT=6334 # Milvus configuration Only available when VECTOR_STORE is `milvus`. -# The milvus host. -MILVUS_HOST=127.0.0.1 -# The milvus host. -MILVUS_PORT=19530 +# The milvus uri. +MILVUS_URI=http://127.0.0.1:19530 +# The milvus token. +MILVUS_TOKEN= # The milvus username. MILVUS_USER=root # The milvus password. MILVUS_PASSWORD=Milvus -# The milvus tls switch. -MILVUS_SECURE=false # MyScale configuration, only available when VECTOR_STORE is `myscale` # For multi-language support, please set MYSCALE_FTS_PARAMS with referring to: diff --git a/docker/README.md b/docker/README.md index 1223a58024..7ce3f9bd75 100644 --- a/docker/README.md +++ b/docker/README.md @@ -83,7 +83,7 @@ The `.env.example` file provided in the Docker setup is extensive and covers a w 7. **Vector Database Configuration**: - `VECTOR_STORE`: Type of vector database (e.g., `weaviate`, `milvus`). - - Specific settings for each vector store like `WEAVIATE_ENDPOINT`, `MILVUS_HOST`. + - Specific settings for each vector store like `WEAVIATE_ENDPOINT`, `MILVUS_URI`. 8. **CORS Configuration**: - `WEB_API_CORS_ALLOW_ORIGINS`, `CONSOLE_CORS_ALLOW_ORIGINS`: Settings for cross-origin resource sharing. diff --git a/docker/certbot/README.md b/docker/certbot/README.md index c6f73ae699..21be34b33a 100644 --- a/docker/certbot/README.md +++ b/docker/certbot/README.md @@ -2,8 +2,8 @@ ## Short description -Docker-compose certbot configurations with Backward compatibility (without certbot container). -Use `docker-compose --profile certbot up` to use this features. +docker compose certbot configurations with Backward compatibility (without certbot container). +Use `docker compose --profile certbot up` to use this features. ## The simplest way for launching new servers with SSL certificates @@ -18,21 +18,21 @@ Use `docker-compose --profile certbot up` to use this features. ``` execute command: ```shell - sudo docker network prune - sudo docker-compose --profile certbot up --force-recreate -d + docker network prune + docker compose --profile certbot up --force-recreate -d ``` then after the containers launched: ```shell - sudo docker-compose exec -it certbot /bin/sh /update-cert.sh + docker compose exec -it certbot /bin/sh /update-cert.sh ``` -2. Edit `.env` file and `sudo docker-compose --profile certbot up` again. +2. Edit `.env` file and `docker compose --profile certbot up` again. set `.env` value additionally ```properties NGINX_HTTPS_ENABLED=true ``` execute command: ```shell - sudo docker-compose --profile certbot up -d --no-deps --force-recreate nginx + docker compose --profile certbot up -d --no-deps --force-recreate nginx ``` Then you can access your serve with HTTPS. [https://your_domain.com](https://your_domain.com) @@ -42,8 +42,8 @@ Use `docker-compose --profile certbot up` to use this features. For SSL certificates renewal, execute commands below: ```shell -sudo docker-compose exec -it certbot /bin/sh /update-cert.sh -sudo docker-compose exec nginx nginx -s reload +docker compose exec -it certbot /bin/sh /update-cert.sh +docker compose exec nginx nginx -s reload ``` ## Options for certbot @@ -57,14 +57,14 @@ CERTBOT_OPTIONS=--dry-run To apply changes to `CERTBOT_OPTIONS`, regenerate the certbot container before updating the certificates. ```shell -sudo docker-compose --profile certbot up -d --no-deps --force-recreate certbot -sudo docker-compose exec -it certbot /bin/sh /update-cert.sh +docker compose --profile certbot up -d --no-deps --force-recreate certbot +docker compose exec -it certbot /bin/sh /update-cert.sh ``` Then, reload the nginx container if necessary. ```shell -sudo docker-compose exec nginx nginx -s reload +docker compose exec nginx nginx -s reload ``` ## For legacy servers @@ -72,5 +72,5 @@ sudo docker-compose exec nginx nginx -s reload To use cert files dir `nginx/ssl` as before, simply launch containers WITHOUT `--profile certbot` option. ```shell -sudo docker-compose up -d +docker compose up -d ``` diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 9a4c40448b..dbfc1ea531 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -19,6 +19,11 @@ services: - ./volumes/db/data:/var/lib/postgresql/data ports: - "${EXPOSE_POSTGRES_PORT:-5432}:5432" + healthcheck: + test: [ "CMD", "pg_isready" ] + interval: 1s + timeout: 3s + retries: 30 # The redis cache. redis: @@ -31,10 +36,12 @@ services: command: redis-server --requirepass difyai123456 ports: - "${EXPOSE_REDIS_PORT:-6379}:6379" + healthcheck: + test: [ "CMD", "redis-cli", "ping" ] # The DifySandbox sandbox: - image: langgenius/dify-sandbox:0.2.6 + image: langgenius/dify-sandbox:0.2.7 restart: always environment: # The DifySandbox configurations @@ -49,6 +56,8 @@ services: SANDBOX_PORT: ${SANDBOX_PORT:-8194} volumes: - ./volumes/sandbox/dependencies:/dependencies + healthcheck: + test: [ "CMD", "curl", "-f", "http://localhost:8194/health" ] networks: - ssrf_proxy_network diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index c1ed2ce5f8..5afb876a1c 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -42,8 +42,17 @@ x-shared-env: &shared-api-worker-env REDIS_PASSWORD: ${REDIS_PASSWORD:-difyai123456} REDIS_USE_SSL: ${REDIS_USE_SSL:-false} REDIS_DB: 0 + REDIS_USE_SENTINEL: ${REDIS_USE_SENTINEL:-false} + REDIS_SENTINELS: ${REDIS_SENTINELS:-} + REDIS_SENTINEL_SERVICE_NAME: ${REDIS_SENTINEL_SERVICE_NAME:-} + REDIS_SENTINEL_USERNAME: ${REDIS_SENTINEL_USERNAME:-} + REDIS_SENTINEL_PASSWORD: ${REDIS_SENTINEL_PASSWORD:-} + REDIS_SENTINEL_SOCKET_TIMEOUT: ${REDIS_SENTINEL_SOCKET_TIMEOUT:-0.1} CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1} BROKER_USE_SSL: ${BROKER_USE_SSL:-false} + CELERY_USE_SENTINEL: ${CELERY_USE_SENTINEL:-false} + CELERY_SENTINEL_MASTER_NAME: ${CELERY_SENTINEL_MASTER_NAME:-} + CELERY_SENTINEL_SOCKET_TIMEOUT: ${CELERY_SENTINEL_SOCKET_TIMEOUT:-0.1} WEB_API_CORS_ALLOW_ORIGINS: ${WEB_API_CORS_ALLOW_ORIGINS:-*} CONSOLE_CORS_ALLOW_ORIGINS: ${CONSOLE_CORS_ALLOW_ORIGINS:-*} STORAGE_TYPE: ${STORAGE_TYPE:-local} @@ -66,16 +75,26 @@ x-shared-env: &shared-api-worker-env ALIYUN_OSS_ENDPOINT: ${ALIYUN_OSS_ENDPOINT:-} ALIYUN_OSS_REGION: ${ALIYUN_OSS_REGION:-} ALIYUN_OSS_AUTH_VERSION: ${ALIYUN_OSS_AUTH_VERSION:-v4} + ALIYUN_OSS_PATHS: ${ALIYUN_OSS_PATH:-} TENCENT_COS_BUCKET_NAME: ${TENCENT_COS_BUCKET_NAME:-} TENCENT_COS_SECRET_KEY: ${TENCENT_COS_SECRET_KEY:-} TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-} TENCENT_COS_REGION: ${TENCENT_COS_REGION:-} TENCENT_COS_SCHEME: ${TENCENT_COS_SCHEME:-} + HUAWEI_OBS_BUCKET_NAME: ${HUAWEI_OBS_BUCKET_NAME:-} + HUAWEI_OBS_SECRET_KEY: ${HUAWEI_OBS_SECRET_KEY:-} + HUAWEI_OBS_ACCESS_KEY: ${HUAWEI_OBS_ACCESS_KEY:-} + HUAWEI_OBS_SERVER: ${HUAWEI_OBS_SERVER:-} OCI_ENDPOINT: ${OCI_ENDPOINT:-} OCI_BUCKET_NAME: ${OCI_BUCKET_NAME:-} OCI_ACCESS_KEY: ${OCI_ACCESS_KEY:-} OCI_SECRET_KEY: ${OCI_SECRET_KEY:-} OCI_REGION: ${OCI_REGION:-} + VOLCENGINE_TOS_BUCKET_NAME: ${VOLCENGINE_TOS_BUCKET_NAME:-} + VOLCENGINE_TOS_SECRET_KEY: ${VOLCENGINE_TOS_SECRET_KEY:-} + VOLCENGINE_TOS_ACCESS_KEY: ${VOLCENGINE_TOS_ACCESS_KEY:-} + VOLCENGINE_TOS_ENDPOINT: ${VOLCENGINE_TOS_ENDPOINT:-} + VOLCENGINE_TOS_REGION: ${VOLCENGINE_TOS_REGION:-} VECTOR_STORE: ${VECTOR_STORE:-weaviate} WEAVIATE_ENDPOINT: ${WEAVIATE_ENDPOINT:-http://weaviate:8080} WEAVIATE_API_KEY: ${WEAVIATE_API_KEY:-WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih} @@ -84,11 +103,10 @@ x-shared-env: &shared-api-worker-env QDRANT_CLIENT_TIMEOUT: ${QDRANT_CLIENT_TIMEOUT:-20} QDRANT_GRPC_ENABLED: ${QDRANT_GRPC_ENABLED:-false} QDRANT_GRPC_PORT: ${QDRANT_GRPC_PORT:-6334} - MILVUS_HOST: ${MILVUS_HOST:-127.0.0.1} - MILVUS_PORT: ${MILVUS_PORT:-19530} + MILVUS_URI: ${MILVUS_URI:-http://127.0.0.1:19530} + MILVUS_TOKEN: ${MILVUS_TOKEN:-} MILVUS_USER: ${MILVUS_USER:-root} MILVUS_PASSWORD: ${MILVUS_PASSWORD:-Milvus} - MILVUS_SECURE: ${MILVUS_SECURE:-false} MYSCALE_HOST: ${MYSCALE_HOST:-myscale} MYSCALE_PORT: ${MYSCALE_PORT:-8123} MYSCALE_USER: ${MYSCALE_USER:-default} @@ -190,7 +208,7 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:0.7.2 + image: langgenius/dify-api:0.8.0 restart: always environment: # Use the shared environment variables. @@ -210,7 +228,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.7.2 + image: langgenius/dify-api:0.8.0 restart: always environment: # Use the shared environment variables. @@ -229,7 +247,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.7.2 + image: langgenius/dify-web:0.8.0 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -274,7 +292,7 @@ services: # The DifySandbox sandbox: - image: langgenius/dify-sandbox:0.2.6 + image: langgenius/dify-sandbox:0.2.7 restart: always environment: # The DifySandbox configurations @@ -289,12 +307,14 @@ services: SANDBOX_PORT: ${SANDBOX_PORT:-8194} volumes: - ./volumes/sandbox/dependencies:/dependencies + healthcheck: + test: [ "CMD", "curl", "-f", "http://localhost:8194/health" ] networks: - ssrf_proxy_network # ssrf_proxy server # for more information, please refer to - # https://docs.dify.ai/learn-more/faq/self-host-faq#id-18.-why-is-ssrf_proxy-needed + # https://docs.dify.ai/learn-more/faq/install-faq#id-18.-why-is-ssrf_proxy-needed ssrf_proxy: image: ubuntu/squid:latest restart: always @@ -331,7 +351,7 @@ services: - CERTBOT_DOMAIN=${CERTBOT_DOMAIN} - CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-} entrypoint: [ "/docker-entrypoint.sh" ] - command: ["tail", "-f", "/dev/null"] + command: [ "tail", "-f", "/dev/null" ] # The nginx reverse proxy. # used for reverse proxying the API service and Web service. @@ -377,7 +397,7 @@ services: weaviate: image: semitechnologies/weaviate:1.19.0 profiles: - - '' + - "" - weaviate restart: always volumes: @@ -473,13 +493,13 @@ services: - oracle restart: always volumes: - - type: volume - source: oradata + - source: oradata + type: volume target: /opt/oracle/oradata - ./startupscripts:/opt/oracle/scripts/startup environment: - - ORACLE_PWD=${ORACLE_PWD:-Dify123456} - - ORACLE_CHARACTERSET=${ORACLE_CHARACTERSET:-AL32UTF8} + ORACLE_PWD: ${ORACLE_PWD:-Dify123456} + ORACLE_CHARACTERSET: ${ORACLE_CHARACTERSET:-AL32UTF8} # Milvus vector database services etcd: @@ -488,10 +508,10 @@ services: profiles: - milvus environment: - - ETCD_AUTO_COMPACTION_MODE=${ETCD_AUTO_COMPACTION_MODE:-revision} - - ETCD_AUTO_COMPACTION_RETENTION=${ETCD_AUTO_COMPACTION_RETENTION:-1000} - - ETCD_QUOTA_BACKEND_BYTES=${ETCD_QUOTA_BACKEND_BYTES:-4294967296} - - ETCD_SNAPSHOT_COUNT=${ETCD_SNAPSHOT_COUNT:-50000} + ETCD_AUTO_COMPACTION_MODE: ${ETCD_AUTO_COMPACTION_MODE:-revision} + ETCD_AUTO_COMPACTION_RETENTION: ${ETCD_AUTO_COMPACTION_RETENTION:-1000} + ETCD_QUOTA_BACKEND_BYTES: ${ETCD_QUOTA_BACKEND_BYTES:-4294967296} + ETCD_SNAPSHOT_COUNT: ${ETCD_SNAPSHOT_COUNT:-50000} volumes: - ./volumes/milvus/etcd:/etcd command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd @@ -541,8 +561,11 @@ services: timeout: 20s retries: 3 depends_on: - - "etcd" - - "minio" + - etcd + - minio + ports: + - 19530:19530 + - 9091:9091 networks: - milvus @@ -553,10 +576,10 @@ services: profiles: - opensearch environment: - - discovery.type=${OPENSEARCH_DISCOVERY_TYPE:-single-node} - - bootstrap.memory_lock=${OPENSEARCH_BOOTSTRAP_MEMORY_LOCK:-true} - - OPENSEARCH_JAVA_OPTS=-Xms${OPENSEARCH_JAVA_OPTS_MIN:-512m} -Xmx${OPENSEARCH_JAVA_OPTS_MAX:-1024m} - - OPENSEARCH_INITIAL_ADMIN_PASSWORD=${OPENSEARCH_INITIAL_ADMIN_PASSWORD:-Qazwsxedc!@#123} + discovery.type: ${OPENSEARCH_DISCOVERY_TYPE:-single-node} + bootstrap.memory_lock: ${OPENSEARCH_BOOTSTRAP_MEMORY_LOCK:-true} + OPENSEARCH_JAVA_OPTS: -Xms${OPENSEARCH_JAVA_OPTS_MIN:-512m} -Xmx${OPENSEARCH_JAVA_OPTS_MAX:-1024m} + OPENSEARCH_INITIAL_ADMIN_PASSWORD: ${OPENSEARCH_INITIAL_ADMIN_PASSWORD:-Qazwsxedc!@#123} ulimits: memlock: soft: ${OPENSEARCH_MEMLOCK_SOFT:--1} @@ -596,7 +619,7 @@ services: - ./volumes/myscale/log:/var/log/clickhouse-server - ./volumes/myscale/config/users.d/custom_users_config.xml:/etc/clickhouse-server/users.d/custom_users_config.xml ports: - - "${MYSCALE_PORT:-8123}:${MYSCALE_PORT:-8123}" + - ${MYSCALE_PORT:-8123}:${MYSCALE_PORT:-8123} # https://www.elastic.co/guide/en/elasticsearch/reference/current/settings.html # https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html#docker-prod-prerequisites @@ -609,18 +632,18 @@ services: volumes: - dify_es01_data:/usr/share/elasticsearch/data environment: - - ELASTIC_PASSWORD=${ELASTICSEARCH_PASSWORD:-elastic} - - cluster.name=dify-es-cluster - - node.name=dify-es0 - - discovery.type=single-node - - xpack.license.self_generated.type=trial - - xpack.security.enabled=true - - xpack.security.enrollment.enabled=false - - xpack.security.http.ssl.enabled=false + ELASTIC_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} + cluster.name: dify-es-cluster + node.name: dify-es0 + discovery.type: single-node + xpack.license.self_generated.type: trial + xpack.security.enabled: "true" + xpack.security.enrollment.enabled: "false" + xpack.security.http.ssl.enabled: "false" ports: - ${ELASTICSEARCH_PORT:-9200}:9200 healthcheck: - test: ["CMD", "curl", "-s", "http://localhost:9200/_cluster/health?pretty"] + test: [ "CMD", "curl", "-s", "http://localhost:9200/_cluster/health?pretty" ] interval: 30s timeout: 10s retries: 50 @@ -636,15 +659,15 @@ services: - elasticsearch restart: always environment: - - XPACK_ENCRYPTEDSAVEDOBJECTS_ENCRYPTIONKEY=d1a66dfd-c4d3-4a0a-8290-2abcb83ab3aa - - NO_PROXY=localhost,127.0.0.1,elasticsearch,kibana - - XPACK_SECURITY_ENABLED=true - - XPACK_SECURITY_ENROLLMENT_ENABLED=false - - XPACK_SECURITY_HTTP_SSL_ENABLED=false - - XPACK_FLEET_ISAIRGAPPED=true - - I18N_LOCALE=zh-CN - - SERVER_PORT=5601 - - ELASTICSEARCH_HOSTS="http://elasticsearch:9200" + XPACK_ENCRYPTEDSAVEDOBJECTS_ENCRYPTIONKEY: d1a66dfd-c4d3-4a0a-8290-2abcb83ab3aa + NO_PROXY: localhost,127.0.0.1,elasticsearch,kibana + XPACK_SECURITY_ENABLED: "true" + XPACK_SECURITY_ENROLLMENT_ENABLED: "false" + XPACK_SECURITY_HTTP_SSL_ENABLED: "false" + XPACK_FLEET_ISAIRGAPPED: "true" + I18N_LOCALE: zh-CN + SERVER_PORT: "5601" + ELASTICSEARCH_HOSTS: http://elasticsearch:9200 ports: - ${KIBANA_PORT:-5601}:5601 healthcheck: diff --git a/docker/startupscripts/init.sh b/docker/startupscripts/init.sh index ee7600850a..c6e6e1966f 100755 --- a/docker/startupscripts/init.sh +++ b/docker/startupscripts/init.sh @@ -1,13 +1,13 @@ #!/usr/bin/env bash -DB_INITIALISED="/opt/oracle/oradata/dbinit" -#[ -f ${DB_INITIALISED} ] && exit -#touch ${DB_INITIALISED} -if [ -f ${DB_INITIALISED} ]; then +DB_INITIALIZED="/opt/oracle/oradata/dbinit" +#[ -f ${DB_INITIALIZED} ] && exit +#touch ${DB_INITIALIZED} +if [ -f ${DB_INITIALIZED} ]; then echo 'File exists. Standards for have been Init' exit else - echo 'File does not exist. Standards for first time Strart up this DB' + echo 'File does not exist. Standards for first time Start up this DB' "$ORACLE_HOME"/bin/sqlplus -s "/ as sysdba" @"/opt/oracle/scripts/startup/init_user.script"; - touch ${DB_INITIALISED} + touch ${DB_INITIALIZED} fi diff --git a/sdks/nodejs-client/README.md b/sdks/nodejs-client/README.md index 50303b4867..37b5ca2d0a 100644 --- a/sdks/nodejs-client/README.md +++ b/sdks/nodejs-client/README.md @@ -18,7 +18,7 @@ const query = 'Please tell me a short story in 10 words or less.' const remote_url_files = [{ type: 'image', transfer_method: 'remote_url', - url: 'your_url_addresss' + url: 'your_url_address' }] // Create a completion client diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx index 84ec157323..e728749b85 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx @@ -108,7 +108,7 @@ const AppDetailLayout: FC = (props) => { useEffect(() => { setAppDetail() fetchAppDetail({ url: '/apps', id: appId }).then((res) => { - // redirections + // redirection if ((res.mode === 'workflow' || res.mode === 'advanced-chat') && (pathname).endsWith('configuration')) { router.replace(`/app/${appId}/workflow`) } @@ -128,7 +128,7 @@ const AppDetailLayout: FC = (props) => { if (e.status === 404) router.replace('/apps') }) - }, [appId, isCurrentWorkspaceEditor]) + }, [appId, isCurrentWorkspaceEditor, systemFeatures]) useUnmount(() => { setAppDetail() diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView.tsx index 3584e13733..8f3ee510b8 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView.tsx @@ -95,7 +95,7 @@ const CardView: FC = ({ appId }) => { if (systemFeatures.enable_web_sso_switch_component) { const [sso_err] = await asyncRunSafe( - updateAppSSO({ id: appId, enabled: params.enable_sso }) as Promise, + updateAppSSO({ id: appId, enabled: Boolean(params.enable_sso) }) as Promise, ) if (sso_err) { handleCallbackResult(sso_err) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx index b908322a92..6e5046ecf8 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx @@ -48,9 +48,10 @@ const ProviderPanel: FC = ({ e.preventDefault() e.stopPropagation() - const url = `${config?.host}/project/${config?.project_key}` - window.open(url, '_blank', 'noopener,noreferrer') - }, []) + const url = config?.project_url + if (url) + window.open(url, '_blank', 'noopener,noreferrer') + }, [config?.project_url]) const handleChosen = useCallback((e: React.MouseEvent) => { e.stopPropagation() diff --git a/web/app/(commonLayout)/apps/AppCard.tsx b/web/app/(commonLayout)/apps/AppCard.tsx index bc7308a711..1ffb132cf8 100644 --- a/web/app/(commonLayout)/apps/AppCard.tsx +++ b/web/app/(commonLayout)/apps/AppCard.tsx @@ -21,7 +21,7 @@ import Divider from '@/app/components/base/divider' import { getRedirection } from '@/utils/app-redirection' import { useProviderContext } from '@/context/provider-context' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' -import { AiText, ChatBot, CuteRobote } from '@/app/components/base/icons/src/vender/solid/communication' +import { AiText, ChatBot, CuteRobot } from '@/app/components/base/icons/src/vender/solid/communication' import { Route } from '@/app/components/base/icons/src/vender/solid/mapsAndTravel' import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal' import EditAppModal from '@/app/components/explore/create-app-modal' @@ -79,6 +79,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { icon, icon_background, description, + use_icon_as_answer_icon, }) => { try { await updateAppInfo({ @@ -88,6 +89,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { icon, icon_background, description, + use_icon_as_answer_icon, }) setShowEditModal(false) notify({ @@ -255,7 +257,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { e.preventDefault() getRedirection(isCurrentWorkspaceEditor, app, push) }} - className='group flex col-span-1 bg-white border-2 border-solid border-transparent rounded-xl shadow-sm min-h-[160px] flex flex-col transition-all duration-200 ease-in-out cursor-pointer hover:shadow-lg' + className='relative group col-span-1 bg-white border-2 border-solid border-transparent rounded-xl shadow-sm flex flex-col transition-all duration-200 ease-in-out cursor-pointer hover:shadow-lg' >
@@ -271,7 +273,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { )} {app.mode === 'agent-chat' && ( - + )} {app.mode === 'chat' && ( @@ -297,17 +299,16 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
-
- {app.description} +
+
+ {app.description} +
{isCurrentWorkspaceEditor && ( @@ -371,6 +372,8 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { appIconBackground={app.icon_background} appIconUrl={app.icon_url} appDescription={app.description} + appMode={app.mode} + appUseIconAsAnswerIcon={app.use_icon_as_answer_icon} show={showEditModal} onConfirm={onEdit} onHide={() => setShowEditModal(false)} diff --git a/web/app/(commonLayout)/apps/Apps.tsx b/web/app/(commonLayout)/apps/Apps.tsx index c16512bd50..132096c6b4 100644 --- a/web/app/(commonLayout)/apps/Apps.tsx +++ b/web/app/(commonLayout)/apps/Apps.tsx @@ -139,7 +139,7 @@ const Apps = () => {
- {/* desscription */} + {/* description */} {appDetail.description && (
{appDetail.description}
)} @@ -423,6 +425,8 @@ const AppInfo = ({ expand }: IAppInfoProps) => { appIconBackground={appDetail.icon_background} appIconUrl={appDetail.icon_url} appDescription={appDetail.description} + appMode={appDetail.mode} + appUseIconAsAnswerIcon={appDetail.use_icon_as_answer_icon} show={showEditModal} onConfirm={onEdit} onHide={() => setShowEditModal(false)} diff --git a/web/app/components/app/annotation/index.tsx b/web/app/components/app/annotation/index.tsx index 0f54f5bfc3..c66aaef6ce 100644 --- a/web/app/components/app/annotation/index.tsx +++ b/web/app/components/app/annotation/index.tsx @@ -232,8 +232,8 @@ const Annotation: FC = ({ middlePagesSiblingCount={1} setCurrentPage={setCurrPage} totalPages={Math.ceil(total / APP_PAGE_LIMIT)} - truncableClassName="w-8 px-0.5 text-center" - truncableText="..." + truncatableClassName="w-8 px-0.5 text-center" + truncatableText="..." > = ({ middlePagesSiblingCount={1} setCurrentPage={setCurrPage} totalPages={Math.ceil(total / APP_PAGE_LIMIT)} - truncableClassName="w-8 px-0.5 text-center" - truncableText="..." + truncatableClassName="w-8 px-0.5 text-center" + truncatableText="..." > state.appDetail) + const [publishedTime, setPublishedTime] = useState(publishedAt) const { app_base_url: appBaseURL = '', access_token: accessToken = '' } = appDetail?.site ?? {} const appMode = (appDetail?.mode !== 'completion' && appDetail?.mode !== 'workflow') ? 'chat' : appDetail.mode const appURL = `${appBaseURL}/${appMode}/${accessToken}` @@ -76,6 +77,7 @@ const AppPublisher = ({ try { await onPublish?.(modelAndParameter) setPublished(true) + setPublishedTime(Date.now()) } catch (e) { setPublished(false) @@ -131,13 +133,13 @@ const AppPublisher = ({
- {publishedAt ? t('workflow.common.latestPublished') : t('workflow.common.currentDraftUnpublished')} + {publishedTime ? t('workflow.common.latestPublished') : t('workflow.common.currentDraftUnpublished')}
- {publishedAt + {publishedTime ? (
- {t('workflow.common.publishedAt')} {formatTimeFromNow(publishedAt)} + {t('workflow.common.publishedAt')} {formatTimeFromNow(publishedTime)}
- }>{t('workflow.common.runApp')} + }>{t('workflow.common.runApp')} {appDetail?.mode === 'workflow' ? ( } > @@ -199,16 +201,16 @@ const AppPublisher = ({ setEmbeddingModalOpen(true) handleTrigger() }} - disabled={!publishedAt} + disabled={!publishedTime} icon={} > {t('workflow.common.embedIntoSite')} )} - }>{t('workflow.common.accessAPIReference')} + }>{t('workflow.common.accessAPIReference')} {appDetail?.mode === 'workflow' && ( = ({ {isShowConfirmAddVar && ( v.name)} - onConfrim={handleAutoAdd(true)} + onConfirm={handleAutoAdd(true)} onCancel={handleAutoAdd(false)} onHide={hideConfirmAddVar} /> diff --git a/web/app/components/app/configuration/config-prompt/confirm-add-var/index.tsx b/web/app/components/app/configuration/config-prompt/confirm-add-var/index.tsx index f08f2ffc69..922f8bb36a 100644 --- a/web/app/components/app/configuration/config-prompt/confirm-add-var/index.tsx +++ b/web/app/components/app/configuration/config-prompt/confirm-add-var/index.tsx @@ -7,7 +7,7 @@ import Button from '@/app/components/base/button' export type IConfirmAddVarProps = { varNameArr: string[] - onConfrim: () => void + onConfirm: () => void onCancel: () => void onHide: () => void } @@ -22,7 +22,7 @@ const VarIcon = ( const ConfirmAddVar: FC = ({ varNameArr, - onConfrim, + onConfirm, onCancel, // onHide, }) => { @@ -63,7 +63,7 @@ const ConfirmAddVar: FC = ({
- +
diff --git a/web/app/components/app/configuration/config-prompt/conversation-histroy/edit-modal.tsx b/web/app/components/app/configuration/config-prompt/conversation-history/edit-modal.tsx similarity index 100% rename from web/app/components/app/configuration/config-prompt/conversation-histroy/edit-modal.tsx rename to web/app/components/app/configuration/config-prompt/conversation-history/edit-modal.tsx diff --git a/web/app/components/app/configuration/config-prompt/conversation-histroy/history-panel.tsx b/web/app/components/app/configuration/config-prompt/conversation-history/history-panel.tsx similarity index 100% rename from web/app/components/app/configuration/config-prompt/conversation-histroy/history-panel.tsx rename to web/app/components/app/configuration/config-prompt/conversation-history/history-panel.tsx diff --git a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx index 69e01a8e22..d7bfe8534e 100644 --- a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx @@ -33,7 +33,7 @@ export type ISimplePromptInput = { promptTemplate: string promptVariables: PromptVariable[] readonly?: boolean - onChange?: (promp: string, promptVariables: PromptVariable[]) => void + onChange?: (prompt: string, promptVariables: PromptVariable[]) => void noTitle?: boolean gradientBorder?: boolean editorHeight?: number @@ -239,7 +239,7 @@ const Prompt: FC = ({ {isShowConfirmAddVar && ( v.name)} - onConfrim={handleAutoAdd(true)} + onConfirm={handleAutoAdd(true)} onCancel={handleAutoAdd(false)} onHide={hideConfirmAddVar} /> diff --git a/web/app/components/app/configuration/config-var/config-modal/index.tsx b/web/app/components/app/configuration/config-var/config-modal/index.tsx index 3296c77fb2..606280653e 100644 --- a/web/app/components/app/configuration/config-var/config-modal/index.tsx +++ b/web/app/components/app/configuration/config-var/config-modal/index.tsx @@ -47,7 +47,7 @@ const ConfigModal: FC = ({ if (!isValid) { Toast.notify({ type: 'error', - message: t(`appDebug.varKeyError.${errorMessageKey}`, { key: t('appDebug.variableConig.varName') }), + message: t(`appDebug.varKeyError.${errorMessageKey}`, { key: t('appDebug.variableConfig.varName') }), }) return false } @@ -101,7 +101,7 @@ const ConfigModal: FC = ({ // } if (!tempPayload.label) { - Toast.notify({ type: 'error', message: t('appDebug.variableConig.errorMsg.labelNameRequired') }) + Toast.notify({ type: 'error', message: t('appDebug.variableConfig.errorMsg.labelNameRequired') }) return } if (isStringInput || type === InputVarType.number) { @@ -109,7 +109,7 @@ const ConfigModal: FC = ({ } else { if (options?.length === 0) { - Toast.notify({ type: 'error', message: t('appDebug.variableConig.errorMsg.atLeastOneOption') }) + Toast.notify({ type: 'error', message: t('appDebug.variableConfig.errorMsg.atLeastOneOption') }) return } const obj: Record = {} @@ -122,7 +122,7 @@ const ConfigModal: FC = ({ obj[o] = true }) if (hasRepeatedItem) { - Toast.notify({ type: 'error', message: t('appDebug.variableConig.errorMsg.optionRepeat') }) + Toast.notify({ type: 'error', message: t('appDebug.variableConfig.errorMsg.optionRepeat') }) return } onConfirm(tempPayload, moreInfo) @@ -131,14 +131,14 @@ const ConfigModal: FC = ({ return (
- +
handlePayloadChange('type')(InputVarType.textInput)} /> handlePayloadChange('type')(InputVarType.paragraph)} /> @@ -147,39 +147,39 @@ const ConfigModal: FC = ({
- + handlePayloadChange('variable')(e.target.value)} onBlur={handleVarKeyBlur} - placeholder={t('appDebug.variableConig.inputPlaceholder')!} + placeholder={t('appDebug.variableConfig.inputPlaceholder')!} /> - + handlePayloadChange('label')(e.target.value)} - placeholder={t('appDebug.variableConig.inputPlaceholder')!} + placeholder={t('appDebug.variableConfig.inputPlaceholder')!} /> {isStringInput && ( - + )} {type === InputVarType.select && ( - + )} - +
diff --git a/web/app/components/app/configuration/config-var/config-select/index.tsx b/web/app/components/app/configuration/config-var/config-select/index.tsx index e23c7330b1..449cb8b12f 100644 --- a/web/app/components/app/configuration/config-var/config-select/index.tsx +++ b/web/app/components/app/configuration/config-var/config-select/index.tsx @@ -77,7 +77,7 @@ const ConfigSelect: FC = ({ onClick={() => { onChange([...options, '']) }} className='flex items-center h-9 px-3 gap-2 rounded-lg cursor-pointer text-gray-400 bg-gray-100'> -
{t('appDebug.variableConig.addOption')}
+
{t('appDebug.variableConfig.addOption')}
) diff --git a/web/app/components/app/configuration/config-var/index.tsx b/web/app/components/app/configuration/config-var/index.tsx index 802528e0af..fc165571c4 100644 --- a/web/app/components/app/configuration/config-var/index.tsx +++ b/web/app/components/app/configuration/config-var/index.tsx @@ -88,7 +88,6 @@ const ConfigVar: FC = ({ promptVariables, readonly, onPromptVar } as InputVar })() const updatePromptVariableItem = (payload: InputVar) => { - console.log(payload) const newPromptVariables = produce(promptVariables, (draft) => { const { variable, label, type, ...rest } = payload draft[currIndex] = { diff --git a/web/app/components/app/configuration/config-var/select-type-item/index.tsx b/web/app/components/app/configuration/config-var/select-type-item/index.tsx index bb5e700d11..c76aed1a10 100644 --- a/web/app/components/app/configuration/config-var/select-type-item/index.tsx +++ b/web/app/components/app/configuration/config-var/select-type-item/index.tsx @@ -18,7 +18,7 @@ const SelectTypeItem: FC = ({ onClick, }) => { const { t } = useTranslation() - const typeName = t(`appDebug.variableConig.${type}`) + const typeName = t(`appDebug.variableConfig.${type}`) return (
= ({
- - - - + + + +
- +
diff --git a/web/app/components/app/configuration/config/agent/agent-setting/index.tsx b/web/app/components/app/configuration/config/agent/agent-setting/index.tsx index b295a4e709..959336457f 100644 --- a/web/app/components/app/configuration/config/agent/agent-setting/index.tsx +++ b/web/app/components/app/configuration/config/agent/agent-setting/index.tsx @@ -5,7 +5,7 @@ import { useTranslation } from 'react-i18next' import { RiCloseLine } from '@remixicon/react' import ItemPanel from './item-panel' import Button from '@/app/components/base/button' -import { CuteRobote } from '@/app/components/base/icons/src/vender/solid/communication' +import { CuteRobot } from '@/app/components/base/icons/src/vender/solid/communication' import { Unblur } from '@/app/components/base/icons/src/vender/solid/education' import Slider from '@/app/components/base/slider' import type { AgentConfig } from '@/models/debug' @@ -65,7 +65,7 @@ const AgentSetting: FC = ({ + } name={t('appDebug.agent.agentMode')} description={t('appDebug.agent.agentModeDes')} diff --git a/web/app/components/app/configuration/config/assistant-type-picker/index.tsx b/web/app/components/app/configuration/config/assistant-type-picker/index.tsx index 6bdf678f85..336d736e3b 100644 --- a/web/app/components/app/configuration/config/assistant-type-picker/index.tsx +++ b/web/app/components/app/configuration/config/assistant-type-picker/index.tsx @@ -12,7 +12,7 @@ import { } from '@/app/components/base/portal-to-follow-elem' import { BubbleText } from '@/app/components/base/icons/src/vender/solid/education' import Radio from '@/app/components/base/radio/ui' -import { CuteRobote } from '@/app/components/base/icons/src/vender/solid/communication' +import { CuteRobot } from '@/app/components/base/icons/src/vender/solid/communication' import { Settings04 } from '@/app/components/base/icons/src/vender/line/general' import { ArrowUpRight } from '@/app/components/base/icons/src/vender/line/arrows' import type { AgentConfig } from '@/models/debug' @@ -117,7 +117,7 @@ const AssistantTypePicker: FC = ({ > setOpen(v => !v)}>
- {isAgent ? : } + {isAgent ? : }
{t(`appDebug.assistantType.${isAgent ? 'agentAssistant' : 'chatAssistant'}.name`)}
@@ -135,7 +135,7 @@ const AssistantTypePicker: FC = ({ onClick={handleChange} /> void }) { - const [tempshowOpeningStatement, setTempShowOpeningStatement] = React.useState(!!introduction) + const [tempShowOpeningStatement, setTempShowOpeningStatement] = React.useState(!!introduction) useEffect(() => { // wait to api data back if (introduction) @@ -48,7 +48,7 @@ function useFeature({ // }, [moreLikeThis]) const featureConfig = { - openingStatement: tempshowOpeningStatement, + openingStatement: tempShowOpeningStatement, moreLikeThis, suggestedQuestionsAfterAnswer, speechToText, diff --git a/web/app/components/app/configuration/config/index.tsx b/web/app/components/app/configuration/config/index.tsx index b8bedba20b..12551f508e 100644 --- a/web/app/components/app/configuration/config/index.tsx +++ b/web/app/components/app/configuration/config/index.tsx @@ -7,9 +7,9 @@ import { useBoolean, useScroll } from 'ahooks' import { useFormattingChangedDispatcher } from '../debug/hooks' import DatasetConfig from '../dataset-config' import ChatGroup from '../features/chat-group' -import ExperienceEnchanceGroup from '../features/experience-enchance-group' +import ExperienceEnhanceGroup from '../features/experience-enhance-group' import Toolbox from '../toolbox' -import HistoryPanel from '../config-prompt/conversation-histroy/history-panel' +import HistoryPanel from '../config-prompt/conversation-history/history-panel' import ConfigVision from '../config-vision' import useAnnotationConfig from '../toolbox/annotation/use-annotation-config' import AddFeatureBtn from './feature/add-feature-btn' @@ -254,7 +254,7 @@ const Config: FC = () => { /> )} - {/* ChatConifig */} + {/* ChatConfig */} { hasChatConfig && ( { {/* Text Generation config */}{ hasCompletionConfig && ( - diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx index 7f55649dab..91cae54bb8 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx @@ -1,20 +1,12 @@ 'use client' -import { memo, useMemo } from 'react' +import { memo, useEffect, useMemo } from 'react' import type { FC } from 'react' import { useTranslation } from 'react-i18next' -import { - RiAlertFill, -} from '@remixicon/react' import WeightedScore from './weighted-score' import TopKItem from '@/app/components/base/param-item/top-k-item' import ScoreThresholdItem from '@/app/components/base/param-item/score-threshold-item' -import RadioCard from '@/app/components/base/radio-card/simple' import { RETRIEVE_TYPE } from '@/types/app' -import { - MultiPathRetrieval, - NTo1Retrieval, -} from '@/app/components/base/icons/src/public/common' import type { DatasetConfigs, } from '@/models/debug' @@ -31,7 +23,6 @@ import { RerankingModeEnum } from '@/models/datasets' import cn from '@/utils/classnames' import { useSelectedDatasetsMode } from '@/app/components/workflow/nodes/knowledge-retrieval/hooks' import Switch from '@/app/components/base/switch' -import { useGetLanguage } from '@/context/i18n' type Props = { datasetConfigs: DatasetConfigs @@ -43,11 +34,6 @@ type Props = { selectedDatasets?: DataSet[] } -const LEGACY_LINK_MAP = { - en_US: 'https://docs.dify.ai/guides/knowledge-base/integrate-knowledge-within-application', - zh_Hans: 'https://docs.dify.ai/v/zh-hans/guides/knowledge-base/integrate_knowledge_within_application', -} as Record - const ConfigContent: FC = ({ datasetConfigs, onChange, @@ -58,15 +44,18 @@ const ConfigContent: FC = ({ selectedDatasets = [], }) => { const { t } = useTranslation() - const language = useGetLanguage() const selectedDatasetsMode = useSelectedDatasetsMode(selectedDatasets) const type = datasetConfigs.retrieval_model - const setType = (value: RETRIEVE_TYPE) => { - onChange({ - ...datasetConfigs, - retrieval_model: value, - }, true) - } + + useEffect(() => { + if (type === RETRIEVE_TYPE.oneWay) { + onChange({ + ...datasetConfigs, + retrieval_model: RETRIEVE_TYPE.multiWay, + }, isInWorkflow) + } + }, [type]) + const { modelList: rerankModelList, defaultModel: rerankDefaultModel, @@ -166,63 +155,21 @@ const ConfigContent: FC = ({ return (
{t('dataset.retrievalSettings')}
-
- } - title={( -
- {t('appDebug.datasetConfig.retrieveOneWay.title')} - - {t('dataset.nTo1RetrievalLegacy')} -
- )} - > -
legacy
- -
- )} - description={t('appDebug.datasetConfig.retrieveOneWay.description')} - isChosen={type === RETRIEVE_TYPE.oneWay} - onChosen={() => { setType(RETRIEVE_TYPE.oneWay) }} - extra={( -
- -
- {t('dataset.nTo1RetrievalLegacyLinkText')} - - {t('dataset.nTo1RetrievalLegacyLink')} - -
-
- )} - /> - } - title={t('appDebug.datasetConfig.retrieveMultiWay.title')} - description={t('appDebug.datasetConfig.retrieveMultiWay.description')} - isChosen={type === RETRIEVE_TYPE.multiWay} - onChosen={() => { setType(RETRIEVE_TYPE.multiWay) }} - /> +
+ {t('dataset.defaultRetrievalTip')}
{type === RETRIEVE_TYPE.multiWay && ( <> -
-
- {t('dataset.rerankSettings')} +
+
+ {t('dataset.rerankSettings')} +
+
{ selectedDatasetsMode.inconsistentEmbeddingModel && ( -
+
{t('dataset.inconsistentEmbeddingModelTip')}
) @@ -230,7 +177,7 @@ const ConfigContent: FC = ({ { selectedDatasetsMode.mixtureHighQualityAndEconomic && ( -
+
{t('dataset.mixtureHighQualityAndEconomicTip')}
) diff --git a/web/app/components/app/configuration/dataset-config/params-config/index.tsx b/web/app/components/app/configuration/dataset-config/params-config/index.tsx index 5cb76e32b2..656cbfea65 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/index.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/index.tsx @@ -70,13 +70,13 @@ const ParamsConfig = ({ const { defaultModel: rerankDefaultModel, - currentModel: isRerankDefaultModelVaild, + currentModel: isRerankDefaultModelValid, } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) const isValid = () => { let errMsg = '' if (tempDataSetConfigs.retrieval_model === RETRIEVE_TYPE.multiWay) { - if (!tempDataSetConfigs.reranking_model?.reranking_model_name && (!rerankDefaultModel && isRerankDefaultModelVaild)) + if (!tempDataSetConfigs.reranking_model?.reranking_model_name && (!rerankDefaultModel && isRerankDefaultModelValid)) errMsg = t('appDebug.datasetConfig.rerankModelRequired') } if (errMsg) { diff --git a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx index f9117a51c3..4493755ba0 100644 --- a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx +++ b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx @@ -135,7 +135,7 @@ const SelectDataSet: FC = ({
{item.name}
{!item.embedding_available && ( - {t('dataset.unavailable')} + {t('dataset.unavailable')} )}
{ diff --git a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx index eec5979dd5..65858ce8cf 100644 --- a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx +++ b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx @@ -51,7 +51,7 @@ const SettingsModal: FC = ({ const { modelList: rerankModelList, defaultModel: rerankDefaultModel, - currentModel: isRerankDefaultModelVaild, + currentModel: isRerankDefaultModelValid, } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) const { t } = useTranslation() const { notify } = useToastContext() @@ -83,7 +83,7 @@ const SettingsModal: FC = ({ if ( !isReRankModelSelected({ rerankDefaultModel, - isRerankDefaultModelVaild: !!isRerankDefaultModelVaild, + isRerankDefaultModelValid: !!isRerankDefaultModelValid, rerankModelList, retrievalConfig, indexMethod, diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx index e79cdf4793..80dfb5c534 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx @@ -89,7 +89,7 @@ const ChatItem: FC = ({ `apps/${appId}/chat-messages`, data, { - onGetConvesationMessages: (conversationId, getAbortController) => fetchConversationMessages(appId, conversationId, getAbortController), + onGetConversationMessages: (conversationId, getAbortController) => fetchConversationMessages(appId, conversationId, getAbortController), onGetSuggestedQuestions: (responseItemId, getAbortController) => fetchSuggestedQuestions(appId, responseItemId, getAbortController), }, ) @@ -129,6 +129,7 @@ const ChatItem: FC = ({ questionIcon={} allToolIcons={allToolIcons} hideLogModal + noSpacing /> ) } diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx index 892d0cfe8b..3d2f3bca59 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx @@ -128,6 +128,7 @@ const DebugWithMultipleModel = () => { onSend={handleSend} speechToTextConfig={speechToTextConfig} visionConfig={visionConfig} + noSpacing />
) diff --git a/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx b/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx index ea15c1a4ce..d93ad00659 100644 --- a/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx +++ b/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx @@ -94,7 +94,7 @@ const DebugWithSingleModel = forwardRef fetchConversationMessages(appId, conversationId, getAbortController), + onGetConversationMessages: (conversationId, getAbortController) => fetchConversationMessages(appId, conversationId, getAbortController), onGetSuggestedQuestions: (responseItemId, getAbortController) => fetchSuggestedQuestions(appId, responseItemId, getAbortController), }, ) @@ -130,6 +130,7 @@ const DebugWithSingleModel = forwardRef ) }) diff --git a/web/app/components/app/configuration/features/chat-group/opening-statement/index.tsx b/web/app/components/app/configuration/features/chat-group/opening-statement/index.tsx index d007225bda..e652579cfc 100644 --- a/web/app/components/app/configuration/features/chat-group/opening-statement/index.tsx +++ b/web/app/components/app/configuration/features/chat-group/opening-statement/index.tsx @@ -227,7 +227,7 @@ const OpeningStatement: FC = ({ onClick={() => { setTempSuggestedQuestions([...tempSuggestedQuestions, '']) }} className='mt-1 flex items-center h-9 px-3 gap-2 rounded-lg cursor-pointer text-gray-400 bg-gray-100 hover:bg-gray-200'> -
{t('appDebug.variableConig.addOption')}
+
{t('appDebug.variableConfig.addOption')}
)}
@@ -287,7 +287,7 @@ const OpeningStatement: FC = ({ {isShowConfirmAddVar && ( diff --git a/web/app/components/app/configuration/features/experience-enchance-group/index.tsx b/web/app/components/app/configuration/features/experience-enhance-group/index.tsx similarity index 88% rename from web/app/components/app/configuration/features/experience-enchance-group/index.tsx rename to web/app/components/app/configuration/features/experience-enhance-group/index.tsx index 6902a17468..4a629a6b0e 100644 --- a/web/app/components/app/configuration/features/experience-enchance-group/index.tsx +++ b/web/app/components/app/configuration/features/experience-enhance-group/index.tsx @@ -16,7 +16,7 @@ type ExperienceGroupProps = { isShowMoreLike: boolean } -const ExperienceEnchanceGroup: FC = ({ +const ExperienceEnhanceGroup: FC = ({ isShowTextToSpeech, isShowMoreLike, }) => { @@ -40,4 +40,4 @@ const ExperienceEnchanceGroup: FC = ({ ) } -export default React.memo(ExperienceEnchanceGroup) +export default React.memo(ExperienceEnhanceGroup) diff --git a/web/app/components/app/configuration/features/experience-enchance-group/more-like-this/index.tsx b/web/app/components/app/configuration/features/experience-enhance-group/more-like-this/index.tsx similarity index 100% rename from web/app/components/app/configuration/features/experience-enchance-group/more-like-this/index.tsx rename to web/app/components/app/configuration/features/experience-enhance-group/more-like-this/index.tsx diff --git a/web/app/components/app/configuration/index.tsx b/web/app/components/app/configuration/index.tsx index 432accb0d2..357dc84b7a 100644 --- a/web/app/components/app/configuration/index.tsx +++ b/web/app/components/app/configuration/index.tsx @@ -14,7 +14,7 @@ import Loading from '../../base/loading' import AppPublisher from '../app-publisher' import AgentSettingButton from './config/agent-setting-button' import useAdvancedPromptConfig from './hooks/use-advanced-prompt-config' -import EditHistoryModal from './config-prompt/conversation-histroy/edit-modal' +import EditHistoryModal from './config-prompt/conversation-history/edit-modal' import { useDebugWithSingleOrMultipleModel, useFormattingChangedDispatcher, diff --git a/web/app/components/app/configuration/toolbox/moderation/moderation-setting-modal.tsx b/web/app/components/app/configuration/toolbox/moderation/moderation-setting-modal.tsx index 64b2dd222a..589eb42ab3 100644 --- a/web/app/components/app/configuration/toolbox/moderation/moderation-setting-modal.tsx +++ b/web/app/components/app/configuration/toolbox/moderation/moderation-setting-modal.tsx @@ -64,7 +64,7 @@ const ModerationSettingModal: FC = ({ const systemOpenaiProviderQuota = systemOpenaiProviderEnabled ? openaiProvider?.system_configuration.quota_configurations.find(item => item.quota_type === openaiProvider.system_configuration.current_quota_type) : undefined const systemOpenaiProviderCanUse = systemOpenaiProviderQuota?.is_valid const customOpenaiProvidersCanUse = openaiProvider?.custom_configuration.status === CustomConfigurationStatusEnum.active - const openaiProviderConfiged = customOpenaiProvidersCanUse || systemOpenaiProviderCanUse + const isOpenAIProviderConfigured = customOpenaiProvidersCanUse || systemOpenaiProviderCanUse const providers: Provider[] = [ { key: 'openai_moderation', @@ -190,7 +190,7 @@ const ModerationSettingModal: FC = ({ } const handleSave = () => { - if (localeData.type === 'openai_moderation' && !openaiProviderConfiged) + if (localeData.type === 'openai_moderation' && !isOpenAIProviderConfigured) return if (!localeData.config?.inputs_config?.enabled && !localeData.config?.outputs_config?.enabled) { @@ -254,7 +254,7 @@ const ModerationSettingModal: FC = ({ className={` flex items-center px-3 py-2 rounded-lg text-sm text-gray-900 cursor-pointer ${localeData.type === provider.key ? 'bg-white border-[1.5px] border-primary-400 shadow-sm' : 'border border-gray-100 bg-gray-25'} - ${localeData.type === 'openai_moderation' && provider.key === 'openai_moderation' && !openaiProviderConfiged && 'opacity-50'} + ${localeData.type === 'openai_moderation' && provider.key === 'openai_moderation' && !isOpenAIProviderConfigured && 'opacity-50'} `} onClick={() => handleDataTypeChange(provider.key)} > @@ -267,7 +267,7 @@ const ModerationSettingModal: FC = ({ } { - !isLoading && !openaiProviderConfiged && localeData.type === 'openai_moderation' && ( + !isLoading && !isOpenAIProviderConfigured && localeData.type === 'openai_moderation' && (
@@ -361,7 +361,7 @@ const ModerationSettingModal: FC = ({ diff --git a/web/app/components/app/configuration/tools/external-data-tool-modal.tsx b/web/app/components/app/configuration/tools/external-data-tool-modal.tsx index 3187990609..2785f435e4 100644 --- a/web/app/components/app/configuration/tools/external-data-tool-modal.tsx +++ b/web/app/components/app/configuration/tools/external-data-tool-modal.tsx @@ -172,12 +172,12 @@ const ExternalDataToolModal: FC = ({ } } - const formatedData = formatData(localeData) + const formattedData = formatData(localeData) - if (onValidateBeforeSave && !onValidateBeforeSave(formatedData)) + if (onValidateBeforeSave && !onValidateBeforeSave(formattedData)) return - onSave(formatData(formatedData)) + onSave(formatData(formattedData)) } const action = data.type ? t('common.operation.edit') : t('common.operation.add') @@ -189,7 +189,7 @@ const ExternalDataToolModal: FC = ({ className='!p-8 !pb-6 !max-w-none !w-[640px]' >
- {`${action} ${t('appDebug.variableConig.apiBasedVar')}`} + {`${action} ${t('appDebug.variableConfig.apiBasedVar')}`}
diff --git a/web/app/components/app/create-app-modal/index.tsx b/web/app/components/app/create-app-modal/index.tsx index 5fb0d7ef3f..d503e71918 100644 --- a/web/app/components/app/create-app-modal/index.tsx +++ b/web/app/components/app/create-app-modal/index.tsx @@ -21,7 +21,7 @@ import Modal from '@/app/components/base/modal' import Button from '@/app/components/base/button' import AppIcon from '@/app/components/base/app-icon' import AppsFull from '@/app/components/billing/apps-full-in-dialog' -import { AiText, ChatBot, CuteRobote } from '@/app/components/base/icons/src/vender/solid/communication' +import { AiText, ChatBot, CuteRobot } from '@/app/components/base/icons/src/vender/solid/communication' import { Route } from '@/app/components/base/icons/src/vender/solid/mapsAndTravel' import Tooltip from '@/app/components/base/tooltip' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' @@ -158,7 +158,7 @@ const CreateAppModal = ({ show, onSuccess, onClose }: CreateAppDialogProps) => { setShowChatBotType(false) }} > - +
{t('app.types.agent')}
diff --git a/web/app/components/app/log/index.tsx b/web/app/components/app/log/index.tsx index dd6ebd08f0..77fcb6fcf9 100644 --- a/web/app/components/app/log/index.tsx +++ b/web/app/components/app/log/index.tsx @@ -119,8 +119,8 @@ const Logs: FC = ({ appDetail }) => { middlePagesSiblingCount={1} setCurrentPage={setCurrPage} totalPages={Math.ceil(total / APP_PAGE_LIMIT)} - truncableClassName="w-8 px-0.5 text-center" - truncableText="..." + truncatableClassName="w-8 px-0.5 text-center" + truncatableText="..." > => { + const handleFeedback = async (mid: string, { rating }: FeedbackType): Promise => { try { await updateLogMessageFeedbacks({ url: `/apps/${appId}/feedbacks`, body: { message_id: mid, rating } }) conversationDetailMutate() @@ -586,7 +586,7 @@ const ChatConversationDetailComp: FC<{ appId?: string; conversationId?: string } const { notify } = useContext(ToastContext) const { t } = useTranslation() - const handleFeedback = async (mid: string, { rating }: Feedbacktype): Promise => { + const handleFeedback = async (mid: string, { rating }: FeedbackType): Promise => { try { await updateLogMessageFeedbacks({ url: `/apps/${appId}/feedbacks`, body: { message_id: mid, rating } }) notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) diff --git a/web/app/components/app/overview/appCard.tsx b/web/app/components/app/overview/appCard.tsx index 0d0b95c0a0..f9f5c1fbff 100644 --- a/web/app/components/app/overview/appCard.tsx +++ b/web/app/components/app/overview/appCard.tsx @@ -134,8 +134,8 @@ function AppCard({ return (
@@ -176,7 +176,6 @@ function AppCard({ {isApp && } {/* button copy link/ button regenerate */} @@ -202,8 +201,8 @@ function AppCard({ onClick={() => setShowConfirmDelete(true)} >
diff --git a/web/app/components/app/overview/settings/index.tsx b/web/app/components/app/overview/settings/index.tsx index 02eff06c0a..a501d06ce4 100644 --- a/web/app/components/app/overview/settings/index.tsx +++ b/web/app/components/app/overview/settings/index.tsx @@ -43,6 +43,7 @@ export type ConfigParams = { icon: string icon_background?: string show_workflow_steps: boolean + use_icon_as_answer_icon: boolean enable_sso?: boolean } @@ -72,6 +73,7 @@ const SettingsModal: FC = ({ custom_disclaimer, default_language, show_workflow_steps, + use_icon_as_answer_icon, } = appInfo.site const [inputInfo, setInputInfo] = useState({ title, @@ -82,6 +84,7 @@ const SettingsModal: FC = ({ privacyPolicy: privacy_policy, customDisclaimer: custom_disclaimer, show_workflow_steps, + use_icon_as_answer_icon, enable_sso: appInfo.enable_sso, }) const [language, setLanguage] = useState(default_language) @@ -94,6 +97,7 @@ const SettingsModal: FC = ({ ? { type: 'image', url: icon_url!, fileId: icon } : { type: 'emoji', icon, background: icon_background! }, ) + const isChatBot = appInfo.mode === 'chat' || appInfo.mode === 'advanced-chat' || appInfo.mode === 'agent-chat' useEffect(() => { setInputInfo({ @@ -105,6 +109,7 @@ const SettingsModal: FC = ({ privacyPolicy: privacy_policy, customDisclaimer: custom_disclaimer, show_workflow_steps, + use_icon_as_answer_icon, enable_sso: appInfo.enable_sso, }) setLanguage(default_language) @@ -157,6 +162,7 @@ const SettingsModal: FC = ({ icon: appIcon.type === 'emoji' ? appIcon.icon : appIcon.fileId, icon_background: appIcon.type === 'emoji' ? appIcon.background : undefined, show_workflow_steps: inputInfo.show_workflow_steps, + use_icon_as_answer_icon: inputInfo.use_icon_as_answer_icon, enable_sso: inputInfo.enable_sso, } await onSave?.(params) @@ -209,6 +215,18 @@ const SettingsModal: FC = ({ onChange={onChange('desc')} placeholder={t(`${prefixSettings}.webDescPlaceholder`) as string} /> + {isChatBot && ( +
+
+
{t('app.answerIcon.title')}
+ setInputInfo({ ...inputInfo, use_icon_as_answer_icon: v })} + /> +
+

{t('app.answerIcon.description')}

+
+ )}
{t(`${prefixSettings}.language`)}
item.supported)} diff --git a/web/app/components/app/text-generate/item/index.tsx b/web/app/components/app/text-generate/item/index.tsx index 9794967d9d..e3bd8eadc5 100644 --- a/web/app/components/app/text-generate/item/index.tsx +++ b/web/app/components/app/text-generate/item/index.tsx @@ -16,13 +16,13 @@ import { Markdown } from '@/app/components/base/markdown' import Loading from '@/app/components/base/loading' import Toast from '@/app/components/base/toast' import AudioBtn from '@/app/components/base/audio-btn' -import type { Feedbacktype } from '@/app/components/base/chat/chat/type' +import type { FeedbackType } from '@/app/components/base/chat/chat/type' import { fetchMoreLikeThis, updateFeedback } from '@/service/share' import { File02 } from '@/app/components/base/icons/src/vender/line/files' import { Bookmark } from '@/app/components/base/icons/src/vender/line/general' import { Stars02 } from '@/app/components/base/icons/src/vender/line/weather' import { RefreshCcw01 } from '@/app/components/base/icons/src/vender/line/arrows' -import { fetchTextGenerationMessge } from '@/service/debug' +import { fetchTextGenerationMessage } from '@/service/debug' import AnnotationCtrlBtn from '@/app/components/app/configuration/toolbox/annotation/annotation-ctrl-btn' import EditReplyModal from '@/app/components/app/annotation/edit-annotation-modal' import { useStore as useAppStore } from '@/app/components/app/store' @@ -47,8 +47,8 @@ export type IGenerationItemProps = { isInWebApp?: boolean moreLikeThis?: boolean depth?: number - feedback?: Feedbacktype - onFeedback?: (feedback: Feedbacktype) => void + feedback?: FeedbackType + onFeedback?: (feedback: FeedbackType) => void onSave?: (messageId: string) => void isMobile?: boolean isInstalledApp: boolean @@ -125,7 +125,7 @@ const GenerationItem: FC = ({ const [completionRes, setCompletionRes] = useState('') const [childMessageId, setChildMessageId] = useState(null) const hasChild = !!childMessageId - const [childFeedback, setChildFeedback] = useState({ + const [childFeedback, setChildFeedback] = useState({ rating: null, }) const { @@ -135,7 +135,7 @@ const GenerationItem: FC = ({ const setCurrentLogItem = useAppStore(s => s.setCurrentLogItem) const setShowPromptLogModal = useAppStore(s => s.setShowPromptLogModal) - const handleFeedback = async (childFeedback: Feedbacktype) => { + const handleFeedback = async (childFeedback: FeedbackType) => { await updateFeedback({ url: `/messages/${childMessageId}/feedbacks`, body: { rating: childFeedback.rating } }, isInstalledApp, installedAppId) setChildFeedback(childFeedback) } @@ -205,7 +205,7 @@ const GenerationItem: FC = ({ }, [isLoading]) const handleOpenLogModal = async () => { - const data = await fetchTextGenerationMessge({ + const data = await fetchTextGenerationMessage({ appId: params.appId as string, messageId: messageId!, }) diff --git a/web/app/components/app/type-selector/index.tsx b/web/app/components/app/type-selector/index.tsx index 2bd4f8d082..a09e189f50 100644 --- a/web/app/components/app/type-selector/index.tsx +++ b/web/app/components/app/type-selector/index.tsx @@ -9,7 +9,7 @@ import { } from '@/app/components/base/portal-to-follow-elem' import { Check, DotsGrid } from '@/app/components/base/icons/src/vender/line/general' import { XCircle } from '@/app/components/base/icons/src/vender/solid/general' -import { ChatBot, CuteRobote } from '@/app/components/base/icons/src/vender/solid/communication' +import { ChatBot, CuteRobot } from '@/app/components/base/icons/src/vender/solid/communication' import { Route } from '@/app/components/base/icons/src/vender/solid/mapsAndTravel' export type AppSelectorProps = { value: string @@ -65,7 +65,7 @@ const AppTypeSelector = ({ value, onChange }: AppSelectorProps) => { {value === 'agent' && ( <>
- +
{t('app.typeSelector.agent')}
{ @@ -106,7 +106,7 @@ const AppTypeSelector = ({ value, onChange }: AppSelectorProps) => { onChange('agent') setOpen(false) }}> - +
{t('app.typeSelector.agent')}
{value === 'agent' && }
diff --git a/web/app/components/app/workflow-log/index.tsx b/web/app/components/app/workflow-log/index.tsx index 303c2069b2..b7e20f0a0b 100644 --- a/web/app/components/app/workflow-log/index.tsx +++ b/web/app/components/app/workflow-log/index.tsx @@ -93,8 +93,8 @@ const Logs: FC = ({ appDetail }) => { middlePagesSiblingCount={1} setCurrentPage={setCurrPage} totalPages={Math.ceil(total / APP_PAGE_LIMIT)} - truncableClassName="w-8 px-0.5 text-center" - truncableText="..." + truncatableClassName="w-8 px-0.5 text-center" + truncatableText="..." > = ({ onClick={() => switchTab('TRACING')} >{t('runLog.tracing')}
- {/* panel detal */} + {/* panel detail */}
{loading && (
diff --git a/web/app/components/base/answer-icon/index.tsx b/web/app/components/base/answer-icon/index.tsx new file mode 100644 index 0000000000..8c6363e05c --- /dev/null +++ b/web/app/components/base/answer-icon/index.tsx @@ -0,0 +1,47 @@ +'use client' + +import type { FC } from 'react' +import { init } from 'emoji-mart' +import data from '@emoji-mart/data' +import classNames from '@/utils/classnames' +import type { AppIconType } from '@/types/app' + +init({ data }) + +export type AnswerIconProps = { + iconType?: AppIconType | null + icon?: string | null + background?: string | null + imageUrl?: string | null +} + +const AnswerIcon: FC = ({ + iconType, + icon, + background, + imageUrl, +}) => { + const wrapperClassName = classNames( + 'flex', + 'items-center', + 'justify-center', + 'w-full', + 'h-full', + 'rounded-full', + 'border-[0.5px]', + 'border-black/5', + 'text-xl', + ) + const isValidImageIcon = iconType === 'image' && imageUrl + return
+ {isValidImageIcon + ? answer icon + : (icon && icon !== '') ? : + } +
+} + +export default AnswerIcon diff --git a/web/app/components/base/app-unavailable.tsx b/web/app/components/base/app-unavailable.tsx index c5ea7be0b9..b8b42108a9 100644 --- a/web/app/components/base/app-unavailable.tsx +++ b/web/app/components/base/app-unavailable.tsx @@ -22,7 +22,7 @@ const AppUnavailable: FC = ({ style={{ borderRight: '1px solid rgba(0,0,0,.3)', }}>{code} -
{unknownReason || (isUnknownReason ? t('share.common.appUnkonwError') : t('share.common.appUnavailable'))}
+
{unknownReason || (isUnknownReason ? t('share.common.appUnknownError') : t('share.common.appUnavailable'))}
) } diff --git a/web/app/components/base/audio-gallery/AudioPlayer.module.css b/web/app/components/base/audio-gallery/AudioPlayer.module.css new file mode 100644 index 0000000000..6c070e107c --- /dev/null +++ b/web/app/components/base/audio-gallery/AudioPlayer.module.css @@ -0,0 +1,119 @@ +.audioPlayer { + display: flex; + flex-direction: row; + align-items: center; + background-color: #ffffff; + border-radius: 10px; + padding: 8px; + min-width: 240px; + max-width: 420px; + max-height: 40px; + backdrop-filter: blur(5px); + border: 1px solid rgba(16, 24, 40, 0.08); + box-shadow: 0 1px 2px rgba(9, 9, 11, 0.05); + gap: 8px; +} + +.playButton { + display: inline-flex; + width: 16px; + height: 16px; + border-radius: 50%; + background-color: #296DFF; + color: white; + border: none; + cursor: pointer; + align-items: center; + justify-content: center; + transition: background-color 0.1s; + flex-shrink: 0; +} + +.playButton:hover { + background-color: #3367d6; +} + +.playButton:disabled { + background-color: #bdbdbf; +} + +.audioControls { + flex-grow: 1; + +} + +.progressBarContainer { + height: 32px; + display: flex; + align-items: center; + justify-content: center; +} + +.waveform { + position: relative; + display: flex; + cursor: pointer; + height: 24px; + width: 100%; + flex-grow: 1; + align-items: center; + justify-content: center; +} + +.progressBar { + position: absolute; + top: 0; + left: 0; + opacity: 0.5; + border-radius: 2px; + flex: none; + order: 55; + flex-grow: 0; + height: 100%; + background-color: rgba(66, 133, 244, 0.3); + pointer-events: none; +} + +.timeDisplay { + /* position: absolute; */ + color: #296DFF; + border-radius: 2px; + order: 0; + height: 100%; + width: 50px; + display: inline-flex; + align-items: center; + justify-content: center; +} + +/* .currentTime { + position: absolute; + bottom: calc(100% + 5px); + transform: translateX(-50%); + background-color: rgba(255,255,255,.8); + padding: 2px 4px; + border-radius:10px; + box-shadow: 0 1px 5px rgba(0, 0, 0, 0.08); +} */ + +.duration { + background-color: rgba(255, 255, 255, 0.8); + padding: 2px 4px; + border-radius: 10px; +} + +.source_unavailable { + border: none; + display: flex; + align-items: center; + justify-content: center; + width: 100%; + height: 100%; + position: absolute; + color: #bdbdbf; +} + +.playButton svg path, +.playButton svg rect{ + fill:currentColor; +} diff --git a/web/app/components/base/audio-gallery/AudioPlayer.tsx b/web/app/components/base/audio-gallery/AudioPlayer.tsx new file mode 100644 index 0000000000..c482981e8a --- /dev/null +++ b/web/app/components/base/audio-gallery/AudioPlayer.tsx @@ -0,0 +1,320 @@ +import React, { useCallback, useEffect, useRef, useState } from 'react' +import { t } from 'i18next' +import styles from './AudioPlayer.module.css' +import Toast from '@/app/components/base/toast' + +type AudioPlayerProps = { + src: string +} + +const AudioPlayer: React.FC = ({ src }) => { + const [isPlaying, setIsPlaying] = useState(false) + const [currentTime, setCurrentTime] = useState(0) + const [duration, setDuration] = useState(0) + const [waveformData, setWaveformData] = useState([]) + const [bufferedTime, setBufferedTime] = useState(0) + const audioRef = useRef(null) + const canvasRef = useRef(null) + const [hasStartedPlaying, setHasStartedPlaying] = useState(false) + const [hoverTime, setHoverTime] = useState(0) + const [isAudioAvailable, setIsAudioAvailable] = useState(true) + + useEffect(() => { + const audio = audioRef.current + if (!audio) + return + + const handleError = () => { + setIsAudioAvailable(false) + } + + const setAudioData = () => { + setDuration(audio.duration) + } + + const setAudioTime = () => { + setCurrentTime(audio.currentTime) + } + + const handleProgress = () => { + if (audio.buffered.length > 0) + setBufferedTime(audio.buffered.end(audio.buffered.length - 1)) + } + + const handleEnded = () => { + setIsPlaying(false) + } + + audio.addEventListener('loadedmetadata', setAudioData) + audio.addEventListener('timeupdate', setAudioTime) + audio.addEventListener('progress', handleProgress) + audio.addEventListener('ended', handleEnded) + audio.addEventListener('error', handleError) + + // Preload audio metadata + audio.load() + + // Delayed generation of waveform data + // eslint-disable-next-line @typescript-eslint/no-use-before-define + const timer = setTimeout(() => generateWaveformData(src), 1000) + + return () => { + audio.removeEventListener('loadedmetadata', setAudioData) + audio.removeEventListener('timeupdate', setAudioTime) + audio.removeEventListener('progress', handleProgress) + audio.removeEventListener('ended', handleEnded) + audio.removeEventListener('error', handleError) + clearTimeout(timer) + } + }, [src]) + + const generateWaveformData = async (audioSrc: string) => { + if (!window.AudioContext && !(window as any).webkitAudioContext) { + setIsAudioAvailable(false) + Toast.notify({ + type: 'error', + message: 'Web Audio API is not supported in this browser', + }) + return null + } + + const url = new URL(src) + const isHttp = url.protocol === 'http:' || url.protocol === 'https:' + if (!isHttp) { + setIsAudioAvailable(false) + return null + } + + const audioContext = new (window.AudioContext || (window as any).webkitAudioContext)() + const samples = 70 + + try { + const response = await fetch(audioSrc, { mode: 'cors' }) + if (!response || !response.ok) { + setIsAudioAvailable(false) + return null + } + + const arrayBuffer = await response.arrayBuffer() + const audioBuffer = await audioContext.decodeAudioData(arrayBuffer) + const channelData = audioBuffer.getChannelData(0) + const blockSize = Math.floor(channelData.length / samples) + const waveformData: number[] = [] + + for (let i = 0; i < samples; i++) { + let sum = 0 + for (let j = 0; j < blockSize; j++) + sum += Math.abs(channelData[i * blockSize + j]) + + // Apply nonlinear scaling to enhance small amplitudes + waveformData.push((sum / blockSize) * 5) + } + + // Normalized waveform data + const maxAmplitude = Math.max(...waveformData) + const normalizedWaveform = waveformData.map(amp => amp / maxAmplitude) + + setWaveformData(normalizedWaveform) + setIsAudioAvailable(true) + } + catch (error) { + const waveform: number[] = [] + let prevValue = Math.random() + + for (let i = 0; i < samples; i++) { + const targetValue = Math.random() + const interpolatedValue = prevValue + (targetValue - prevValue) * 0.3 + waveform.push(interpolatedValue) + prevValue = interpolatedValue + } + + const maxAmplitude = Math.max(...waveform) + const randomWaveform = waveform.map(amp => amp / maxAmplitude) + + setWaveformData(randomWaveform) + setIsAudioAvailable(true) + } + finally { + await audioContext.close() + } + } + + const togglePlay = useCallback(() => { + const audio = audioRef.current + if (audio && isAudioAvailable) { + if (isPlaying) { + setHasStartedPlaying(false) + audio.pause() + } + else { + setHasStartedPlaying(true) + audio.play().catch(error => console.error('Error playing audio:', error)) + } + + setIsPlaying(!isPlaying) + } + else { + Toast.notify({ + type: 'error', + message: 'Audio element not found', + }) + setIsAudioAvailable(false) + } + }, [isAudioAvailable, isPlaying]) + + const handleCanvasInteraction = useCallback((e: React.MouseEvent | React.TouchEvent) => { + e.preventDefault() + + const getClientX = (event: React.MouseEvent | React.TouchEvent): number => { + if ('touches' in event) + return event.touches[0].clientX + return event.clientX + } + + const updateProgress = (clientX: number) => { + const canvas = canvasRef.current + const audio = audioRef.current + if (!canvas || !audio) + return + + const rect = canvas.getBoundingClientRect() + const percent = Math.min(Math.max(0, clientX - rect.left), rect.width) / rect.width + const newTime = percent * duration + + // Removes the buffer check, allowing drag to any location + audio.currentTime = newTime + setCurrentTime(newTime) + + if (!isPlaying) { + setIsPlaying(true) + audio.play().catch((error) => { + Toast.notify({ + type: 'error', + message: `Error playing audio: ${error}`, + }) + setIsPlaying(false) + }) + } + } + + updateProgress(getClientX(e)) + }, [duration, isPlaying]) + + const formatTime = (time: number) => { + const minutes = Math.floor(time / 60) + const seconds = Math.floor(time % 60) + return `${minutes}:${seconds.toString().padStart(2, '0')}` + } + + const drawWaveform = useCallback(() => { + const canvas = canvasRef.current + if (!canvas) + return + + const ctx = canvas.getContext('2d') + if (!ctx) + return + + const width = canvas.width + const height = canvas.height + const data = waveformData + + ctx.clearRect(0, 0, width, height) + + const barWidth = width / data.length + const playedWidth = (currentTime / duration) * width + const cornerRadius = 2 + + // Draw waveform bars + data.forEach((value, index) => { + let color + + if (index * barWidth <= playedWidth) + color = '#296DFF' + else if ((index * barWidth / width) * duration <= hoverTime) + color = 'rgba(21,90,239,.40)' + else + color = 'rgba(21,90,239,.20)' + + const barHeight = value * height + const rectX = index * barWidth + const rectY = (height - barHeight) / 2 + const rectWidth = barWidth * 0.5 + const rectHeight = barHeight + + ctx.lineWidth = 1 + ctx.fillStyle = color + if (ctx.roundRect) { + ctx.beginPath() + ctx.roundRect(rectX, rectY, rectWidth, rectHeight, cornerRadius) + ctx.fill() + } + else { + ctx.fillRect(rectX, rectY, rectWidth, rectHeight) + } + }) + }, [currentTime, duration, hoverTime, waveformData]) + + useEffect(() => { + drawWaveform() + }, [drawWaveform, bufferedTime, hasStartedPlaying]) + + const handleMouseMove = useCallback((e: React.MouseEvent) => { + const canvas = canvasRef.current + const audio = audioRef.current + if (!canvas || !audio) + return + + const rect = canvas.getBoundingClientRect() + const percent = Math.min(Math.max(0, e.clientX - rect.left), rect.width) / rect.width + const time = percent * duration + + // Check if the hovered position is within a buffered range before updating hoverTime + for (let i = 0; i < audio.buffered.length; i++) { + if (time >= audio.buffered.start(i) && time <= audio.buffered.end(i)) { + setHoverTime(time) + break + } + } + }, [duration]) + + return ( +
+
+ ) +} + +export default AudioPlayer diff --git a/web/app/components/base/audio-gallery/index.tsx b/web/app/components/base/audio-gallery/index.tsx new file mode 100644 index 0000000000..6e11d43164 --- /dev/null +++ b/web/app/components/base/audio-gallery/index.tsx @@ -0,0 +1,12 @@ +import React from 'react' +import AudioPlayer from './AudioPlayer' + +type Props = { + srcs: string[] +} + +const AudioGallery: React.FC = ({ srcs }) => { + return (<>
{srcs.map((src, index) => ())}) +} + +export default React.memo(AudioGallery) diff --git a/web/app/components/base/block-input/index.tsx b/web/app/components/base/block-input/index.tsx index 79ff646bd1..43c14de4c9 100644 --- a/web/app/components/base/block-input/index.tsx +++ b/web/app/components/base/block-input/index.tsx @@ -53,7 +53,7 @@ const BlockInput: FC = ({ const [isEditing, setIsEditing] = useState(false) useEffect(() => { if (isEditing && contentEditableRef.current) { - // TODO: Focus at the click positon + // TODO: Focus at the click position if (currentValue) contentEditableRef.current.setSelectionRange(currentValue.length, currentValue.length) @@ -119,7 +119,7 @@ const BlockInput: FC = ({ onBlur={() => { blur() setIsEditing(false) - // click confirm also make blur. Then outter value is change. So below code has problem. + // click confirm also make blur. Then outer value is change. So below code has problem. // setTimeout(() => { // handleCancel() // }, 1000) diff --git a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx index 8eda66c52a..5a7bf1f17e 100644 --- a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx +++ b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx @@ -13,6 +13,7 @@ import { getUrl, stopChatMessageResponding, } from '@/service/share' +import AnswerIcon from '@/app/components/base/answer-icon' const ChatWrapper = () => { const { @@ -128,21 +129,31 @@ const ChatWrapper = () => { isMobile, ]) + const answerIcon = (appData?.site && appData.site.use_icon_as_answer_icon) + ? + : null + return ( diff --git a/web/app/components/base/chat/chat-with-history/config-panel/index.tsx b/web/app/components/base/chat/chat-with-history/config-panel/index.tsx index 05f253290f..c864a3925d 100644 --- a/web/app/components/base/chat/chat-with-history/config-panel/index.tsx +++ b/web/app/components/base/chat/chat-with-history/config-panel/index.tsx @@ -152,7 +152,7 @@ const ConfigPanel = () => { : (
- {t('share.chat.powerBy')} + {t('share.chat.poweredBy')} { customConfig?.replace_webapp_logo ? logo diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx index 624cc53a18..1e05cc39ef 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.tsx @@ -65,6 +65,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { prompt_public: false, copyright: '', show_workflow_steps: true, + use_icon_as_answer_icon: app.use_icon_as_answer_icon, }, plan: 'basic', } as AppData @@ -216,12 +217,12 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { }, [newConversation]) const currentConversationItem = useMemo(() => { - let coversationItem = conversationList.find(item => item.id === currentConversationId) + let conversationItem = conversationList.find(item => item.id === currentConversationId) - if (!coversationItem && pinnedConversationList.length) - coversationItem = pinnedConversationList.find(item => item.id === currentConversationId) + if (!conversationItem && pinnedConversationList.length) + conversationItem = pinnedConversationList.find(item => item.id === currentConversationId) - return coversationItem + return conversationItem }, [conversationList, currentConversationId, pinnedConversationList]) const { notify } = useToastContext() diff --git a/web/app/components/base/chat/chat/answer/index.tsx b/web/app/components/base/chat/chat/answer/index.tsx index 78a0842595..270cd553c2 100644 --- a/web/app/components/base/chat/chat/answer/index.tsx +++ b/web/app/components/base/chat/chat/answer/index.tsx @@ -22,6 +22,7 @@ import Citation from '@/app/components/base/chat/chat/citation' import { EditTitle } from '@/app/components/app/annotation/edit-annotation-modal/edit-item' import type { Emoji } from '@/app/components/tools/types' import type { AppData } from '@/models/share' +import AnswerIcon from '@/app/components/base/answer-icon' type AnswerProps = { item: ChatItem @@ -89,11 +90,7 @@ const Answer: FC = ({
{ - answerIcon || ( -
- 🤖 -
- ) + answerIcon || } { responding && ( diff --git a/web/app/components/base/chat/chat/answer/workflow-process.tsx b/web/app/components/base/chat/chat/answer/workflow-process.tsx index 5f36e40c40..1f17798f83 100644 --- a/web/app/components/base/chat/chat/answer/workflow-process.tsx +++ b/web/app/components/base/chat/chat/answer/workflow-process.tsx @@ -11,10 +11,10 @@ import { } from '@remixicon/react' import { useTranslation } from 'react-i18next' import type { ChatItem, WorkflowProcess } from '../../types' +import TracingPanel from '@/app/components/workflow/run/tracing-panel' import cn from '@/utils/classnames' import { CheckCircle } from '@/app/components/base/icons/src/vender/solid/general' import { WorkflowRunningStatus } from '@/app/components/workflow/types' -import NodePanel from '@/app/components/workflow/run/node' import { useStore as useAppStore } from '@/app/components/app/store' type WorkflowProcessProps = { @@ -107,16 +107,12 @@ const WorkflowProcessItem = ({ !collapse && (
{ - data.tracing.map(node => ( -
- -
- )) + }
) diff --git a/web/app/components/base/chat/chat/chat-input.tsx b/web/app/components/base/chat/chat/chat-input.tsx index c4578fab62..fdb09dc3ae 100644 --- a/web/app/components/base/chat/chat/chat-input.tsx +++ b/web/app/components/base/chat/chat/chat-input.tsx @@ -32,18 +32,21 @@ import { useDraggableUploader, useImageFiles, } from '@/app/components/base/image-uploader/hooks' +import cn from '@/utils/classnames' type ChatInputProps = { visionConfig?: VisionConfig speechToTextConfig?: EnableType onSend?: OnSend theme?: Theme | null + noSpacing?: boolean } const ChatInput: FC = ({ visionConfig, speechToTextConfig, onSend, theme, + noSpacing, }) => { const { appData } = useChatWithHistoryContext() const { t } = useTranslation() @@ -146,7 +149,7 @@ const ChatInput: FC = ({ return ( <> -
+
= ({ { visionConfig?.enabled && ( <> -
+
= ({ onDrop={onDrop} autoSize /> -
+
{query.trim().length}
diff --git a/web/app/components/base/chat/chat/citation/index.tsx b/web/app/components/base/chat/chat/citation/index.tsx index 4bed9638d3..2ca7b80ae7 100644 --- a/web/app/components/base/chat/chat/citation/index.tsx +++ b/web/app/components/base/chat/chat/citation/index.tsx @@ -24,7 +24,7 @@ const Citation: FC = ({ }) => { const { t } = useTranslation() const elesRef = useRef([]) - const [limitNumberInOneLine, setlimitNumberInOneLine] = useState(0) + const [limitNumberInOneLine, setLimitNumberInOneLine] = useState(0) const [showMore, setShowMore] = useState(false) const resources = useMemo(() => data.reduce((prev: Resources[], next) => { const documentId = next.document_id @@ -57,14 +57,14 @@ const Citation: FC = ({ totalWidth -= elesRef.current[i].clientWidth if (totalWidth + 34 > containerWidth!) - setlimitNumberInOneLine(i - 1) + setLimitNumberInOneLine(i - 1) else - setlimitNumberInOneLine(i) + setLimitNumberInOneLine(i) break } else { - setlimitNumberInOneLine(i + 1) + setLimitNumberInOneLine(i + 1) } } } diff --git a/web/app/components/base/chat/chat/citation/popup.tsx b/web/app/components/base/chat/chat/citation/popup.tsx index d039d98e6b..b61bf623fe 100644 --- a/web/app/components/base/chat/chat/citation/popup.tsx +++ b/web/app/components/base/chat/chat/citation/popup.tsx @@ -53,72 +53,74 @@ const Popup: FC = ({
-
+
{data.documentName}
-
- { - data.sources.map((source, index) => ( - -
-
-
- -
- {source.segment_position || index + 1} +
+
+ { + data.sources.map((source, index) => ( + +
+
+
+ +
+ {source.segment_position || index + 1} +
+ { + showHitInfo && ( + + {t('common.chat.citation.linkToDataset')} + + + ) + }
+
{source.content}
{ showHitInfo && ( - - {t('common.chat.citation.linkToDataset')} - - +
+ } + /> + } + /> + } + /> + { + source.score && ( + + ) + } +
) }
-
{source.content}
{ - showHitInfo && ( -
- } - /> - } - /> - } - /> - { - source.score && ( - - ) - } -
+ index !== data.sources.length - 1 && ( +
) } -
- { - index !== data.sources.length - 1 && ( -
- ) - } - - )) - } + + )) + } +
diff --git a/web/app/components/base/chat/chat/hooks.ts b/web/app/components/base/chat/chat/hooks.ts index a70b1c2ed0..892f88c4ad 100644 --- a/web/app/components/base/chat/chat/hooks.ts +++ b/web/app/components/base/chat/chat/hooks.ts @@ -26,7 +26,7 @@ import { AudioPlayerManager } from '@/app/components/base/audio-btn/audio.player type GetAbortController = (abortController: AbortController) => void type SendCallback = { - onGetConvesationMessages?: (conversationId: string, getAbortController: GetAbortController) => Promise + onGetConversationMessages?: (conversationId: string, getAbortController: GetAbortController) => Promise onGetSuggestedQuestions?: (responseItemId: string, getAbortController: GetAbortController) => Promise onConversationComplete?: (conversationId: string) => void isPublicAPI?: boolean @@ -198,7 +198,7 @@ export const useChat = ( url: string, data: any, { - onGetConvesationMessages, + onGetConversationMessages, onGetSuggestedQuestions, onConversationComplete, isPublicAPI, @@ -241,8 +241,6 @@ export const useChat = ( isAnswer: true, } - let isInIteration = false - handleResponding(true) hasStopResponded.current = false @@ -324,8 +322,8 @@ export const useChat = ( if (onConversationComplete) onConversationComplete(conversationId.current) - if (conversationId.current && !hasStopResponded.current && onGetConvesationMessages) { - const { data }: any = await onGetConvesationMessages( + if (conversationId.current && !hasStopResponded.current && onGetConversationMessages) { + const { data }: any = await onGetConversationMessages( conversationId.current, newAbortController => conversationMessagesAbortControllerRef.current = newAbortController, ) @@ -372,11 +370,16 @@ export const useChat = ( handleUpdateChatList(newChatList) } if (config?.suggested_questions_after_answer?.enabled && !hasStopResponded.current && onGetSuggestedQuestions) { - const { data }: any = await onGetSuggestedQuestions( - responseItem.id, - newAbortController => suggestedQuestionsAbortControllerRef.current = newAbortController, - ) - setSuggestQuestions(data) + try { + const { data }: any = await onGetSuggestedQuestions( + responseItem.id, + newAbortController => suggestedQuestionsAbortControllerRef.current = newAbortController, + ) + setSuggestQuestions(data) + } + catch (e) { + setSuggestQuestions([]) + } } }, onFile(file) { @@ -498,12 +501,13 @@ export const useChat = ( ...responseItem, } })) - isInIteration = true }, onIterationFinish: ({ data }) => { const tracing = responseItem.workflowProcess!.tracing! - tracing[tracing.length - 1] = { - ...tracing[tracing.length - 1], + const iterationIndex = tracing.findIndex(item => item.node_id === data.node_id + && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))! + tracing[iterationIndex] = { + ...tracing[iterationIndex], ...data, status: WorkflowRunningStatus.Succeeded, } as any @@ -515,10 +519,9 @@ export const useChat = ( ...responseItem, } })) - isInIteration = false }, onNodeStarted: ({ data }) => { - if (isInIteration) + if (data.iteration_id) return responseItem.workflowProcess!.tracing!.push({ @@ -534,10 +537,15 @@ export const useChat = ( })) }, onNodeFinished: ({ data }) => { - if (isInIteration) + if (data.iteration_id) return - const currentIndex = responseItem.workflowProcess!.tracing!.findIndex(item => item.node_id === data.node_id) + const currentIndex = responseItem.workflowProcess!.tracing!.findIndex((item) => { + if (!item.execution_metadata?.parallel_id) + return item.node_id === data.node_id + + return item.node_id === data.node_id && (item.execution_metadata?.parallel_id === data.execution_metadata.parallel_id) + }) responseItem.workflowProcess!.tracing[currentIndex] = data as any handleUpdateChatList(produce(chatListRef.current, (draft) => { const currentIndex = draft.findIndex(item => item.id === responseItem.id) diff --git a/web/app/components/base/chat/chat/index.tsx b/web/app/components/base/chat/chat/index.tsx index c5d9af45c2..65e49eff67 100644 --- a/web/app/components/base/chat/chat/index.tsx +++ b/web/app/components/base/chat/chat/index.tsx @@ -60,6 +60,7 @@ export type ChatProps = { hideProcessDetail?: boolean hideLogModal?: boolean themeBuilder?: ThemeBuilder + noSpacing?: boolean } const Chat: FC = ({ @@ -89,6 +90,7 @@ const Chat: FC = ({ hideProcessDetail, hideLogModal, themeBuilder, + noSpacing, }) => { const { t } = useTranslation() const { currentLogItem, setCurrentLogItem, showPromptLogModal, setShowPromptLogModal, showAgentLogModal, setShowAgentLogModal } = useAppStore(useShallow(state => ({ @@ -106,7 +108,7 @@ const Chat: FC = ({ const chatFooterInnerRef = useRef(null) const userScrolledRef = useRef(false) - const handleScrolltoBottom = useCallback(() => { + const handleScrollToBottom = useCallback(() => { if (chatContainerRef.current && !userScrolledRef.current) chatContainerRef.current.scrollTop = chatContainerRef.current.scrollHeight }, []) @@ -123,14 +125,14 @@ const Chat: FC = ({ }, []) useEffect(() => { - handleScrolltoBottom() + handleScrollToBottom() handleWindowResize() - }, [handleScrolltoBottom, handleWindowResize]) + }, [handleScrollToBottom, handleWindowResize]) useEffect(() => { if (chatContainerRef.current) { requestAnimationFrame(() => { - handleScrolltoBottom() + handleScrollToBottom() handleWindowResize() }) } @@ -148,7 +150,7 @@ const Chat: FC = ({ const { blockSize } = entry.borderBoxSize[0] chatContainerRef.current!.style.paddingBottom = `${blockSize}px` - handleScrolltoBottom() + handleScrollToBottom() } }) @@ -158,7 +160,7 @@ const Chat: FC = ({ resizeObserver.disconnect() } } - }, [handleScrolltoBottom]) + }, [handleScrollToBottom]) useEffect(() => { const chatContainer = chatContainerRef.current @@ -192,12 +194,12 @@ const Chat: FC = ({
{chatNode}
{ chatList.map((item, index) => { @@ -268,6 +270,7 @@ const Chat: FC = ({ speechToTextConfig={config?.speech_to_text} onSend={onSend} theme={themeBuilder?.theme} + noSpacing={noSpacing} /> ) } diff --git a/web/app/components/base/chat/chat/loading-anim/index.tsx b/web/app/components/base/chat/chat/loading-anim/index.tsx index 09f8a54789..dd43ef9c14 100644 --- a/web/app/components/base/chat/chat/loading-anim/index.tsx +++ b/web/app/components/base/chat/chat/loading-anim/index.tsx @@ -3,15 +3,15 @@ import type { FC } from 'react' import React from 'react' import s from './style.module.css' -export type ILoaidingAnimProps = { +export type ILoadingAnimProps = { type: 'text' | 'avatar' } -const LoaidingAnim: FC = ({ +const LoadingAnim: FC = ({ type, }) => { return (
) } -export default React.memo(LoaidingAnim) +export default React.memo(LoadingAnim) diff --git a/web/app/components/base/chat/chat/type.ts b/web/app/components/base/chat/chat/type.ts index 16ccff4d4d..b2cb18011c 100644 --- a/web/app/components/base/chat/chat/type.ts +++ b/web/app/components/base/chat/chat/type.ts @@ -8,14 +8,14 @@ export type MessageMore = { latency: number | string } -export type Feedbacktype = { +export type FeedbackType = { rating: MessageRating content?: string | null } export type FeedbackFunc = ( messageId: string, - feedback: Feedbacktype + feedback: FeedbackType ) => Promise export type SubmitAnnotationFunc = ( messageId: string, @@ -71,11 +71,11 @@ export type IChatItem = { /** * The user feedback result of this message */ - feedback?: Feedbacktype + feedback?: FeedbackType /** * The admin feedback result of this message */ - adminFeedback?: Feedbacktype + adminFeedback?: FeedbackType /** * Whether to hide the feedback area */ diff --git a/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx b/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx index 68646a17db..48ee411058 100644 --- a/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx +++ b/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx @@ -14,7 +14,8 @@ import { getUrl, stopChatMessageResponding, } from '@/service/share' -import LogoAvatar from '@/app/components/base/logo/logo-embeded-chat-avatar' +import LogoAvatar from '@/app/components/base/logo/logo-embedded-chat-avatar' +import AnswerIcon from '@/app/components/base/answer-icon' const ChatWrapper = () => { const { @@ -98,7 +99,7 @@ const ChatWrapper = () => { return ( <> {!currentConversationId && ( -
+
{ return null }, [currentConversationId, inputsForms, isMobile]) + const answerIcon = isDify() + ? + : (appData?.site && appData.site.use_icon_as_answer_icon) + ? + : null + return ( : null} + answerIcon={answerIcon} hideProcessDetail themeBuilder={themeBuilder} /> diff --git a/web/app/components/base/chat/embedded-chatbot/config-panel/index.tsx b/web/app/components/base/chat/embedded-chatbot/config-panel/index.tsx index df5d12ef14..2cc46cadf8 100644 --- a/web/app/components/base/chat/embedded-chatbot/config-panel/index.tsx +++ b/web/app/components/base/chat/embedded-chatbot/config-panel/index.tsx @@ -160,7 +160,7 @@ const ConfigPanel = () => { : (
- {t('share.chat.powerBy')} + {t('share.chat.poweredBy')} { customConfig?.replace_webapp_logo ? logo diff --git a/web/app/components/base/chat/embedded-chatbot/index.tsx b/web/app/components/base/chat/embedded-chatbot/index.tsx index 480adaae2d..407c0de6d8 100644 --- a/web/app/components/base/chat/embedded-chatbot/index.tsx +++ b/web/app/components/base/chat/embedded-chatbot/index.tsx @@ -17,7 +17,7 @@ import { checkOrSetAccessToken } from '@/app/components/share/utils' import AppUnavailable from '@/app/components/base/app-unavailable' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import Loading from '@/app/components/base/loading' -import LogoHeader from '@/app/components/base/logo/logo-embeded-chat-header' +import LogoHeader from '@/app/components/base/logo/logo-embedded-chat-header' import Header from '@/app/components/base/chat/embedded-chatbot/header' import ConfigPanel from '@/app/components/base/chat/embedded-chatbot/config-panel' import ChatWrapper from '@/app/components/base/chat/embedded-chatbot/chat-wrapper' diff --git a/web/app/components/base/chat/types.ts b/web/app/components/base/chat/types.ts index baffe42843..21277fec57 100644 --- a/web/app/components/base/chat/types.ts +++ b/web/app/components/base/chat/types.ts @@ -22,7 +22,7 @@ export type UserInputForm = { } export type UserInputFormTextInput = { - 'text-inpput': UserInputForm & { + 'text-input': UserInputForm & { max_length: number } } diff --git a/web/app/components/base/features/feature-panel/moderation/moderation-setting-modal.tsx b/web/app/components/base/features/feature-panel/moderation/moderation-setting-modal.tsx index 9acc51b5f3..635506c053 100644 --- a/web/app/components/base/features/feature-panel/moderation/moderation-setting-modal.tsx +++ b/web/app/components/base/features/feature-panel/moderation/moderation-setting-modal.tsx @@ -67,7 +67,7 @@ const ModerationSettingModal: FC = ({ const systemOpenaiProviderQuota = systemOpenaiProviderEnabled ? openaiProvider?.system_configuration.quota_configurations.find(item => item.quota_type === openaiProvider.system_configuration.current_quota_type) : undefined const systemOpenaiProviderCanUse = systemOpenaiProviderQuota?.is_valid const customOpenaiProvidersCanUse = openaiProvider?.custom_configuration.status === CustomConfigurationStatusEnum.active - const openaiProviderConfiged = customOpenaiProvidersCanUse || systemOpenaiProviderCanUse + const isOpenAIProviderConfigured = customOpenaiProvidersCanUse || systemOpenaiProviderCanUse const providers: Provider[] = [ { key: 'openai_moderation', @@ -193,7 +193,7 @@ const ModerationSettingModal: FC = ({ } const handleSave = () => { - if (localeData.type === 'openai_moderation' && !openaiProviderConfiged) + if (localeData.type === 'openai_moderation' && !isOpenAIProviderConfigured) return if (!localeData.config?.inputs_config?.enabled && !localeData.config?.outputs_config?.enabled) { @@ -257,7 +257,7 @@ const ModerationSettingModal: FC = ({ className={` flex items-center px-3 py-2 rounded-lg text-sm text-gray-900 cursor-pointer ${localeData.type === provider.key ? 'bg-white border-[1.5px] border-primary-400 shadow-sm' : 'border border-gray-100 bg-gray-25'} - ${localeData.type === 'openai_moderation' && provider.key === 'openai_moderation' && !openaiProviderConfiged && 'opacity-50'} + ${localeData.type === 'openai_moderation' && provider.key === 'openai_moderation' && !isOpenAIProviderConfigured && 'opacity-50'} `} onClick={() => handleDataTypeChange(provider.key)} > @@ -270,7 +270,7 @@ const ModerationSettingModal: FC = ({ }
{ - !isLoading && !openaiProviderConfiged && localeData.type === 'openai_moderation' && ( + !isLoading && !isOpenAIProviderConfigured && localeData.type === 'openai_moderation' && (
@@ -364,7 +364,7 @@ const ModerationSettingModal: FC = ({ diff --git a/web/app/components/base/features/feature-panel/opening-statement/index.tsx b/web/app/components/base/features/feature-panel/opening-statement/index.tsx index 54bf8bd937..b039165c9e 100644 --- a/web/app/components/base/features/feature-panel/opening-statement/index.tsx +++ b/web/app/components/base/features/feature-panel/opening-statement/index.tsx @@ -248,7 +248,7 @@ const OpeningStatement: FC = ({ onClick={() => { setTempSuggestedQuestions([...tempSuggestedQuestions, '']) }} className='mt-1 flex items-center h-9 px-3 gap-2 rounded-lg cursor-pointer text-gray-400 bg-gray-100 hover:bg-gray-200'> -
{t('appDebug.variableConig.addOption')}
+
{t('appDebug.variableConfig.addOption')}
)}
@@ -308,7 +308,7 @@ const OpeningStatement: FC = ({ {isShowConfirmAddVar && ( diff --git a/web/app/components/base/file-icon/index.tsx b/web/app/components/base/file-icon/index.tsx index 874637ca7a..21e48b3dd4 100644 --- a/web/app/components/base/file-icon/index.tsx +++ b/web/app/components/base/file-icon/index.tsx @@ -8,7 +8,7 @@ import { Md, Pdf, Txt, - Unknow, + Unknown, Xlsx, } from '@/app/components/base/icons/src/public/files' import { Notion } from '@/app/components/base/icons/src/public/common' @@ -47,7 +47,7 @@ const FileIcon: FC = ({ case 'notion': return default: - return + return } } diff --git a/web/app/components/base/icons/assets/vender/solid/communication/cute-robote.svg b/web/app/components/base/icons/assets/vender/solid/communication/cute-robote.svg index 5eb7476085..8fa74ce264 100644 --- a/web/app/components/base/icons/assets/vender/solid/communication/cute-robote.svg +++ b/web/app/components/base/icons/assets/vender/solid/communication/cute-robote.svg @@ -1,5 +1,5 @@ - + diff --git a/web/app/components/base/icons/src/public/files/Unknow.json b/web/app/components/base/icons/src/public/files/Unknown.json similarity index 99% rename from web/app/components/base/icons/src/public/files/Unknow.json rename to web/app/components/base/icons/src/public/files/Unknown.json index 33067fa96f..c39df990d0 100644 --- a/web/app/components/base/icons/src/public/files/Unknow.json +++ b/web/app/components/base/icons/src/public/files/Unknown.json @@ -195,5 +195,5 @@ } ] }, - "name": "Unknow" + "name": "Unknown" } \ No newline at end of file diff --git a/web/app/components/base/icons/src/public/files/Unknow.tsx b/web/app/components/base/icons/src/public/files/Unknown.tsx similarity index 87% rename from web/app/components/base/icons/src/public/files/Unknow.tsx rename to web/app/components/base/icons/src/public/files/Unknown.tsx index ce84d344bf..de909ed65e 100644 --- a/web/app/components/base/icons/src/public/files/Unknow.tsx +++ b/web/app/components/base/icons/src/public/files/Unknown.tsx @@ -2,7 +2,7 @@ // DON NOT EDIT IT MANUALLY import * as React from 'react' -import data from './Unknow.json' +import data from './Unknown.json' import IconBase from '@/app/components/base/icons/IconBase' import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase' @@ -11,6 +11,6 @@ const Icon = React.forwardRef, Omit ) -Icon.displayName = 'Unknow' +Icon.displayName = 'Unknown' export default Icon diff --git a/web/app/components/base/icons/src/public/files/index.ts b/web/app/components/base/icons/src/public/files/index.ts index 2814c4ae39..f38c28cbdb 100644 --- a/web/app/components/base/icons/src/public/files/index.ts +++ b/web/app/components/base/icons/src/public/files/index.ts @@ -6,6 +6,6 @@ export { default as Json } from './Json' export { default as Md } from './Md' export { default as Pdf } from './Pdf' export { default as Txt } from './Txt' -export { default as Unknow } from './Unknow' +export { default as Unknown } from './Unknown' export { default as Xlsx } from './Xlsx' export { default as Yaml } from './Yaml' diff --git a/web/app/components/base/icons/src/vender/solid/communication/CuteRobote.json b/web/app/components/base/icons/src/vender/solid/communication/CuteRobot.json similarity index 96% rename from web/app/components/base/icons/src/vender/solid/communication/CuteRobote.json rename to web/app/components/base/icons/src/vender/solid/communication/CuteRobot.json index f8c92c3174..5b36575f56 100644 --- a/web/app/components/base/icons/src/vender/solid/communication/CuteRobote.json +++ b/web/app/components/base/icons/src/vender/solid/communication/CuteRobot.json @@ -15,7 +15,7 @@ "type": "element", "name": "g", "attributes": { - "id": "cute-robote" + "id": "cute-robot" }, "children": [ { @@ -34,5 +34,5 @@ } ] }, - "name": "CuteRobote" + "name": "CuteRobot" } \ No newline at end of file diff --git a/web/app/components/base/icons/src/vender/solid/communication/CuteRobote.tsx b/web/app/components/base/icons/src/vender/solid/communication/CuteRobot.tsx similarity index 86% rename from web/app/components/base/icons/src/vender/solid/communication/CuteRobote.tsx rename to web/app/components/base/icons/src/vender/solid/communication/CuteRobot.tsx index d416fb5b66..49994048b7 100644 --- a/web/app/components/base/icons/src/vender/solid/communication/CuteRobote.tsx +++ b/web/app/components/base/icons/src/vender/solid/communication/CuteRobot.tsx @@ -2,7 +2,7 @@ // DON NOT EDIT IT MANUALLY import * as React from 'react' -import data from './CuteRobote.json' +import data from './CuteRobot.json' import IconBase from '@/app/components/base/icons/IconBase' import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase' @@ -11,6 +11,6 @@ const Icon = React.forwardRef, Omit ) -Icon.displayName = 'CuteRobote' +Icon.displayName = 'CuteRobot' export default Icon diff --git a/web/app/components/base/icons/src/vender/solid/communication/index.ts b/web/app/components/base/icons/src/vender/solid/communication/index.ts index 854953c116..673de27463 100644 --- a/web/app/components/base/icons/src/vender/solid/communication/index.ts +++ b/web/app/components/base/icons/src/vender/solid/communication/index.ts @@ -1,6 +1,6 @@ export { default as AiText } from './AiText' export { default as ChatBot } from './ChatBot' -export { default as CuteRobote } from './CuteRobote' +export { default as CuteRobot } from './CuteRobot' export { default as EditList } from './EditList' export { default as MessageDotsCircle } from './MessageDotsCircle' export { default as MessageFast } from './MessageFast' diff --git a/web/app/components/base/image-uploader/audio-preview.tsx b/web/app/components/base/image-uploader/audio-preview.tsx new file mode 100644 index 0000000000..24ede8aa43 --- /dev/null +++ b/web/app/components/base/image-uploader/audio-preview.tsx @@ -0,0 +1,38 @@ +import type { FC } from 'react' +import { createPortal } from 'react-dom' +import { RiCloseLine } from '@remixicon/react' + +type AudioPreviewProps = { + url: string + title: string + onCancel: () => void +} +const AudioPreview: FC = ({ + url, + title, + onCancel, +}) => { + return createPortal( +
e.stopPropagation()}> +
+ +
+
+ +
+
+ , + document.body, + ) +} + +export default AudioPreview diff --git a/web/app/components/base/image-uploader/hooks.ts b/web/app/components/base/image-uploader/hooks.ts index bb929f4a40..41074000a2 100644 --- a/web/app/components/base/image-uploader/hooks.ts +++ b/web/app/components/base/image-uploader/hooks.ts @@ -199,7 +199,7 @@ export const useClipboardUploader = ({ visionConfig, onUpload, files }: useClipb const handleClipboardPaste = useCallback((e: ClipboardEvent) => { // reserve native text copy behavior const file = e.clipboardData?.files[0] - // when copyed file, prevent default action + // when copied file, prevent default action if (file) { e.preventDefault() handleLocalFileUpload(file) diff --git a/web/app/components/base/image-uploader/image-preview.tsx b/web/app/components/base/image-uploader/image-preview.tsx index 5b33832a14..41f29fda2e 100644 --- a/web/app/components/base/image-uploader/image-preview.tsx +++ b/web/app/components/base/image-uploader/image-preview.tsx @@ -1,19 +1,43 @@ import type { FC } from 'react' +import { useRef } from 'react' +import { t } from 'i18next' import { createPortal } from 'react-dom' -import { RiCloseLine } from '@remixicon/react' +import { RiCloseLine, RiExternalLinkLine } from '@remixicon/react' +import Tooltip from '@/app/components/base/tooltip' +import { randomString } from '@/utils' type ImagePreviewProps = { url: string + title: string onCancel: () => void } const ImagePreview: FC = ({ url, + title, onCancel, }) => { + const selector = useRef(`copy-tooltip-${randomString(4)}`) + + const openInNewTab = () => { + // Open in a new window, considering the case when the page is inside an iframe + if (url.startsWith('http')) { + window.open(url, '_blank') + } + else if (url.startsWith('data:image')) { + // Base64 image + const win = window.open() + win?.document.write(`${title}`) + } + else { + console.error('Unable to open image', url) + } + } + return createPortal(
e.stopPropagation()}> + {/* eslint-disable-next-line @next/next/no-img-element */} preview image @@ -23,6 +47,18 @@ const ImagePreview: FC = ({ >
+ +
+ +
+
, document.body, ) diff --git a/web/app/components/base/image-uploader/video-preview.tsx b/web/app/components/base/image-uploader/video-preview.tsx new file mode 100644 index 0000000000..15291412eb --- /dev/null +++ b/web/app/components/base/image-uploader/video-preview.tsx @@ -0,0 +1,38 @@ +import type { FC } from 'react' +import { createPortal } from 'react-dom' +import { RiCloseLine } from '@remixicon/react' + +type VideoPreviewProps = { + url: string + title: string + onCancel: () => void +} +const VideoPreview: FC = ({ + url, + title, + onCancel, +}) => { + return createPortal( +
e.stopPropagation()}> +
+ +
+
+ +
+
+ , + document.body, + ) +} + +export default VideoPreview diff --git a/web/app/components/base/input/index.tsx b/web/app/components/base/input/index.tsx index 5ab8249446..cac5efd879 100644 --- a/web/app/components/base/input/index.tsx +++ b/web/app/components/base/input/index.tsx @@ -1,43 +1,83 @@ -'use client' -import type { SVGProps } from 'react' -import React, { useState } from 'react' +import type { CSSProperties } from 'react' +import React from 'react' import { useTranslation } from 'react-i18next' -import cn from 'classnames' +import { RiCloseCircleFill, RiErrorWarningLine, RiSearchLine } from '@remixicon/react' +import { type VariantProps, cva } from 'class-variance-authority' +import cn from '@/utils/classnames' -type InputProps = { - placeholder?: string - value?: string - defaultValue?: string - onChange?: (v: string) => void - className?: string - wrapperClassName?: string - type?: string - showPrefix?: React.ReactNode - prefixIcon?: React.ReactNode -} - -const GlassIcon = ({ className }: SVGProps) => ( - - - +export const inputVariants = cva( + '', + { + variants: { + size: { + regular: 'px-3 radius-md system-sm-regular', + large: 'px-4 radius-lg system-md-regular', + }, + }, + defaultVariants: { + size: 'regular', + }, + }, ) -const Input = ({ value, defaultValue, onChange, className = '', wrapperClassName = '', placeholder, type, showPrefix, prefixIcon }: InputProps) => { - const [localValue, setLocalValue] = useState(value ?? defaultValue) +export type InputProps = { + showLeftIcon?: boolean + showClearIcon?: boolean + onClear?: () => void + disabled?: boolean + destructive?: boolean + wrapperClassName?: string + styleCss?: CSSProperties +} & React.InputHTMLAttributes & VariantProps + +const Input = ({ + size, + disabled, + destructive, + showLeftIcon, + showClearIcon, + onClear, + wrapperClassName, + className, + styleCss, + value, + placeholder, + onChange, + ...props +}: InputProps) => { const { t } = useTranslation() return ( -
- {showPrefix && {prefixIcon ?? }} +
+ {showLeftIcon && } { - setLocalValue(e.target.value) - onChange && onChange(e.target.value) - }} + style={styleCss} + className={cn( + 'w-full py-[7px] bg-components-input-bg-normal border border-transparent text-components-input-text-filled hover:bg-components-input-bg-hover hover:border-components-input-border-hover focus:bg-components-input-bg-active focus:border-components-input-border-active focus:shadow-xs placeholder:text-components-input-text-placeholder appearance-none outline-none caret-primary-600', + inputVariants({ size }), + showLeftIcon && 'pl-[26px]', + showLeftIcon && size === 'large' && 'pl-7', + showClearIcon && value && 'pr-[26px]', + showClearIcon && value && size === 'large' && 'pr-7', + destructive && 'pr-[26px]', + destructive && size === 'large' && 'pr-7', + disabled && 'bg-components-input-bg-disabled border-transparent text-components-input-text-filled-disabled cursor-not-allowed hover:bg-components-input-bg-disabled hover:border-transparent', + destructive && 'bg-components-input-bg-destructive border-components-input-border-destructive text-components-input-text-filled hover:bg-components-input-bg-destructive hover:border-components-input-border-destructive focus:bg-components-input-bg-destructive focus:border-components-input-border-destructive', + className, + )} + placeholder={placeholder ?? (showLeftIcon ? t('common.operation.search') ?? '' : 'please input')} + value={value} + onChange={onChange} + disabled={disabled} + {...props} /> + {showClearIcon && value && !disabled && !destructive && ( +
+ +
+ )} + {destructive && ( + + )}
) } diff --git a/web/app/components/base/logo/logo-embedded-chat-avatar.tsx b/web/app/components/base/logo/logo-embedded-chat-avatar.tsx new file mode 100644 index 0000000000..7fd94827eb --- /dev/null +++ b/web/app/components/base/logo/logo-embedded-chat-avatar.tsx @@ -0,0 +1,18 @@ +import type { FC } from 'react' + +type LogoEmbeddedChatAvatarProps = { + className?: string +} +const LogoEmbeddedChatAvatar: FC = ({ + className, +}) => { + return ( + logo + ) +} + +export default LogoEmbeddedChatAvatar diff --git a/web/app/components/base/logo/logo-embedded-chat-header.tsx b/web/app/components/base/logo/logo-embedded-chat-header.tsx new file mode 100644 index 0000000000..976ce0c77a --- /dev/null +++ b/web/app/components/base/logo/logo-embedded-chat-header.tsx @@ -0,0 +1,19 @@ +import type { FC } from 'react' + +type LogoEmbeddedChatHeaderProps = { + className?: string +} + +const LogoEmbeddedChatHeader: FC = ({ + className, +}) => { + return ( + logo + ) +} + +export default LogoEmbeddedChatHeader diff --git a/web/app/components/base/logo/logo-embeded-chat-avatar.tsx b/web/app/components/base/logo/logo-embeded-chat-avatar.tsx deleted file mode 100644 index c5880e7bf6..0000000000 --- a/web/app/components/base/logo/logo-embeded-chat-avatar.tsx +++ /dev/null @@ -1,18 +0,0 @@ -import type { FC } from 'react' - -type LogoEmbededChatAvatarProps = { - className?: string -} -const LogoEmbededChatAvatar: FC = ({ - className, -}) => { - return ( - logo - ) -} - -export default LogoEmbededChatAvatar diff --git a/web/app/components/base/logo/logo-embeded-chat-header.tsx b/web/app/components/base/logo/logo-embeded-chat-header.tsx deleted file mode 100644 index f979d501c2..0000000000 --- a/web/app/components/base/logo/logo-embeded-chat-header.tsx +++ /dev/null @@ -1,18 +0,0 @@ -import type { FC } from 'react' - -type LogoEmbededChatHeaderProps = { - className?: string -} -const LogoEmbededChatHeader: FC = ({ - className, -}) => { - return ( - logo - ) -} - -export default LogoEmbededChatHeader diff --git a/web/app/components/base/markdown.tsx b/web/app/components/base/markdown.tsx index af4b13ff70..11bcd84e18 100644 --- a/web/app/components/base/markdown.tsx +++ b/web/app/components/base/markdown.tsx @@ -8,13 +8,16 @@ import RemarkGfm from 'remark-gfm' import SyntaxHighlighter from 'react-syntax-highlighter' import { atelierHeathLight } from 'react-syntax-highlighter/dist/esm/styles/hljs' import type { RefObject } from 'react' -import { memo, useEffect, useMemo, useRef, useState } from 'react' +import { Component, memo, useEffect, useMemo, useRef, useState } from 'react' import type { CodeComponent } from 'react-markdown/lib/ast-to-react' import cn from '@/utils/classnames' import CopyBtn from '@/app/components/base/copy-btn' import SVGBtn from '@/app/components/base/svg' import Flowchart from '@/app/components/base/mermaid' import ImageGallery from '@/app/components/base/image-gallery' +import { useChatContext } from '@/app/components/base/chat/chat/context' +import VideoGallery from '@/app/components/base/video-gallery' +import AudioGallery from '@/app/components/base/audio-gallery' // Available language https://github.com/react-syntax-highlighter/react-syntax-highlighter/blob/master/AVAILABLE_LANGUAGES_HLJS.MD const capitalizationLanguageNameMap: Record = { @@ -33,6 +36,10 @@ const capitalizationLanguageNameMap: Record = { markdown: 'MarkDown', makefile: 'MakeFile', echarts: 'ECharts', + shell: 'Shell', + powershell: 'PowerShell', + json: 'JSON', + latex: 'Latex', } const getCorrectCapitalizationLanguageName = (language: string) => { if (!language) @@ -65,6 +72,7 @@ export function PreCode(props: { children: any }) { ) } +// eslint-disable-next-line unused-imports/no-unused-vars const useLazyLoad = (ref: RefObject): boolean => { const [isIntersecting, setIntersecting] = useState(false) @@ -104,7 +112,7 @@ const CodeBlock: CodeComponent = memo(({ inline, className, children, ...props } const match = /language-(\w+)/.exec(className || '') const language = match?.[1] const languageShowName = getCorrectCapitalizationLanguageName(language || '') - let chartData = JSON.parse(String('{"title":{"text":"Something went wrong."}}').replace(/\n$/, '')) + let chartData = JSON.parse(String('{"title":{"text":"ECharts error - Wrong JSON format."}}').replace(/\n$/, '')) if (language === 'echarts') { try { chartData = JSON.parse(String(children).replace(/\n$/, '')) @@ -126,12 +134,7 @@ const CodeBlock: CodeComponent = memo(({ inline, className, children, ...props } >
{languageShowName}
- {language === 'mermaid' - && - } + {language === 'mermaid' && } ) : ( (language === 'echarts') - ? (
-
) +
) : ( { + const srcs = node.children.filter(child => 'properties' in child).map(child => (child as any).properties.src) + if (srcs.length === 0) + return null + return +}) +VideoBlock.displayName = 'VideoBlock' + +const AudioBlock: CodeComponent = memo(({ node }) => { + const srcs = node.children.filter(child => 'properties' in child).map(child => (child as any).properties.src) + if (srcs.length === 0) + return null + return +}) +AudioBlock.displayName = 'AudioBlock' + +const Paragraph = (paragraph: any) => { + const { node }: any = paragraph + const children_node = node.children + if (children_node && children_node[0] && 'tagName' in children_node[0] && children_node[0].tagName === 'img') { + return ( + <> + +
{paragraph.children.slice(1)}
+ + ) + } + return
{paragraph.children}
+} + +const Img = ({ src }: any) => { + return () +} + +const Link = ({ node, ...props }: any) => { + if (node.properties?.href && node.properties.href?.toString().startsWith('abbr')) { + // eslint-disable-next-line react-hooks/rules-of-hooks + const { onSend } = useChatContext() + const hidden_text = decodeURIComponent(node.properties.href.toString().split('abbr:')[1]) + + return onSend?.(hidden_text)} title={node.children[0]?.value}>{node.children[0]?.value} + } + else { + return {node.children[0] ? node.children[0]?.value : 'Download'} + } +} + export function Markdown(props: { content: string; className?: string }) { const latexContent = preprocessLaTeX(props.content) return (
{ + return (tree) => { + const iterate = (node: any) => { + if (node.type === 'element' && !node.properties?.src && node.properties?.ref && node.properties.ref.startsWith('{') && node.properties.ref.endsWith('}')) + delete node.properties.ref + + if (node.children) + node.children.forEach(iterate) + } + tree.children.forEach(iterate) + } + }, ]} components={{ code: CodeBlock, - img({ src }) { - return ( - - ) - }, - p: (paragraph) => { - const { node }: any = paragraph - if (node.children[0].tagName === 'img') { - const image = node.children[0] - - return ( - <> - -

{paragraph.children.slice(1)}

- - ) - } - return

{paragraph.children}

- }, + img: Img, + video: VideoBlock, + audio: AudioBlock, + a: Link, + p: Paragraph, }} linkTarget='_blank' > @@ -211,3 +260,25 @@ export function Markdown(props: { content: string; className?: string }) {
) } + +// **Add an ECharts runtime error handler +// Avoid error #7832 (Crash when ECharts accesses undefined objects) +// This can happen when a component attempts to access an undefined object that references an unregistered map, causing the program to crash. + +export default class ErrorBoundary extends Component { + constructor(props) { + super(props) + this.state = { hasError: false } + } + + componentDidCatch(error, errorInfo) { + this.setState({ hasError: true }) + console.error(error, errorInfo) + } + + render() { + if (this.state.hasError) + return
Oops! ECharts reported a runtime error.
(see the browser console for more information)
+ return this.props.children + } +} diff --git a/web/app/components/base/notion-page-selector/base.tsx b/web/app/components/base/notion-page-selector/base.tsx index 63aff09a93..e3b321b120 100644 --- a/web/app/components/base/notion-page-selector/base.tsx +++ b/web/app/components/base/notion-page-selector/base.tsx @@ -72,7 +72,7 @@ const NotionPageSelector = ({ const handleSelectWorkspace = useCallback((workspaceId: string) => { setCurrentWorkspaceId(workspaceId) }, []) - const handleSelecPages = (newSelectedPagesId: Set) => { + const handleSelectPages = (newSelectedPagesId: Set) => { const selectedPages = Array.from(newSelectedPagesId).map(pageId => getPagesMapAndSelectedPagesId[0][pageId]) setSelectedPagesId(new Set(Array.from(newSelectedPagesId))) @@ -117,7 +117,7 @@ const NotionPageSelector = ({ searchValue={searchValue} list={currentWorkspace?.pages || []} pagesMap={getPagesMapAndSelectedPagesId[0]} - onSelect={handleSelecPages} + onSelect={handleSelectPages} canPreview={canPreview} previewPageId={previewPageId} onPreview={handlePreviewPage} diff --git a/web/app/components/base/notion-page-selector/page-selector/index.tsx b/web/app/components/base/notion-page-selector/page-selector/index.tsx index b61fa34567..8f398790e7 100644 --- a/web/app/components/base/notion-page-selector/page-selector/index.tsx +++ b/web/app/components/base/notion-page-selector/page-selector/index.tsx @@ -22,13 +22,13 @@ type PageSelectorProps = { type NotionPageTreeItem = { children: Set descendants: Set - deepth: number + depth: number ancestors: string[] } & DataSourceNotionPage type NotionPageTreeMap = Record type NotionPageItem = { expand: boolean - deepth: number + depth: number } & DataSourceNotionPage const recursivePushInParentDescendants = ( @@ -51,7 +51,7 @@ const recursivePushInParentDescendants = ( ...pagesMap[parentId], children, descendants, - deepth: 0, + depth: 0, ancestors: [], } } @@ -60,7 +60,7 @@ const recursivePushInParentDescendants = ( listTreeMap[parentId].descendants.add(pageId) listTreeMap[parentId].descendants.add(leafItem.page_id) } - leafItem.deepth++ + leafItem.depth++ leafItem.ancestors.unshift(listTreeMap[parentId].page_name) if (listTreeMap[parentId].parent_id !== 'root') @@ -95,7 +95,7 @@ const ItemComponent = ({ index, style, data }: ListChildComponentProps<{ return (
handleToggle(index)} /> ) @@ -106,7 +106,7 @@ const ItemComponent = ({ index, style, data }: ListChildComponentProps<{ ) } return ( -
+
) } @@ -185,7 +185,7 @@ const PageSelector = ({ return { ...item, expand: false, - deepth: 0, + depth: 0, } })) } @@ -195,7 +195,7 @@ const PageSelector = ({ return { ...item, expand: false, - deepth: 0, + depth: 0, } }) const currentDataList = searchValue ? searchDataList : dataList @@ -205,7 +205,7 @@ const PageSelector = ({ return list.reduce((prev: NotionPageTreeMap, next: DataSourceNotionPage) => { const pageId = next.page_id if (!prev[pageId]) - prev[pageId] = { ...next, children: new Set(), descendants: new Set(), deepth: 0, ancestors: [] } + prev[pageId] = { ...next, children: new Set(), descendants: new Set(), depth: 0, ancestors: [] } recursivePushInParentDescendants(pagesMap, prev, prev[pageId], prev[pageId]) return prev @@ -233,7 +233,7 @@ const PageSelector = ({ ...childrenIds.map(item => ({ ...pagesMap[item], expand: false, - deepth: listMapWithChildrenAndDescendants[item].deepth, + depth: listMapWithChildrenAndDescendants[item].depth, })), ...dataList.slice(index + 1)] } diff --git a/web/app/components/base/pagination/index.tsx b/web/app/components/base/pagination/index.tsx index 98b1e266ae..f8c5684b55 100644 --- a/web/app/components/base/pagination/index.tsx +++ b/web/app/components/base/pagination/index.tsx @@ -23,8 +23,8 @@ const CustomizedPagination: FC = ({ current, onChange, total, limit = 10 middlePagesSiblingCount={1} setCurrentPage={onChange} totalPages={totalPages} - truncableClassName="w-8 px-0.5 text-center" - truncableText="..." + truncatableClassName="w-8 px-0.5 text-center" + truncatableText="..." > ) => [RefObject, boolean] -export const useSelectOrDelete: UseSelectOrDeleteHanlder = (nodeKey: string, command?: LexicalCommand) => { +export type UseSelectOrDeleteHandler = (nodeKey: string, command?: LexicalCommand) => [RefObject, boolean] +export const useSelectOrDelete: UseSelectOrDeleteHandler = (nodeKey: string, command?: LexicalCommand) => { const ref = useRef(null) const [editor] = useLexicalComposerContext() const [isSelected, setSelected, clearSelection] = useLexicalNodeSelection(nodeKey) diff --git a/web/app/components/base/search-input/index.tsx b/web/app/components/base/search-input/index.tsx index a85bc2db8a..4b3821da5a 100644 --- a/web/app/components/base/search-input/index.tsx +++ b/web/app/components/base/search-input/index.tsx @@ -25,9 +25,9 @@ const SearchInput: FC = ({ return (
diff --git a/web/app/components/base/select/locale.tsx b/web/app/components/base/select/locale.tsx index 3066364316..8b89c66950 100644 --- a/web/app/components/base/select/locale.tsx +++ b/web/app/components/base/select/locale.tsx @@ -77,7 +77,7 @@ export function InputSelect({
- + {item?.name}
diff --git a/web/app/components/base/slider/index.tsx b/web/app/components/base/slider/index.tsx index 18ef3a7a29..2b8f181633 100644 --- a/web/app/components/base/slider/index.tsx +++ b/web/app/components/base/slider/index.tsx @@ -32,7 +32,7 @@ const Slider: React.FC = ({ max={max || 100} step={step || 1} className={cn('relative slider', className)} - thumbClassName={cn('absolute top-[-9px] w-2 h-5 border-[0.5px] border-components-silder-knob-border rounded-[3px] bg-components-silder-knob shadow-sm focus:outline-none', !disabled && 'cursor-pointer', thumbClassName)} + thumbClassName={cn('absolute top-[-9px] w-2 h-5 border-[0.5px] border-components-slider-knob-border rounded-[3px] bg-components-slider-knob shadow-sm focus:outline-none', !disabled && 'cursor-pointer', thumbClassName)} trackClassName={cn('h-0.5 rounded-full slider-track', trackClassName)} onChange={onChange} /> diff --git a/web/app/components/base/slider/style.css b/web/app/components/base/slider/style.css index 6b4394ed93..e215a9914e 100644 --- a/web/app/components/base/slider/style.css +++ b/web/app/components/base/slider/style.css @@ -3,9 +3,9 @@ } .slider-track { - background-color: var(--color-components-silder-range); + background-color: var(--color-components-slider-range); } .slider-track-1 { - background-color: var(--color-components-silder-track); + background-color: var(--color-components-slider-track); } \ No newline at end of file diff --git a/web/app/components/base/tag-input/index.tsx b/web/app/components/base/tag-input/index.tsx index 7eab355a0d..404fd89f38 100644 --- a/web/app/components/base/tag-input/index.tsx +++ b/web/app/components/base/tag-input/index.tsx @@ -46,16 +46,16 @@ const TagInput: FC = ({ if (isSpecialMode) e.preventDefault() - const valueTrimed = value.trim() - if (!valueTrimed || (items.find(item => item === valueTrimed))) + const valueTrimmed = value.trim() + if (!valueTrimmed || (items.find(item => item === valueTrimmed))) return - if (valueTrimed.length > 20) { + if (valueTrimmed.length > 20) { notify({ type: 'error', message: t('datasetDocuments.segment.keywordError') }) return } - onChange([...items, valueTrimed]) + onChange([...items, valueTrimmed]) setTimeout(() => { setValue('') }) diff --git a/web/app/components/base/text-generation/types.ts b/web/app/components/base/text-generation/types.ts index 82a4177592..1e2c04f6e8 100644 --- a/web/app/components/base/text-generation/types.ts +++ b/web/app/components/base/text-generation/types.ts @@ -15,7 +15,7 @@ export type UserInputForm = { } export type UserInputFormTextInput = { - 'text-inpput': UserInputForm & { + 'text-input': UserInputForm & { max_length: number } } diff --git a/web/app/components/base/video-gallery/VideoPlayer.module.css b/web/app/components/base/video-gallery/VideoPlayer.module.css new file mode 100644 index 0000000000..04c4a367d6 --- /dev/null +++ b/web/app/components/base/video-gallery/VideoPlayer.module.css @@ -0,0 +1,188 @@ +.videoPlayer { + position: relative; + width: 100%; + max-width: 800px; + margin: 0 auto; + border-radius: 8px; + overflow: hidden; +} + +.video { + width: 100%; + display: block; +} + +.controls { + position: absolute; + bottom: 0; + left: 0; + right: 0; + width: 100%; + height: 100%; + display: flex; + flex-direction: column; + justify-content: flex-end; + transition: opacity 0.3s ease; +} + +.controls.hidden { + opacity: 0; +} + +.controls.visible { + opacity: 1; +} + +.overlay { + background: linear-gradient(to top, rgba(0, 0, 0, 0.7) 0%, transparent 100%); + padding: 20px; + display: flex; + flex-direction: column; +} + +.progressBarContainer { + width: 100%; + margin-bottom: 10px; +} + +.controlsContent { + display: flex; + justify-content: space-between; + align-items: center; +} + +.leftControls, .rightControls { + display: flex; + align-items: center; +} + +.playPauseButton, .muteButton, .fullscreenButton { + background: none; + border: none; + color: white; + cursor: pointer; + padding: 4px; + margin-right: 10px; + display: flex; + align-items: center; + justify-content: center; +} + +.playPauseButton:hover, .muteButton:hover, .fullscreenButton:hover { + background-color: rgba(255, 255, 255, 0.1); + border-radius: 50%; +} + +.time { + color: white; + font-size: 14px; + margin-left: 8px; +} + +.volumeControl { + display: flex; + align-items: center; + margin-right: 16px; +} + +.volumeSlider { + width: 60px; + height: 4px; + background: rgba(255, 255, 255, 0.3); + border-radius: 2px; + cursor: pointer; + margin-left: 12px; + position: relative; +} + +.volumeLevel { + position: absolute; + top: 0; + left: 0; + height: 100%; + background: #ffffff; + border-radius: 2px; +} + +.progressBar { + position: relative; + width: 100%; + height: 4px; + background: rgba(255, 255, 255, 0.3); + cursor: pointer; + border-radius: 2px; + overflow: visible; + transition: height 0.2s ease; +} + +.progressBar:hover { + height: 6px; +} + +.progress { + height: 100%; + background: #ffffff; + transition: width 0.1s ease-in-out; +} + +.hoverTimeIndicator { + position: absolute; + bottom: 100%; + transform: translateX(-50%); + background-color: rgba(0, 0, 0, 0.7); + color: white; + padding: 4px 8px; + border-radius: 4px; + font-size: 12px; + pointer-events: none; + white-space: nowrap; + margin-bottom: 8px; +} + +.hoverTimeIndicator::after { + content: ''; + position: absolute; + top: 100%; + left: 50%; + margin-left: -4px; + border-width: 4px; + border-style: solid; + border-color: rgba(0, 0, 0, 0.7) transparent transparent transparent; +} + +.controls.smallSize .controlsContent { + justify-content: space-between; +} + +.controls.smallSize .leftControls, +.controls.smallSize .rightControls { + flex: 0 0 auto; + display: flex; + align-items: center; +} + +.controls.smallSize .rightControls { + justify-content: flex-end; +} + +.controls.smallSize .progressBarContainer { + margin-bottom: 4px; +} + +.controls.smallSize .playPauseButton, +.controls.smallSize .muteButton, +.controls.smallSize .fullscreenButton { + padding: 2px; + margin-right: 4px; +} + +.controls.smallSize .playPauseButton svg, +.controls.smallSize .muteButton svg, +.controls.smallSize .fullscreenButton svg { + width: 16px; + height: 16px; +} + +.controls.smallSize .muteButton { + order: -1; +} diff --git a/web/app/components/base/video-gallery/VideoPlayer.tsx b/web/app/components/base/video-gallery/VideoPlayer.tsx new file mode 100644 index 0000000000..d7c86a1af9 --- /dev/null +++ b/web/app/components/base/video-gallery/VideoPlayer.tsx @@ -0,0 +1,278 @@ +import React, { useCallback, useEffect, useRef, useState } from 'react' +import styles from './VideoPlayer.module.css' + +type VideoPlayerProps = { + src: string +} + +const PlayIcon = () => ( + + + +) + +const PauseIcon = () => ( + + + +) + +const MuteIcon = () => ( + + + +) + +const UnmuteIcon = () => ( + + + +) + +const FullscreenIcon = () => ( + + + +) + +const VideoPlayer: React.FC = ({ src }) => { + const [isPlaying, setIsPlaying] = useState(false) + const [currentTime, setCurrentTime] = useState(0) + const [duration, setDuration] = useState(0) + const [isMuted, setIsMuted] = useState(false) + const [volume, setVolume] = useState(1) + const [isDragging, setIsDragging] = useState(false) + const [isControlsVisible, setIsControlsVisible] = useState(true) + const [hoverTime, setHoverTime] = useState(null) + const videoRef = useRef(null) + const progressRef = useRef(null) + const volumeRef = useRef(null) + const controlsTimeoutRef = useRef(null) + const [isSmallSize, setIsSmallSize] = useState(false) + const containerRef = useRef(null) + + useEffect(() => { + const video = videoRef.current + if (!video) + return + + const setVideoData = () => { + setDuration(video.duration) + setVolume(video.volume) + } + + const setVideoTime = () => { + setCurrentTime(video.currentTime) + } + + const handleEnded = () => { + setIsPlaying(false) + } + + video.addEventListener('loadedmetadata', setVideoData) + video.addEventListener('timeupdate', setVideoTime) + video.addEventListener('ended', handleEnded) + + return () => { + video.removeEventListener('loadedmetadata', setVideoData) + video.removeEventListener('timeupdate', setVideoTime) + video.removeEventListener('ended', handleEnded) + } + }, [src]) + + useEffect(() => { + return () => { + if (controlsTimeoutRef.current) + clearTimeout(controlsTimeoutRef.current) + } + }, []) + + const showControls = useCallback(() => { + setIsControlsVisible(true) + if (controlsTimeoutRef.current) + clearTimeout(controlsTimeoutRef.current) + + controlsTimeoutRef.current = setTimeout(() => setIsControlsVisible(false), 3000) + }, []) + + const togglePlayPause = useCallback(() => { + const video = videoRef.current + if (video) { + if (isPlaying) + video.pause() + else video.play().catch(error => console.error('Error playing video:', error)) + setIsPlaying(!isPlaying) + } + }, [isPlaying]) + + const toggleMute = useCallback(() => { + const video = videoRef.current + if (video) { + const newMutedState = !video.muted + video.muted = newMutedState + setIsMuted(newMutedState) + setVolume(newMutedState ? 0 : (video.volume > 0 ? video.volume : 1)) + video.volume = newMutedState ? 0 : (video.volume > 0 ? video.volume : 1) + } + }, []) + + const toggleFullscreen = useCallback(() => { + const video = videoRef.current + if (video) { + if (document.fullscreenElement) + document.exitFullscreen() + else video.requestFullscreen() + } + }, []) + + const formatTime = (time: number) => { + const minutes = Math.floor(time / 60) + const seconds = Math.floor(time % 60) + return `${minutes.toString().padStart(2, '0')}:${seconds.toString().padStart(2, '0')}` + } + + const updateVideoProgress = useCallback((clientX: number) => { + const progressBar = progressRef.current + const video = videoRef.current + if (progressBar && video) { + const rect = progressBar.getBoundingClientRect() + const pos = (clientX - rect.left) / rect.width + const newTime = pos * video.duration + if (newTime >= 0 && newTime <= video.duration) { + setHoverTime(newTime) + if (isDragging) + video.currentTime = newTime + } + } + }, [isDragging]) + + const handleMouseMove = useCallback((e: React.MouseEvent) => { + updateVideoProgress(e.clientX) + }, [updateVideoProgress]) + + const handleMouseLeave = useCallback(() => { + if (!isDragging) + setHoverTime(null) + }, [isDragging]) + + const handleMouseDown = useCallback((e: React.MouseEvent) => { + e.preventDefault() + setIsDragging(true) + updateVideoProgress(e.clientX) + }, [updateVideoProgress]) + + useEffect(() => { + const handleGlobalMouseMove = (e: MouseEvent) => { + if (isDragging) + updateVideoProgress(e.clientX) + } + + const handleGlobalMouseUp = () => { + setIsDragging(false) + setHoverTime(null) + } + + if (isDragging) { + document.addEventListener('mousemove', handleGlobalMouseMove) + document.addEventListener('mouseup', handleGlobalMouseUp) + } + + return () => { + document.removeEventListener('mousemove', handleGlobalMouseMove) + document.removeEventListener('mouseup', handleGlobalMouseUp) + } + }, [isDragging, updateVideoProgress]) + + const checkSize = useCallback(() => { + if (containerRef.current) + setIsSmallSize(containerRef.current.offsetWidth < 400) + }, []) + + useEffect(() => { + checkSize() + window.addEventListener('resize', checkSize) + return () => window.removeEventListener('resize', checkSize) + }, [checkSize]) + + const handleVolumeChange = useCallback((e: React.MouseEvent) => { + const volumeBar = volumeRef.current + const video = videoRef.current + if (volumeBar && video) { + const rect = volumeBar.getBoundingClientRect() + const newVolume = (e.clientX - rect.left) / rect.width + const clampedVolume = Math.max(0, Math.min(1, newVolume)) + video.volume = clampedVolume + setVolume(clampedVolume) + setIsMuted(clampedVolume === 0) + } + }, []) + + return ( +
+
) diff --git a/web/app/components/browser-initor.tsx b/web/app/components/browser-initor.tsx index 711ff62a94..939ddd567d 100644 --- a/web/app/components/browser-initor.tsx +++ b/web/app/components/browser-initor.tsx @@ -43,10 +43,10 @@ Object.defineProperty(globalThis, 'sessionStorage', { value: sessionStorage, }) -const BrowerInitor = ({ +const BrowserInitor = ({ children, }: { children: React.ReactElement }) => { return children } -export default BrowerInitor +export default BrowserInitor diff --git a/web/app/components/datasets/common/check-rerank-model.ts b/web/app/components/datasets/common/check-rerank-model.ts index 42810e4bf0..581c2bb69a 100644 --- a/web/app/components/datasets/common/check-rerank-model.ts +++ b/web/app/components/datasets/common/check-rerank-model.ts @@ -7,13 +7,13 @@ import { RerankingModeEnum } from '@/models/datasets' export const isReRankModelSelected = ({ rerankDefaultModel, - isRerankDefaultModelVaild, + isRerankDefaultModelValid, retrievalConfig, rerankModelList, indexMethod, }: { rerankDefaultModel?: DefaultModelResponse - isRerankDefaultModelVaild: boolean + isRerankDefaultModelValid: boolean retrievalConfig: RetrievalConfig rerankModelList: Model[] indexMethod?: string @@ -25,7 +25,7 @@ export const isReRankModelSelected = ({ return provider?.models.find(({ model }) => model === retrievalConfig.reranking_model?.reranking_model_name) } - if (isRerankDefaultModelVaild) + if (isRerankDefaultModelValid) return !!rerankDefaultModel return false diff --git a/web/app/components/datasets/common/retrieval-method-config/index.tsx b/web/app/components/datasets/common/retrieval-method-config/index.tsx index 1e407b62e1..20d93568ad 100644 --- a/web/app/components/datasets/common/retrieval-method-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-method-config/index.tsx @@ -11,6 +11,11 @@ import { FileSearch02 } from '@/app/components/base/icons/src/vender/solid/files import { useProviderContext } from '@/context/provider-context' import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { + DEFAULT_WEIGHTED_SCORE, + RerankingModeEnum, + WeightedScoreEnum, +} from '@/models/datasets' type Props = { value: RetrievalConfig @@ -32,6 +37,18 @@ const RetrievalMethodConfig: FC = ({ reranking_provider_name: rerankDefaultModel?.provider.provider || '', reranking_model_name: rerankDefaultModel?.model || '', }, + reranking_mode: passValue.reranking_mode || (rerankDefaultModel ? RerankingModeEnum.RerankingModel : RerankingModeEnum.WeightedScore), + weights: passValue.weights || { + weight_type: WeightedScoreEnum.Customized, + vector_setting: { + vector_weight: DEFAULT_WEIGHTED_SCORE.other.semantic, + embedding_provider_name: '', + embedding_model_name: '', + }, + keyword_setting: { + keyword_weight: DEFAULT_WEIGHTED_SCORE.other.keyword, + }, + }, } } return passValue diff --git a/web/app/components/datasets/common/retrieval-param-config/index.tsx b/web/app/components/datasets/common/retrieval-param-config/index.tsx index 54a0963f59..323e47f3b4 100644 --- a/web/app/components/datasets/common/retrieval-param-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-param-config/index.tsx @@ -191,7 +191,7 @@ const RetrievalParamConfig: FC = ({
{option.label}
{option.tips}
} - triggerClassName='ml-0.5 w-3.5 h-4.5' + triggerClassName='ml-0.5 w-3.5 h-3.5' />
)) diff --git a/web/app/components/datasets/create/assets/unknow.svg b/web/app/components/datasets/create/assets/unknown.svg similarity index 100% rename from web/app/components/datasets/create/assets/unknow.svg rename to web/app/components/datasets/create/assets/unknown.svg diff --git a/web/app/components/datasets/create/embedding-process/index.module.css b/web/app/components/datasets/create/embedding-process/index.module.css index a15b1310b4..1ebb006b54 100644 --- a/web/app/components/datasets/create/embedding-process/index.module.css +++ b/web/app/components/datasets/create/embedding-process/index.module.css @@ -83,7 +83,7 @@ .fileIcon { @apply w-4 h-4 mr-1 bg-center bg-no-repeat; - background-image: url(../assets/unknow.svg); + background-image: url(../assets/unknown.svg); background-size: 16px; } .fileIcon.csv { diff --git a/web/app/components/datasets/create/embedding-process/index.tsx b/web/app/components/datasets/create/embedding-process/index.tsx index 574dd083c7..7786582085 100644 --- a/web/app/components/datasets/create/embedding-process/index.tsx +++ b/web/app/components/datasets/create/embedding-process/index.tsx @@ -13,8 +13,7 @@ import cn from '@/utils/classnames' import { FieldInfo } from '@/app/components/datasets/documents/detail/metadata' import Button from '@/app/components/base/button' import type { FullDocumentDetail, IndexingStatusResponse, ProcessRuleResponse } from '@/models/datasets' -import { formatNumber } from '@/utils/format' -import { fetchIndexingStatusBatch as doFetchIndexingStatus, fetchIndexingEstimateBatch, fetchProcessRule } from '@/service/datasets' +import { fetchIndexingStatusBatch as doFetchIndexingStatus, fetchProcessRule } from '@/service/datasets' import { DataSourceType } from '@/models/datasets' import NotionIcon from '@/app/components/base/notion-icon' import PriorityLabel from '@/app/components/billing/priority-label' @@ -142,14 +141,6 @@ const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], index }, apiParams => fetchProcessRule(omit(apiParams, 'action')), { revalidateOnFocus: false, }) - // get cost - const { data: indexingEstimateDetail } = useSWR({ - action: 'fetchIndexingEstimateBatch', - datasetId, - batchId, - }, apiParams => fetchIndexingEstimateBatch(omit(apiParams, 'action')), { - revalidateOnFocus: false, - }) const router = useRouter() const navToDocumentList = () => { @@ -190,28 +181,11 @@ const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], index return ( <> -
+
{isEmbedding && t('datasetDocuments.embedding.processing')} {isEmbeddingCompleted && t('datasetDocuments.embedding.completed')}
-
- {indexingType === 'high_quality' && ( -
-
- {t('datasetDocuments.embedding.highQuality')} · {t('datasetDocuments.embedding.estimate')} - {formatNumber(indexingEstimateDetail?.tokens || 0)}tokens - (${formatNumber(indexingEstimateDetail?.total_price || 0)}) -
- )} - {indexingType === 'economy' && ( -
-
- {t('datasetDocuments.embedding.economy')} · {t('datasetDocuments.embedding.estimate')} - 0tokens -
- )} -
{ enableBilling && plan.type !== Plan.team && ( diff --git a/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx b/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx index e9247c49df..23c3f0314a 100644 --- a/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx +++ b/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx @@ -32,7 +32,7 @@ const EmptyDatasetCreationModal = ({ return } if (inputValue.length > 40) { - notify({ type: 'error', message: t('datasetCreation.stepOne.modal.nameLengthInvaild') }) + notify({ type: 'error', message: t('datasetCreation.stepOne.modal.nameLengthInvalid') }) return } try { @@ -58,7 +58,7 @@ const EmptyDatasetCreationModal = ({
{t('datasetCreation.stepOne.modal.tip')}
{t('datasetCreation.stepOne.modal.input')}
- + setInputValue(e.target.value)} />
diff --git a/web/app/components/datasets/create/file-uploader/index.module.css b/web/app/components/datasets/create/file-uploader/index.module.css index d141815c5a..bf5b7dcaf5 100644 --- a/web/app/components/datasets/create/file-uploader/index.module.css +++ b/web/app/components/datasets/create/file-uploader/index.module.css @@ -104,7 +104,7 @@ .fileIcon { @apply shrink-0 w-6 h-6 mr-2 bg-center bg-no-repeat; - background-image: url(../assets/unknow.svg); + background-image: url(../assets/unknown.svg); background-size: 24px; } diff --git a/web/app/components/datasets/create/step-two/index.module.css b/web/app/components/datasets/create/step-two/index.module.css index 24a62c8e3c..f89d6d67ea 100644 --- a/web/app/components/datasets/create/step-two/index.module.css +++ b/web/app/components/datasets/create/step-two/index.module.css @@ -30,7 +30,7 @@ } .indexItem { - min-height: 146px; + min-height: 126px; } .indexItem .disableMask { @@ -121,10 +121,6 @@ @apply pb-1; } -.radioItem.indexItem .typeHeader .tip { - @apply pb-3; -} - .radioItem .typeIcon { position: absolute; top: 18px; @@ -264,7 +260,7 @@ } .input { - @apply inline-flex h-9 w-full py-1 px-2 rounded-lg text-xs leading-normal; + @apply inline-flex h-9 w-full py-1 px-2 pr-14 rounded-lg text-xs leading-normal; @apply bg-gray-100 caret-primary-600 hover:bg-gray-100 focus:ring-1 focus:ring-inset focus:ring-gray-200 focus-visible:outline-none focus:bg-white placeholder:text-gray-400; } diff --git a/web/app/components/datasets/create/step-two/index.tsx b/web/app/components/datasets/create/step-two/index.tsx index 10b378d8c5..15332b944d 100644 --- a/web/app/components/datasets/create/step-two/index.tsx +++ b/web/app/components/datasets/create/step-two/index.tsx @@ -14,7 +14,7 @@ import PreviewItem, { PreviewType } from './preview-item' import LanguageSelect from './language-select' import s from './index.module.css' import cn from '@/utils/classnames' -import type { CrawlOptions, CrawlResultItem, CreateDocumentReq, CustomFile, FileIndexingEstimateResponse, FullDocumentDetail, IndexingEstimateParams, IndexingEstimateResponse, NotionInfo, PreProcessingRule, ProcessRule, Rules, createDocumentResponse } from '@/models/datasets' +import type { CrawlOptions, CrawlResultItem, CreateDocumentReq, CustomFile, FileIndexingEstimateResponse, FullDocumentDetail, IndexingEstimateParams, NotionInfo, PreProcessingRule, ProcessRule, Rules, createDocumentResponse } from '@/models/datasets' import { createDocument, createFirstDocument, @@ -41,8 +41,10 @@ import { IS_CE_EDITION } from '@/config' import { RETRIEVE_METHOD } from '@/types/app' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import Tooltip from '@/app/components/base/tooltip' -import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { useDefaultModel, useModelList, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { LanguagesSupported } from '@/i18n/language' +import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' +import type { DefaultModel } from '@/app/components/header/account-setting/model-provider-page/declarations' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { Globe01 } from '@/app/components/base/icons/src/vender/line/mapsAndTravel' @@ -109,7 +111,7 @@ const StepTwo = ({ const [previewScrolled, setPreviewScrolled] = useState(false) const [segmentationType, setSegmentationType] = useState(SegmentType.AUTO) const [segmentIdentifier, setSegmentIdentifier] = useState('\\n') - const [max, setMax] = useState(500) + const [max, setMax] = useState(5000) // default chunk length const [overlap, setOverlap] = useState(50) const [rules, setRules] = useState([]) const [defaultConfig, setDefaultConfig] = useState() @@ -131,7 +133,6 @@ const StepTwo = ({ const [showPreview, { setTrue: setShowPreview, setFalse: hidePreview }] = useBoolean() const [customFileIndexingEstimate, setCustomFileIndexingEstimate] = useState(null) const [automaticFileIndexingEstimate, setAutomaticFileIndexingEstimate] = useState(null) - const [estimateTokes, setEstimateTokes] = useState | null>(null) const fileIndexingEstimate = (() => { return segmentationType === SegmentType.AUTO ? automaticFileIndexingEstimate : customFileIndexingEstimate @@ -192,13 +193,10 @@ const StepTwo = ({ const fetchFileIndexingEstimate = async (docForm = DocForm.TEXT) => { // eslint-disable-next-line @typescript-eslint/no-use-before-define const res = await didFetchFileIndexingEstimate(getFileIndexingEstimateParams(docForm)!) - if (segmentationType === SegmentType.CUSTOM) { + if (segmentationType === SegmentType.CUSTOM) setCustomFileIndexingEstimate(res) - } - else { + else setAutomaticFileIndexingEstimate(res) - indexType === IndexingType.QUALIFIED && setEstimateTokes({ tokens: res.tokens, total_price: res.total_price }) - } } const confirmChangeCustomConfig = () => { @@ -308,8 +306,21 @@ const StepTwo = ({ const { modelList: rerankModelList, defaultModel: rerankDefaultModel, - currentModel: isRerankDefaultModelVaild, + currentModel: isRerankDefaultModelValid, } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) + const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding) + const { data: defaultEmbeddingModel } = useDefaultModel(ModelTypeEnum.textEmbedding) + const [embeddingModel, setEmbeddingModel] = useState( + currentDataset?.embedding_model + ? { + provider: currentDataset.embedding_model_provider, + model: currentDataset.embedding_model, + } + : { + provider: defaultEmbeddingModel?.provider.provider || '', + model: defaultEmbeddingModel?.model || '', + }, + ) const getCreationParams = () => { let params if (segmentationType === SegmentType.CUSTOM && overlap > max) { @@ -324,6 +335,8 @@ const StepTwo = ({ process_rule: getProcessRule(), // eslint-disable-next-line @typescript-eslint/no-use-before-define retrieval_model: retrievalConfig, // Readonly. If want to changed, just go to settings page. + embedding_model: embeddingModel.model, // Readonly + embedding_model_provider: embeddingModel.provider, // Readonly } as CreateDocumentReq } else { // create @@ -331,7 +344,7 @@ const StepTwo = ({ if ( !isReRankModelSelected({ rerankDefaultModel, - isRerankDefaultModelVaild: !!isRerankDefaultModelVaild, + isRerankDefaultModelValid: !!isRerankDefaultModelValid, rerankModelList, // eslint-disable-next-line @typescript-eslint/no-use-before-define retrievalConfig, @@ -360,6 +373,8 @@ const StepTwo = ({ doc_language: docLanguage, retrieval_model: postRetrievalConfig, + embedding_model: embeddingModel.model, + embedding_model_provider: embeddingModel.provider, } as CreateDocumentReq if (dataSourceType === DataSourceType.FILE) { params.data_source.info_list.file_info_list = { @@ -613,14 +628,17 @@ const StepTwo = ({
{t('datasetCreation.stepTwo.maxLength')}
- setMax(parseInt(e.target.value.replace(/^0+/, ''), 10))} - /> +
+ setMax(parseInt(e.target.value.replace(/^0+/, ''), 10))} + /> +
Tokens
+
@@ -635,14 +653,17 @@ const StepTwo = ({ } />
- setOverlap(parseInt(e.target.value.replace(/^0+/, ''), 10))} - /> +
+ setOverlap(parseInt(e.target.value.replace(/^0+/, ''), 10))} + /> +
Tokens
+
@@ -675,7 +696,7 @@ const StepTwo = ({ !isAPIKeySet && s.disabled, !hasSetIndexType && indexType === IndexingType.QUALIFIED && s.active, hasSetIndexType && s.disabled, - hasSetIndexType && '!w-full', + hasSetIndexType && '!w-full !min-h-[96px]', )} onClick={() => { if (isAPIKeySet) @@ -690,16 +711,6 @@ const StepTwo = ({ {!hasSetIndexType && {t('datasetCreation.stepTwo.recommend')}}
{t('datasetCreation.stepTwo.qualifiedTip')}
-
{t('datasetCreation.stepTwo.emstimateCost')}
- { - estimateTokes - ? ( -
{formatNumber(estimateTokes.tokens)} tokens(${formatNumber(estimateTokes.total_price)})
- ) - : ( -
{t('datasetCreation.stepTwo.calculating')}
- ) - }
{!isAPIKeySet && (
@@ -717,7 +728,7 @@ const StepTwo = ({ s.indexItem, !hasSetIndexType && indexType === IndexingType.ECONOMICAL && s.active, hasSetIndexType && s.disabled, - hasSetIndexType && '!w-full', + hasSetIndexType && '!w-full !min-h-[96px]', )} onClick={changeToEconomicalType} > @@ -726,15 +737,13 @@ const StepTwo = ({
{t('datasetCreation.stepTwo.economical')}
{t('datasetCreation.stepTwo.economicalTip')}
-
{t('datasetCreation.stepTwo.emstimateCost')}
-
0 tokens
)}
- {hasSetIndexType && ( + {hasSetIndexType && indexType === IndexingType.ECONOMICAL && (
- {t('datasetCreation.stepTwo.indexSettedTip')} + {t('datasetCreation.stepTwo.indexSettingTip')} {t('datasetCreation.stepTwo.datasetSettingLink')}
)} @@ -767,12 +776,32 @@ const StepTwo = ({ )}
)} + {/* Embedding model */} + {indexType === IndexingType.QUALIFIED && ( +
+
{t('datasetSettings.form.embeddingModel')}
+ { + setEmbeddingModel(model) + }} + /> + {!!datasetId && ( +
+ {t('datasetCreation.stepTwo.indexSettingTip')} + {t('datasetCreation.stepTwo.datasetSettingLink')} +
+ )} +
+ )} {/* Retrieval Method Config */}
{!datasetId ? (
- {t('datasetSettings.form.retrievalSetting.title')} +
{t('datasetSettings.form.retrievalSetting.title')}
{t('datasetSettings.form.retrievalSetting.learnMore')} {t('datasetSettings.form.retrievalSetting.longDescription')} @@ -861,7 +890,7 @@ const StepTwo = ({
-
{t('datasetCreation.stepTwo.emstimateSegment')}
+
{t('datasetCreation.stepTwo.estimateSegment')}
{ fileIndexingEstimate diff --git a/web/app/components/datasets/create/step-two/preview-item/index.tsx b/web/app/components/datasets/create/step-two/preview-item/index.tsx index fdec6c734d..56102b6540 100644 --- a/web/app/components/datasets/create/step-two/preview-item/index.tsx +++ b/web/app/components/datasets/create/step-two/preview-item/index.tsx @@ -41,14 +41,14 @@ const PreviewItem: FC = ({ const charNums = type === PreviewType.TEXT ? (content || '').length : (qa?.answer || '').length + (qa?.question || '').length - const formatedIndex = (() => String(index).padStart(3, '0'))() + const formattedIndex = (() => String(index).padStart(3, '0'))() return (
{sharpIcon} - {formatedIndex} + {formattedIndex}
{textIcon} diff --git a/web/app/components/datasets/create/steps-nav-bar/index.tsx b/web/app/components/datasets/create/steps-nav-bar/index.tsx index 70724a308c..b676f3ace4 100644 --- a/web/app/components/datasets/create/steps-nav-bar/index.tsx +++ b/web/app/components/datasets/create/steps-nav-bar/index.tsx @@ -49,7 +49,7 @@ const StepsNavBar = ({ key={item} className={cn(s.stepItem, s[`step${item}`], step === item && s.active, step > item && s.done, isMobile && 'px-0')} > -
{item}
+
{step > item ? '' : item}
{isMobile ? '' : t(STEP_T_MAP[item])}
))} diff --git a/web/app/components/datasets/create/website/firecrawl/base/field.tsx b/web/app/components/datasets/create/website/firecrawl/base/field.tsx index cac40798c1..5b5ca90c5d 100644 --- a/web/app/components/datasets/create/website/firecrawl/base/field.tsx +++ b/web/app/components/datasets/create/website/firecrawl/base/field.tsx @@ -38,7 +38,7 @@ const Field: FC = ({ popupContent={
{tooltip}
} - popupClassName='relative top-[3px] w-3 h-3 ml-1' + triggerClassName='ml-0.5 w-4 h-4' /> )}
diff --git a/web/app/components/datasets/create/website/firecrawl/base/input.tsx b/web/app/components/datasets/create/website/firecrawl/base/input.tsx index 06249f57e7..7d2d2b609f 100644 --- a/web/app/components/datasets/create/website/firecrawl/base/input.tsx +++ b/web/app/components/datasets/create/website/firecrawl/base/input.tsx @@ -9,7 +9,7 @@ type Props = { isNumber?: boolean } -const MIN_VALUE = 1 +const MIN_VALUE = 0 const Input: FC = ({ value, diff --git a/web/app/components/datasets/documents/detail/completed/index.tsx b/web/app/components/datasets/documents/detail/completed/index.tsx index f2addac2e2..0251dfa54e 100644 --- a/web/app/components/datasets/documents/detail/completed/index.tsx +++ b/web/app/components/datasets/documents/detail/completed/index.tsx @@ -24,7 +24,7 @@ import { ToastContext } from '@/app/components/base/toast' import type { Item } from '@/app/components/base/select' import { SimpleSelect } from '@/app/components/base/select' import { deleteSegment, disableSegment, enableSegment, fetchSegments, updateSegment } from '@/service/datasets' -import type { SegmentDetailModel, SegmentUpdator, SegmentsQuery, SegmentsResponse } from '@/models/datasets' +import type { SegmentDetailModel, SegmentUpdater, SegmentsQuery, SegmentsResponse } from '@/models/datasets' import { asyncRunSafe } from '@/utils' import type { CommonResponse } from '@/models/common' import AutoHeightTextarea from '@/app/components/base/auto-height-textarea/common' @@ -322,7 +322,7 @@ const Completed: FC = ({ } const handleUpdateSegment = async (segmentId: string, question: string, answer: string, keywords: string[]) => { - const params: SegmentUpdator = { content: '' } + const params: SegmentUpdater = { content: '' } if (docForm === 'qa_model') { if (!question.trim()) return notify({ type: 'error', message: t('datasetDocuments.segment.questionEmpty') }) @@ -391,7 +391,7 @@ const Completed: FC = ({ defaultValue={'all'} className={s.select} wrapperClassName='h-fit w-[120px] mr-2' /> - + setSearchValue(e.target.value), 500)} />
= (
} -const EmbeddingDetail: FC = ({ detail, stopPosition = 'top', datasetId: dstId, documentId: docId, indexingType, detailUpdate }) => { +const EmbeddingDetail: FC = ({ detail, stopPosition = 'top', datasetId: dstId, documentId: docId, detailUpdate }) => { const onTop = stopPosition === 'top' const { t } = useTranslation() const { notify } = useContext(ToastContext) const { datasetId = '', documentId = '' } = useContext(DocumentContext) - const { indexingTechnique } = useContext(DatasetDetailContext) const localDatasetId = dstId ?? datasetId const localDocumentId = docId ?? documentId - const localIndexingTechnique = indexingType ?? indexingTechnique const [indexingStatusDetail, setIndexingStatusDetail] = useState(null) const fetchIndexingStatus = async () => { @@ -160,14 +156,6 @@ const EmbeddingDetail: FC = ({ detail, stopPosition = 'top', datasetId: d } }, [startQueryStatus, stopQueryStatus]) - const { data: indexingEstimateDetail, error: indexingEstimateErr } = useSWR({ - action: 'fetchIndexingEstimate', - datasetId: localDatasetId, - documentId: localDocumentId, - }, apiParams => fetchIndexingEstimate(omit(apiParams, 'action')), { - revalidateOnFocus: false, - }) - const { data: ruleDetail, error: ruleError } = useSWR({ action: 'fetchProcessRule', params: { documentId: localDocumentId }, @@ -250,21 +238,6 @@ const EmbeddingDetail: FC = ({ detail, stopPosition = 'top', datasetId: d
{t('datasetDocuments.embedding.segments')} {indexingStatusDetail?.completed_segments}/{indexingStatusDetail?.total_segments} · {percent}%
- {localIndexingTechnique === 'high_quaility' && ( -
-
- {t('datasetDocuments.embedding.highQuality')} · {t('datasetDocuments.embedding.estimate')} - {formatNumber(indexingEstimateDetail?.tokens || 0)}tokens - (${formatNumber(indexingEstimateDetail?.total_price || 0)}) -
- )} - {localIndexingTechnique === 'economy' && ( -
-
- {t('datasetDocuments.embedding.economy')} · {t('datasetDocuments.embedding.estimate')} - 0tokens -
- )}
{!onTop && ( diff --git a/web/app/components/datasets/documents/detail/embedding/style.module.css b/web/app/components/datasets/documents/detail/embedding/style.module.css index 6dc1a5e80b..c24444ac12 100644 --- a/web/app/components/datasets/documents/detail/embedding/style.module.css +++ b/web/app/components/datasets/documents/detail/embedding/style.module.css @@ -31,7 +31,7 @@ @apply rounded-r-md; } .progressData { - @apply w-full flex justify-between items-center text-xs text-gray-700; + @apply w-full flex items-center text-xs text-gray-700; } .previewTip { @apply pb-1 pt-12 text-gray-900 text-sm font-medium; diff --git a/web/app/components/datasets/documents/detail/metadata/index.tsx b/web/app/components/datasets/documents/detail/metadata/index.tsx index 5a2da3efa9..2e0f8af961 100644 --- a/web/app/components/datasets/documents/detail/metadata/index.tsx +++ b/web/app/components/datasets/documents/detail/metadata/index.tsx @@ -79,7 +79,7 @@ export const FieldInfo: FC = ({ /> : onUpdate?.(e.target.value)} value={value} defaultValue={defaultValue} placeholder={`${t('datasetDocuments.metadata.placeholder.add')}${label}`} diff --git a/web/app/components/datasets/documents/detail/new-segment-modal.tsx b/web/app/components/datasets/documents/detail/new-segment-modal.tsx index 24e0ba3cdc..dae9cf19fb 100644 --- a/web/app/components/datasets/documents/detail/new-segment-modal.tsx +++ b/web/app/components/datasets/documents/detail/new-segment-modal.tsx @@ -9,7 +9,7 @@ import Button from '@/app/components/base/button' import AutoHeightTextarea from '@/app/components/base/auto-height-textarea/common' import { Hash02 } from '@/app/components/base/icons/src/vender/line/general' import { ToastContext } from '@/app/components/base/toast' -import type { SegmentUpdator } from '@/models/datasets' +import type { SegmentUpdater } from '@/models/datasets' import { addSegment } from '@/service/datasets' import TagInput from '@/app/components/base/tag-input' @@ -42,7 +42,7 @@ const NewSegmentModal: FC = ({ } const handleSave = async () => { - const params: SegmentUpdator = { content: '' } + const params: SegmentUpdater = { content: '' } if (docForm === 'qa_model') { if (!question.trim()) return notify({ type: 'error', message: t('datasetDocuments.segment.questionEmpty') }) diff --git a/web/app/components/datasets/documents/index.tsx b/web/app/components/datasets/documents/index.tsx index ce81e44792..81b85c8220 100644 --- a/web/app/components/datasets/documents/index.tsx +++ b/web/app/components/datasets/documents/index.tsx @@ -4,8 +4,9 @@ import React, { useMemo, useState } from 'react' import useSWR from 'swr' import { useTranslation } from 'react-i18next' import { useRouter } from 'next/navigation' -import { debounce, groupBy, omit } from 'lodash-es' +import { groupBy, omit } from 'lodash-es' import { PlusIcon } from '@heroicons/react/24/solid' +import { useDebounce } from 'ahooks' import List from './list' import s from './style.module.css' import Loading from '@/app/components/base/loading' @@ -87,9 +88,11 @@ const Documents: FC = ({ datasetId }) => { const isDataSourceFile = dataset?.data_source_type === DataSourceType.FILE const embeddingAvailable = !!dataset?.embedding_available + const debouncedSearchValue = useDebounce(searchValue, { wait: 500 }) + const query = useMemo(() => { - return { page: currPage + 1, limit, keyword: searchValue, fetch: isDataSourceNotion ? true : '' } - }, [searchValue, currPage, isDataSourceNotion]) + return { page: currPage + 1, limit, keyword: debouncedSearchValue, fetch: isDataSourceNotion ? true : '' } + }, [currPage, debouncedSearchValue, isDataSourceNotion]) const { data: documentsRes, error, mutate } = useSWR( { @@ -106,15 +109,15 @@ const Documents: FC = ({ datasetId }) => { let percent = 0 const documentsData = documentsRes?.data?.map((documentItem) => { const { indexing_status, completed_segments, total_segments } = documentItem - const isEmbeddinged = indexing_status === 'completed' || indexing_status === 'paused' || indexing_status === 'error' + const isEmbedded = indexing_status === 'completed' || indexing_status === 'paused' || indexing_status === 'error' - if (isEmbeddinged) + if (isEmbedded) completedNum++ const completedCount = completed_segments || 0 const totalCount = total_segments || 0 if (totalCount === 0 && completedCount === 0) { - percent = isEmbeddinged ? 100 : 0 + percent = isEmbedded ? 100 : 0 } else { const per = Math.round(completedCount * 100 / totalCount) @@ -201,10 +204,10 @@ const Documents: FC = ({ datasetId }) => {
setSearchValue(e.target.value)} value={searchValue} />
diff --git a/web/app/components/datasets/hit-testing/modify-retrieval-modal.tsx b/web/app/components/datasets/hit-testing/modify-retrieval-modal.tsx index be5c1be2e7..999f1cdf0d 100644 --- a/web/app/components/datasets/hit-testing/modify-retrieval-modal.tsx +++ b/web/app/components/datasets/hit-testing/modify-retrieval-modal.tsx @@ -39,14 +39,14 @@ const ModifyRetrievalModal: FC = ({ const { modelList: rerankModelList, defaultModel: rerankDefaultModel, - currentModel: isRerankDefaultModelVaild, + currentModel: isRerankDefaultModelValid, } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) const handleSave = () => { if ( !isReRankModelSelected({ rerankDefaultModel, - isRerankDefaultModelVaild: !!isRerankDefaultModelVaild, + isRerankDefaultModelValid: !!isRerankDefaultModelValid, rerankModelList, retrievalConfig, indexMethod, diff --git a/web/app/components/datasets/settings/form/index.tsx b/web/app/components/datasets/settings/form/index.tsx index 404a8ed6a0..0f6bdd0a59 100644 --- a/web/app/components/datasets/settings/form/index.tsx +++ b/web/app/components/datasets/settings/form/index.tsx @@ -73,7 +73,7 @@ const Form = () => { const { modelList: rerankModelList, defaultModel: rerankDefaultModel, - currentModel: isRerankDefaultModelVaild, + currentModel: isRerankDefaultModelValid, } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding) @@ -99,7 +99,7 @@ const Form = () => { if ( !isReRankModelSelected({ rerankDefaultModel, - isRerankDefaultModelVaild: !!isRerankDefaultModelVaild, + isRerankDefaultModelValid: !!isRerankDefaultModelValid, rerankModelList, retrievalConfig, indexMethod, diff --git a/web/app/components/develop/secret-key/secret-key-modal.tsx b/web/app/components/develop/secret-key/secret-key-modal.tsx index fd28a67e7e..dbb5cc37c7 100644 --- a/web/app/components/develop/secret-key/secret-key-modal.tsx +++ b/web/app/components/develop/secret-key/secret-key-modal.tsx @@ -41,7 +41,7 @@ const SecretKeyModal = ({ }: ISecretKeyModalProps) => { const { t } = useTranslation() const { formatTime } = useTimestamp() - const { currentWorkspace, isCurrentWorkspaceManager } = useAppContext() + const { currentWorkspace, isCurrentWorkspaceManager, isCurrentWorkspaceEditor } = useAppContext() const [showConfirmDelete, setShowConfirmDelete] = useState(false) const [isVisible, setVisible] = useState(false) const [newKey, setNewKey] = useState(undefined) @@ -142,7 +142,7 @@ const SecretKeyModal = ({ ) }
- diff --git a/web/app/components/develop/template/template_advanced_chat.en.mdx b/web/app/components/develop/template/template_advanced_chat.en.mdx index 6487ac79d7..655b6efce1 100644 --- a/web/app/components/develop/template/template_advanced_chat.en.mdx +++ b/web/app/components/develop/template/template_advanced_chat.en.mdx @@ -3,7 +3,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from # Advanced Chat App API -Chat applications support session persistence, allowing previous chat history to be used as context for responses. This can be applicable for chatbots, customer service AI, etc. +Chat applications support session persistence, allowing previous chat history to be used as context for responses. This can be applicable for chatbot, customer service AI, etc.
### Base URL @@ -60,7 +60,7 @@ Chat applications support session persistence, allowing previous chat history to Should be uniquely defined by the developer within the application. - Converation ID, to continue the conversation based on previous chat records, it is necessary to pass the previous message's conversation_id. + Conversation ID, to continue the conversation based on previous chat records, it is necessary to pass the previous message's conversation_id. File list, suitable for inputting files (images) combined with text understanding and answering questions, available only when the model supports Vision capability. @@ -239,7 +239,7 @@ Chat applications support session persistence, allowing previous chat history to "message_id": "9da23599-e713-473b-982c-4328d4f5c78a", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "mode": "chat", - "answer": "iPhone 13 Pro Max specs are listed heere:...", + "answer": "iPhone 13 Pro Max specs are listed here:...", "metadata": { "usage": { "prompt_tokens": 1033, @@ -732,7 +732,7 @@ Chat applications support session persistence, allowing previous chat history to ```bash {{ title: 'cURL' }} - curl -X DELETE '${props.appDetail.api_base_url}/conversations/{convsation_id}' \ + curl -X DELETE '${props.appDetail.api_base_url}/conversations/{conversation_id}' \ --header 'Content-Type: application/json' \ --header 'Accept: application/json' \ --header 'Authorization: Bearer {api_key}' \ diff --git a/web/app/components/develop/template/template_advanced_chat.zh.mdx b/web/app/components/develop/template/template_advanced_chat.zh.mdx index 33551509e5..2aa42fbb19 100755 --- a/web/app/components/develop/template/template_advanced_chat.zh.mdx +++ b/web/app/components/develop/template/template_advanced_chat.zh.mdx @@ -250,7 +250,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' "message_id": "9da23599-e713-473b-982c-4328d4f5c78a", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "mode": "chat", - "answer": "iPhone 13 Pro Max specs are listed heere:...", + "answer": "iPhone 13 Pro Max specs are listed here:...", "metadata": { "usage": { "prompt_tokens": 1033, @@ -767,7 +767,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' ```bash {{ title: 'cURL' }} - curl -X DELETE '${props.appDetail.api_base_url}/conversations/{convsation_id}' \ + curl -X DELETE '${props.appDetail.api_base_url}/conversations/{conversation_id}' \ --header 'Content-Type: application/json' \ --header 'Accept: application/json' \ --header 'Authorization: Bearer {api_key}' \ diff --git a/web/app/components/develop/template/template_chat.en.mdx b/web/app/components/develop/template/template_chat.en.mdx index 07840640f4..d6dfbaaaf9 100644 --- a/web/app/components/develop/template/template_chat.en.mdx +++ b/web/app/components/develop/template/template_chat.en.mdx @@ -3,7 +3,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from # Chat App API -Chat applications support session persistence, allowing previous chat history to be used as context for responses. This can be applicable for chatbots, customer service AI, etc. +Chat applications support session persistence, allowing previous chat history to be used as context for responses. This can be applicable for chatbot, customer service AI, etc.
### Base URL @@ -61,7 +61,7 @@ Chat applications support session persistence, allowing previous chat history to Should be uniquely defined by the developer within the application. - Converation ID, to continue the conversation based on previous chat records, it is necessary to pass the previous message's conversation_id. + Conversation ID, to continue the conversation based on previous chat records, it is necessary to pass the previous message's conversation_id. File list, suitable for inputting files (images) combined with text understanding and answering questions, available only when the model supports Vision capability. @@ -200,7 +200,7 @@ Chat applications support session persistence, allowing previous chat history to "message_id": "9da23599-e713-473b-982c-4328d4f5c78a", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "mode": "chat", - "answer": "iPhone 13 Pro Max specs are listed heere:...", + "answer": "iPhone 13 Pro Max specs are listed here:...", "metadata": { "usage": { "prompt_tokens": 1033, @@ -534,7 +534,7 @@ Chat applications support session persistence, allowing previous chat history to - `id` (string) ID - `type` (string) File type, image for images - `url` (string) Preview image URL - - `belongs_to` (string) belongs to,user orassistant + - `belongs_to` (string) belongs to,user or assistant - `agent_thoughts` (array[object]) Agent thought(Empty if it's a Basic Assistant) - `id` (string) Agent thought ID, every iteration has a unique agent thought ID - `message_id` (string) Unique message ID @@ -772,7 +772,7 @@ Chat applications support session persistence, allowing previous chat history to ```bash {{ title: 'cURL' }} - curl -X DELETE '${props.appDetail.api_base_url}/conversations/{convsation_id}' \ + curl -X DELETE '${props.appDetail.api_base_url}/conversations/{conversation_id}' \ --header 'Content-Type: application/json' \ --header 'Accept: application/json' \ --header 'Authorization: Bearer {api_key}' \ diff --git a/web/app/components/develop/template/template_chat.zh.mdx b/web/app/components/develop/template/template_chat.zh.mdx index 727d884c1a..a91da81a1c 100644 --- a/web/app/components/develop/template/template_chat.zh.mdx +++ b/web/app/components/develop/template/template_chat.zh.mdx @@ -213,7 +213,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' "message_id": "9da23599-e713-473b-982c-4328d4f5c78a", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "mode": "chat", - "answer": "iPhone 13 Pro Max specs are listed heere:...", + "answer": "iPhone 13 Pro Max specs are listed here:...", "metadata": { "usage": { "prompt_tokens": 1033, @@ -786,7 +786,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' ```bash {{ title: 'cURL' }} - curl -X DELETE '${props.appDetail.api_base_url}/conversations/{convsation_id}' \ + curl -X DELETE '${props.appDetail.api_base_url}/conversations/{conversation_id}' \ --header 'Content-Type: application/json' \ --header 'Accept: application/json' \ --header 'Authorization: Bearer {api_key}' \ diff --git a/web/app/components/explore/app-card/index.tsx b/web/app/components/explore/app-card/index.tsx index 3d666fdb1a..b1ea4a95bf 100644 --- a/web/app/components/explore/app-card/index.tsx +++ b/web/app/components/explore/app-card/index.tsx @@ -5,7 +5,7 @@ import Button from '../../base/button' import cn from '@/utils/classnames' import type { App } from '@/models/explore' import AppIcon from '@/app/components/base/app-icon' -import { AiText, ChatBot, CuteRobote } from '@/app/components/base/icons/src/vender/solid/communication' +import { AiText, ChatBot, CuteRobot } from '@/app/components/base/icons/src/vender/solid/communication' import { Route } from '@/app/components/base/icons/src/vender/solid/mapsAndTravel' export type AppCardProps = { app: App @@ -23,7 +23,7 @@ const AppCard = ({ const { t } = useTranslation() const { app: appBasicInfo } = app return ( -
+
)} {appBasicInfo.mode === 'agent-chat' && ( - + )} {appBasicInfo.mode === 'chat' && ( @@ -64,9 +64,13 @@ const AppCard = ({
-
{app.description}
+
+
+ {app.description} +
+
{isExplore && canCreate && ( -
+
)} {!isExplore && ( -
+
+ {/* answer icon */} + {isEditModal && (appMode === 'chat' || appMode === 'advanced-chat' || appMode === 'agent-chat') && ( +
+
+
{t('app.answerIcon.title')}
+ setUseIconAsAnswerIcon(v)} + /> +
+

{t('app.answerIcon.descriptionInExplore')}

+
+ )} {!isEditModal && isAppsFull && }
diff --git a/web/app/components/header/account-dropdown/index.tsx b/web/app/components/header/account-dropdown/index.tsx index 2298bff82d..03157ed7cb 100644 --- a/web/app/components/header/account-dropdown/index.tsx +++ b/web/app/components/header/account-dropdown/index.tsx @@ -22,11 +22,11 @@ import { LanguagesSupported } from '@/i18n/language' import { useProviderContext } from '@/context/provider-context' import { Plan } from '@/app/components/billing/type' -export type IAppSelecotr = { +export type IAppSelector = { isMobile: boolean } -export default function AppSelector({ isMobile }: IAppSelecotr) { +export default function AppSelector({ isMobile }: IAppSelector) { const itemClassName = ` flex items-center w-full h-9 px-3 text-gray-700 text-[14px] rounded-lg font-normal hover:bg-gray-50 cursor-pointer @@ -125,7 +125,7 @@ export default function AppSelector({ isMobile }: IAppSelecotr) { className={classNames(itemClassName, 'group justify-between')} href='https://github.com/langgenius/dify/discussions/categories/feedbacks' target='_blank' rel='noopener noreferrer'> -
{t('common.userProfile.roadmapAndFeedback')}
+
{t('common.userProfile.communityFeedback')}
@@ -149,6 +149,15 @@ export default function AppSelector({ isMobile }: IAppSelecotr) { + + +
{t('common.userProfile.roadmap')}
+ + +
{ document?.body?.getAttribute('data-public-site-about') !== 'hide' && ( diff --git a/web/app/components/header/account-setting/account-page/index.tsx b/web/app/components/header/account-setting/account-page/index.tsx index ce7f7e7e22..eecd275b35 100644 --- a/web/app/components/header/account-setting/account-page/index.tsx +++ b/web/app/components/header/account-setting/account-page/index.tsx @@ -90,7 +90,7 @@ export default function AccountPage() { setPassword('') setConfirmPassword('') } - const handleSavePassowrd = async () => { + const handleSavePassword = async () => { if (!valid()) return try { @@ -235,7 +235,7 @@ export default function AccountPage() { diff --git a/web/app/components/header/account-setting/members-page/index.tsx b/web/app/components/header/account-setting/members-page/index.tsx index 711e772684..e09e4bbc0d 100644 --- a/web/app/components/header/account-setting/members-page/index.tsx +++ b/web/app/components/header/account-setting/members-page/index.tsx @@ -15,7 +15,7 @@ import I18n from '@/context/i18n' import { useAppContext } from '@/context/app-context' import Avatar from '@/app/components/base/avatar' import type { InvitationResult } from '@/models/common' -import LogoEmbededChatHeader from '@/app/components/base/logo/logo-embeded-chat-header' +import LogoEmbeddedChatHeader from '@/app/components/base/logo/logo-embedded-chat-header' import { useProviderContext } from '@/context/provider-context' import { Plan } from '@/app/components/billing/type' import UpgradeBtn from '@/app/components/billing/upgrade-btn' @@ -49,7 +49,7 @@ const MembersPage = () => { <>
- +
{currentWorkspace?.name}
{enableBilling && ( diff --git a/web/app/components/header/account-setting/members-page/invited-modal/index.tsx b/web/app/components/header/account-setting/members-page/invited-modal/index.tsx index 7af19b06c3..fc64d46b06 100644 --- a/web/app/components/header/account-setting/members-page/invited-modal/index.tsx +++ b/web/app/components/header/account-setting/members-page/invited-modal/index.tsx @@ -11,8 +11,8 @@ import { IS_CE_EDITION } from '@/config' import type { InvitationResult } from '@/models/common' import Tooltip from '@/app/components/base/tooltip' -export type SuccessInvationResult = Extract -export type FailedInvationResult = Extract +export type SuccessInvitationResult = Extract +export type FailedInvitationResult = Extract type IInvitedModalProps = { invitationResults: InvitationResult[] @@ -24,8 +24,8 @@ const InvitedModal = ({ }: IInvitedModalProps) => { const { t } = useTranslation() - const successInvationResults = useMemo(() => invitationResults?.filter(item => item.status === 'success') as SuccessInvationResult[], [invitationResults]) - const failedInvationResults = useMemo(() => invitationResults?.filter(item => item.status !== 'success') as FailedInvationResult[], [invitationResults]) + const successInvitationResults = useMemo(() => invitationResults?.filter(item => item.status === 'success') as SuccessInvitationResult[], [invitationResults]) + const failedInvitationResults = useMemo(() => invitationResults?.filter(item => item.status !== 'success') as FailedInvitationResult[], [invitationResults]) return (
@@ -49,20 +49,20 @@ const InvitedModal = ({
{t('common.members.invitationSentTip')}
{ - !!successInvationResults.length + !!successInvitationResults.length && <>
{t('common.members.invitationLink')}
- {successInvationResults.map(item => + {successInvitationResults.map(item => )} } { - !!failedInvationResults.length + !!failedInvitationResults.length && <> -
{t('common.members.failedinvitationEmails')}
+
{t('common.members.failedInvitationEmails')}
{ - failedInvationResults.map(item => + failedInvitationResults.map(item =>
{ const { modelProviders: providers } = useProviderContext() const setShowModelModal = useModalContextSelector(state => state.setShowModelModal) const defaultModelNotConfigured = !textGenerationDefaultModel && !embeddingsDefaultModel && !speech2textDefaultModel && !rerankDefaultModel && !ttsDefaultModel - const [configedProviders, notConfigedProviders] = useMemo(() => { - const configedProviders: ModelProvider[] = [] - const notConfigedProviders: ModelProvider[] = [] + const [configuredProviders, notConfiguredProviders] = useMemo(() => { + const configuredProviders: ModelProvider[] = [] + const notConfiguredProviders: ModelProvider[] = [] providers.forEach((provider) => { if ( @@ -47,12 +47,12 @@ const ModelProviderPage = () => { && provider.system_configuration.quota_configurations.find(item => item.quota_type === provider.system_configuration.current_quota_type) ) ) - configedProviders.push(provider) + configuredProviders.push(provider) else - notConfigedProviders.push(provider) + notConfiguredProviders.push(provider) }) - return [configedProviders, notConfigedProviders] + return [configuredProviders, notConfiguredProviders] }, [providers]) const handleOpenModal = ( @@ -110,10 +110,10 @@ const ModelProviderPage = () => { />
{ - !!configedProviders?.length && ( + !!configuredProviders?.length && (
{ - configedProviders?.map(provider => ( + configuredProviders?.map(provider => ( { ) } { - !!notConfigedProviders?.length && ( + !!notConfiguredProviders?.length && ( <>
+ {t('common.modelProvider.addMoreModelProvider')} @@ -133,7 +133,7 @@ const ModelProviderPage = () => {
{ - notConfigedProviders?.map(provider => ( + notConfiguredProviders?.map(provider => ( = ({ onChange({ ...value, [key]: val, ...shouldClearVariable }) } - // convert tooltip '\n' to
- const renderTooltipContent = (content: string) => { - return content.split('\n').map((line, index, array) => ( - - {line} - {index < array.length - 1 &&
} -
- )) - } - const renderField = (formSchema: CredentialFormSchema) => { const tooltip = formSchema.tooltip const tooltipContent = (tooltip && ( - - - {renderTooltipContent(tooltip[language] || tooltip.en_US)} -
- } > - - - )) + + {tooltip[language] || tooltip.en_US} +
} + triggerClassName='ml-1 w-4 h-4' + asChild={false} + /> + )) if (formSchema.type === FormTypeEnum.textInput || formSchema.type === FormTypeEnum.secretInput || formSchema.type === FormTypeEnum.textNumber) { const { variable, @@ -103,10 +91,10 @@ const Form: FC = ({ if (show_on.length && !show_on.every(showOnItem => value[showOnItem.variable] === showOnItem.value)) return null - const disabed = readonly || (isEditMode && (variable === '__model_type' || variable === '__model_name')) + const disabled = readonly || (isEditMode && (variable === '__model_type' || variable === '__model_name')) return (
-
+
{label[language] || label.en_US} { required && ( @@ -116,12 +104,12 @@ const Form: FC = ({ {tooltipContent}
handleFormChange(variable, val)} validated={validatedSuccess} placeholder={placeholder?.[language] || placeholder?.en_US} - disabled={disabed} + disabled={disabled} type={formSchema.type === FormTypeEnum.textNumber ? 'number' : 'text'} {...(formSchema.type === FormTypeEnum.textNumber ? { min: (formSchema as CredentialFormSchemaNumberInput).min, max: (formSchema as CredentialFormSchemaNumberInput).max } : {})} /> @@ -143,11 +131,11 @@ const Form: FC = ({ if (show_on.length && !show_on.every(showOnItem => value[showOnItem.variable] === showOnItem.value)) return null - const disabed = isEditMode && (variable === '__model_type' || variable === '__model_name') + const disabled = isEditMode && (variable === '__model_type' || variable === '__model_name') return (
-
+
{label[language] || label.en_US} { required && ( @@ -168,7 +156,7 @@ const Form: FC = ({ className={` flex items-center px-3 py-2 rounded-lg border border-gray-100 bg-gray-25 cursor-pointer ${value[variable] === option.value && 'bg-white border-[1.5px] border-primary-400 shadow-sm'} - ${disabed && '!cursor-not-allowed opacity-60'} + ${disabled && '!cursor-not-allowed opacity-60'} `} onClick={() => handleFormChange(variable, option.value)} key={`${variable}-${option.value}`} @@ -203,7 +191,7 @@ const Form: FC = ({ return (
-
+
{label[language] || label.en_US} { @@ -247,7 +235,7 @@ const Form: FC = ({
- {label[language] || label.en_US} + {label[language] || label.en_US} { required && ( * diff --git a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx index e60ef418ed..376a08c120 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx @@ -127,12 +127,10 @@ const ParameterItem: FC = ({ && !isNullOrUndefined(parameterRule.min) && !isNullOrUndefined(parameterRule.max) - if (parameterRule.type === 'int' || parameterRule.type === 'float') { + if (parameterRule.type === 'int') { let step = 100 if (parameterRule.max) { - if (parameterRule.max < 10) - step = 0.1 - else if (parameterRule.max < 100) + if (parameterRule.max < 100) step = 1 else if (parameterRule.max < 1000) step = 10 @@ -164,6 +162,31 @@ const ParameterItem: FC = ({ ) } + if (parameterRule.type === 'float') { + return ( + <> + {numberInputWithSlide && } + + + ) + } + if (parameterRule.type === 'boolean') { return ( = ({ const customConfig = provider.custom_configuration const systemConfig = provider.system_configuration const priorityUseType = provider.preferred_provider_type - const customConfiged = customConfig.status === CustomConfigurationStatusEnum.active + const isCustomConfigured = customConfig.status === CustomConfigurationStatusEnum.active const configurateMethods = provider.configurate_methods const handleChangePriority = async (key: PreferredProviderTypeEnum) => { @@ -69,7 +69,7 @@ const CredentialPanel: FC = ({
API-KEY - +
{ - systemConfig.enabled && customConfiged && ( + systemConfig.enabled && isCustomConfigured && ( = ({ ) } { - systemConfig.enabled && customConfiged && !provider.provider_credential_schema && ( + systemConfig.enabled && isCustomConfigured && !provider.provider_credential_schema && (
= ({ {t('common.modelProvider.systemReasoningModel.tip')}
} - triggerClassName='ml-0.5' + triggerClassName='ml-0.5 w-4 h-4 shrink-0' />
@@ -168,8 +168,7 @@ const SystemModel: FC = ({ {t('common.modelProvider.embeddingModel.tip')}
} - needsDelay={false} - triggerClassName='ml-0.5' + triggerClassName='ml-0.5 w-4 h-4 shrink-0' />
@@ -189,8 +188,7 @@ const SystemModel: FC = ({ {t('common.modelProvider.rerankModel.tip')}
} - needsDelay={false} - triggerClassName='ml-0.5' + triggerClassName='ml-0.5 w-4 h-4 shrink-0' />
@@ -210,8 +208,7 @@ const SystemModel: FC = ({ {t('common.modelProvider.speechToTextModel.tip')}
} - needsDelay={false} - triggerClassName='ml-0.5' + triggerClassName='ml-0.5 w-4 h-4 shrink-0' />
@@ -231,7 +228,7 @@ const SystemModel: FC = ({ {t('common.modelProvider.ttsModel.tip')}
} - triggerClassName='ml-0.5' + triggerClassName='ml-0.5 w-4 h-4 shrink-0' />
diff --git a/web/app/components/header/explore-nav/index.tsx b/web/app/components/header/explore-nav/index.tsx index 032b886613..0046fc293d 100644 --- a/web/app/components/header/explore-nav/index.tsx +++ b/web/app/components/header/explore-nav/index.tsx @@ -17,16 +17,16 @@ const ExploreNav = ({ }: ExploreNavProps) => { const { t } = useTranslation() const selectedSegment = useSelectedLayoutSegment() - const actived = selectedSegment === 'explore' + const activated = selectedSegment === 'explore' return ( { - actived + activated ? : } diff --git a/web/app/components/header/nav/index.tsx b/web/app/components/header/nav/index.tsx index 1ea36c5123..bfb4324320 100644 --- a/web/app/components/header/nav/index.tsx +++ b/web/app/components/header/nav/index.tsx @@ -34,21 +34,21 @@ const Nav = ({ const setAppDetail = useAppStore(state => state.setAppDetail) const [hovered, setHovered] = useState(false) const segment = useSelectedLayoutSegment() - const isActived = Array.isArray(activeSegment) ? activeSegment.includes(segment!) : segment === activeSegment + const isActivated = Array.isArray(activeSegment) ? activeSegment.includes(segment!) : segment === activeSegment return (
setAppDetail()} className={classNames(` flex items-center h-7 px-2.5 cursor-pointer rounded-[10px] - ${isActived ? 'text-components-main-nav-nav-button-text-active' : 'text-components-main-nav-nav-button-text'} - ${curNav && isActived && 'hover:bg-components-main-nav-nav-button-bg-active-hover'} + ${isActivated ? 'text-components-main-nav-nav-button-text-active' : 'text-components-main-nav-nav-button-text'} + ${curNav && isActivated && 'hover:bg-components-main-nav-nav-button-bg-active-hover'} `)} onMouseEnter={() => setHovered(true)} onMouseLeave={() => setHovered(false)} @@ -57,7 +57,7 @@ const Nav = ({ { (hovered && curNav) ? - : isActived + : isActivated ? activeIcon : icon } @@ -66,7 +66,7 @@ const Nav = ({
{ - curNav && isActived && ( + curNav && isActivated && ( <>
/
)} {nav.mode === 'agent-chat' && ( - + )} {nav.mode === 'chat' && ( diff --git a/web/app/components/header/tools-nav/index.tsx b/web/app/components/header/tools-nav/index.tsx index 5184f5e5ce..096a552229 100644 --- a/web/app/components/header/tools-nav/index.tsx +++ b/web/app/components/header/tools-nav/index.tsx @@ -17,16 +17,16 @@ const ToolsNav = ({ }: ToolsNavProps) => { const { t } = useTranslation() const selectedSegment = useSelectedLayoutSegment() - const actived = selectedSegment === 'tools' + const activated = selectedSegment === 'tools' return ( { - actived + activated ? : } diff --git a/web/app/components/share/text-generation/index.tsx b/web/app/components/share/text-generation/index.tsx index dd6efa86fa..a2f6864242 100644 --- a/web/app/components/share/text-generation/index.tsx +++ b/web/app/components/share/text-generation/index.tsx @@ -163,8 +163,8 @@ const TextGeneration: FC = ({ } const allSuccessTaskList = allTaskList.filter(task => task.status === TaskStatus.completed) const allFailedTaskList = allTaskList.filter(task => task.status === TaskStatus.failed) - const allTaskFinished = allTaskList.every(task => task.status === TaskStatus.completed) - const allTaskRuned = allTaskList.every(task => [TaskStatus.completed, TaskStatus.failed].includes(task.status)) + const allTasksFinished = allTaskList.every(task => task.status === TaskStatus.completed) + const allTasksRun = allTaskList.every(task => [TaskStatus.completed, TaskStatus.failed].includes(task.status)) const [batchCompletionRes, doSetBatchCompletionRes] = useState>({}) const batchCompletionResRef = useRef>({}) const setBatchCompletionRes = (res: Record) => { @@ -286,7 +286,7 @@ const TextGeneration: FC = ({ const handleRunBatch = (data: string[][]) => { if (!checkBatchInputs(data)) return - if (!allTaskFinished) { + if (!allTasksFinished) { notify({ type: 'info', message: t('appDebug.errorMessage.waitForBatchResponse') }) return } @@ -318,17 +318,17 @@ const TextGeneration: FC = ({ showResSidebar() } const handleCompleted = (completionRes: string, taskId?: number, isSuccess?: boolean) => { - const allTasklistLatest = getLatestTaskList() + const allTaskListLatest = getLatestTaskList() const batchCompletionResLatest = getBatchCompletionRes() - const pendingTaskList = allTasklistLatest.filter(task => task.status === TaskStatus.pending) - const hadRunedTaskNum = 1 + allTasklistLatest.filter(task => [TaskStatus.completed, TaskStatus.failed].includes(task.status)).length - const needToAddNextGroupTask = (getCurrGroupNum() !== hadRunedTaskNum) && pendingTaskList.length > 0 && (hadRunedTaskNum % GROUP_SIZE === 0 || (allTasklistLatest.length - hadRunedTaskNum < GROUP_SIZE)) + const pendingTaskList = allTaskListLatest.filter(task => task.status === TaskStatus.pending) + const runTasksCount = 1 + allTaskListLatest.filter(task => [TaskStatus.completed, TaskStatus.failed].includes(task.status)).length + const needToAddNextGroupTask = (getCurrGroupNum() !== runTasksCount) && pendingTaskList.length > 0 && (runTasksCount % GROUP_SIZE === 0 || (allTaskListLatest.length - runTasksCount < GROUP_SIZE)) // avoid add many task at the same time if (needToAddNextGroupTask) - setCurrGroupNum(hadRunedTaskNum) + setCurrGroupNum(runTasksCount) const nextPendingTaskIds = needToAddNextGroupTask ? pendingTaskList.slice(0, GROUP_SIZE).map(item => item.id) : [] - const newAllTaskList = allTasklistLatest.map((item) => { + const newAllTaskList = allTaskListLatest.map((item) => { if (item.id === taskId) { return { ...item, @@ -393,7 +393,7 @@ const TextGeneration: FC = ({ }) const prompt_variables = userInputsFormToPromptVariables(user_input_form) setPromptConfig({ - prompt_template: '', // placeholder for feture + prompt_template: '', // placeholder for future prompt_variables, } as PromptConfig) setMoreLikeThisConfig(more_like_this) @@ -587,7 +587,7 @@ const TextGeneration: FC = ({ isRight: true, extra: savedMessages.length > 0 ? ( -
+
{savedMessages.length}
) @@ -614,7 +614,7 @@ const TextGeneration: FC = ({
diff --git a/web/app/components/share/text-generation/result/content.tsx b/web/app/components/share/text-generation/result/content.tsx index 17cce0fae5..4e39db42c8 100644 --- a/web/app/components/share/text-generation/result/content.tsx +++ b/web/app/components/share/text-generation/result/content.tsx @@ -1,14 +1,14 @@ import type { FC } from 'react' import React from 'react' import Header from './header' -import type { Feedbacktype } from '@/app/components/base/chat/chat/type' +import type { FeedbackType } from '@/app/components/base/chat/chat/type' import { format } from '@/service/base' export type IResultProps = { content: string showFeedback: boolean - feedback: Feedbacktype - onFeedback: (feedback: Feedbacktype) => void + feedback: FeedbackType + onFeedback: (feedback: FeedbackType) => void } const Result: FC = ({ content, diff --git a/web/app/components/share/text-generation/result/header.tsx b/web/app/components/share/text-generation/result/header.tsx index bd5c317153..0233b098d0 100644 --- a/web/app/components/share/text-generation/result/header.tsx +++ b/web/app/components/share/text-generation/result/header.tsx @@ -4,7 +4,7 @@ import React from 'react' import { useTranslation } from 'react-i18next' import { ClipboardDocumentIcon, HandThumbDownIcon, HandThumbUpIcon } from '@heroicons/react/24/outline' import copy from 'copy-to-clipboard' -import type { Feedbacktype } from '@/app/components/base/chat/chat/type' +import type { FeedbackType } from '@/app/components/base/chat/chat/type' import Button from '@/app/components/base/button' import Toast from '@/app/components/base/toast' import Tooltip from '@/app/components/base/tooltip' @@ -12,8 +12,8 @@ import Tooltip from '@/app/components/base/tooltip' type IResultHeaderProps = { result: string showFeedback: boolean - feedback: Feedbacktype - onFeedback: (feedback: Feedbacktype) => void + feedback: FeedbackType + onFeedback: (feedback: FeedbackType) => void } const Header: FC = ({ diff --git a/web/app/components/share/text-generation/result/index.tsx b/web/app/components/share/text-generation/result/index.tsx index caa2f9183e..2d5546f9b4 100644 --- a/web/app/components/share/text-generation/result/index.tsx +++ b/web/app/components/share/text-generation/result/index.tsx @@ -9,7 +9,7 @@ import TextGenerationRes from '@/app/components/app/text-generate/item' import NoData from '@/app/components/share/text-generation/no-data' import Toast from '@/app/components/base/toast' import { sendCompletionMessage, sendWorkflowMessage, updateFeedback } from '@/service/share' -import type { Feedbacktype } from '@/app/components/base/chat/chat/type' +import type { FeedbackType } from '@/app/components/base/chat/chat/type' import Loading from '@/app/components/base/loading' import type { PromptConfig } from '@/models/debug' import type { InstalledApp } from '@/models/explore' @@ -83,23 +83,23 @@ const Result: FC = ({ doSetCompletionRes(res) } const getCompletionRes = () => completionResRef.current - const [workflowProcessData, doSetWorkflowProccessData] = useState() + const [workflowProcessData, doSetWorkflowProcessData] = useState() const workflowProcessDataRef = useRef() - const setWorkflowProccessData = (data: WorkflowProcess) => { + const setWorkflowProcessData = (data: WorkflowProcess) => { workflowProcessDataRef.current = data - doSetWorkflowProccessData(data) + doSetWorkflowProcessData(data) } - const getWorkflowProccessData = () => workflowProcessDataRef.current + const getWorkflowProcessData = () => workflowProcessDataRef.current const { notify } = Toast const isNoData = !completionRes const [messageId, setMessageId] = useState(null) - const [feedback, setFeedback] = useState({ + const [feedback, setFeedback] = useState({ rating: null, }) - const handleFeedback = async (feedback: Feedbacktype) => { + const handleFeedback = async (feedback: FeedbackType) => { await updateFeedback({ url: `/messages/${messageId}/feedbacks`, body: { rating: feedback.rating } }, isInstalledApp, installedAppInfo?.id) setFeedback(feedback) } @@ -196,14 +196,12 @@ const Result: FC = ({ })() if (isWorkflow) { - let isInIteration = false - sendWorkflowMessage( data, { onWorkflowStarted: ({ workflow_run_id }) => { tempMessageId = workflow_run_id - setWorkflowProccessData({ + setWorkflowProcessData({ status: WorkflowRunningStatus.Running, tracing: [], expand: false, @@ -211,7 +209,7 @@ const Result: FC = ({ }) }, onIterationStart: ({ data }) => { - setWorkflowProccessData(produce(getWorkflowProccessData()!, (draft) => { + setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => { draft.expand = true draft.tracing!.push({ ...data, @@ -219,26 +217,31 @@ const Result: FC = ({ expand: true, } as any) })) - isInIteration = true }, onIterationNext: () => { + setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => { + draft.expand = true + const iterations = draft.tracing.find(item => item.node_id === data.node_id + && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))! + iterations?.details!.push([]) + })) }, onIterationFinish: ({ data }) => { - setWorkflowProccessData(produce(getWorkflowProccessData()!, (draft) => { + setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => { draft.expand = true - // const iteration = draft.tracing![draft.tracing!.length - 1] - draft.tracing![draft.tracing!.length - 1] = { + const iterationsIndex = draft.tracing.findIndex(item => item.node_id === data.node_id + && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))! + draft.tracing[iterationsIndex] = { ...data, expand: !!data.error, } as any })) - isInIteration = false }, onNodeStarted: ({ data }) => { - if (isInIteration) + if (data.iteration_id) return - setWorkflowProccessData(produce(getWorkflowProccessData()!, (draft) => { + setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => { draft.expand = true draft.tracing!.push({ ...data, @@ -248,11 +251,12 @@ const Result: FC = ({ })) }, onNodeFinished: ({ data }) => { - if (isInIteration) + if (data.iteration_id) return - setWorkflowProccessData(produce(getWorkflowProccessData()!, (draft) => { - const currentIndex = draft.tracing!.findIndex(trace => trace.node_id === data.node_id) + setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => { + const currentIndex = draft.tracing!.findIndex(trace => trace.node_id === data.node_id + && (trace.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || trace.parallel_id === data.execution_metadata?.parallel_id)) if (currentIndex > -1 && draft.tracing) { draft.tracing[currentIndex] = { ...(draft.tracing[currentIndex].extras @@ -269,7 +273,7 @@ const Result: FC = ({ return if (data.error) { notify({ type: 'error', message: data.error }) - setWorkflowProccessData(produce(getWorkflowProccessData()!, (draft) => { + setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => { draft.status = WorkflowRunningStatus.Failed })) setRespondingFalse() @@ -277,7 +281,7 @@ const Result: FC = ({ isEnd = true return } - setWorkflowProccessData(produce(getWorkflowProccessData()!, (draft) => { + setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => { draft.status = WorkflowRunningStatus.Succeeded })) if (!data.outputs) { @@ -287,7 +291,7 @@ const Result: FC = ({ setCompletionRes(data.outputs) const isStringOutput = Object.keys(data.outputs).length === 1 && typeof data.outputs[Object.keys(data.outputs)[0]] === 'string' if (isStringOutput) { - setWorkflowProccessData(produce(getWorkflowProccessData()!, (draft) => { + setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => { draft.resultText = data.outputs[Object.keys(data.outputs)[0]] })) } @@ -299,13 +303,13 @@ const Result: FC = ({ }, onTextChunk: (params) => { const { data: { text } } = params - setWorkflowProccessData(produce(getWorkflowProccessData()!, (draft) => { + setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => { draft.resultText += text })) }, onTextReplace: (params) => { const { data: { text } } = params - setWorkflowProccessData(produce(getWorkflowProccessData()!, (draft) => { + setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => { draft.resultText = text })) }, diff --git a/web/app/components/tools/add-tool-modal/category.tsx b/web/app/components/tools/add-tool-modal/category.tsx index cc2c325410..a18c30ad54 100644 --- a/web/app/components/tools/add-tool-modal/category.tsx +++ b/web/app/components/tools/add-tool-modal/category.tsx @@ -17,7 +17,7 @@ type Props = { const Icon = ({ svgString, active }: { svgString: string; active: boolean }) => { const svgRef = useRef(null) - const SVGParsor = (svg: string) => { + const SVGParser = (svg: string) => { if (!svg) return null const parser = new DOMParser() @@ -25,7 +25,7 @@ const Icon = ({ svgString, active }: { svgString: string; active: boolean }) => return doc.documentElement } useMount(() => { - const svgElement = SVGParsor(svgString) + const svgElement = SVGParser(svgString) if (svgRef.current && svgElement) svgRef.current.appendChild(svgElement) }) diff --git a/web/app/components/tools/edit-custom-collection-modal/config-credentials.tsx b/web/app/components/tools/edit-custom-collection-modal/config-credentials.tsx index c7685d496d..d580c00102 100644 --- a/web/app/components/tools/edit-custom-collection-modal/config-credentials.tsx +++ b/web/app/components/tools/edit-custom-collection-modal/config-credentials.tsx @@ -114,7 +114,7 @@ const ConfigCredential: FC = ({ {t('tools.createTool.authMethod.keyTooltip')}
} - triggerClassName='ml-0.5' + triggerClassName='ml-0.5 w-4 h-4' />
{ const store = useStoreApi() @@ -52,6 +54,8 @@ const CandidateNode = () => { y, }, }) + if (candidateNode.data.type === BlockEnum.Iteration) + draft.push(getIterationStartNode(candidateNode.id)) }) setNodes(newNodes) if (candidateNode.type === CUSTOM_NOTE_NODE) diff --git a/web/app/components/workflow/constants.ts b/web/app/components/workflow/constants.ts index 070748bab0..6a4629e9c8 100644 --- a/web/app/components/workflow/constants.ts +++ b/web/app/components/workflow/constants.ts @@ -15,6 +15,7 @@ import VariableAssignerDefault from './nodes/variable-assigner/default' import AssignerDefault from './nodes/assigner/default' import EndNodeDefault from './nodes/end/default' import IterationDefault from './nodes/iteration/default' +import IterationStartDefault from './nodes/iteration-start/default' type NodesExtraData = { author: string @@ -89,6 +90,15 @@ export const NODES_EXTRA_DATA: Record = { getAvailableNextNodes: IterationDefault.getAvailableNextNodes, checkValid: IterationDefault.checkValid, }, + [BlockEnum.IterationStart]: { + author: 'Dify', + about: '', + availablePrevNodes: [], + availableNextNodes: [], + getAvailablePrevNodes: IterationStartDefault.getAvailablePrevNodes, + getAvailableNextNodes: IterationStartDefault.getAvailableNextNodes, + checkValid: IterationStartDefault.checkValid, + }, [BlockEnum.Code]: { author: 'Dify', about: '', @@ -222,6 +232,12 @@ export const NODES_INITIAL_DATA = { desc: '', ...IterationDefault.defaultValue, }, + [BlockEnum.IterationStart]: { + type: BlockEnum.IterationStart, + title: '', + desc: '', + ...IterationStartDefault.defaultValue, + }, [BlockEnum.Code]: { type: BlockEnum.Code, title: '', @@ -305,11 +321,13 @@ export const AUTO_LAYOUT_OFFSET = { export const ITERATION_NODE_Z_INDEX = 1 export const ITERATION_CHILDREN_Z_INDEX = 1002 export const ITERATION_PADDING = { - top: 85, + top: 65, right: 16, bottom: 20, left: 16, } +export const PARALLEL_LIMIT = 10 +export const PARALLEL_DEPTH_LIMIT = 3 export const RETRIEVAL_OUTPUT_STRUCT = `{ "content": "", @@ -412,4 +430,5 @@ export const PARAMETER_EXTRACTOR_COMMON_STRUCT: Var[] = [ export const WORKFLOW_DATA_UPDATE = 'WORKFLOW_DATA_UPDATE' export const CUSTOM_NODE = 'custom' +export const CUSTOM_EDGE = 'custom' export const DSL_EXPORT_CHECK = 'DSL_EXPORT_CHECK' diff --git a/web/app/components/workflow/custom-edge.tsx b/web/app/components/workflow/custom-edge.tsx index 5e945790d8..68e2ef945e 100644 --- a/web/app/components/workflow/custom-edge.tsx +++ b/web/app/components/workflow/custom-edge.tsx @@ -79,7 +79,7 @@ const CustomEdge = ({ id={id} path={edgePath} style={{ - stroke: (selected || data?._connectedNodeIsHovering || data?._runned) ? '#2970FF' : '#D0D5DD', + stroke: (selected || data?._connectedNodeIsHovering || data?._run) ? '#2970FF' : '#D0D5DD', strokeWidth: 2, }} /> diff --git a/web/app/components/workflow/header/checklist.tsx b/web/app/components/workflow/header/checklist.tsx index 7de9cfa2f4..6a9a6a6b9f 100644 --- a/web/app/components/workflow/header/checklist.tsx +++ b/web/app/components/workflow/header/checklist.tsx @@ -125,7 +125,7 @@ const WorkflowChecklist = ({
- {t('workflow.common.needConnecttip')} + {t('workflow.common.needConnectTip')}
) diff --git a/web/app/components/workflow/header/view-history.tsx b/web/app/components/workflow/header/view-history.tsx index 06eebfd329..a18ddad65d 100644 --- a/web/app/components/workflow/header/view-history.tsx +++ b/web/app/components/workflow/header/view-history.tsx @@ -32,7 +32,7 @@ import { } from '@/app/components/base/icons/src/vender/line/time' import { AlertTriangle } from '@/app/components/base/icons/src/vender/line/alertsAndFeedback' import { - fetcChatRunHistory, + fetchChatRunHistory, fetchWorkflowRunHistory, } from '@/service/workflow' import Loading from '@/app/components/base/loading' @@ -67,7 +67,7 @@ const ViewHistory = ({ const historyWorkflowData = useStore(s => s.historyWorkflowData) const { handleBackupDraft } = useWorkflowRun() const { data: runList, isLoading: runListLoading } = useSWR((appDetail && !isChatMode && open) ? `/apps/${appDetail.id}/workflow-runs` : null, fetchWorkflowRunHistory) - const { data: chatList, isLoading: chatListLoading } = useSWR((appDetail && isChatMode && open) ? `/apps/${appDetail.id}/advanced-chat/workflow-runs` : null, fetcChatRunHistory) + const { data: chatList, isLoading: chatListLoading } = useSWR((appDetail && isChatMode && open) ? `/apps/${appDetail.id}/advanced-chat/workflow-runs` : null, fetchChatRunHistory) const data = isChatMode ? chatList : runList const isLoading = isChatMode ? chatListLoading : runListLoading diff --git a/web/app/components/workflow/hooks/use-checklist.ts b/web/app/components/workflow/hooks/use-checklist.ts index 7f45769acd..36201ddfef 100644 --- a/web/app/components/workflow/hooks/use-checklist.ts +++ b/web/app/components/workflow/hooks/use-checklist.ts @@ -138,7 +138,7 @@ export const useChecklistBeforePublish = () => { } if (!validNodes.find(n => n.id === node.id)) { - notify({ type: 'error', message: `[${node.data.title}] ${t('workflow.common.needConnecttip')}` }) + notify({ type: 'error', message: `[${node.data.title}] ${t('workflow.common.needConnectTip')}` }) return false } } diff --git a/web/app/components/workflow/hooks/use-edges-interactions.ts b/web/app/components/workflow/hooks/use-edges-interactions.ts index bc3fb0e8bf..a97b65134f 100644 --- a/web/app/components/workflow/hooks/use-edges-interactions.ts +++ b/web/app/components/workflow/hooks/use-edges-interactions.ts @@ -155,7 +155,7 @@ export const useEdgesInteractions = () => { const newEdges = produce(edges, (draft) => { draft.forEach((edge) => { - edge.data._runned = false + edge.data._run = false }) }) setEdges(newEdges) diff --git a/web/app/components/workflow/hooks/use-nodes-interactions.ts b/web/app/components/workflow/hooks/use-nodes-interactions.ts index 3645e18449..af2a1500ba 100644 --- a/web/app/components/workflow/hooks/use-nodes-interactions.ts +++ b/web/app/components/workflow/hooks/use-nodes-interactions.ts @@ -16,6 +16,7 @@ import { useReactFlow, useStoreApi, } from 'reactflow' +import { unionBy } from 'lodash-es' import type { ToolDefaultValue } from '../block-selector/types' import type { Edge, @@ -25,6 +26,7 @@ import type { import { BlockEnum } from '../types' import { useWorkflowStore } from '../store' import { + CUSTOM_EDGE, ITERATION_CHILDREN_Z_INDEX, ITERATION_PADDING, NODES_INITIAL_DATA, @@ -40,6 +42,7 @@ import { } from '../utils' import { CUSTOM_NOTE_NODE } from '../note-node/constants' import type { IterationNodeType } from '../nodes/iteration/types' +import { CUSTOM_ITERATION_START_NODE } from '../nodes/iteration-start/constants' import type { VariableAssignerNodeType } from '../nodes/variable-assigner/types' import { useNodeIterationInteractions } from '../nodes/iteration/use-interactions' import { useWorkflowHistoryStore } from '../workflow-history-store' @@ -60,6 +63,7 @@ export const useNodesInteractions = () => { const { store: workflowHistoryStore } = useWorkflowHistoryStore() const { handleSyncWorkflowDraft } = useNodesSyncDraft() const { + checkNestedParallelLimit, getAfterNodesInSameBranch, } = useWorkflow() const { getNodesReadOnly } = useNodesReadOnly() @@ -79,7 +83,7 @@ export const useNodesInteractions = () => { if (getNodesReadOnly()) return - if (node.data.isIterationStart || node.type === CUSTOM_NOTE_NODE) + if (node.type === CUSTOM_ITERATION_START_NODE || node.type === CUSTOM_NOTE_NODE) return dragNodeStartPosition.current = { x: node.position.x, y: node.position.y } @@ -89,7 +93,7 @@ export const useNodesInteractions = () => { if (getNodesReadOnly()) return - if (node.data.isIterationStart) + if (node.type === CUSTOM_ITERATION_START_NODE) return const { @@ -156,7 +160,7 @@ export const useNodesInteractions = () => { if (getNodesReadOnly()) return - if (node.type === CUSTOM_NOTE_NODE) + if (node.type === CUSTOM_NOTE_NODE || node.type === CUSTOM_ITERATION_START_NODE) return const { @@ -207,13 +211,30 @@ export const useNodesInteractions = () => { }) }) setEdges(newEdges) + const connectedEdges = getConnectedEdges([node], edges).filter(edge => edge.target === node.id) + + const targetNodes: Node[] = [] + for (let i = 0; i < connectedEdges.length; i++) { + const sourceConnectedEdges = getConnectedEdges([{ id: connectedEdges[i].source } as Node], edges).filter(edge => edge.source === connectedEdges[i].source && edge.sourceHandle === connectedEdges[i].sourceHandle) + targetNodes.push(...sourceConnectedEdges.map(edge => nodes.find(n => n.id === edge.target)!)) + } + const uniqTargetNodes = unionBy(targetNodes, 'id') + if (uniqTargetNodes.length > 1) { + const newNodes = produce(nodes, (draft) => { + draft.forEach((n) => { + if (uniqTargetNodes.some(targetNode => n.id === targetNode.id)) + n.data._inParallelHovering = true + }) + }) + setNodes(newNodes) + } }, [store, workflowStore, getNodesReadOnly]) const handleNodeLeave = useCallback((_, node) => { if (getNodesReadOnly()) return - if (node.type === CUSTOM_NOTE_NODE) + if (node.type === CUSTOM_NOTE_NODE || node.type === CUSTOM_ITERATION_START_NODE) return const { @@ -229,6 +250,7 @@ export const useNodesInteractions = () => { const newNodes = produce(getNodes(), (draft) => { draft.forEach((node) => { node.data._isEntering = false + node.data._inParallelHovering = false }) }) setNodes(newNodes) @@ -287,6 +309,8 @@ export const useNodesInteractions = () => { }, [store, handleSyncWorkflowDraft]) const handleNodeClick = useCallback((_, node) => { + if (node.type === CUSTOM_ITERATION_START_NODE) + return handleNodeSelect(node.id) }, [handleNodeSelect]) @@ -314,25 +338,15 @@ export const useNodesInteractions = () => { if (targetNode?.parentId !== sourceNode?.parentId) return - if (targetNode?.data.isIterationStart) - return - if (sourceNode?.type === CUSTOM_NOTE_NODE || targetNode?.type === CUSTOM_NOTE_NODE) return - const needDeleteEdges = edges.filter((edge) => { - if ( - (edge.source === source && edge.sourceHandle === sourceHandle) - || (edge.target === target && edge.targetHandle === targetHandle && targetNode?.data.type !== BlockEnum.VariableAssigner && targetNode?.data.type !== BlockEnum.VariableAggregator) - ) - return true + if (edges.find(edge => edge.source === source && edge.sourceHandle === sourceHandle && edge.target === target && edge.targetHandle === targetHandle)) + return - return false - }) - const needDeleteEdgesIds = needDeleteEdges.map(edge => edge.id) const newEdge = { id: `${source}-${sourceHandle}-${target}-${targetHandle}`, - type: 'custom', + type: CUSTOM_EDGE, source: source!, target: target!, sourceHandle, @@ -347,7 +361,6 @@ export const useNodesInteractions = () => { } const nodesConnectedSourceOrTargetHandleIdsMap = getNodesConnectedSourceOrTargetHandleIdsMap( [ - ...needDeleteEdges.map(edge => ({ type: 'remove', edge })), { type: 'add', edge: newEdge }, ], nodes, @@ -362,19 +375,26 @@ export const useNodesInteractions = () => { } }) }) - setNodes(newNodes) const newEdges = produce(edges, (draft) => { - const filtered = draft.filter(edge => !needDeleteEdgesIds.includes(edge.id)) - - filtered.push(newEdge) - - return filtered + draft.push(newEdge) }) - setEdges(newEdges) - handleSyncWorkflowDraft() - saveStateToHistory(WorkflowHistoryEvent.NodeConnect) - }, [getNodesReadOnly, store, handleSyncWorkflowDraft, saveStateToHistory]) + if (checkNestedParallelLimit(newNodes, newEdges, targetNode?.parentId)) { + setNodes(newNodes) + setEdges(newEdges) + + handleSyncWorkflowDraft() + saveStateToHistory(WorkflowHistoryEvent.NodeConnect) + } + else { + const { + setConnectingNodePayload, + setEnteringNodePayload, + } = workflowStore.getState() + setConnectingNodePayload(undefined) + setEnteringNodePayload(undefined) + } + }, [getNodesReadOnly, store, workflowStore, handleSyncWorkflowDraft, saveStateToHistory, checkNestedParallelLimit]) const handleNodeConnectStart = useCallback((_, { nodeId, handleType, handleId }) => { if (getNodesReadOnly()) @@ -393,14 +413,12 @@ export const useNodesInteractions = () => { return } - if (!node.data.isIterationStart) { - setConnectingNodePayload({ - nodeId, - nodeType: node.data.type, - handleType, - handleId, - }) - } + setConnectingNodePayload({ + nodeId, + nodeType: node.data.type, + handleType, + handleId, + }) } }, [store, workflowStore, getNodesReadOnly]) @@ -510,6 +528,12 @@ export const useNodesInteractions = () => { return handleNodeDelete(nodeId) } else { + if (iterationChildren.length === 1) { + handleNodeDelete(iterationChildren[0].id) + handleNodeDelete(nodeId) + + return + } const { setShowConfirm, showConfirm } = workflowStore.getState() if (!showConfirm) { @@ -541,14 +565,8 @@ export const useNodesInteractions = () => { } } - if (node.id === currentNode.parentId) { + if (node.id === currentNode.parentId) node.data._children = node.data._children?.filter(child => child !== nodeId) - - if (currentNode.id === (node as Node).data.start_node_id) { - (node as Node).data.start_node_id = ''; - (node as Node).data.startNodeType = undefined - } - } }) draft.splice(currentNodeIndex, 1) }) @@ -559,7 +577,7 @@ export const useNodesInteractions = () => { setEdges(newEdges) handleSyncWorkflowDraft() - if (currentNode.type === 'custom-note') + if (currentNode.type === CUSTOM_NOTE_NODE) saveStateToHistory(WorkflowHistoryEvent.NoteDelete) else @@ -591,7 +609,10 @@ export const useNodesInteractions = () => { } = store.getState() const nodes = getNodes() const nodesWithSameType = nodes.filter(node => node.data.type === nodeType) - const newNode = generateNewNode({ + const { + newNode, + newIterationStartNode, + } = generateNewNode({ data: { ...NODES_INITIAL_DATA[nodeType], title: nodesWithSameType.length > 0 ? `${t(`workflow.blocks.${nodeType}`)} ${nodesWithSameType.length + 1}` : t(`workflow.blocks.${nodeType}`), @@ -627,7 +648,7 @@ export const useNodesInteractions = () => { const newEdge: Edge = { id: `${prevNodeId}-${prevNodeSourceHandle}-${newNode.id}-${targetHandle}`, - type: 'custom', + type: CUSTOM_EDGE, source: prevNodeId, sourceHandle: prevNodeSourceHandle, target: newNode.id, @@ -662,8 +683,10 @@ export const useNodesInteractions = () => { node.data._children?.push(newNode.id) }) draft.push(newNode) + if (newIterationStartNode) + draft.push(newIterationStartNode) }) - setNodes(newNodes) + if (newNode.data.type === BlockEnum.VariableAssigner || newNode.data.type === BlockEnum.VariableAggregator) { const { setShowAssignVariablePopup } = workflowStore.getState() @@ -687,7 +710,14 @@ export const useNodesInteractions = () => { }) draft.push(newEdge) }) - setEdges(newEdges) + + if (checkNestedParallelLimit(newNodes, newEdges, prevNode.parentId)) { + setNodes(newNodes) + setEdges(newEdges) + } + else { + return false + } } if (!prevNodeId && nextNodeId) { const nextNodeIndex = nodes.findIndex(node => node.id === nextNodeId) @@ -706,15 +736,13 @@ export const useNodesInteractions = () => { newNode.data.iteration_id = nextNode.parentId newNode.zIndex = ITERATION_CHILDREN_Z_INDEX } - if (nextNode.data.isIterationStart) - newNode.data.isIterationStart = true let newEdge if ((nodeType !== BlockEnum.IfElse) && (nodeType !== BlockEnum.QuestionClassifier)) { newEdge = { id: `${newNode.id}-${sourceHandle}-${nextNodeId}-${nextNodeTargetHandle}`, - type: 'custom', + type: CUSTOM_EDGE, source: newNode.id, sourceHandle, target: nextNodeId, @@ -763,13 +791,11 @@ export const useNodesInteractions = () => { node.data.start_node_id = newNode.id node.data.startNodeType = newNode.data.type } - - if (node.id === nextNodeId && node.data.isIterationStart) - node.data.isIterationStart = false }) draft.push(newNode) + if (newIterationStartNode) + draft.push(newIterationStartNode) }) - setNodes(newNodes) if (newEdge) { const newEdges = produce(edges, (draft) => { draft.forEach((item) => { @@ -780,7 +806,21 @@ export const useNodesInteractions = () => { }) draft.push(newEdge) }) - setEdges(newEdges) + + if (checkNestedParallelLimit(newNodes, newEdges, nextNode.parentId)) { + setNodes(newNodes) + setEdges(newEdges) + } + else { + return false + } + } + else { + if (checkNestedParallelLimit(newNodes, edges)) + setNodes(newNodes) + + else + return false } } if (prevNodeId && nextNodeId) { @@ -804,7 +844,7 @@ export const useNodesInteractions = () => { const currentEdgeIndex = edges.findIndex(edge => edge.source === prevNodeId && edge.target === nextNodeId) const newPrevEdge = { id: `${prevNodeId}-${prevNodeSourceHandle}-${newNode.id}-${targetHandle}`, - type: 'custom', + type: CUSTOM_EDGE, source: prevNodeId, sourceHandle: prevNodeSourceHandle, target: newNode.id, @@ -822,7 +862,7 @@ export const useNodesInteractions = () => { if (nodeType !== BlockEnum.IfElse && nodeType !== BlockEnum.QuestionClassifier) { newNextEdge = { id: `${newNode.id}-${sourceHandle}-${nextNodeId}-${nextNodeTargetHandle}`, - type: 'custom', + type: CUSTOM_EDGE, source: newNode.id, sourceHandle, target: nextNodeId, @@ -865,6 +905,8 @@ export const useNodesInteractions = () => { node.data._children?.push(newNode.id) }) draft.push(newNode) + if (newIterationStartNode) + draft.push(newIterationStartNode) }) setNodes(newNodes) if (newNode.data.type === BlockEnum.VariableAssigner || newNode.data.type === BlockEnum.VariableAggregator) { @@ -898,7 +940,7 @@ export const useNodesInteractions = () => { } handleSyncWorkflowDraft() saveStateToHistory(WorkflowHistoryEvent.NodeAdd) - }, [getNodesReadOnly, store, t, handleSyncWorkflowDraft, saveStateToHistory, workflowStore, getAfterNodesInSameBranch]) + }, [getNodesReadOnly, store, t, handleSyncWorkflowDraft, saveStateToHistory, workflowStore, getAfterNodesInSameBranch, checkNestedParallelLimit]) const handleNodeChange = useCallback(( currentNodeId: string, @@ -919,7 +961,10 @@ export const useNodesInteractions = () => { const currentNode = nodes.find(node => node.id === currentNodeId)! const connectedEdges = getConnectedEdges([currentNode], edges) const nodesWithSameType = nodes.filter(node => node.data.type === nodeType) - const newCurrentNode = generateNewNode({ + const { + newNode: newCurrentNode, + newIterationStartNode, + } = generateNewNode({ data: { ...NODES_INITIAL_DATA[nodeType], title: nodesWithSameType.length > 0 ? `${t(`workflow.blocks.${nodeType}`)} ${nodesWithSameType.length + 1}` : t(`workflow.blocks.${nodeType}`), @@ -929,7 +974,6 @@ export const useNodesInteractions = () => { selected: currentNode.data.selected, isInIteration: currentNode.data.isInIteration, iteration_id: currentNode.data.iteration_id, - isIterationStart: currentNode.data.isIterationStart, }, position: { x: currentNode.position.x, @@ -955,18 +999,12 @@ export const useNodesInteractions = () => { ...nodesConnectedSourceOrTargetHandleIdsMap[node.id], } } - if (node.id === currentNode.parentId && currentNode.data.isIterationStart) { - node.data._children = [ - newCurrentNode.id, - ...(node.data._children || []), - ].filter(child => child !== currentNodeId) - node.data.start_node_id = newCurrentNode.id - node.data.startNodeType = newCurrentNode.data.type - } }) const index = draft.findIndex(node => node.id === currentNodeId) draft.splice(index, 1, newCurrentNode) + if (newIterationStartNode) + draft.push(newIterationStartNode) }) setNodes(newNodes) const newEdges = produce(edges, (draft) => { @@ -1011,7 +1049,7 @@ export const useNodesInteractions = () => { }, [store]) const handleNodeContextMenu = useCallback((e: MouseEvent, node: Node) => { - if (node.type === CUSTOM_NOTE_NODE) + if (node.type === CUSTOM_NOTE_NODE || node.type === CUSTOM_ITERATION_START_NODE) return e.preventDefault() @@ -1041,7 +1079,7 @@ export const useNodesInteractions = () => { if (nodeId) { // If nodeId is provided, copy that specific node - const nodeToCopy = nodes.find(node => node.id === nodeId && node.data.type !== BlockEnum.Start) + const nodeToCopy = nodes.find(node => node.id === nodeId && node.data.type !== BlockEnum.Start && node.type !== CUSTOM_ITERATION_START_NODE) if (nodeToCopy) setClipboardElements([nodeToCopy]) } @@ -1087,7 +1125,10 @@ export const useNodesInteractions = () => { clipboardElements.forEach((nodeToPaste, index) => { const nodeType = nodeToPaste.data.type - const newNode = generateNewNode({ + const { + newNode, + newIterationStartNode, + } = generateNewNode({ type: nodeToPaste.type, data: { ...NODES_INITIAL_DATA[nodeType], @@ -1106,24 +1147,17 @@ export const useNodesInteractions = () => { zIndex: nodeToPaste.zIndex, }) newNode.id = newNode.id + index - - // If only the iteration start node is copied, remove the isIterationStart flag // This new node is movable and can be placed anywhere - if (clipboardElements.length === 1 && newNode.data.isIterationStart) - newNode.data.isIterationStart = false - let newChildren: Node[] = [] if (nodeToPaste.data.type === BlockEnum.Iteration) { - newNode.data._children = []; - (newNode.data as IterationNodeType).start_node_id = '' + newIterationStartNode!.parentId = newNode.id; + (newNode.data as IterationNodeType).start_node_id = newIterationStartNode!.id newChildren = handleNodeIterationChildrenCopy(nodeToPaste.id, newNode.id) - newChildren.forEach((child) => { newNode.data._children?.push(child.id) - if (child.data.isIterationStart) - (newNode.data as IterationNodeType).start_node_id = child.id }) + newChildren.push(newIterationStartNode!) } nodesToPaste.push(newNode) @@ -1230,6 +1264,42 @@ export const useNodesInteractions = () => { saveStateToHistory(WorkflowHistoryEvent.NodeResize) }, [getNodesReadOnly, store, handleSyncWorkflowDraft, saveStateToHistory]) + const handleNodeDisconnect = useCallback((nodeId: string) => { + if (getNodesReadOnly()) + return + + const { + getNodes, + setNodes, + edges, + setEdges, + } = store.getState() + const nodes = getNodes() + const currentNode = nodes.find(node => node.id === nodeId)! + const connectedEdges = getConnectedEdges([currentNode], edges) + const nodesConnectedSourceOrTargetHandleIdsMap = getNodesConnectedSourceOrTargetHandleIdsMap( + connectedEdges.map(edge => ({ type: 'remove', edge })), + nodes, + ) + const newNodes = produce(nodes, (draft: Node[]) => { + draft.forEach((node) => { + if (nodesConnectedSourceOrTargetHandleIdsMap[node.id]) { + node.data = { + ...node.data, + ...nodesConnectedSourceOrTargetHandleIdsMap[node.id], + } + } + }) + }) + setNodes(newNodes) + const newEdges = produce(edges, (draft) => { + return draft.filter(edge => !connectedEdges.find(connectedEdge => connectedEdge.id === edge.id)) + }) + setEdges(newEdges) + handleSyncWorkflowDraft() + saveStateToHistory(WorkflowHistoryEvent.EdgeDelete) + }, [store, getNodesReadOnly, handleSyncWorkflowDraft, saveStateToHistory]) + const handleHistoryBack = useCallback(() => { if (getNodesReadOnly() || getWorkflowReadOnly()) return @@ -1282,6 +1352,7 @@ export const useNodesInteractions = () => { handleNodesDuplicate, handleNodesDelete, handleNodeResize, + handleNodeDisconnect, handleHistoryBack, handleHistoryForward, } diff --git a/web/app/components/workflow/hooks/use-shortcuts.ts b/web/app/components/workflow/hooks/use-shortcuts.ts index 666c3a45ba..439b521a30 100644 --- a/web/app/components/workflow/hooks/use-shortcuts.ts +++ b/web/app/components/workflow/hooks/use-shortcuts.ts @@ -70,7 +70,8 @@ export const useShortcuts = (): void => { }) useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.c`, (e) => { - if (shouldHandleShortcut(e)) { + const { showDebugAndPreviewPanel, showInputsPanel } = workflowStore.getState() + if (shouldHandleShortcut(e) && !showDebugAndPreviewPanel && !showInputsPanel) { e.preventDefault() handleNodesCopy() } diff --git a/web/app/components/workflow/hooks/use-workflow-interactions.ts b/web/app/components/workflow/hooks/use-workflow-interactions.ts index 47b8a30a5a..b39a3d8014 100644 --- a/web/app/components/workflow/hooks/use-workflow-interactions.ts +++ b/web/app/components/workflow/hooks/use-workflow-interactions.ts @@ -10,7 +10,7 @@ import { CUSTOM_NODE, DSL_EXPORT_CHECK, WORKFLOW_DATA_UPDATE, } from '../constants' -import type { Node, WorkflowDataUpdator } from '../types' +import type { Node, WorkflowDataUpdater } from '../types' import { ControlMode } from '../types' import { getLayoutByDagre, @@ -208,7 +208,7 @@ export const useWorkflowUpdate = () => { const workflowStore = useWorkflowStore() const { eventEmitter } = useEventEmitterContextContext() - const handleUpdateWorkflowCanvas = useCallback((payload: WorkflowDataUpdator) => { + const handleUpdateWorkflowCanvas = useCallback((payload: WorkflowDataUpdater) => { const { nodes, edges, @@ -236,7 +236,7 @@ export const useWorkflowUpdate = () => { } = workflowStore.getState() setIsSyncingWorkflowDraft(true) fetchWorkflowDraft(`/apps/${appId}/workflows/draft`).then((response) => { - handleUpdateWorkflowCanvas(response.graph as WorkflowDataUpdator) + handleUpdateWorkflowCanvas(response.graph as WorkflowDataUpdater) setSyncWorkflowDraftHash(response.hash) setEnvSecrets((response.environment_variables || []).filter(env => env.value_type === 'secret').reduce((acc, env) => { acc[env.id] = env.value diff --git a/web/app/components/workflow/hooks/use-workflow-run.ts b/web/app/components/workflow/hooks/use-workflow-run.ts index 96f6557fe0..e1da503f38 100644 --- a/web/app/components/workflow/hooks/use-workflow-run.ts +++ b/web/app/components/workflow/hooks/use-workflow-run.ts @@ -1,5 +1,6 @@ import { useCallback } from 'react' import { + getIncomers, useReactFlow, useStoreApi, } from 'reactflow' @@ -8,6 +9,7 @@ import { v4 as uuidV4 } from 'uuid' import { usePathname } from 'next/navigation' import { useWorkflowStore } from '../store' import { useNodesSyncDraft } from '../hooks' +import type { Node } from '../types' import { NodeRunningStatus, WorkflowRunningStatus, @@ -140,9 +142,6 @@ export const useWorkflowRun = () => { resultText: '', }) - let isInIteration = false - let iterationLength = 0 - let ttsUrl = '' let ttsIsPublic = false if (params.token) { @@ -249,19 +248,20 @@ export const useWorkflowRun = () => { setEdges, transform, } = store.getState() - if (isInIteration) { + const nodes = getNodes() + const node = nodes.find(node => node.id === data.node_id) + if (node?.parentId) { setWorkflowRunningData(produce(workflowRunningData!, (draft) => { const tracing = draft.tracing! - const iterations = tracing[tracing.length - 1] - const currIteration = iterations.details![iterations.details!.length - 1] - currIteration.push({ + const iterations = tracing.find(trace => trace.node_id === node?.parentId) + const currIteration = iterations?.details![node.data.iteration_index] || iterations?.details![iterations.details!.length - 1] + currIteration?.push({ ...data, status: NodeRunningStatus.Running, } as any) })) } else { - const nodes = getNodes() setWorkflowRunningData(produce(workflowRunningData!, (draft) => { draft.tracing!.push({ ...data, @@ -288,11 +288,12 @@ export const useWorkflowRun = () => { draft[currentNodeIndex].data._runningStatus = NodeRunningStatus.Running }) setNodes(newNodes) + const incomeNodesId = getIncomers({ id: data.node_id } as Node, newNodes, edges).filter(node => node.data._runningStatus === NodeRunningStatus.Succeeded).map(node => node.id) const newEdges = produce(edges, (draft) => { - const edge = draft.find(edge => edge.target === data.node_id && edge.source === prevNodeId) - - if (edge) - edge.data = { ...edge.data, _runned: true } as any + draft.forEach((edge) => { + if (edge.target === data.node_id && incomeNodesId.includes(edge.source)) + edge.data = { ...edge.data, _runned: true } as any + }) }) setEdges(newEdges) } @@ -309,25 +310,46 @@ export const useWorkflowRun = () => { getNodes, setNodes, } = store.getState() - if (isInIteration) { + const nodes = getNodes() + const nodeParentId = nodes.find(node => node.id === data.node_id)!.parentId + if (nodeParentId) { setWorkflowRunningData(produce(workflowRunningData!, (draft) => { const tracing = draft.tracing! - const iterations = tracing[tracing.length - 1] - const currIteration = iterations.details![iterations.details!.length - 1] - const nodeInfo = currIteration[currIteration.length - 1] + const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node - currIteration[currIteration.length - 1] = { - ...nodeInfo, - ...data, - status: NodeRunningStatus.Succeeded, - } as any + if (iterations && iterations.details) { + const iterationIndex = data.execution_metadata?.iteration_index || 0 + if (!iterations.details[iterationIndex]) + iterations.details[iterationIndex] = [] + + const currIteration = iterations.details[iterationIndex] + const nodeIndex = currIteration.findIndex(node => + node.node_id === data.node_id && ( + node.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || node.parallel_id === data.execution_metadata?.parallel_id), + ) + if (data.status === NodeRunningStatus.Succeeded) { + if (nodeIndex !== -1) { + currIteration[nodeIndex] = { + ...currIteration[nodeIndex], + ...data, + } as any + } + else { + currIteration.push({ + ...data, + } as any) + } + } + } })) } else { - const nodes = getNodes() setWorkflowRunningData(produce(workflowRunningData!, (draft) => { - const currentIndex = draft.tracing!.findIndex(trace => trace.node_id === data.node_id) - + const currentIndex = draft.tracing!.findIndex((trace) => { + if (!trace.execution_metadata?.parallel_id) + return trace.node_id === data.node_id + return trace.node_id === data.node_id && trace.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id + }) if (currentIndex > -1 && draft.tracing) { draft.tracing[currentIndex] = { ...(draft.tracing[currentIndex].extras @@ -337,16 +359,14 @@ export const useWorkflowRun = () => { } as any } })) - const newNodes = produce(nodes, (draft) => { const currentNode = draft.find(node => node.id === data.node_id)! - currentNode.data._runningStatus = data.status as any }) setNodes(newNodes) - prevNodeId = data.node_id } + if (onNodeFinished) onNodeFinished(params) }, @@ -371,8 +391,6 @@ export const useWorkflowRun = () => { details: [], } as any) })) - isInIteration = true - iterationLength = data.metadata.iterator_length const { setViewport, @@ -418,13 +436,13 @@ export const useWorkflowRun = () => { } = store.getState() setWorkflowRunningData(produce(workflowRunningData!, (draft) => { - const iteration = draft.tracing![draft.tracing!.length - 1] - if (iteration.details!.length >= iterationLength) - return - - iteration.details!.push([]) + const iteration = draft.tracing!.find(trace => trace.node_id === data.node_id) + if (iteration) { + if (iteration.details!.length >= iteration.metadata.iterator_length!) + return + } + iteration?.details!.push([]) })) - const nodes = getNodes() const newNodes = produce(nodes, (draft) => { const currentNode = draft.find(node => node.id === data.node_id)! @@ -450,13 +468,14 @@ export const useWorkflowRun = () => { const nodes = getNodes() setWorkflowRunningData(produce(workflowRunningData!, (draft) => { const tracing = draft.tracing! - tracing[tracing.length - 1] = { - ...tracing[tracing.length - 1], - ...data, - status: NodeRunningStatus.Succeeded, - } as any + const currIterationNode = tracing.find(trace => trace.node_id === data.node_id) + if (currIterationNode) { + Object.assign(currIterationNode, { + ...data, + status: NodeRunningStatus.Succeeded, + }) + } })) - isInIteration = false const newNodes = produce(nodes, (draft) => { const currentNode = draft.find(node => node.id === data.node_id)! @@ -470,6 +489,12 @@ export const useWorkflowRun = () => { if (onIterationFinish) onIterationFinish(params) }, + onParallelBranchStarted: (params) => { + // console.log(params, 'parallel start') + }, + onParallelBranchFinished: (params) => { + // console.log(params, 'finished') + }, onTextChunk: (params) => { const { data: { text } } = params const { diff --git a/web/app/components/workflow/hooks/use-workflow-template.ts b/web/app/components/workflow/hooks/use-workflow-template.ts index 3af3f733f1..e36f0b61f9 100644 --- a/web/app/components/workflow/hooks/use-workflow-template.ts +++ b/web/app/components/workflow/hooks/use-workflow-template.ts @@ -10,13 +10,13 @@ export const useWorkflowTemplate = () => { const isChatMode = useIsChatMode() const nodesInitialData = useNodesInitialData() - const startNode = generateNewNode({ + const { newNode: startNode } = generateNewNode({ data: nodesInitialData.start, position: START_INITIAL_POSITION, }) if (isChatMode) { - const llmNode = generateNewNode({ + const { newNode: llmNode } = generateNewNode({ id: 'llm', data: { ...nodesInitialData.llm, @@ -31,7 +31,7 @@ export const useWorkflowTemplate = () => { }, } as any) - const answerNode = generateNewNode({ + const { newNode: answerNode } = generateNewNode({ id: 'answer', data: { ...nodesInitialData.answer, diff --git a/web/app/components/workflow/hooks/use-workflow.ts b/web/app/components/workflow/hooks/use-workflow.ts index cfff4220fa..b201b28b88 100644 --- a/web/app/components/workflow/hooks/use-workflow.ts +++ b/web/app/components/workflow/hooks/use-workflow.ts @@ -6,6 +6,7 @@ import { } from 'react' import dayjs from 'dayjs' import { uniqBy } from 'lodash-es' +import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { getIncomers, @@ -29,6 +30,11 @@ import { useWorkflowStore, } from '../store' import { + getParallelInfo, +} from '../utils' +import { + PARALLEL_DEPTH_LIMIT, + PARALLEL_LIMIT, SUPPORT_OUTPUT_VARS_NODE, } from '../constants' import { CUSTOM_NOTE_NODE } from '../note-node/constants' @@ -50,6 +56,7 @@ import { } from '@/service/tools' import I18n from '@/context/i18n' import { CollectionType } from '@/app/components/tools/types' +import { CUSTOM_ITERATION_START_NODE } from '@/app/components/workflow/nodes/iteration-start/constants' export const useIsChatMode = () => { const appDetail = useAppStore(s => s.appDetail) @@ -58,6 +65,7 @@ export const useIsChatMode = () => { } export const useWorkflow = () => { + const { t } = useTranslation() const { locale } = useContext(I18n) const store = useStoreApi() const workflowStore = useWorkflowStore() @@ -77,7 +85,7 @@ export const useWorkflow = () => { const currentNode = nodes.find(node => node.id === nodeId) if (currentNode?.parentId) - startNode = nodes.find(node => node.parentId === currentNode.parentId && node.data.isIterationStart) + startNode = nodes.find(node => node.parentId === currentNode.parentId && node.type === CUSTOM_ITERATION_START_NODE) if (!startNode) return [] @@ -275,7 +283,43 @@ export const useWorkflow = () => { return isUsed }, [isVarUsedInNodes]) - const isValidConnection = useCallback(({ source, target }: Connection) => { + const checkParallelLimit = useCallback((nodeId: string, nodeHandle = 'source') => { + const { + edges, + } = store.getState() + const connectedEdges = edges.filter(edge => edge.source === nodeId && edge.sourceHandle === nodeHandle) + if (connectedEdges.length > PARALLEL_LIMIT - 1) { + const { setShowTips } = workflowStore.getState() + setShowTips(t('workflow.common.parallelTip.limit', { num: PARALLEL_LIMIT })) + return false + } + + return true + }, [store, workflowStore, t]) + + const checkNestedParallelLimit = useCallback((nodes: Node[], edges: Edge[], parentNodeId?: string) => { + const { + parallelList, + hasAbnormalEdges, + } = getParallelInfo(nodes, edges, parentNodeId) + + if (hasAbnormalEdges) + return false + + for (let i = 0; i < parallelList.length; i++) { + const parallel = parallelList[i] + + if (parallel.depth > PARALLEL_DEPTH_LIMIT) { + const { setShowTips } = workflowStore.getState() + setShowTips(t('workflow.common.parallelTip.depthLimit', { num: PARALLEL_DEPTH_LIMIT })) + return false + } + } + + return true + }, [t, workflowStore]) + + const isValidConnection = useCallback(({ source, sourceHandle, target }: Connection) => { const { edges, getNodes, @@ -284,12 +328,15 @@ export const useWorkflow = () => { const sourceNode: Node = nodes.find(node => node.id === source)! const targetNode: Node = nodes.find(node => node.id === target)! - if (targetNode.data.isIterationStart) + if (!checkParallelLimit(source!, sourceHandle || 'source')) return false if (sourceNode.type === CUSTOM_NOTE_NODE || targetNode.type === CUSTOM_NOTE_NODE) return false + if (sourceNode.parentId !== targetNode.parentId) + return false + if (sourceNode && targetNode) { const sourceNodeAvailableNextNodes = nodesExtraData[sourceNode.data.type].availableNextNodes const targetNodeAvailablePrevNodes = [...nodesExtraData[targetNode.data.type].availablePrevNodes, BlockEnum.Start] @@ -316,7 +363,7 @@ export const useWorkflow = () => { } return !hasCycle(targetNode) - }, [store, nodesExtraData]) + }, [store, nodesExtraData, checkParallelLimit]) const formatTimeFromNow = useCallback((time: number) => { return dayjs(time).locale(locale === 'zh-Hans' ? 'zh-cn' : locale).fromNow() @@ -339,6 +386,8 @@ export const useWorkflow = () => { isVarUsedInNodes, removeUsedVarInNodes, isNodeVarsUsedInNodes, + checkParallelLimit, + checkNestedParallelLimit, isValidConnection, formatTimeFromNow, getNode, diff --git a/web/app/components/workflow/index.tsx b/web/app/components/workflow/index.tsx index d96faa8677..cdccd60a3b 100644 --- a/web/app/components/workflow/index.tsx +++ b/web/app/components/workflow/index.tsx @@ -55,6 +55,8 @@ import Header from './header' import CustomNode from './nodes' import CustomNoteNode from './note-node' import { CUSTOM_NOTE_NODE } from './note-node/constants' +import CustomIterationStartNode from './nodes/iteration-start' +import { CUSTOM_ITERATION_START_NODE } from './nodes/iteration-start/constants' import Operator from './operator' import CustomEdge from './custom-edge' import CustomConnectionLine from './custom-connection-line' @@ -67,6 +69,7 @@ import NodeContextmenu from './node-contextmenu' import SyncingDataModal from './syncing-data-modal' import UpdateDSLModal from './update-dsl-modal' import DSLExportConfirmModal from './dsl-export-confirm-modal' +import LimitTips from './limit-tips' import { useStore, useWorkflowStore, @@ -92,6 +95,7 @@ import Confirm from '@/app/components/base/confirm' const nodeTypes = { [CUSTOM_NODE]: CustomNode, [CUSTOM_NOTE_NODE]: CustomNoteNode, + [CUSTOM_ITERATION_START_NODE]: CustomIterationStartNode, } const edgeTypes = { [CUSTOM_NODE]: CustomEdge, @@ -317,6 +321,7 @@ const Workflow: FC = memo(({ /> ) } + { + const showTips = useStore(s => s.showTips) + const setShowTips = useStore(s => s.setShowTips) + + if (!showTips) + return null + + return ( +
+
+
+ +
+
+ {showTips} +
+ setShowTips('')} + > + + +
+ ) +} + +export default LimitTips diff --git a/web/app/components/workflow/nodes/_base/components/before-run-form/form-item.tsx b/web/app/components/workflow/nodes/_base/components/before-run-form/form-item.tsx index cde437c4c5..603368a4b3 100644 --- a/web/app/components/workflow/nodes/_base/components/before-run-form/form-item.tsx +++ b/web/app/components/workflow/nodes/_base/components/before-run-form/form-item.tsx @@ -104,7 +104,7 @@ const FormItem: FC = ({ type="text" value={value || ''} onChange={e => onChange(e.target.value)} - placeholder={t('appDebug.variableConig.inputPlaceholder')!} + placeholder={t('appDebug.variableConfig.inputPlaceholder')!} autoFocus={autoFocus} /> ) @@ -117,7 +117,7 @@ const FormItem: FC = ({ type="number" value={value || ''} onChange={e => onChange(e.target.value)} - placeholder={t('appDebug.variableConig.inputPlaceholder')!} + placeholder={t('appDebug.variableConfig.inputPlaceholder')!} autoFocus={autoFocus} /> ) @@ -129,7 +129,7 @@ const FormItem: FC = ({ className="w-full px-3 py-1 text-sm leading-[18px] text-gray-900 border-0 rounded-lg grow h-[120px] bg-gray-50 focus:outline-none focus:ring-1 focus:ring-inset focus:ring-gray-200" value={value || ''} onChange={e => onChange(e.target.value)} - placeholder={t('appDebug.variableConig.inputPlaceholder')!} + placeholder={t('appDebug.variableConfig.inputPlaceholder')!} autoFocus={autoFocus} /> ) @@ -207,7 +207,7 @@ const FormItem: FC = ({ key={index} isInNode value={item} - title={{t('appDebug.variableConig.content')} {index + 1} } + title={{t('appDebug.variableConfig.content')} {index + 1} } onChange={handleArrayItemChange(index)} headerRight={ (value as any).length > 1 diff --git a/web/app/components/workflow/nodes/_base/components/next-step/add.tsx b/web/app/components/workflow/nodes/_base/components/next-step/add.tsx index 0ab0c8e39e..75694983cd 100644 --- a/web/app/components/workflow/nodes/_base/components/next-step/add.tsx +++ b/web/app/components/workflow/nodes/_base/components/next-step/add.tsx @@ -1,6 +1,7 @@ import { memo, useCallback, + useState, } from 'react' import { useTranslation } from 'react-i18next' import { @@ -10,6 +11,7 @@ import { useAvailableBlocks, useNodesInteractions, useNodesReadOnly, + useWorkflow, } from '@/app/components/workflow/hooks' import BlockSelector from '@/app/components/workflow/block-selector' import type { @@ -21,18 +23,20 @@ type AddProps = { nodeId: string nodeData: CommonNodeType sourceHandle: string - branchName?: string + isParallel?: boolean } const Add = ({ nodeId, nodeData, sourceHandle, - branchName, + isParallel, }: AddProps) => { const { t } = useTranslation() + const [open, setOpen] = useState(false) const { handleNodeAdd } = useNodesInteractions() const { nodesReadOnly } = useNodesReadOnly() const { availableNextBlocks } = useAvailableBlocks(nodeData.type, nodeData.isInIteration) + const { checkParallelLimit } = useWorkflow() const handleSelect = useCallback((type, toolDefaultValue) => { handleNodeAdd( @@ -47,6 +51,13 @@ const Add = ({ ) }, [nodeId, sourceHandle, handleNodeAdd]) + const handleOpenChange = useCallback((newOpen: boolean) => { + if (newOpen && !checkParallelLimit(nodeId, sourceHandle)) + return + + setOpen(newOpen) + }, [checkParallelLimit, nodeId, sourceHandle]) + const renderTrigger = useCallback((open: boolean) => { return (
- { - branchName && ( -
-
{branchName.toLocaleUpperCase()}
-
- ) - }
- {t('workflow.panel.selectNextStep')} +
+ { + isParallel + ? t('workflow.common.addParallelNode') + : t('workflow.panel.selectNextStep') + } +
) - }, [branchName, t, nodesReadOnly]) + }, [t, nodesReadOnly, isParallel]) return ( { + return ( +
+ { + branchName && ( +
+ {branchName} +
+ ) + } + { + nextNodes.map(nextNode => ( + + )) + } + +
+ ) +} + +export default Container diff --git a/web/app/components/workflow/nodes/_base/components/next-step/index.tsx b/web/app/components/workflow/nodes/_base/components/next-step/index.tsx index 261eb3fac7..d980eb284e 100644 --- a/web/app/components/workflow/nodes/_base/components/next-step/index.tsx +++ b/web/app/components/workflow/nodes/_base/components/next-step/index.tsx @@ -1,4 +1,5 @@ -import { memo } from 'react' +import { memo, useMemo } from 'react' +import { useTranslation } from 'react-i18next' import { getConnectedEdges, getOutgoers, @@ -8,13 +9,11 @@ import { import { useToolIcon } from '../../../../hooks' import BlockIcon from '../../../../block-icon' import type { - Branch, Node, } from '../../../../types' import { BlockEnum } from '../../../../types' -import Add from './add' -import Item from './item' import Line from './line' +import Container from './container' type NextStepProps = { selectedNode: Node @@ -22,15 +21,33 @@ type NextStepProps = { const NextStep = ({ selectedNode, }: NextStepProps) => { + const { t } = useTranslation() const data = selectedNode.data const toolIcon = useToolIcon(data) const store = useStoreApi() - const branches = data._targetBranches || [] + const branches = useMemo(() => { + return data._targetBranches || [] + }, [data]) const nodeWithBranches = data.type === BlockEnum.IfElse || data.type === BlockEnum.QuestionClassifier const edges = useEdges() const outgoers = getOutgoers(selectedNode as Node, store.getState().getNodes(), edges) const connectedEdges = getConnectedEdges([selectedNode] as Node[], edges).filter(edge => edge.source === selectedNode!.id) + const branchesOutgoers = useMemo(() => { + if (!branches?.length) + return [] + + return branches.map((branch) => { + const connected = connectedEdges.filter(edge => edge.sourceHandle === branch.id) + const nextNodes = connected.map(edge => outgoers.find(outgoer => outgoer.id === edge.target)!) + + return { + branch, + nextNodes, + } + }) + }, [branches, connectedEdges, outgoers]) + return (
@@ -39,59 +56,32 @@ const NextStep = ({ toolIcon={toolIcon} />
- -
+ item.nextNodes.length + 1) : [1]} + /> +
{ - !nodeWithBranches && !!outgoers.length && ( - - ) - } - { - !nodeWithBranches && !outgoers.length && ( - ) } { - !!branches?.length && nodeWithBranches && ( - branches.map((branch: Branch) => { - const connected = connectedEdges.find(edge => edge.sourceHandle === branch.id) - const target = outgoers.find(outgoer => outgoer.id === connected?.target) - + nodeWithBranches && ( + branchesOutgoers.map((item, index) => { return ( -
- { - connected && ( - - ) - } - { - !connected && ( - - ) - } -
+ ) }) ) diff --git a/web/app/components/workflow/nodes/_base/components/next-step/item.tsx b/web/app/components/workflow/nodes/_base/components/next-step/item.tsx index b806de5684..db3748abd9 100644 --- a/web/app/components/workflow/nodes/_base/components/next-step/item.tsx +++ b/web/app/components/workflow/nodes/_base/components/next-step/item.tsx @@ -1,94 +1,82 @@ import { memo, useCallback, + useState, } from 'react' import { useTranslation } from 'react-i18next' -import { intersection } from 'lodash-es' +import Operator from './operator' import type { CommonNodeType, - OnSelectBlock, } from '@/app/components/workflow/types' import BlockIcon from '@/app/components/workflow/block-icon' -import BlockSelector from '@/app/components/workflow/block-selector' import { - useAvailableBlocks, useNodesInteractions, useNodesReadOnly, useToolIcon, } from '@/app/components/workflow/hooks' import Button from '@/app/components/base/button' +import cn from '@/utils/classnames' type ItemProps = { nodeId: string sourceHandle: string - branchName?: string data: CommonNodeType } const Item = ({ nodeId, sourceHandle, - branchName, data, }: ItemProps) => { const { t } = useTranslation() - const { handleNodeChange } = useNodesInteractions() + const [open, setOpen] = useState(false) const { nodesReadOnly } = useNodesReadOnly() + const { handleNodeSelect } = useNodesInteractions() const toolIcon = useToolIcon(data) - const { - availablePrevBlocks, - availableNextBlocks, - } = useAvailableBlocks(data.type, data.isInIteration) - const handleSelect = useCallback((type, toolDefaultValue) => { - handleNodeChange(nodeId, type, sourceHandle, toolDefaultValue) - }, [nodeId, sourceHandle, handleNodeChange]) - const renderTrigger = useCallback((open: boolean) => { - return ( - - ) - }, [t]) + const handleOpenChange = useCallback((v: boolean) => { + setOpen(v) + }, []) return (
- { - branchName && ( -
-
{branchName.toLocaleUpperCase()}
-
- ) - } -
{data.title}
+
+ {data.title} +
{ !nodesReadOnly && ( - item !== data.type)} - /> + <> + +
+ +
+ ) }
diff --git a/web/app/components/workflow/nodes/_base/components/next-step/line.tsx b/web/app/components/workflow/nodes/_base/components/next-step/line.tsx index ccf4e8d730..3a4430cb5d 100644 --- a/web/app/components/workflow/nodes/_base/components/next-step/line.tsx +++ b/web/app/components/workflow/nodes/_base/components/next-step/line.tsx @@ -1,56 +1,70 @@ import { memo } from 'react' type LineProps = { - linesNumber: number + list: number[] } const Line = ({ - linesNumber, + list, }: LineProps) => { - const svgHeight = linesNumber * 36 + (linesNumber - 1) * 12 + const listHeight = list.map((item) => { + return item * 36 + (item - 1) * 2 + 12 + 6 + }) + const processedList = listHeight.map((item, index) => { + if (index === 0) + return item + + return listHeight.slice(0, index).reduce((acc, cur) => acc + cur, 0) + item + }) + const processedListLength = processedList.length + const svgHeight = processedList[processedListLength - 1] + (processedListLength - 1) * 8 return ( { - Array(linesNumber).fill(0).map((_, index) => ( - - { - index === 0 && ( - <> + processedList.map((item, index) => { + const prevItem = index > 0 ? processedList[index - 1] : 0 + const space = prevItem + index * 8 + 16 + return ( + + { + index === 0 && ( + <> + + + + ) + } + { + index > 0 && ( - - - ) - } - { - index > 0 && ( - - ) - } - - - )) + ) + } + + + ) + }) } ) diff --git a/web/app/components/workflow/nodes/_base/components/next-step/operator.tsx b/web/app/components/workflow/nodes/_base/components/next-step/operator.tsx new file mode 100644 index 0000000000..ad6c7abd0c --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/next-step/operator.tsx @@ -0,0 +1,129 @@ +import { + useCallback, +} from 'react' +import { useTranslation } from 'react-i18next' +import { RiMoreFill } from '@remixicon/react' +import { intersection } from 'lodash-es' +import { + PortalToFollowElem, + PortalToFollowElemContent, + PortalToFollowElemTrigger, +} from '@/app/components/base/portal-to-follow-elem' +import Button from '@/app/components/base/button' +import BlockSelector from '@/app/components/workflow/block-selector' +import { + useAvailableBlocks, + useNodesInteractions, +} from '@/app/components/workflow/hooks' +import type { + CommonNodeType, + OnSelectBlock, +} from '@/app/components/workflow/types' + +type ChangeItemProps = { + data: CommonNodeType + nodeId: string + sourceHandle: string +} +const ChangeItem = ({ + data, + nodeId, + sourceHandle, +}: ChangeItemProps) => { + const { t } = useTranslation() + + const { handleNodeChange } = useNodesInteractions() + const { + availablePrevBlocks, + availableNextBlocks, + } = useAvailableBlocks(data.type, data.isInIteration) + + const handleSelect = useCallback((type, toolDefaultValue) => { + handleNodeChange(nodeId, type, sourceHandle, toolDefaultValue) + }, [nodeId, sourceHandle, handleNodeChange]) + + const renderTrigger = useCallback(() => { + return ( +
+ {t('workflow.panel.change')} +
+ ) + }, [t]) + + return ( + item !== data.type)} + /> + ) +} + +type OperatorProps = { + open: boolean + onOpenChange: (v: boolean) => void + data: CommonNodeType + nodeId: string + sourceHandle: string +} +const Operator = ({ + open, + onOpenChange, + data, + nodeId, + sourceHandle, +}: OperatorProps) => { + const { t } = useTranslation() + const { + handleNodeDelete, + handleNodeDisconnect, + } = useNodesInteractions() + + return ( + + onOpenChange(!open)}> + + + +
+
+ +
handleNodeDisconnect(nodeId)} + > + {t('workflow.common.disconnect')} +
+
+
+
handleNodeDelete(nodeId)} + > + {t('common.operation.delete')} +
+
+
+
+
+ ) +} + +export default Operator diff --git a/web/app/components/workflow/nodes/_base/components/node-handle.tsx b/web/app/components/workflow/nodes/_base/components/node-handle.tsx index 56870f79d6..9a662366d4 100644 --- a/web/app/components/workflow/nodes/_base/components/node-handle.tsx +++ b/web/app/components/workflow/nodes/_base/components/node-handle.tsx @@ -9,16 +9,22 @@ import { Handle, Position, } from 'reactflow' +import { useTranslation } from 'react-i18next' import { BlockEnum } from '../../../types' import type { Node } from '../../../types' import BlockSelector from '../../../block-selector' import type { ToolDefaultValue } from '../../../block-selector/types' import { useAvailableBlocks, + useIsChatMode, useNodesInteractions, useNodesReadOnly, + useWorkflow, } from '../../../hooks' -import { useStore } from '../../../store' +import { + useStore, +} from '../../../store' +import Tooltip from '@/app/components/base/tooltip' type NodeHandleProps = { handleId: string @@ -38,9 +44,7 @@ export const NodeTargetHandle = memo(({ const { getNodesReadOnly } = useNodesReadOnly() const connected = data._connectedTargetHandleIds?.includes(handleId) const { availablePrevBlocks } = useAvailableBlocks(data.type, data.isInIteration) - const isConnectable = !!availablePrevBlocks.length && ( - !data.isIterationStart - ) + const isConnectable = !!availablePrevBlocks.length const handleOpenChange = useCallback((v: boolean) => { setOpen(v) @@ -112,12 +116,15 @@ export const NodeSourceHandle = memo(({ handleClassName, nodeSelectorClassName, }: NodeHandleProps) => { + const { t } = useTranslation() const notInitialWorkflow = useStore(s => s.notInitialWorkflow) const [open, setOpen] = useState(false) const { handleNodeAdd } = useNodesInteractions() const { getNodesReadOnly } = useNodesReadOnly() const { availableNextBlocks } = useAvailableBlocks(data.type, data.isInIteration) const isConnectable = !!availableNextBlocks.length + const isChatMode = useIsChatMode() + const { checkParallelLimit } = useWorkflow() const connected = data._connectedSourceHandleIds?.includes(handleId) const handleOpenChange = useCallback((v: boolean) => { @@ -125,9 +132,9 @@ export const NodeSourceHandle = memo(({ }, []) const handleHandleClick = useCallback((e: MouseEvent) => { e.stopPropagation() - if (!connected) + if (checkParallelLimit(id, handleId)) setOpen(v => !v) - }, [connected]) + }, [checkParallelLimit, id, handleId]) const handleSelect = useCallback((type: BlockEnum, toolDefaultValue?: ToolDefaultValue) => { handleNodeAdd( { @@ -142,12 +149,25 @@ export const NodeSourceHandle = memo(({ }, [handleNodeAdd, id, handleId]) useEffect(() => { - if (notInitialWorkflow && data.type === BlockEnum.Start) + if (notInitialWorkflow && data.type === BlockEnum.Start && !isChatMode) setOpen(true) - }, [notInitialWorkflow, data.type]) + }, [notInitialWorkflow, data.type, isChatMode]) return ( - <> + +
+ {t('workflow.common.parallelTip.click.title')} + {t('workflow.common.parallelTip.click.desc')} +
+
+ {t('workflow.common.parallelTip.drag.title')} + {t('workflow.common.parallelTip.drag.desc')} +
+
+ )} + > { - !connected && isConnectable && !getNodesReadOnly() && ( + isConnectable && !getNodesReadOnly() && ( - + ) }) NodeSourceHandle.displayName = 'NodeSourceHandle' diff --git a/web/app/components/workflow/nodes/_base/components/node-resizer.tsx b/web/app/components/workflow/nodes/_base/components/node-resizer.tsx index 4c83bea8d6..a8e7a9aa11 100644 --- a/web/app/components/workflow/nodes/_base/components/node-resizer.tsx +++ b/web/app/components/workflow/nodes/_base/components/node-resizer.tsx @@ -28,8 +28,8 @@ const NodeResizer = ({ nodeId, nodeData, icon = , - minWidth = 272, - minHeight = 176, + minWidth = 258, + minHeight = 152, maxWidth, }: NodeResizerProps) => { const { handleNodeResize } = useNodesInteractions() diff --git a/web/app/components/workflow/nodes/_base/hooks/use-resize-panel.ts b/web/app/components/workflow/nodes/_base/hooks/use-resize-panel.ts index 1a521cd58f..f2259a02cf 100644 --- a/web/app/components/workflow/nodes/_base/hooks/use-resize-panel.ts +++ b/web/app/components/workflow/nodes/_base/hooks/use-resize-panel.ts @@ -5,7 +5,7 @@ import { useState, } from 'react' -export type UseResizePanelPrarams = { +export type UseResizePanelParams = { direction?: 'horizontal' | 'vertical' | 'both' triggerDirection?: 'top' | 'right' | 'bottom' | 'left' | 'top-right' | 'top-left' | 'bottom-right' | 'bottom-left' minWidth?: number @@ -15,7 +15,7 @@ export type UseResizePanelPrarams = { onResized?: (width: number, height: number) => void onResize?: (width: number, height: number) => void } -export const useResizePanel = (params?: UseResizePanelPrarams) => { +export const useResizePanel = (params?: UseResizePanelParams) => { const { direction = 'both', triggerDirection = 'bottom-right', diff --git a/web/app/components/workflow/nodes/_base/node.tsx b/web/app/components/workflow/nodes/_base/node.tsx index 0b45c80888..bd5921c735 100644 --- a/web/app/components/workflow/nodes/_base/node.tsx +++ b/web/app/components/workflow/nodes/_base/node.tsx @@ -14,6 +14,7 @@ import { RiErrorWarningLine, RiLoader2Line, } from '@remixicon/react' +import { useTranslation } from 'react-i18next' import type { NodeProps } from '../../types' import { BlockEnum, @@ -43,6 +44,7 @@ const BaseNode: FC = ({ data, children, }) => { + const { t } = useTranslation() const nodeRef = useRef(null) const { nodesReadOnly } = useNodesReadOnly() const { handleNodeIterationChildSizeChange } = useNodeIterationInteractions() @@ -80,6 +82,7 @@ const BaseNode: FC = ({ className={cn( 'flex border-[2px] rounded-2xl', showSelectedBorder ? 'border-components-option-card-option-selected-border' : 'border-transparent', + !showSelectedBorder && data._inParallelHovering && 'border-workflow-block-border-highlight', )} ref={nodeRef} style={{ @@ -100,6 +103,13 @@ const BaseNode: FC = ({ data._isBundled && '!shadow-lg', )} > + { + data._inParallelHovering && ( +
+ {t('workflow.common.parallelRun')} +
+ ) + } { data._showAddVariablePopup && ( = ({ const { availableVars, availableNodes } = useAvailableVarList(nodeId, { onlyLeafNodeVar: false, filterVar: (varPayload: Var) => { - return [VarType.string, VarType.number, VarType.secret].includes(varPayload.type) + return [VarType.string, VarType.number, VarType.secret, VarType.arrayNumber, VarType.arrayString].includes(varPayload.type) }, }) diff --git a/web/app/components/workflow/nodes/http/components/key-value/key-value-edit/item.tsx b/web/app/components/workflow/nodes/http/components/key-value/key-value-edit/item.tsx index 93c2696b98..0f28db96c3 100644 --- a/web/app/components/workflow/nodes/http/components/key-value/key-value-edit/item.tsx +++ b/web/app/components/workflow/nodes/http/components/key-value/key-value-edit/item.tsx @@ -73,7 +73,7 @@ const KeyValueItem: FC = ({ handleChange('key')(e.target.value)} /> )}
diff --git a/web/app/components/workflow/nodes/iteration-start/constants.ts b/web/app/components/workflow/nodes/iteration-start/constants.ts new file mode 100644 index 0000000000..94e3ccbd90 --- /dev/null +++ b/web/app/components/workflow/nodes/iteration-start/constants.ts @@ -0,0 +1 @@ +export const CUSTOM_ITERATION_START_NODE = 'custom-iteration-start' diff --git a/web/app/components/workflow/nodes/iteration-start/default.ts b/web/app/components/workflow/nodes/iteration-start/default.ts new file mode 100644 index 0000000000..d98efa7ba2 --- /dev/null +++ b/web/app/components/workflow/nodes/iteration-start/default.ts @@ -0,0 +1,21 @@ +import type { NodeDefault } from '../../types' +import type { IterationStartNodeType } from './types' +import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/constants' + +const nodeDefault: NodeDefault = { + defaultValue: {}, + getAvailablePrevNodes() { + return [] + }, + getAvailableNextNodes(isChatMode: boolean) { + const nodes = isChatMode ? ALL_CHAT_AVAILABLE_BLOCKS : ALL_COMPLETION_AVAILABLE_BLOCKS + return nodes + }, + checkValid() { + return { + isValid: true, + } + }, +} + +export default nodeDefault diff --git a/web/app/components/workflow/nodes/iteration-start/index.tsx b/web/app/components/workflow/nodes/iteration-start/index.tsx new file mode 100644 index 0000000000..9d7ac1f905 --- /dev/null +++ b/web/app/components/workflow/nodes/iteration-start/index.tsx @@ -0,0 +1,42 @@ +import { memo } from 'react' +import { useTranslation } from 'react-i18next' +import type { NodeProps } from 'reactflow' +import { RiHome5Fill } from '@remixicon/react' +import Tooltip from '@/app/components/base/tooltip' +import { NodeSourceHandle } from '@/app/components/workflow/nodes/_base/components/node-handle' + +const IterationStartNode = ({ id, data }: NodeProps) => { + const { t } = useTranslation() + + return ( +
+ +
+ +
+
+ +
+ ) +} + +export const IterationStartNodeDumb = () => { + const { t } = useTranslation() + + return ( +
+ +
+ +
+
+
+ ) +} + +export default memo(IterationStartNode) diff --git a/web/app/components/workflow/nodes/iteration-start/types.ts b/web/app/components/workflow/nodes/iteration-start/types.ts new file mode 100644 index 0000000000..319cce0bc2 --- /dev/null +++ b/web/app/components/workflow/nodes/iteration-start/types.ts @@ -0,0 +1,3 @@ +import type { CommonNodeType } from '@/app/components/workflow/types' + +export type IterationStartNodeType = CommonNodeType diff --git a/web/app/components/workflow/nodes/iteration/add-block.tsx b/web/app/components/workflow/nodes/iteration/add-block.tsx index fd8480b7df..07e2b5daf0 100644 --- a/web/app/components/workflow/nodes/iteration/add-block.tsx +++ b/web/app/components/workflow/nodes/iteration/add-block.tsx @@ -2,87 +2,49 @@ import { memo, useCallback, } from 'react' -import produce from 'immer' import { RiAddLine, } from '@remixicon/react' -import { useStoreApi } from 'reactflow' import { useTranslation } from 'react-i18next' import { - generateNewNode, -} from '../../utils' -import { - WorkflowHistoryEvent, useAvailableBlocks, + useNodesInteractions, useNodesReadOnly, - useWorkflowHistory, } from '../../hooks' -import { NODES_INITIAL_DATA } from '../../constants' -import InsertBlock from './insert-block' import type { IterationNodeType } from './types' import cn from '@/utils/classnames' import BlockSelector from '@/app/components/workflow/block-selector' -import { IterationStart } from '@/app/components/base/icons/src/vender/workflow' import type { OnSelectBlock, } from '@/app/components/workflow/types' import { BlockEnum, } from '@/app/components/workflow/types' -import Tooltip from '@/app/components/base/tooltip' type AddBlockProps = { iterationNodeId: string iterationNodeData: IterationNodeType } const AddBlock = ({ - iterationNodeId, iterationNodeData, }: AddBlockProps) => { const { t } = useTranslation() - const store = useStoreApi() const { nodesReadOnly } = useNodesReadOnly() + const { handleNodeAdd } = useNodesInteractions() const { availableNextBlocks } = useAvailableBlocks(BlockEnum.Start, true) - const { availablePrevBlocks } = useAvailableBlocks(iterationNodeData.startNodeType, true) - const { saveStateToHistory } = useWorkflowHistory() const handleSelect = useCallback((type, toolDefaultValue) => { - const { - getNodes, - setNodes, - } = store.getState() - const nodes = getNodes() - const nodesWithSameType = nodes.filter(node => node.data.type === type) - const newNode = generateNewNode({ - data: { - ...NODES_INITIAL_DATA[type], - title: nodesWithSameType.length > 0 ? `${t(`workflow.blocks.${type}`)} ${nodesWithSameType.length + 1}` : t(`workflow.blocks.${type}`), - ...(toolDefaultValue || {}), - isIterationStart: true, - isInIteration: true, - iteration_id: iterationNodeId, + handleNodeAdd( + { + nodeType: type, + toolDefaultValue, }, - position: { - x: 117, - y: 85, + { + prevNodeId: iterationNodeData.start_node_id, + prevNodeSourceHandle: 'source', }, - zIndex: 1001, - parentId: iterationNodeId, - extent: 'parent', - }) - const newNodes = produce(nodes, (draft) => { - draft.forEach((node) => { - if (node.id === iterationNodeId) { - node.data._children = [newNode.id] - node.data.start_node_id = newNode.id - node.data.startNodeType = newNode.data.type - } - }) - draft.push(newNode) - }) - setNodes(newNodes) - saveStateToHistory(WorkflowHistoryEvent.NodeAdd) - }, [store, t, iterationNodeId, saveStateToHistory]) + ) + }, [handleNodeAdd, iterationNodeData.start_node_id]) const renderTriggerElement = useCallback((open: boolean) => { return ( @@ -98,35 +60,18 @@ const AddBlock = ({ }, [nodesReadOnly, t]) return ( -
- -
- -
-
+
- { - iterationNodeData.startNodeType && ( - - ) - }
- { - !iterationNodeData.startNodeType && ( - - ) - } +
) } diff --git a/web/app/components/workflow/nodes/iteration/default.ts b/web/app/components/workflow/nodes/iteration/default.ts index 43f8a751ac..3afa52d06e 100644 --- a/web/app/components/workflow/nodes/iteration/default.ts +++ b/web/app/components/workflow/nodes/iteration/default.ts @@ -9,6 +9,7 @@ const nodeDefault: NodeDefault = { start_node_id: '', iterator_selector: [], output_selector: [], + _children: [], }, getAvailablePrevNodes(isChatMode: boolean) { const nodes = isChatMode diff --git a/web/app/components/workflow/nodes/iteration/insert-block.tsx b/web/app/components/workflow/nodes/iteration/insert-block.tsx deleted file mode 100644 index d041fe1c74..0000000000 --- a/web/app/components/workflow/nodes/iteration/insert-block.tsx +++ /dev/null @@ -1,61 +0,0 @@ -import { - memo, - useCallback, - useState, -} from 'react' -import { useNodesInteractions } from '../../hooks' -import type { - BlockEnum, - OnSelectBlock, -} from '../../types' -import BlockSelector from '../../block-selector' -import cn from '@/utils/classnames' - -type InsertBlockProps = { - startNodeId: string - availableBlocksTypes: BlockEnum[] -} -const InsertBlock = ({ - startNodeId, - availableBlocksTypes, -}: InsertBlockProps) => { - const [open, setOpen] = useState(false) - const { handleNodeAdd } = useNodesInteractions() - - const handleOpenChange = useCallback((v: boolean) => { - setOpen(v) - }, []) - const handleInsert = useCallback((nodeType, toolDefaultValue) => { - handleNodeAdd( - { - nodeType, - toolDefaultValue, - }, - { - nextNodeId: startNodeId, - nextNodeTargetHandle: 'target', - }, - ) - }, [startNodeId, handleNodeAdd]) - - return ( -
- 'hover:scale-125 transition-all'} - /> -
- ) -} - -export default memo(InsertBlock) diff --git a/web/app/components/workflow/nodes/iteration/node.tsx b/web/app/components/workflow/nodes/iteration/node.tsx index f4520402f3..48a005a261 100644 --- a/web/app/components/workflow/nodes/iteration/node.tsx +++ b/web/app/components/workflow/nodes/iteration/node.tsx @@ -8,6 +8,7 @@ import { useNodesInitialized, useViewport, } from 'reactflow' +import { IterationStartNodeDumb } from '../iteration-start' import { useNodeIterationInteractions } from './use-interactions' import type { IterationNodeType } from './types' import AddBlock from './add-block' @@ -29,7 +30,7 @@ const Node: FC> = ({ return (
> = ({ size={2 / zoom} color='#E4E5E7' /> - + { + data._isCandidate && ( + + ) + } + { + data._children!.length === 1 && ( + + ) + }
) } diff --git a/web/app/components/workflow/nodes/iteration/use-interactions.ts b/web/app/components/workflow/nodes/iteration/use-interactions.ts index 219c8e731f..f8e3640cc4 100644 --- a/web/app/components/workflow/nodes/iteration/use-interactions.ts +++ b/web/app/components/workflow/nodes/iteration/use-interactions.ts @@ -11,6 +11,7 @@ import { ITERATION_PADDING, NODES_INITIAL_DATA, } from '../../constants' +import { CUSTOM_ITERATION_START_NODE } from '../iteration-start/constants' export const useNodeIterationInteractions = () => { const { t } = useTranslation() @@ -107,12 +108,12 @@ export const useNodeIterationInteractions = () => { const handleNodeIterationChildrenCopy = useCallback((nodeId: string, newNodeId: string) => { const { getNodes } = store.getState() const nodes = getNodes() - const childrenNodes = nodes.filter(n => n.parentId === nodeId) + const childrenNodes = nodes.filter(n => n.parentId === nodeId && n.type !== CUSTOM_ITERATION_START_NODE) return childrenNodes.map((child, index) => { const childNodeType = child.data.type as BlockEnum const nodesWithSameType = nodes.filter(node => node.data.type === childNodeType) - const newNode = generateNewNode({ + const { newNode } = generateNewNode({ data: { ...NODES_INITIAL_DATA[childNodeType], ...child.data, @@ -121,6 +122,7 @@ export const useNodeIterationInteractions = () => { _connectedSourceHandleIds: [], _connectedTargetHandleIds: [], title: nodesWithSameType.length > 0 ? `${t(`workflow.blocks.${childNodeType}`)} ${nodesWithSameType.length + 1}` : t(`workflow.blocks.${childNodeType}`), + iteration_id: newNodeId, }, position: child.position, positionAbsolute: child.positionAbsolute, diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/components/retrieval-config.tsx b/web/app/components/workflow/nodes/knowledge-retrieval/components/retrieval-config.tsx index ebe9d2151e..b335b62e33 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/components/retrieval-config.tsx +++ b/web/app/components/workflow/nodes/knowledge-retrieval/components/retrieval-config.tsx @@ -70,7 +70,7 @@ const RetrievalConfig: FC = ({ } onMultipleRetrievalConfigChange({ top_k: configs.top_k, - score_threshold: configs.score_threshold_enabled ? (configs.score_threshold || DATASET_DEFAULT.score_threshold) : null, + score_threshold: configs.score_threshold_enabled ? (configs.score_threshold ?? DATASET_DEFAULT.score_threshold) : null, reranking_model: payload.retrieval_mode === RETRIEVE_TYPE.oneWay ? undefined : (!configs.reranking_model?.reranking_provider_name diff --git a/web/app/components/workflow/nodes/llm/default.ts b/web/app/components/workflow/nodes/llm/default.ts index 803add6f00..d7597942e5 100644 --- a/web/app/components/workflow/nodes/llm/default.ts +++ b/web/app/components/workflow/nodes/llm/default.ts @@ -44,7 +44,7 @@ const nodeDefault: NodeDefault = { if (!errorMessages && !payload.memory) { const isChatModel = payload.model.mode === 'chat' - const isPromptyEmpty = isChatModel + const isPromptEmpty = isChatModel ? !(payload.prompt_template as PromptItem[]).some((t) => { if (t.edition_type === EditionType.jinja2) return t.jinja2_text !== '' @@ -52,7 +52,7 @@ const nodeDefault: NodeDefault = { return t.text !== '' }) : ((payload.prompt_template as PromptItem).edition_type === EditionType.jinja2 ? (payload.prompt_template as PromptItem).jinja2_text === '' : (payload.prompt_template as PromptItem).text === '') - if (isPromptyEmpty) + if (isPromptEmpty) errorMessages = t(`${i18nPrefix}.fieldRequired`, { field: t('workflow.nodes.llm.prompt') }) } diff --git a/web/app/components/workflow/nodes/parameter-extractor/components/extract-parameter/update.tsx b/web/app/components/workflow/nodes/parameter-extractor/components/extract-parameter/update.tsx index 2ac331558b..081a683234 100644 --- a/web/app/components/workflow/nodes/parameter-extractor/components/extract-parameter/update.tsx +++ b/web/app/components/workflow/nodes/parameter-extractor/components/extract-parameter/update.tsx @@ -99,7 +99,7 @@ const AddExtractParameter: FC = ({ if (!param.name) errMessage = t(`${errorI18nPrefix}.fieldRequired`, { field: t(`${i18nPrefix}.addExtractParameterContent.name`) }) if (!errMessage && param.type === ParamType.select && (!param.options || param.options.length === 0)) - errMessage = t(`${errorI18nPrefix}.fieldRequired`, { field: t('appDebug.variableConig.options') }) + errMessage = t(`${errorI18nPrefix}.fieldRequired`, { field: t('appDebug.variableConfig.options') }) if (!errMessage && !param.description) errMessage = t(`${errorI18nPrefix}.fieldRequired`, { field: t(`${i18nPrefix}.addExtractParameterContent.description`) }) @@ -160,7 +160,7 @@ const AddExtractParameter: FC = ({ /> {param.type === ParamType.select && ( - + )} diff --git a/web/app/components/workflow/operator/add-block.tsx b/web/app/components/workflow/operator/add-block.tsx index 48222cc528..388fbc053f 100644 --- a/web/app/components/workflow/operator/add-block.tsx +++ b/web/app/components/workflow/operator/add-block.tsx @@ -55,7 +55,7 @@ const AddBlock = ({ } = store.getState() const nodes = getNodes() const nodesWithSameType = nodes.filter(node => node.data.type === type) - const newNode = generateNewNode({ + const { newNode } = generateNewNode({ data: { ...NODES_INITIAL_DATA[type], title: nodesWithSameType.length > 0 ? `${t(`workflow.blocks.${type}`)} ${nodesWithSameType.length + 1}` : t(`workflow.blocks.${type}`), diff --git a/web/app/components/workflow/operator/hooks.ts b/web/app/components/workflow/operator/hooks.ts index 5b14211497..edec10bda7 100644 --- a/web/app/components/workflow/operator/hooks.ts +++ b/web/app/components/workflow/operator/hooks.ts @@ -11,7 +11,7 @@ export const useOperator = () => { const { userProfile } = useAppContext() const handleAddNote = useCallback(() => { - const newNode = generateNewNode({ + const { newNode } = generateNewNode({ type: CUSTOM_NOTE_NODE, data: { title: '', diff --git a/web/app/components/workflow/panel/chat-record/index.tsx b/web/app/components/workflow/panel/chat-record/index.tsx index 2ab3165c14..afd20b7358 100644 --- a/web/app/components/workflow/panel/chat-record/index.tsx +++ b/web/app/components/workflow/panel/chat-record/index.tsx @@ -103,9 +103,9 @@ const ChatRecord = () => { } as any} chatList={chatMessageList} chatContainerClassName='px-4' - chatContainerInnerClassName='pt-6' + chatContainerInnerClassName='pt-6 w-full max-w-full mx-auto' chatFooterClassName='px-4 rounded-b-2xl' - chatFooterInnerClassName='pb-4' + chatFooterInnerClassName='pb-4 w-full max-w-full mx-auto' chatNode={} noChatInput allToolIcons={{}} diff --git a/web/app/components/workflow/panel/debug-and-preview/chat-wrapper.tsx b/web/app/components/workflow/panel/debug-and-preview/chat-wrapper.tsx index 4655940037..a7dd607e22 100644 --- a/web/app/components/workflow/panel/debug-and-preview/chat-wrapper.tsx +++ b/web/app/components/workflow/panel/debug-and-preview/chat-wrapper.tsx @@ -103,9 +103,9 @@ const ChatWrapper = forwardRef(({ showConv chatList={chatList} isResponding={isResponding} chatContainerClassName='px-3' - chatContainerInnerClassName='pt-6' + chatContainerInnerClassName='pt-6 w-full max-w-full mx-auto' chatFooterClassName='px-4 rounded-bl-2xl' - chatFooterInnerClassName='pb-4' + chatFooterInnerClassName='pb-4 w-full max-w-full mx-auto' onSend={doSend} onStopResponding={handleStop} chatNode={( @@ -118,6 +118,7 @@ const ChatWrapper = forwardRef(({ showConv } )} + noSpacing suggestedQuestions={suggestedQuestions} showPromptLog chatAnswerContainerInner='!pr-2' diff --git a/web/app/components/workflow/panel/debug-and-preview/hooks.ts b/web/app/components/workflow/panel/debug-and-preview/hooks.ts index 155d2a84ac..51a018bcb1 100644 --- a/web/app/components/workflow/panel/debug-and-preview/hooks.ts +++ b/web/app/components/workflow/panel/debug-and-preview/hooks.ts @@ -180,8 +180,6 @@ export const useChat = ( isAnswer: true, } - let isInIteration = false - handleResponding(true) const bodyParams = { @@ -248,11 +246,16 @@ export const useChat = ( } if (config?.suggested_questions_after_answer?.enabled && !hasStopResponded.current && onGetSuggestedQuestions) { - const { data }: any = await onGetSuggestedQuestions( - responseItem.id, - newAbortController => suggestedQuestionsAbortControllerRef.current = newAbortController, - ) - setSuggestQuestions(data) + try { + const { data }: any = await onGetSuggestedQuestions( + responseItem.id, + newAbortController => suggestedQuestionsAbortControllerRef.current = newAbortController, + ) + setSuggestQuestions(data) + } + catch (error) { + setSuggestQuestions([]) + } } }, onMessageEnd: (messageEnd) => { @@ -312,11 +315,11 @@ export const useChat = ( ...responseItem, } })) - isInIteration = true }, - onIterationNext: () => { + onIterationNext: ({ data }) => { const tracing = responseItem.workflowProcess!.tracing! - const iterations = tracing[tracing.length - 1] + const iterations = tracing.find(item => item.node_id === data.node_id + && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))! iterations.details!.push([]) handleUpdateChatList(produce(chatListRef.current, (draft) => { @@ -326,9 +329,10 @@ export const useChat = ( }, onIterationFinish: ({ data }) => { const tracing = responseItem.workflowProcess!.tracing! - const iterations = tracing[tracing.length - 1] - tracing[tracing.length - 1] = { - ...iterations, + const iterationsIndex = tracing.findIndex(item => item.node_id === data.node_id + && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))! + tracing[iterationsIndex] = { + ...tracing[iterationsIndex], ...data, status: NodeRunningStatus.Succeeded, } as any @@ -336,67 +340,45 @@ export const useChat = ( const currentIndex = draft.length - 1 draft[currentIndex] = responseItem })) - - isInIteration = false }, onNodeStarted: ({ data }) => { - if (isInIteration) { - const tracing = responseItem.workflowProcess!.tracing! - const iterations = tracing[tracing.length - 1] - const currIteration = iterations.details![iterations.details!.length - 1] - currIteration.push({ - ...data, - status: NodeRunningStatus.Running, - } as any) - handleUpdateChatList(produce(chatListRef.current, (draft) => { - const currentIndex = draft.length - 1 - draft[currentIndex] = responseItem - })) - } - else { - responseItem.workflowProcess!.tracing!.push({ - ...data, - status: NodeRunningStatus.Running, - } as any) - handleUpdateChatList(produce(chatListRef.current, (draft) => { - const currentIndex = draft.findIndex(item => item.id === responseItem.id) - draft[currentIndex] = { - ...draft[currentIndex], - ...responseItem, - } - })) - } + if (data.iteration_id) + return + + responseItem.workflowProcess!.tracing!.push({ + ...data, + status: NodeRunningStatus.Running, + } as any) + handleUpdateChatList(produce(chatListRef.current, (draft) => { + const currentIndex = draft.findIndex(item => item.id === responseItem.id) + draft[currentIndex] = { + ...draft[currentIndex], + ...responseItem, + } + })) }, onNodeFinished: ({ data }) => { - if (isInIteration) { - const tracing = responseItem.workflowProcess!.tracing! - const iterations = tracing[tracing.length - 1] - const currIteration = iterations.details![iterations.details!.length - 1] - currIteration[currIteration.length - 1] = { - ...data, - status: NodeRunningStatus.Succeeded, - } as any - handleUpdateChatList(produce(chatListRef.current, (draft) => { - const currentIndex = draft.length - 1 - draft[currentIndex] = responseItem - })) - } - else { - const currentIndex = responseItem.workflowProcess!.tracing!.findIndex(item => item.node_id === data.node_id) - responseItem.workflowProcess!.tracing[currentIndex] = { - ...(responseItem.workflowProcess!.tracing[currentIndex].extras - ? { extras: responseItem.workflowProcess!.tracing[currentIndex].extras } - : {}), - ...data, - } as any - handleUpdateChatList(produce(chatListRef.current, (draft) => { - const currentIndex = draft.findIndex(item => item.id === responseItem.id) - draft[currentIndex] = { - ...draft[currentIndex], - ...responseItem, - } - })) - } + if (data.iteration_id) + return + + const currentIndex = responseItem.workflowProcess!.tracing!.findIndex((item) => { + if (!item.execution_metadata?.parallel_id) + return item.node_id === data.node_id + return item.node_id === data.node_id && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id) + }) + responseItem.workflowProcess!.tracing[currentIndex] = { + ...(responseItem.workflowProcess!.tracing[currentIndex]?.extras + ? { extras: responseItem.workflowProcess!.tracing[currentIndex].extras } + : {}), + ...data, + } as any + handleUpdateChatList(produce(chatListRef.current, (draft) => { + const currentIndex = draft.findIndex(item => item.id === responseItem.id) + draft[currentIndex] = { + ...draft[currentIndex], + ...responseItem, + } + })) }, }, ) diff --git a/web/app/components/workflow/panel/debug-and-preview/index.tsx b/web/app/components/workflow/panel/debug-and-preview/index.tsx index 29fc48f896..8e1c489098 100644 --- a/web/app/components/workflow/panel/debug-and-preview/index.tsx +++ b/web/app/components/workflow/panel/debug-and-preview/index.tsx @@ -88,7 +88,6 @@ const DebugAndPreview = () => { - {expanded &&
}
)}
diff --git a/web/app/components/workflow/panel/record.tsx b/web/app/components/workflow/panel/record.tsx index 079dd2cc86..d79f1a9439 100644 --- a/web/app/components/workflow/panel/record.tsx +++ b/web/app/components/workflow/panel/record.tsx @@ -1,5 +1,5 @@ import { memo, useCallback } from 'react' -import type { WorkflowDataUpdator } from '../types' +import type { WorkflowDataUpdater } from '../types' import Run from '../run' import { useStore } from '../store' import { useWorkflowUpdate } from '../hooks' @@ -9,7 +9,7 @@ const Record = () => { const { handleUpdateWorkflowCanvas } = useWorkflowUpdate() const handleResultCallback = useCallback((res: any) => { - const graph: WorkflowDataUpdator = res.graph + const graph: WorkflowDataUpdater = res.graph handleUpdateWorkflowCanvas({ nodes: graph.nodes, edges: graph.edges, diff --git a/web/app/components/workflow/run/index.tsx b/web/app/components/workflow/run/index.tsx index 702ce06e1c..331ef1c2f5 100644 --- a/web/app/components/workflow/run/index.tsx +++ b/web/app/components/workflow/run/index.tsx @@ -63,26 +63,22 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe const formatNodeList = useCallback((list: NodeTracing[]) => { const allItems = list.reverse() const result: NodeTracing[] = [] - let iterationIndex = 0 allItems.forEach((item) => { const { node_type, execution_metadata } = item if (node_type !== BlockEnum.Iteration) { const isInIteration = !!execution_metadata?.iteration_id if (isInIteration) { - const iterationDetails = result[result.length - 1].details! - const currentIterationIndex = execution_metadata?.iteration_index - const isIterationFirstNode = iterationIndex !== currentIterationIndex || iterationDetails.length === 0 + const iterationNode = result.find(node => node.node_id === execution_metadata?.iteration_id) + const iterationDetails = iterationNode?.details + const currentIterationIndex = execution_metadata?.iteration_index ?? 0 - if (isIterationFirstNode) { - iterationDetails!.push([item]) - iterationIndex = currentIterationIndex! + if (Array.isArray(iterationDetails)) { + if (iterationDetails.length === 0 || !iterationDetails[currentIterationIndex]) + iterationDetails[currentIterationIndex] = [item] + else + iterationDetails[currentIterationIndex].push(item) } - - else { - iterationDetails[iterationDetails.length - 1].push(item) - } - return } // not in iteration @@ -90,7 +86,6 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe return } - result.push({ ...item, details: [], @@ -134,12 +129,12 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe getData(appDetail.id, runID) }, [appDetail, runID]) - const [height, setHieght] = useState(0) + const [height, setHeight] = useState(0) const ref = useRef(null) const adjustResultHeight = () => { if (ref.current) - setHieght(ref.current?.clientHeight - 16 - 16 - 2 - 1) + setHeight(ref.current?.clientHeight - 16 - 16 - 2 - 1) } useEffect(() => { @@ -197,7 +192,7 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe onClick={() => switchTab('TRACING')} >{t('runLog.tracing')}
- {/* panel detal */} + {/* panel detail */}
{loading && (
diff --git a/web/app/components/workflow/run/iteration-result-panel.tsx b/web/app/components/workflow/run/iteration-result-panel.tsx index c833ea0342..4fc30f03df 100644 --- a/web/app/components/workflow/run/iteration-result-panel.tsx +++ b/web/app/components/workflow/run/iteration-result-panel.tsx @@ -1,10 +1,14 @@ 'use client' import type { FC } from 'react' -import React, { useCallback } from 'react' +import React, { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' -import { RiCloseLine } from '@remixicon/react' +import { + RiArrowRightSLine, + RiCloseLine, +} from '@remixicon/react' import { ArrowNarrowLeft } from '../../base/icons/src/vender/line/arrows' -import NodePanel from './node' +import TracingPanel from './tracing-panel' +import { Iteration } from '@/app/components/base/icons/src/vender/workflow' import cn from '@/utils/classnames' import type { NodeTracing } from '@/types/workflow' const i18nPrefix = 'workflow.singleRun' @@ -23,43 +27,67 @@ const IterationResultPanel: FC = ({ noWrap, }) => { const { t } = useTranslation() + const [expandedIterations, setExpandedIterations] = useState>([]) + + const toggleIteration = useCallback((index: number) => { + setExpandedIterations(prev => ({ + ...prev, + [index]: !prev[index], + })) + }, []) const main = ( <> -
+
-
+
{t(`${i18nPrefix}.testRunIteration`)}
- +
-
+
-
{t(`${i18nPrefix}.back`)}
+
{t(`${i18nPrefix}.back`)}
{/* List */} -
+
{list.map((iteration, index) => ( -
-
-
{t(`${i18nPrefix}.iteration`)} {index + 1}
-
+
+
toggleIteration(index)} + > +
+
+ +
+ + {t(`${i18nPrefix}.iteration`)} {index + 1} + + +
-
- {iteration.map(node => ( - - ))} + {expandedIterations[index] &&
} +
+
))} diff --git a/web/app/components/workflow/run/meta.tsx b/web/app/components/workflow/run/meta.tsx index 86eb221ad9..b2d7269a51 100644 --- a/web/app/components/workflow/run/meta.tsx +++ b/web/app/components/workflow/run/meta.tsx @@ -16,7 +16,7 @@ type Props = { const MetaData: FC = ({ status, executor, - startTime = 0, + startTime, time, tokens, steps = 1, @@ -64,7 +64,7 @@ const MetaData: FC = ({
)} {status !== 'running' && ( - {formatTime(startTime, t('appLog.dateTimeFormat') as string)} + {startTime ? formatTime(startTime, t('appLog.dateTimeFormat') as string) : '-'} )}
@@ -75,7 +75,7 @@ const MetaData: FC = ({
)} {status !== 'running' && ( - {`${time?.toFixed(3)}s`} + {time ? `${time.toFixed(3)}s` : '-'} )}
diff --git a/web/app/components/workflow/run/node.tsx b/web/app/components/workflow/run/node.tsx index 66f996f13b..2e45290ddf 100644 --- a/web/app/components/workflow/run/node.tsx +++ b/web/app/components/workflow/run/node.tsx @@ -4,15 +4,17 @@ import type { FC } from 'react' import { useCallback, useEffect, useState } from 'react' import { RiArrowRightSLine, - RiCheckboxCircleLine, + RiCheckboxCircleFill, RiErrorWarningLine, RiLoader2Line, } from '@remixicon/react' import BlockIcon from '../block-icon' import { BlockEnum } from '../types' import Split from '../nodes/_base/components/split' +import { Iteration } from '@/app/components/base/icons/src/vender/workflow' import cn from '@/utils/classnames' import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor' +import Button from '@/app/components/base/button' import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' import { AlertTriangle } from '@/app/components/base/icons/src/vender/line/alertsAndFeedback' import type { NodeTracing } from '@/types/workflow' @@ -61,31 +63,38 @@ const NodePanel: FC = ({ return `${parseFloat((tokens / 1000000).toFixed(3))}M` } + const getCount = (iteration_curr_length: number | undefined, iteration_length: number) => { + if ((iteration_curr_length && iteration_curr_length < iteration_length) || !iteration_length) + return iteration_curr_length + + return iteration_length + } + useEffect(() => { setCollapseState(!nodeInfo.expand) }, [nodeInfo.expand, setCollapseState]) const isIterationNode = nodeInfo.node_type === BlockEnum.Iteration - const handleOnShowIterationDetail = (e: React.MouseEvent) => { + const handleOnShowIterationDetail = (e: React.MouseEvent) => { e.stopPropagation() e.nativeEvent.stopImmediatePropagation() onShowIterationDetail?.(nodeInfo.details || []) } return ( -
-
+
+
setCollapseState(!collapseState)} > {!hideProcessDetail && ( @@ -93,23 +102,23 @@ const NodePanel: FC = ({
{nodeInfo.title}
{nodeInfo.status !== 'running' && !hideInfo && ( -
{`${getTime(nodeInfo.elapsed_time || 0)} · ${getTokenCount(nodeInfo.execution_metadata?.total_tokens || 0)} tokens`}
+
{nodeInfo.execution_metadata?.total_tokens ? `${getTokenCount(nodeInfo.execution_metadata?.total_tokens || 0)} tokens · ` : ''}{`${getTime(nodeInfo.elapsed_time || 0)}`}
)} {nodeInfo.status === 'succeeded' && ( - + )} {nodeInfo.status === 'failed' && ( - + )} {nodeInfo.status === 'stopped' && ( )} {nodeInfo.status === 'running' && ( -
+
Running
@@ -120,25 +129,27 @@ const NodePanel: FC = ({ {/* The nav to the iteration detail */} {isIterationNode && !notShowIterationNav && (
-
-
{t('workflow.nodes.iteration.iteration', { count: nodeInfo.metadata?.iterator_length || nodeInfo.details?.length })}
+
+
)} -
+
{nodeInfo.status === 'stopped' && (
{t('workflow.tracing.stopBy', { user: nodeInfo.created_by ? nodeInfo.created_by.name : 'N/A' })}
)} diff --git a/web/app/components/workflow/run/status.tsx b/web/app/components/workflow/run/status.tsx index 2eeafca95d..0677e43401 100644 --- a/web/app/components/workflow/run/status.tsx +++ b/web/app/components/workflow/run/status.tsx @@ -71,7 +71,7 @@ const StatusPanel: FC = ({
)} {status !== 'running' && ( - {`${time?.toFixed(3)}s`} + {time ? `${time?.toFixed(3)}s` : '-'} )}
diff --git a/web/app/components/workflow/run/tracing-panel.tsx b/web/app/components/workflow/run/tracing-panel.tsx index c80bcb41db..7684dc84b6 100644 --- a/web/app/components/workflow/run/tracing-panel.tsx +++ b/web/app/components/workflow/run/tracing-panel.tsx @@ -1,24 +1,266 @@ 'use client' import type { FC } from 'react' +import +React, +{ + useCallback, + useState, +} from 'react' +import cn from 'classnames' +import { + RiArrowDownSLine, + RiMenu4Line, +} from '@remixicon/react' import NodePanel from './node' +import { + BlockEnum, +} from '@/app/components/workflow/types' import type { NodeTracing } from '@/types/workflow' type TracingPanelProps = { list: NodeTracing[] - onShowIterationDetail: (detail: NodeTracing[][]) => void + onShowIterationDetail?: (detail: NodeTracing[][]) => void + className?: string + hideNodeInfo?: boolean + hideNodeProcessDetail?: boolean } -const TracingPanel: FC = ({ list, onShowIterationDetail }) => { +type TracingNodeProps = { + id: string + uniqueId: string + isParallel: boolean + data: NodeTracing | null + children: TracingNodeProps[] + parallelTitle?: string + branchTitle?: string + hideNodeInfo?: boolean + hideNodeProcessDetail?: boolean +} + +function buildLogTree(nodes: NodeTracing[]): TracingNodeProps[] { + const rootNodes: TracingNodeProps[] = [] + const parallelStacks: { [key: string]: TracingNodeProps } = {} + const levelCounts: { [key: string]: number } = {} + const parallelChildCounts: { [key: string]: Set } = {} + let uniqueIdCounter = 0 + const getUniqueId = () => { + uniqueIdCounter++ + return `unique-${uniqueIdCounter}` + } + + const getParallelTitle = (parentId: string | null): string => { + const levelKey = parentId || 'root' + if (!levelCounts[levelKey]) + levelCounts[levelKey] = 0 + + levelCounts[levelKey]++ + + const parentTitle = parentId ? parallelStacks[parentId]?.parallelTitle : '' + const levelNumber = parentTitle ? parseInt(parentTitle.split('-')[1]) + 1 : 1 + const letter = parallelChildCounts[levelKey]?.size > 1 ? String.fromCharCode(64 + levelCounts[levelKey]) : '' + return `PARALLEL-${levelNumber}${letter}` + } + + const getBranchTitle = (parentId: string | null, branchNum: number): string => { + const levelKey = parentId || 'root' + const parentTitle = parentId ? parallelStacks[parentId]?.parallelTitle : '' + const levelNumber = parentTitle ? parseInt(parentTitle.split('-')[1]) + 1 : 1 + const letter = parallelChildCounts[levelKey]?.size > 1 ? String.fromCharCode(64 + levelCounts[levelKey]) : '' + const branchLetter = String.fromCharCode(64 + branchNum) + return `BRANCH-${levelNumber}${letter}-${branchLetter}` + } + + // Count parallel children (for figuring out if we need to use letters) + for (const node of nodes) { + const parent_parallel_id = node.parent_parallel_id ?? node.execution_metadata?.parent_parallel_id ?? null + const parallel_id = node.parallel_id ?? node.execution_metadata?.parallel_id ?? null + + if (parallel_id) { + const parentKey = parent_parallel_id || 'root' + if (!parallelChildCounts[parentKey]) + parallelChildCounts[parentKey] = new Set() + + parallelChildCounts[parentKey].add(parallel_id) + } + } + + for (const node of nodes) { + const parallel_id = node.parallel_id ?? node.execution_metadata?.parallel_id ?? null + const parent_parallel_id = node.parent_parallel_id ?? node.execution_metadata?.parent_parallel_id ?? null + const parallel_start_node_id = node.parallel_start_node_id ?? node.execution_metadata?.parallel_start_node_id ?? null + const parent_parallel_start_node_id = node.parent_parallel_start_node_id ?? node.execution_metadata?.parent_parallel_start_node_id ?? null + + if (!parallel_id || node.node_type === BlockEnum.End) { + rootNodes.push({ + id: node.id, + uniqueId: getUniqueId(), + isParallel: false, + data: node, + children: [], + }) + } + else { + if (!parallelStacks[parallel_id]) { + const newParallelGroup: TracingNodeProps = { + id: parallel_id, + uniqueId: getUniqueId(), + isParallel: true, + data: null, + children: [], + parallelTitle: '', + } + parallelStacks[parallel_id] = newParallelGroup + + if (parent_parallel_id && parallelStacks[parent_parallel_id]) { + const sameBranchIndex = parallelStacks[parent_parallel_id].children.findLastIndex(c => + c.data?.execution_metadata?.parallel_start_node_id === parent_parallel_start_node_id || c.data?.parallel_start_node_id === parent_parallel_start_node_id, + ) + parallelStacks[parent_parallel_id].children.splice(sameBranchIndex + 1, 0, newParallelGroup) + newParallelGroup.parallelTitle = getParallelTitle(parent_parallel_id) + } + else { + newParallelGroup.parallelTitle = getParallelTitle(parent_parallel_id) + rootNodes.push(newParallelGroup) + } + } + const branchTitle = parallel_start_node_id === node.node_id ? getBranchTitle(parent_parallel_id, parallelStacks[parallel_id].children.length + 1) : '' + if (branchTitle) { + parallelStacks[parallel_id].children.push({ + id: node.id, + uniqueId: getUniqueId(), + isParallel: false, + data: node, + children: [], + branchTitle, + }) + } + else { + let sameBranchIndex = parallelStacks[parallel_id].children.findLastIndex(c => + c.data?.execution_metadata?.parallel_start_node_id === parallel_start_node_id || c.data?.parallel_start_node_id === parallel_start_node_id, + ) + if (parallelStacks[parallel_id].children[sameBranchIndex + 1]?.isParallel) + sameBranchIndex++ + + parallelStacks[parallel_id].children.splice(sameBranchIndex + 1, 0, { + id: node.id, + uniqueId: getUniqueId(), + isParallel: false, + data: node, + children: [], + branchTitle, + }) + } + } + } + + return rootNodes +} + +const TracingPanel: FC = ({ + list, + onShowIterationDetail, + className, + hideNodeInfo = false, + hideNodeProcessDetail = false, +}) => { + const treeNodes = buildLogTree(list) + const [collapsedNodes, setCollapsedNodes] = useState>(new Set()) + const [hoveredParallel, setHoveredParallel] = useState(null) + + const toggleCollapse = (id: string) => { + setCollapsedNodes((prev) => { + const newSet = new Set(prev) + if (newSet.has(id)) + newSet.delete(id) + + else + newSet.add(id) + + return newSet + }) + } + + const handleParallelMouseEnter = useCallback((id: string) => { + setHoveredParallel(id) + }, []) + + const handleParallelMouseLeave = useCallback((e: React.MouseEvent) => { + const relatedTarget = e.relatedTarget as Element | null + if (relatedTarget && 'closest' in relatedTarget) { + const closestParallel = relatedTarget.closest('[data-parallel-id]') + if (closestParallel) + setHoveredParallel(closestParallel.getAttribute('data-parallel-id')) + + else + setHoveredParallel(null) + } + else { + setHoveredParallel(null) + } + }, []) + + const renderNode = (node: TracingNodeProps) => { + if (node.isParallel) { + const isCollapsed = collapsedNodes.has(node.id) + const isHovered = hoveredParallel === node.id + return ( +
handleParallelMouseEnter(node.id)} + onMouseLeave={handleParallelMouseLeave} + > +
+ +
+ {node.parallelTitle} +
+
+
+
+
+ {node.children.map(renderNode)} +
+
+ ) + } + else { + const isHovered = hoveredParallel === node.id + return ( +
+
+ {node.branchTitle} +
+ +
+ ) + } + } + return ( -
- {list.map(node => ( - - ))} +
+ {treeNodes.map(renderNode)}
) } diff --git a/web/app/components/workflow/store.ts b/web/app/components/workflow/store.ts index 2e5e774191..853d0c5934 100644 --- a/web/app/components/workflow/store.ts +++ b/web/app/components/workflow/store.ts @@ -162,6 +162,8 @@ type Shape = { setControlPromptEditorRerenderKey: (controlPromptEditorRerenderKey: number) => void showImportDSLModal: boolean setShowImportDSLModal: (showImportDSLModal: boolean) => void + showTips: string + setShowTips: (showTips: string) => void } export const createWorkflowStore = () => { @@ -262,6 +264,8 @@ export const createWorkflowStore = () => { setControlPromptEditorRerenderKey: controlPromptEditorRerenderKey => set(() => ({ controlPromptEditorRerenderKey })), showImportDSLModal: false, setShowImportDSLModal: showImportDSLModal => set(() => ({ showImportDSLModal })), + showTips: '', + setShowTips: showTips => set(() => ({ showTips })), })) } diff --git a/web/app/components/workflow/types.ts b/web/app/components/workflow/types.ts index 034376fed5..797c2dbd85 100644 --- a/web/app/components/workflow/types.ts +++ b/web/app/components/workflow/types.ts @@ -26,6 +26,7 @@ export enum BlockEnum { Tool = 'tool', ParameterExtractor = 'parameter-extractor', Iteration = 'iteration', + IterationStart = 'iteration-start', Assigner = 'assigner', // is now named as VariableAssigner } @@ -54,7 +55,7 @@ export type CommonNodeType = { _holdAddVariablePopup?: boolean _iterationLength?: number _iterationIndex?: number - isIterationStart?: boolean + _inParallelHovering?: boolean isInIteration?: boolean iteration_id?: string selected?: boolean @@ -69,7 +70,7 @@ export type CommonEdgeType = { _hovering?: boolean _connectedNodeIsHovering?: boolean _connectedNodeIsSelected?: boolean - _runned?: boolean + _run?: boolean _isBundled?: boolean isInIteration?: boolean iteration_id?: string @@ -86,7 +87,7 @@ export type NodePanelProps = { } export type Edge = ReactFlowEdge -export type WorkflowDataUpdator = { +export type WorkflowDataUpdater = { nodes: Node[] edges: Edge[] viewport: Viewport diff --git a/web/app/components/workflow/utils.ts b/web/app/components/workflow/utils.ts index 0d07b2e568..91656e3bbc 100644 --- a/web/app/components/workflow/utils.ts +++ b/web/app/components/workflow/utils.ts @@ -1,12 +1,15 @@ import { Position, getConnectedEdges, + getIncomers, getOutgoers, } from 'reactflow' import dagre from '@dagrejs/dagre' import { v4 as uuid4 } from 'uuid' import { cloneDeep, + groupBy, + isEqual, uniqBy, } from 'lodash-es' import type { @@ -19,14 +22,17 @@ import type { import { BlockEnum } from './types' import { CUSTOM_NODE, + ITERATION_CHILDREN_Z_INDEX, ITERATION_NODE_Z_INDEX, NODE_WIDTH_X_OFFSET, START_INITIAL_POSITION, } from './constants' +import { CUSTOM_ITERATION_START_NODE } from './nodes/iteration-start/constants' import type { QuestionClassifierNodeType } from './nodes/question-classifier/types' import type { IfElseNodeType } from './nodes/if-else/types' import { branchNameCorrect } from './nodes/if-else/utils' import type { ToolNodeType } from './nodes/tool/types' +import type { IterationNodeType } from './nodes/iteration/types' import { CollectionType } from '@/app/components/tools/types' import { toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema' @@ -34,18 +40,18 @@ const WHITE = 'WHITE' const GRAY = 'GRAY' const BLACK = 'BLACK' -const isCyclicUtil = (nodeId: string, color: Record, adjaList: Record, stack: string[]) => { +const isCyclicUtil = (nodeId: string, color: Record, adjList: Record, stack: string[]) => { color[nodeId] = GRAY stack.push(nodeId) - for (let i = 0; i < adjaList[nodeId].length; ++i) { - const childId = adjaList[nodeId][i] + for (let i = 0; i < adjList[nodeId].length; ++i) { + const childId = adjList[nodeId][i] if (color[childId] === GRAY) { stack.push(childId) return true } - if (color[childId] === WHITE && isCyclicUtil(childId, color, adjaList, stack)) + if (color[childId] === WHITE && isCyclicUtil(childId, color, adjList, stack)) return true } color[nodeId] = BLACK @@ -55,21 +61,21 @@ const isCyclicUtil = (nodeId: string, color: Record, adjaList: R } const getCycleEdges = (nodes: Node[], edges: Edge[]) => { - const adjaList: Record = {} + const adjList: Record = {} const color: Record = {} const stack: string[] = [] for (const node of nodes) { color[node.id] = WHITE - adjaList[node.id] = [] + adjList[node.id] = [] } for (const edge of edges) - adjaList[edge.source]?.push(edge.target) + adjList[edge.source]?.push(edge.target) for (let i = 0; i < nodes.length; i++) { if (color[nodes[i].id] === WHITE) - isCyclicUtil(nodes[i].id, color, adjaList, stack) + isCyclicUtil(nodes[i].id, color, adjList, stack) } const cycleEdges = [] @@ -84,9 +90,130 @@ const getCycleEdges = (nodes: Node[], edges: Edge[]) => { return cycleEdges } +export function getIterationStartNode(iterationId: string): Node { + return generateNewNode({ + id: `${iterationId}start`, + type: CUSTOM_ITERATION_START_NODE, + data: { + title: '', + desc: '', + type: BlockEnum.IterationStart, + isInIteration: true, + }, + position: { + x: 24, + y: 68, + }, + zIndex: ITERATION_CHILDREN_Z_INDEX, + parentId: iterationId, + selectable: false, + draggable: false, + }).newNode +} + +export function generateNewNode({ data, position, id, zIndex, type, ...rest }: Omit & { id?: string }): { + newNode: Node + newIterationStartNode?: Node +} { + const newNode = { + id: id || `${Date.now()}`, + type: type || CUSTOM_NODE, + data, + position, + targetPosition: Position.Left, + sourcePosition: Position.Right, + zIndex: data.type === BlockEnum.Iteration ? ITERATION_NODE_Z_INDEX : zIndex, + ...rest, + } as Node + + if (data.type === BlockEnum.Iteration) { + const newIterationStartNode = getIterationStartNode(newNode.id); + (newNode.data as IterationNodeType).start_node_id = newIterationStartNode.id; + (newNode.data as IterationNodeType)._children = [newIterationStartNode.id] + return { + newNode, + newIterationStartNode, + } + } + + return { + newNode, + } +} + +export const preprocessNodesAndEdges = (nodes: Node[], edges: Edge[]) => { + const hasIterationNode = nodes.some(node => node.data.type === BlockEnum.Iteration) + + if (!hasIterationNode) { + return { + nodes, + edges, + } + } + const nodesMap = nodes.reduce((prev, next) => { + prev[next.id] = next + return prev + }, {} as Record) + const iterationNodesWithStartNode = [] + const iterationNodesWithoutStartNode = [] + + for (let i = 0; i < nodes.length; i++) { + const currentNode = nodes[i] as Node + + if (currentNode.data.type === BlockEnum.Iteration) { + if (currentNode.data.start_node_id) { + if (nodesMap[currentNode.data.start_node_id]?.type !== CUSTOM_ITERATION_START_NODE) + iterationNodesWithStartNode.push(currentNode) + } + else { + iterationNodesWithoutStartNode.push(currentNode) + } + } + } + const newIterationStartNodesMap = {} as Record + const newIterationStartNodes = [...iterationNodesWithStartNode, ...iterationNodesWithoutStartNode].map((iterationNode, index) => { + const newNode = getIterationStartNode(iterationNode.id) + newNode.id = newNode.id + index + newIterationStartNodesMap[iterationNode.id] = newNode + return newNode + }) + const newEdges = iterationNodesWithStartNode.map((iterationNode) => { + const newNode = newIterationStartNodesMap[iterationNode.id] + const startNode = nodesMap[iterationNode.data.start_node_id] + const source = newNode.id + const sourceHandle = 'source' + const target = startNode.id + const targetHandle = 'target' + return { + id: `${source}-${sourceHandle}-${target}-${targetHandle}`, + type: 'custom', + source, + sourceHandle, + target, + targetHandle, + data: { + sourceType: newNode.data.type, + targetType: startNode.data.type, + isInIteration: true, + iteration_id: startNode.parentId, + _connectedNodeIsSelected: true, + }, + zIndex: ITERATION_CHILDREN_Z_INDEX, + } + }) + nodes.forEach((node) => { + if (node.data.type === BlockEnum.Iteration && newIterationStartNodesMap[node.id]) + (node.data as IterationNodeType).start_node_id = newIterationStartNodesMap[node.id].id + }) + + return { + nodes: [...nodes, ...newIterationStartNodes], + edges: [...edges, ...newEdges], + } +} + export const initialNodes = (originNodes: Node[], originEdges: Edge[]) => { - const nodes = cloneDeep(originNodes) - const edges = cloneDeep(originEdges) + const { nodes, edges } = preprocessNodesAndEdges(cloneDeep(originNodes), cloneDeep(originEdges)) const firstNode = nodes[0] if (!firstNode?.position) { @@ -148,8 +275,7 @@ export const initialNodes = (originNodes: Node[], originEdges: Edge[]) => { } export const initialEdges = (originEdges: Edge[], originNodes: Node[]) => { - const nodes = cloneDeep(originNodes) - const edges = cloneDeep(originEdges) + const { nodes, edges } = preprocessNodesAndEdges(cloneDeep(originNodes), cloneDeep(originEdges)) let selectedNode: Node | null = null const nodesMap = nodes.reduce((acc, node) => { acc[node.id] = node @@ -291,19 +417,6 @@ export const getNodesConnectedSourceOrTargetHandleIdsMap = (changes: ConnectedSo return nodesConnectedSourceOrTargetHandleIdsMap } -export const generateNewNode = ({ data, position, id, zIndex, type, ...rest }: Omit & { id?: string }) => { - return { - id: id || `${Date.now()}`, - type: type || CUSTOM_NODE, - data, - position, - targetPosition: Position.Left, - sourcePosition: Position.Right, - zIndex: data.type === BlockEnum.Iteration ? ITERATION_NODE_Z_INDEX : zIndex, - ...rest, - } as Node -} - export const genNewNodeTitleFromOld = (oldTitle: string) => { const regex = /^(.+?)\s*\((\d+)\)\s*$/ const match = oldTitle.match(regex) @@ -479,3 +592,167 @@ export const variableTransformer = (v: ValueSelector | string) => { return `{{#${v.join('.')}#}}` } + +type ParallelInfoItem = { + parallelNodeId: string + depth: number + isBranch?: boolean +} +type NodeParallelInfo = { + parallelNodeId: string + edgeHandleId: string + depth: number +} +type NodeHandle = { + node: Node + handle: string +} +type NodeStreamInfo = { + upstreamNodes: Set + downstreamEdges: Set +} +export const getParallelInfo = (nodes: Node[], edges: Edge[], parentNodeId?: string) => { + let startNode + + if (parentNodeId) { + const parentNode = nodes.find(node => node.id === parentNodeId) + if (!parentNode) + throw new Error('Parent node not found') + + startNode = nodes.find(node => node.id === (parentNode.data as IterationNodeType).start_node_id) + } + else { + startNode = nodes.find(node => node.data.type === BlockEnum.Start) + } + if (!startNode) + throw new Error('Start node not found') + + const parallelList = [] as ParallelInfoItem[] + const nextNodeHandles = [{ node: startNode, handle: 'source' }] + let hasAbnormalEdges = false + + const traverse = (firstNodeHandle: NodeHandle) => { + const nodeEdgesSet = {} as Record> + const totalEdgesSet = new Set() + const nextHandles = [firstNodeHandle] + const streamInfo = {} as Record + const parallelListItem = { + parallelNodeId: '', + depth: 0, + } as ParallelInfoItem + const nodeParallelInfoMap = {} as Record + nodeParallelInfoMap[firstNodeHandle.node.id] = { + parallelNodeId: '', + edgeHandleId: '', + depth: 0, + } + + while (nextHandles.length) { + const currentNodeHandle = nextHandles.shift()! + const { node: currentNode, handle: currentHandle = 'source' } = currentNodeHandle + const currentNodeHandleKey = currentNode.id + const connectedEdges = edges.filter(edge => edge.source === currentNode.id && edge.sourceHandle === currentHandle) + const connectedEdgesLength = connectedEdges.length + const outgoers = nodes.filter(node => connectedEdges.some(edge => edge.target === node.id)) + const incomers = getIncomers(currentNode, nodes, edges) + + if (!streamInfo[currentNodeHandleKey]) { + streamInfo[currentNodeHandleKey] = { + upstreamNodes: new Set(), + downstreamEdges: new Set(), + } + } + + if (nodeEdgesSet[currentNodeHandleKey]?.size > 0 && incomers.length > 1) { + const newSet = new Set() + for (const item of totalEdgesSet) { + if (!streamInfo[currentNodeHandleKey].downstreamEdges.has(item)) + newSet.add(item) + } + if (isEqual(nodeEdgesSet[currentNodeHandleKey], newSet)) { + parallelListItem.depth = nodeParallelInfoMap[currentNode.id].depth + nextNodeHandles.push({ node: currentNode, handle: currentHandle }) + break + } + } + + if (nodeParallelInfoMap[currentNode.id].depth > parallelListItem.depth) + parallelListItem.depth = nodeParallelInfoMap[currentNode.id].depth + + outgoers.forEach((outgoer) => { + const outgoerConnectedEdges = getConnectedEdges([outgoer], edges).filter(edge => edge.source === outgoer.id) + const sourceEdgesGroup = groupBy(outgoerConnectedEdges, 'sourceHandle') + const incomers = getIncomers(outgoer, nodes, edges) + + if (outgoers.length > 1 && incomers.length > 1) + hasAbnormalEdges = true + + Object.keys(sourceEdgesGroup).forEach((sourceHandle) => { + nextHandles.push({ node: outgoer, handle: sourceHandle }) + }) + if (!outgoerConnectedEdges.length) + nextHandles.push({ node: outgoer, handle: 'source' }) + + const outgoerKey = outgoer.id + if (!nodeEdgesSet[outgoerKey]) + nodeEdgesSet[outgoerKey] = new Set() + + if (nodeEdgesSet[currentNodeHandleKey]) { + for (const item of nodeEdgesSet[currentNodeHandleKey]) + nodeEdgesSet[outgoerKey].add(item) + } + + if (!streamInfo[outgoerKey]) { + streamInfo[outgoerKey] = { + upstreamNodes: new Set(), + downstreamEdges: new Set(), + } + } + + if (!nodeParallelInfoMap[outgoer.id]) { + nodeParallelInfoMap[outgoer.id] = { + ...nodeParallelInfoMap[currentNode.id], + } + } + + if (connectedEdgesLength > 1) { + const edge = connectedEdges.find(edge => edge.target === outgoer.id)! + nodeEdgesSet[outgoerKey].add(edge.id) + totalEdgesSet.add(edge.id) + + streamInfo[currentNodeHandleKey].downstreamEdges.add(edge.id) + streamInfo[outgoerKey].upstreamNodes.add(currentNodeHandleKey) + + for (const item of streamInfo[currentNodeHandleKey].upstreamNodes) + streamInfo[item].downstreamEdges.add(edge.id) + + if (!parallelListItem.parallelNodeId) + parallelListItem.parallelNodeId = currentNode.id + + const prevDepth = nodeParallelInfoMap[currentNode.id].depth + 1 + const currentDepth = nodeParallelInfoMap[outgoer.id].depth + + nodeParallelInfoMap[outgoer.id].depth = Math.max(prevDepth, currentDepth) + } + else { + for (const item of streamInfo[currentNodeHandleKey].upstreamNodes) + streamInfo[outgoerKey].upstreamNodes.add(item) + + nodeParallelInfoMap[outgoer.id].depth = nodeParallelInfoMap[currentNode.id].depth + } + }) + } + + parallelList.push(parallelListItem) + } + + while (nextNodeHandles.length) { + const nodeHandle = nextNodeHandles.shift()! + traverse(nodeHandle) + } + + return { + parallelList, + hasAbnormalEdges, + } +} diff --git a/web/app/layout.tsx b/web/app/layout.tsx index 9acc131029..008fdefa0b 100644 --- a/web/app/layout.tsx +++ b/web/app/layout.tsx @@ -1,6 +1,6 @@ import type { Viewport } from 'next' import I18nServer from './components/i18n-server' -import BrowerInitor from './components/browser-initor' +import BrowserInitor from './components/browser-initor' import SentryInitor from './components/sentry-initor' import Topbar from './components/base/topbar' import { getLocaleOnServer } from '@/i18n/server' @@ -45,11 +45,11 @@ const LocaleLayout = ({ data-public-site-about={process.env.NEXT_PUBLIC_SITE_ABOUT} > - + {children} - + ) diff --git a/web/app/signin/oneMoreStep.tsx b/web/app/signin/oneMoreStep.tsx index 42a797b4d6..a4324517a5 100644 --- a/web/app/signin/oneMoreStep.tsx +++ b/web/app/signin/oneMoreStep.tsx @@ -97,7 +97,7 @@ const OneMoreStep = () => { } needsDelay > - {t('login.donthave')} + {t('login.dontHave')}
diff --git a/web/config/index.ts b/web/config/index.ts index bb100473b1..71edf8f939 100644 --- a/web/config/index.ts +++ b/web/config/index.ts @@ -103,16 +103,16 @@ export const DEFAULT_PARAGRAPH_VALUE_MAX_LEN = 1000 export const zhRegex = /^[\u4E00-\u9FA5]$/m export const emojiRegex = /^[\uD800-\uDBFF][\uDC00-\uDFFF]$/m export const emailRegex = /^[\w.!#$%&'*+\-/=?^{|}~]+@([\w-]+\.)+[\w-]{2,}$/m -const MAX_ZN_VAR_NAME_LENGHT = 8 -const MAX_EN_VAR_VALUE_LENGHT = 30 +const MAX_ZN_VAR_NAME_LENGTH = 8 +const MAX_EN_VAR_VALUE_LENGTH = 30 export const getMaxVarNameLength = (value: string) => { if (zhRegex.test(value)) - return MAX_ZN_VAR_NAME_LENGHT + return MAX_ZN_VAR_NAME_LENGTH - return MAX_EN_VAR_VALUE_LENGHT + return MAX_EN_VAR_VALUE_LENGTH } -export const MAX_VAR_KEY_LENGHT = 30 +export const MAX_VAR_KEY_LENGTH = 30 export const MAX_PROMPT_MESSAGE_LENGTH = 10 diff --git a/web/hooks/use-moderate.ts b/web/hooks/use-moderate.ts index 0e2908ff66..11b078fbd9 100644 --- a/web/hooks/use-moderate.ts +++ b/web/hooks/use-moderate.ts @@ -13,14 +13,14 @@ export const useModerate = ( content: string, stop: boolean, moderationService: (text: string) => ReturnType, - seperateLength = 50, + separateLength = 50, ) => { const moderatedContentMap = useRef>(new Map()) const moderatingIndex = useRef([]) const [contentArr, setContentArr] = useState([]) const handleModerate = () => { - const stringArr = splitStringByLength(content, seperateLength) + const stringArr = splitStringByLength(content, separateLength) const lastIndex = stringArr.length - 1 stringArr.forEach((item, index) => { diff --git a/web/i18n/auto-gen-i18n.js b/web/i18n/auto-gen-i18n.js new file mode 100644 index 0000000000..c51bc53af3 --- /dev/null +++ b/web/i18n/auto-gen-i18n.js @@ -0,0 +1,82 @@ +/* eslint-disable no-eval */ +const fs = require('node:fs') +const path = require('node:path') +const transpile = require('typescript').transpile +const magicast = require('magicast') +const { parseModule, generateCode, loadFile } = magicast +const bingTranslate = require('bing-translate-api') +const { translate } = bingTranslate +const data = require('./languages.json') + +const targetLanguage = 'en-US' +// https://github.com/plainheart/bing-translate-api/blob/master/src/met/lang.json +const languageKeyMap = data.languages.reduce((map, language) => { + if (language.supported) { + if (language.value === 'zh-Hans' || language.value === 'zh-Hant') + map[language.value] = language.value + else + map[language.value] = language.value.split('-')[0] + } + + return map +}, {}) + +async function translateMissingKeyDeeply(sourceObj, targetObject, toLanguage) { + await Promise.all(Object.keys(sourceObj).map(async (key) => { + if (targetObject[key] === undefined) { + if (typeof sourceObj[key] === 'object') { + targetObject[key] = {} + await translateMissingKeyDeeply(sourceObj[key], targetObject[key], toLanguage) + } + else { + const { translation } = await translate(sourceObj[key], null, languageKeyMap[toLanguage]) + targetObject[key] = translation + // console.log(translation) + } + } + else if (typeof sourceObj[key] === 'object') { + targetObject[key] = targetObject[key] || {} + await translateMissingKeyDeeply(sourceObj[key], targetObject[key], toLanguage) + } + })) +} + +async function autoGenTrans(fileName, toGenLanguage) { + const fullKeyFilePath = path.join(__dirname, targetLanguage, `${fileName}.ts`) + const toGenLanguageFilePath = path.join(__dirname, toGenLanguage, `${fileName}.ts`) + const fullKeyContent = eval(transpile(fs.readFileSync(fullKeyFilePath, 'utf8'))) + // To keep object format and format it for magicast to work: const translation = { ... } => export default {...} + const readContent = await loadFile(toGenLanguageFilePath) + const { code: toGenContent } = generateCode(readContent) + const mod = await parseModule(`export default ${toGenContent.replace('export default translation', '').replace('const translation = ', '')}`) + const toGenOutPut = mod.exports.default + + await translateMissingKeyDeeply(fullKeyContent, toGenOutPut, toGenLanguage) + const { code } = generateCode(mod) + const res = `const translation =${code.replace('export default', '')} + +export default translation +`.replace(/,\n\n/g, ',\n').replace('};', '}') + + fs.writeFileSync(toGenLanguageFilePath, res) +} + +async function main() { + // const fileName = 'workflow' + // Promise.all(Object.keys(languageKeyMap).map(async (toLanguage) => { + // await autoGenTrans(fileName, toLanguage) + // })) + + const files = fs + .readdirSync(path.join(__dirname, targetLanguage)) + .map(file => file.replace(/\.ts/, '')) + .filter(f => f !== 'app-debug') // ast parse error in app-debug + + await Promise.all(files.map(async (file) => { + await Promise.all(Object.keys(languageKeyMap).map(async (language) => { + await autoGenTrans(file, language) + })) + })) +} + +main() diff --git a/web/i18n/script.js b/web/i18n/check-i18n.js similarity index 98% rename from web/i18n/script.js rename to web/i18n/check-i18n.js index 5af8f466ff..fbe1bbd2ef 100644 --- a/web/i18n/script.js +++ b/web/i18n/check-i18n.js @@ -27,6 +27,7 @@ async function getKeysFromLanuage(language) { // console.log(camelCaseFileName) const content = fs.readFileSync(filePath, 'utf8') const translation = eval(transpile(content)) + // console.log(translation) const keys = Object.keys(translation) const nestedKeys = [] const iterateKeys = (obj, prefix = '') => { diff --git a/web/i18n/de-DE/app-annotation.ts b/web/i18n/de-DE/app-annotation.ts index 8b96da2752..ef2fa1f236 100644 --- a/web/i18n/de-DE/app-annotation.ts +++ b/web/i18n/de-DE/app-annotation.ts @@ -1,87 +1,87 @@ -const translation = { - title: 'Anmerkungen', - name: 'Antwort Anmerkung', - editBy: 'Antwort bearbeitet von {{author}}', - noData: { - title: 'Keine Anmerkungen', - description: 'Sie können Anmerkungen während des App-Debuggings bearbeiten oder hier Anmerkungen in großen Mengen importieren für eine hochwertige Antwort.', - }, - table: { - header: { - question: 'Frage', - answer: 'Antwort', - createdAt: 'erstellt am', - hits: 'Treffer', - actions: 'Aktionen', - addAnnotation: 'Anmerkung hinzufügen', - bulkImport: 'Massenimport', - bulkExport: 'Massenexport', - clearAll: 'Alle Anmerkungen löschen', - }, - }, - editModal: { - title: 'Antwort Anmerkung bearbeiten', - queryName: 'Benutzeranfrage', - answerName: 'Geschichtenerzähler Bot', - yourAnswer: 'Ihre Antwort', - answerPlaceholder: 'Geben Sie hier Ihre Antwort ein', - yourQuery: 'Ihre Anfrage', - queryPlaceholder: 'Geben Sie hier Ihre Anfrage ein', - removeThisCache: 'Diese Anmerkung entfernen', - createdAt: 'Erstellt am', - }, - addModal: { - title: 'Antwort Anmerkung hinzufügen', - queryName: 'Frage', - answerName: 'Antwort', - answerPlaceholder: 'Antwort hier eingeben', - queryPlaceholder: 'Anfrage hier eingeben', - createNext: 'Eine weitere annotierte Antwort hinzufügen', - }, - batchModal: { - title: 'Massenimport', - csvUploadTitle: 'Ziehen Sie Ihre CSV-Datei hierher oder ', - browse: 'durchsuchen', - tip: 'Die CSV-Datei muss der folgenden Struktur entsprechen:', - question: 'Frage', - answer: 'Antwort', - contentTitle: 'Inhaltsabschnitt', - content: 'Inhalt', - template: 'Laden Sie die Vorlage hier herunter', - cancel: 'Abbrechen', - run: 'Batch ausführen', - runError: 'Batch-Ausführung fehlgeschlagen', - processing: 'In Batch-Verarbeitung', - completed: 'Import abgeschlossen', - error: 'Importfehler', - ok: 'OK', - }, - errorMessage: { - answerRequired: 'Antwort erforderlich', - queryRequired: 'Frage erforderlich', - }, - viewModal: { - annotatedResponse: 'Antwort Anmerkung', - hitHistory: 'Trefferhistorie', - hit: 'Treffer', - hits: 'Treffer', - noHitHistory: 'Keine Trefferhistorie', - }, - hitHistoryTable: { - query: 'Anfrage', - match: 'Übereinstimmung', - response: 'Antwort', - source: 'Quelle', - score: 'Punktzahl', - time: 'Zeit', - }, - initSetup: { - title: 'Initialeinrichtung Antwort Anmerkung', - configTitle: 'Einrichtung Antwort Anmerkung', - confirmBtn: 'Speichern & Aktivieren', - configConfirmBtn: 'Speichern', - }, - embeddingModelSwitchTip: 'Anmerkungstext-Vektorisierungsmodell, das Wechseln von Modellen wird neu eingebettet, was zusätzliche Kosten verursacht.', -} - -export default translation +const translation = { + title: 'Anmerkungen', + name: 'Antwort Anmerkung', + editBy: 'Antwort bearbeitet von {{author}}', + noData: { + title: 'Keine Anmerkungen', + description: 'Sie können Anmerkungen während des App-Debuggings bearbeiten oder hier Anmerkungen in großen Mengen importieren für eine hochwertige Antwort.', + }, + table: { + header: { + question: 'Frage', + answer: 'Antwort', + createdAt: 'erstellt am', + hits: 'Treffer', + actions: 'Aktionen', + addAnnotation: 'Anmerkung hinzufügen', + bulkImport: 'Massenimport', + bulkExport: 'Massenexport', + clearAll: 'Alle Anmerkungen löschen', + }, + }, + editModal: { + title: 'Antwort Anmerkung bearbeiten', + queryName: 'Benutzeranfrage', + answerName: 'Geschichtenerzähler Bot', + yourAnswer: 'Ihre Antwort', + answerPlaceholder: 'Geben Sie hier Ihre Antwort ein', + yourQuery: 'Ihre Anfrage', + queryPlaceholder: 'Geben Sie hier Ihre Anfrage ein', + removeThisCache: 'Diese Anmerkung entfernen', + createdAt: 'Erstellt am', + }, + addModal: { + title: 'Antwort Anmerkung hinzufügen', + queryName: 'Frage', + answerName: 'Antwort', + answerPlaceholder: 'Antwort hier eingeben', + queryPlaceholder: 'Anfrage hier eingeben', + createNext: 'Eine weitere annotierte Antwort hinzufügen', + }, + batchModal: { + title: 'Massenimport', + csvUploadTitle: 'Ziehen Sie Ihre CSV-Datei hierher oder ', + browse: 'durchsuchen', + tip: 'Die CSV-Datei muss der folgenden Struktur entsprechen:', + question: 'Frage', + answer: 'Antwort', + contentTitle: 'Inhaltsabschnitt', + content: 'Inhalt', + template: 'Laden Sie die Vorlage hier herunter', + cancel: 'Abbrechen', + run: 'Batch ausführen', + runError: 'Batch-Ausführung fehlgeschlagen', + processing: 'In Batch-Verarbeitung', + completed: 'Import abgeschlossen', + error: 'Importfehler', + ok: 'OK', + }, + errorMessage: { + answerRequired: 'Antwort erforderlich', + queryRequired: 'Frage erforderlich', + }, + viewModal: { + annotatedResponse: 'Antwort Anmerkung', + hitHistory: 'Trefferhistorie', + hit: 'Treffer', + hits: 'Treffer', + noHitHistory: 'Keine Trefferhistorie', + }, + hitHistoryTable: { + query: 'Anfrage', + match: 'Übereinstimmung', + response: 'Antwort', + source: 'Quelle', + score: 'Punktzahl', + time: 'Zeit', + }, + initSetup: { + title: 'Initialeinrichtung Antwort Anmerkung', + configTitle: 'Einrichtung Antwort Anmerkung', + confirmBtn: 'Speichern & Aktivieren', + configConfirmBtn: 'Speichern', + }, + embeddingModelSwitchTip: 'Anmerkungstext-Vektorisierungsmodell, das Wechseln von Modellen wird neu eingebettet, was zusätzliche Kosten verursacht.', +} + +export default translation diff --git a/web/i18n/de-DE/app-api.ts b/web/i18n/de-DE/app-api.ts index 91f4589963..84e0989095 100644 --- a/web/i18n/de-DE/app-api.ts +++ b/web/i18n/de-DE/app-api.ts @@ -1,82 +1,83 @@ -const translation = { - apiServer: 'API Server', - apiKey: 'API Schlüssel', - status: 'Status', - disabled: 'Deaktiviert', - ok: 'In Betrieb', - copy: 'Kopieren', - copied: 'Kopiert', - play: 'Abspielen', - pause: 'Pause', - playing: 'Wiedergabe', - merMaind: { - rerender: 'Neu rendern', - }, - never: 'Nie', - apiKeyModal: { - apiSecretKey: 'API Geheimschlüssel', - apiSecretKeyTips: 'Um Missbrauch der API zu verhindern, schützen Sie Ihren API Schlüssel. Vermeiden Sie es, ihn als Klartext im Frontend-Code zu verwenden. :)', - createNewSecretKey: 'Neuen Geheimschlüssel erstellen', - secretKey: 'Geheimschlüssel', - created: 'ERSTELLT', - lastUsed: 'ZULETZT VERWENDET', - generateTips: 'Bewahren Sie diesen Schlüssel an einem sicheren und zugänglichen Ort auf.', - }, - actionMsg: { - deleteConfirmTitle: 'Diesen Geheimschlüssel löschen?', - deleteConfirmTips: 'Diese Aktion kann nicht rückgängig gemacht werden.', - ok: 'OK', - }, - completionMode: { - title: 'Completion App API', - info: 'Für die Erzeugung von hochwertigem Text, wie z.B. Artikel, Zusammenfassungen und Übersetzungen, verwenden Sie die Completion-Messages API mit Benutzereingaben. Die Texterzeugung basiert auf den Modellparametern und Vorlagen für Aufforderungen in Dify Prompt Engineering.', - createCompletionApi: 'Completion Nachricht erstellen', - createCompletionApiTip: 'Erstellen Sie eine Completion Nachricht, um den Frage-Antwort-Modus zu unterstützen.', - inputsTips: '(Optional) Geben Sie Benutzereingabefelder als Schlüssel-Wert-Paare an, die Variablen in Prompt Eng. entsprechen. Schlüssel ist der Variablenname, Wert ist der Parameterwert. Wenn der Feldtyp Select ist, muss der übermittelte Wert eine der voreingestellten Optionen sein.', - queryTips: 'Textinhalt der Benutzereingabe.', - blocking: 'Blockierender Typ, wartet auf die Fertigstellung der Ausführung und gibt Ergebnisse zurück. (Anfragen können unterbrochen werden, wenn der Prozess lang ist)', - streaming: 'Streaming Rückgaben. Implementierung der Streaming-Rückgabe basierend auf SSE (Server-Sent Events).', - messageFeedbackApi: 'Nachrichtenfeedback (Like)', - messageFeedbackApiTip: 'Bewerten Sie empfangene Nachrichten im Namen der Endbenutzer mit Likes oder Dislikes. Diese Daten sind auf der Seite Logs & Annotations sichtbar und werden für zukünftige Modellanpassungen verwendet.', - messageIDTip: 'Nachrichten-ID', - ratingTip: 'like oder dislike, null ist rückgängig machen', - parametersApi: 'Anwendungsparameterinformationen abrufen', - parametersApiTip: 'Abrufen konfigurierter Eingabeparameter, einschließlich Variablennamen, Feldnamen, Typen und Standardwerten. Typischerweise verwendet, um diese Felder in einem Formular anzuzeigen oder Standardwerte nach dem Laden des Clients auszufüllen.', - }, - chatMode: { - title: 'Chat App API', - info: 'Für vielseitige Gesprächsanwendungen im Q&A-Format rufen Sie die chat-messages API auf, um einen Dialog zu initiieren. Führen Sie laufende Gespräche fort, indem Sie die zurückgegebene conversation_id übergeben. Antwortparameter und -vorlagen hängen von den Einstellungen in Dify Prompt Eng. ab.', - createChatApi: 'Chatnachricht erstellen', - createChatApiTip: 'Eine neue Konversationsnachricht erstellen oder einen bestehenden Dialog fortsetzen.', - inputsTips: '(Optional) Geben Sie Benutzereingabefelder als Schlüssel-Wert-Paare an, die Variablen in Prompt Eng. entsprechen. Schlüssel ist der Variablenname, Wert ist der Parameterwert. Wenn der Feldtyp Select ist, muss der übermittelte Wert eine der voreingestellten Optionen sein.', - queryTips: 'Inhalt der Benutzereingabe/Frage', - blocking: 'Blockierender Typ, wartet auf die Fertigstellung der Ausführung und gibt Ergebnisse zurück. (Anfragen können unterbrochen werden, wenn der Prozess lang ist)', - streaming: 'Streaming Rückgaben. Implementierung der Streaming-Rückgabe basierend auf SSE (Server-Sent Events).', - conversationIdTip: '(Optional) Konversations-ID: für erstmalige Konversation leer lassen; conversation_id aus dem Kontext übergeben, um den Dialog fortzusetzen.', - messageFeedbackApi: 'Nachrichtenfeedback des Endbenutzers, like', - messageFeedbackApiTip: 'Bewerten Sie empfangene Nachrichten im Namen der Endbenutzer mit Likes oder Dislikes. Diese Daten sind auf der Seite Logs & Annotations sichtbar und werden für zukünftige Modellanpassungen verwendet.', - messageIDTip: 'Nachrichten-ID', - ratingTip: 'like oder dislike, null ist rückgängig machen', - chatMsgHistoryApi: 'Chatverlaufsnachricht abrufen', - chatMsgHistoryApiTip: 'Die erste Seite gibt die neuesten `limit` Einträge in umgekehrter Reihenfolge zurück.', - chatMsgHistoryConversationIdTip: 'Konversations-ID', - chatMsgHistoryFirstId: 'ID des ersten Chat-Datensatzes auf der aktuellen Seite. Standardmäßig keiner.', - chatMsgHistoryLimit: 'Wie viele Chats in einer Anfrage zurückgegeben werden', - conversationsListApi: 'Konversationsliste abrufen', - conversationsListApiTip: 'Ruft die Sitzungsliste des aktuellen Benutzers ab. Standardmäßig werden die letzten 20 Sitzungen zurückgegeben.', - conversationsListFirstIdTip: 'Die ID des letzten Datensatzes auf der aktuellen Seite, standardmäßig keine.', - conversationsListLimitTip: 'Wie viele Chats in einer Anfrage zurückgegeben werden', - conversationRenamingApi: 'Konversation umbenennen', - conversationRenamingApiTip: 'Konversationen umbenennen; der Name wird in Mehrsitzungs-Client-Schnittstellen angezeigt.', - conversationRenamingNameTip: 'Neuer Name', - parametersApi: 'Anwendungsparameterinformationen abrufen', - parametersApiTip: 'Abrufen konfigurierter Eingabeparameter, einschließlich Variablennamen, Feldnamen, Typen und Standardwerten. Typischerweise verwendet, um diese Felder in einem Formular anzuzeigen oder Standardwerte nach dem Laden des Clients auszufüllen.', - }, - develop: { - requestBody: 'Anfragekörper', - pathParams: 'Pfadparameter', - query: 'Anfrage', - }, -} - +const translation = { + apiServer: 'API Server', + apiKey: 'API Schlüssel', + status: 'Status', + disabled: 'Deaktiviert', + ok: 'In Betrieb', + copy: 'Kopieren', + copied: 'Kopiert', + play: 'Abspielen', + pause: 'Pause', + playing: 'Wiedergabe', + merMaid: { + rerender: 'Neu rendern', + }, + never: 'Nie', + apiKeyModal: { + apiSecretKey: 'API Geheimschlüssel', + apiSecretKeyTips: 'Um Missbrauch der API zu verhindern, schützen Sie Ihren API Schlüssel. Vermeiden Sie es, ihn als Klartext im Frontend-Code zu verwenden. :)', + createNewSecretKey: 'Neuen Geheimschlüssel erstellen', + secretKey: 'Geheimschlüssel', + created: 'ERSTELLT', + lastUsed: 'ZULETZT VERWENDET', + generateTips: 'Bewahren Sie diesen Schlüssel an einem sicheren und zugänglichen Ort auf.', + }, + actionMsg: { + deleteConfirmTitle: 'Diesen Geheimschlüssel löschen?', + deleteConfirmTips: 'Diese Aktion kann nicht rückgängig gemacht werden.', + ok: 'OK', + }, + completionMode: { + title: 'Completion App API', + info: 'Für die Erzeugung von hochwertigem Text, wie z.B. Artikel, Zusammenfassungen und Übersetzungen, verwenden Sie die Completion-Messages API mit Benutzereingaben. Die Texterzeugung basiert auf den Modellparametern und Vorlagen für Aufforderungen in Dify Prompt Engineering.', + createCompletionApi: 'Completion Nachricht erstellen', + createCompletionApiTip: 'Erstellen Sie eine Completion Nachricht, um den Frage-Antwort-Modus zu unterstützen.', + inputsTips: '(Optional) Geben Sie Benutzereingabefelder als Schlüssel-Wert-Paare an, die Variablen in Prompt Eng. entsprechen. Schlüssel ist der Variablenname, Wert ist der Parameterwert. Wenn der Feldtyp Select ist, muss der übermittelte Wert eine der voreingestellten Optionen sein.', + queryTips: 'Textinhalt der Benutzereingabe.', + blocking: 'Blockierender Typ, wartet auf die Fertigstellung der Ausführung und gibt Ergebnisse zurück. (Anfragen können unterbrochen werden, wenn der Prozess lang ist)', + streaming: 'Streaming Rückgaben. Implementierung der Streaming-Rückgabe basierend auf SSE (Server-Sent Events).', + messageFeedbackApi: 'Nachrichtenfeedback (Like)', + messageFeedbackApiTip: 'Bewerten Sie empfangene Nachrichten im Namen der Endbenutzer mit Likes oder Dislikes. Diese Daten sind auf der Seite Logs & Annotations sichtbar und werden für zukünftige Modellanpassungen verwendet.', + messageIDTip: 'Nachrichten-ID', + ratingTip: 'like oder dislike, null ist rückgängig machen', + parametersApi: 'Anwendungsparameterinformationen abrufen', + parametersApiTip: 'Abrufen konfigurierter Eingabeparameter, einschließlich Variablennamen, Feldnamen, Typen und Standardwerten. Typischerweise verwendet, um diese Felder in einem Formular anzuzeigen oder Standardwerte nach dem Laden des Clients auszufüllen.', + }, + chatMode: { + title: 'Chat App API', + info: 'Für vielseitige Gesprächsanwendungen im Q&A-Format rufen Sie die chat-messages API auf, um einen Dialog zu initiieren. Führen Sie laufende Gespräche fort, indem Sie die zurückgegebene conversation_id übergeben. Antwortparameter und -vorlagen hängen von den Einstellungen in Dify Prompt Eng. ab.', + createChatApi: 'Chatnachricht erstellen', + createChatApiTip: 'Eine neue Konversationsnachricht erstellen oder einen bestehenden Dialog fortsetzen.', + inputsTips: '(Optional) Geben Sie Benutzereingabefelder als Schlüssel-Wert-Paare an, die Variablen in Prompt Eng. entsprechen. Schlüssel ist der Variablenname, Wert ist der Parameterwert. Wenn der Feldtyp Select ist, muss der übermittelte Wert eine der voreingestellten Optionen sein.', + queryTips: 'Inhalt der Benutzereingabe/Frage', + blocking: 'Blockierender Typ, wartet auf die Fertigstellung der Ausführung und gibt Ergebnisse zurück. (Anfragen können unterbrochen werden, wenn der Prozess lang ist)', + streaming: 'Streaming Rückgaben. Implementierung der Streaming-Rückgabe basierend auf SSE (Server-Sent Events).', + conversationIdTip: '(Optional) Konversations-ID: für erstmalige Konversation leer lassen; conversation_id aus dem Kontext übergeben, um den Dialog fortzusetzen.', + messageFeedbackApi: 'Nachrichtenfeedback des Endbenutzers, like', + messageFeedbackApiTip: 'Bewerten Sie empfangene Nachrichten im Namen der Endbenutzer mit Likes oder Dislikes. Diese Daten sind auf der Seite Logs & Annotations sichtbar und werden für zukünftige Modellanpassungen verwendet.', + messageIDTip: 'Nachrichten-ID', + ratingTip: 'like oder dislike, null ist rückgängig machen', + chatMsgHistoryApi: 'Chatverlaufsnachricht abrufen', + chatMsgHistoryApiTip: 'Die erste Seite gibt die neuesten `limit` Einträge in umgekehrter Reihenfolge zurück.', + chatMsgHistoryConversationIdTip: 'Konversations-ID', + chatMsgHistoryFirstId: 'ID des ersten Chat-Datensatzes auf der aktuellen Seite. Standardmäßig keiner.', + chatMsgHistoryLimit: 'Wie viele Chats in einer Anfrage zurückgegeben werden', + conversationsListApi: 'Konversationsliste abrufen', + conversationsListApiTip: 'Ruft die Sitzungsliste des aktuellen Benutzers ab. Standardmäßig werden die letzten 20 Sitzungen zurückgegeben.', + conversationsListFirstIdTip: 'Die ID des letzten Datensatzes auf der aktuellen Seite, standardmäßig keine.', + conversationsListLimitTip: 'Wie viele Chats in einer Anfrage zurückgegeben werden', + conversationRenamingApi: 'Konversation umbenennen', + conversationRenamingApiTip: 'Konversationen umbenennen; der Name wird in Mehrsitzungs-Client-Schnittstellen angezeigt.', + conversationRenamingNameTip: 'Neuer Name', + parametersApi: 'Anwendungsparameterinformationen abrufen', + parametersApiTip: 'Abrufen konfigurierter Eingabeparameter, einschließlich Variablennamen, Feldnamen, Typen und Standardwerten. Typischerweise verwendet, um diese Felder in einem Formular anzuzeigen oder Standardwerte nach dem Laden des Clients auszufüllen.', + }, + develop: { + requestBody: 'Anfragekörper', + pathParams: 'Pfadparameter', + query: 'Anfrage', + }, + loading: 'Laden', +} + export default translation diff --git a/web/i18n/de-DE/app-debug.ts b/web/i18n/de-DE/app-debug.ts index acb3f53904..c00c799519 100644 --- a/web/i18n/de-DE/app-debug.ts +++ b/web/i18n/de-DE/app-debug.ts @@ -248,7 +248,7 @@ const translation = { historyNoBeEmpty: 'Konversationsverlauf muss im Prompt gesetzt sein', queryNoBeEmpty: 'Anfrage muss im Prompt gesetzt sein', }, - variableConig: { + variableConfig: { modalTitle: 'Feldeinstellungen', description: 'Einstellung für Variable {{varName}}', fieldType: 'Feldtyp', diff --git a/web/i18n/de-DE/app-log.ts b/web/i18n/de-DE/app-log.ts index 667a23e8aa..0a0e740578 100644 --- a/web/i18n/de-DE/app-log.ts +++ b/web/i18n/de-DE/app-log.ts @@ -1,89 +1,95 @@ -const translation = { - title: 'Protokolle', - description: 'Die Protokolle zeichnen den Betriebsstatus der Anwendung auf, einschließlich Benutzereingaben und KI-Antworten.', - dateTimeFormat: 'MM/DD/YYYY hh:mm A', - table: { - header: { - updatedTime: 'Aktualisierungszeit', - time: 'Erstellungszeit', - endUser: 'Endbenutzer oder Konto', - input: 'Eingabe', - output: 'Ausgabe', - summary: 'Titel', - messageCount: 'Nachrichtenzahl', - userRate: 'Benutzerbewertung', - adminRate: 'Op. Bewertung', - }, - pagination: { - previous: 'Vorherige', - next: 'Nächste', - }, - empty: { - noChat: 'Noch keine Konversation', - noOutput: 'Keine Ausgabe', - element: { - title: 'Ist da jemand?', - content: 'Beobachten und annotieren Sie hier die Interaktionen zwischen Endbenutzern und KI-Anwendungen, um die Genauigkeit der KI kontinuierlich zu verbessern. Sie können versuchen, die Web-App selbst zu teilen oder zu testen, und dann zu dieser Seite zurückkehren.', - }, - }, - }, - detail: { - time: 'Zeit', - conversationId: 'Konversations-ID', - promptTemplate: 'Prompt-Vorlage', - promptTemplateBeforeChat: 'Prompt-Vorlage vor dem Chat · Als Systemnachricht', - annotationTip: 'Verbesserungen markiert von {{user}}', - timeConsuming: '', - second: 's', - tokenCost: 'Verbrauchte Token', - loading: 'lädt', - operation: { - like: 'gefällt mir', - dislike: 'gefällt mir nicht', - addAnnotation: 'Verbesserung hinzufügen', - editAnnotation: 'Verbesserung bearbeiten', - annotationPlaceholder: 'Geben Sie die erwartete Antwort ein, die Sie möchten, dass die KI antwortet, welche für die Feinabstimmung des Modells und die kontinuierliche Verbesserung der Qualität der Textgenerierung in Zukunft verwendet werden kann.', - }, - variables: 'Variablen', - uploadImages: 'Hochgeladene Bilder', - }, - filter: { - period: { - today: 'Heute', - last7days: 'Letzte 7 Tage', - last4weeks: 'Letzte 4 Wochen', - last3months: 'Letzte 3 Monate', - last12months: 'Letzte 12 Monate', - monthToDate: 'Monat bis heute', - quarterToDate: 'Quartal bis heute', - yearToDate: 'Jahr bis heute', - allTime: 'Gesamte Zeit', - }, - annotation: { - all: 'Alle', - annotated: 'Markierte Verbesserungen ({{count}} Elemente)', - not_annotated: 'Nicht annotiert', - }, - sortBy: 'Sortieren nach:', - descending: 'absteigend', - ascending: 'aufsteigend', - }, - workflowTitle: 'Workflow-Protokolle', - workflowSubtitle: 'Das Protokoll hat den Vorgang von Automate aufgezeichnet.', - runDetail: { - title: 'Konversationsprotokoll', - workflowTitle: 'Protokolldetail', - }, - promptLog: 'Prompt-Protokoll', - agentLog: 'Agentenprotokoll', - viewLog: 'Protokoll anzeigen', - agentLogDetail: { - agentMode: 'Agentenmodus', - toolUsed: 'Verwendetes Werkzeug', - iterations: 'Iterationen', - iteration: 'Iteration', - finalProcessing: 'Endverarbeitung', - }, -} - -export default translation +const translation = { + title: 'Protokolle', + description: 'Die Protokolle zeichnen den Betriebsstatus der Anwendung auf, einschließlich Benutzereingaben und KI-Antworten.', + dateTimeFormat: 'MM/DD/YYYY hh:mm A', + table: { + header: { + updatedTime: 'Aktualisierungszeit', + time: 'Erstellungszeit', + endUser: 'Endbenutzer oder Konto', + input: 'Eingabe', + output: 'Ausgabe', + summary: 'Titel', + messageCount: 'Nachrichtenzahl', + userRate: 'Benutzerbewertung', + adminRate: 'Op. Bewertung', + user: 'Endbenutzer oder Konto', + status: 'STATUS', + runtime: 'LAUFZEIT', + version: 'VERSION', + tokens: 'TOKEN', + startTime: 'STARTZEIT', + }, + pagination: { + previous: 'Vorherige', + next: 'Nächste', + }, + empty: { + noChat: 'Noch keine Konversation', + noOutput: 'Keine Ausgabe', + element: { + title: 'Ist da jemand?', + content: 'Beobachten und annotieren Sie hier die Interaktionen zwischen Endbenutzern und KI-Anwendungen, um die Genauigkeit der KI kontinuierlich zu verbessern. Sie können versuchen, die Web-App selbst zu teilen oder zu testen, und dann zu dieser Seite zurückkehren.', + }, + }, + }, + detail: { + time: 'Zeit', + conversationId: 'Konversations-ID', + promptTemplate: 'Prompt-Vorlage', + promptTemplateBeforeChat: 'Prompt-Vorlage vor dem Chat · Als Systemnachricht', + annotationTip: 'Verbesserungen markiert von {{user}}', + timeConsuming: '', + second: 's', + tokenCost: 'Verbrauchte Token', + loading: 'lädt', + operation: { + like: 'gefällt mir', + dislike: 'gefällt mir nicht', + addAnnotation: 'Verbesserung hinzufügen', + editAnnotation: 'Verbesserung bearbeiten', + annotationPlaceholder: 'Geben Sie die erwartete Antwort ein, die Sie möchten, dass die KI antwortet, welche für die Feinabstimmung des Modells und die kontinuierliche Verbesserung der Qualität der Textgenerierung in Zukunft verwendet werden kann.', + }, + variables: 'Variablen', + uploadImages: 'Hochgeladene Bilder', + }, + filter: { + period: { + today: 'Heute', + last7days: 'Letzte 7 Tage', + last4weeks: 'Letzte 4 Wochen', + last3months: 'Letzte 3 Monate', + last12months: 'Letzte 12 Monate', + monthToDate: 'Monat bis heute', + quarterToDate: 'Quartal bis heute', + yearToDate: 'Jahr bis heute', + allTime: 'Gesamte Zeit', + }, + annotation: { + all: 'Alle', + annotated: 'Markierte Verbesserungen ({{count}} Elemente)', + not_annotated: 'Nicht annotiert', + }, + sortBy: 'Sortieren nach:', + descending: 'absteigend', + ascending: 'aufsteigend', + }, + workflowTitle: 'Workflow-Protokolle', + workflowSubtitle: 'Das Protokoll hat den Vorgang von Automate aufgezeichnet.', + runDetail: { + title: 'Konversationsprotokoll', + workflowTitle: 'Protokolldetail', + }, + promptLog: 'Prompt-Protokoll', + agentLog: 'Agentenprotokoll', + viewLog: 'Protokoll anzeigen', + agentLogDetail: { + agentMode: 'Agentenmodus', + toolUsed: 'Verwendetes Werkzeug', + iterations: 'Iterationen', + iteration: 'Iteration', + finalProcessing: 'Endverarbeitung', + }, +} + +export default translation diff --git a/web/i18n/de-DE/app-overview.ts b/web/i18n/de-DE/app-overview.ts index 99100cf868..a44baec82e 100644 --- a/web/i18n/de-DE/app-overview.ts +++ b/web/i18n/de-DE/app-overview.ts @@ -48,6 +48,8 @@ const translation = { title: 'Workflow-Schritte', show: 'Anzeigen', hide: 'Verbergen', + subTitle: 'Details zum Arbeitsablauf', + showDesc: 'Ein- oder Ausblenden von Workflow-Details in der WebApp', }, chatColorTheme: 'Chat-Farbschema', chatColorThemeDesc: 'Legen Sie das Farbschema des Chatbots fest', @@ -64,6 +66,12 @@ const translation = { customDisclaimerPlaceholder: 'Geben Sie den benutzerdefinierten Haftungsausschluss-Text ein', customDisclaimerTip: 'Der ben userdefinierte Haftungsausschluss-Text wird auf der Clientseite angezeigt und bietet zusätzliche Informationen über die Anwendung', }, + sso: { + title: 'WebApp-SSO', + description: 'Alle Benutzer müssen sich mit SSO anmelden, bevor sie WebApp verwenden können', + label: 'SSO-Authentifizierung', + tooltip: 'Wenden Sie sich an den Administrator, um WebApp-SSO zu aktivieren', + }, }, embedded: { entry: 'Eingebettet', @@ -119,7 +127,11 @@ const translation = { tokenPS: 'Token/s', totalMessages: { title: 'Gesamtnachrichten', - explanation: 'Tägliche AI-Interaktionszählung; Prompt-Engineering/Debugging ausgenommen.', + explanation: 'Tägliche Anzahl der KI-Interaktionen.', + }, + totalConversations: { + title: 'Gesamte Konversationen', + explanation: 'Tägliche Anzahl der KI-Konversationen; Prompt-Engineering/Debugging ausgeschlossen.', }, activeUsers: { title: 'Aktive Benutzer', @@ -146,6 +158,10 @@ const translation = { title: 'Token-Ausgabegeschwindigkeit', explanation: 'Misst die Leistung des LLM. Zählt die Token-Ausgabegeschwindigkeit des LLM vom Beginn der Anfrage bis zum Abschluss der Ausgabe.', }, + avgUserInteractions: { + explanation: 'Spiegelt die tägliche Nutzungshäufigkeit der Benutzer wider. Diese Metrik spiegelt die Bindung der Benutzer wider.', + title: 'Durchschnittliche Benutzerinteraktionen', + }, }, } diff --git a/web/i18n/de-DE/app.ts b/web/i18n/de-DE/app.ts index 55caff2284..4fc61f8ee0 100644 --- a/web/i18n/de-DE/app.ts +++ b/web/i18n/de-DE/app.ts @@ -1,107 +1,145 @@ -const translation = { - createApp: 'Neue App erstellen', - types: { - all: 'Alle', - assistant: 'Assistent', - completion: 'Vervollständigung', - }, - modes: { - completion: 'Textgenerator', - chat: 'Basisassistent', - }, - createFromConfigFile: 'App aus Konfigurationsdatei erstellen', - deleteAppConfirmTitle: 'Diese App löschen?', - deleteAppConfirmContent: - 'Das Löschen der App ist unwiderruflich. Nutzer werden keinen Zugang mehr zu Ihrer App haben, und alle Prompt-Konfigurationen und Logs werden dauerhaft gelöscht.', - appDeleted: 'App gelöscht', - appDeleteFailed: 'Löschen der App fehlgeschlagen', - join: 'Treten Sie der Gemeinschaft bei', - communityIntro: - 'Diskutieren Sie mit Teammitgliedern, Mitwirkenden und Entwicklern auf verschiedenen Kanälen.', - roadmap: 'Sehen Sie unseren Fahrplan', - appNamePlaceholder: 'Bitte geben Sie den Namen der App ein', - newApp: { - startToCreate: 'Lassen Sie uns mit Ihrer neuen App beginnen', - captionName: 'App-Symbol & Name', - captionAppType: 'Welchen Typ von App möchten Sie erstellen?', - previewDemo: 'Vorschau-Demo', - chatApp: 'Assistent', - chatAppIntro: - 'Ich möchte eine Chat-basierte Anwendung bauen. Diese App verwendet ein Frage-Antwort-Format und ermöglicht mehrere Runden kontinuierlicher Konversation.', - agentAssistant: 'Neuer Agentenassistent', - completeApp: 'Textgenerator', - completeAppIntro: - 'Ich möchte eine Anwendung erstellen, die hochwertigen Text basierend auf Aufforderungen generiert, wie z.B. das Erstellen von Artikeln, Zusammenfassungen, Übersetzungen und mehr.', - showTemplates: 'Ich möchte aus einer Vorlage wählen', - hideTemplates: 'Zurück zur Modusauswahl', - Create: 'Erstellen', - Cancel: 'Abbrechen', - nameNotEmpty: 'Name darf nicht leer sein', - appTemplateNotSelected: 'Bitte wählen Sie eine Vorlage', - appTypeRequired: 'Bitte wählen Sie einen App-Typ', - appCreated: 'App erstellt', - appCreateFailed: 'Erstellen der App fehlgeschlagen', - }, - editApp: 'App bearbeiten', - editAppTitle: 'App-Informationen bearbeiten', - editDone: 'App-Informationen wurden aktualisiert', - editFailed: 'Aktualisierung der App-Informationen fehlgeschlagen', - iconPicker: { - ok: 'OK', - cancel: 'Abbrechen', - emoji: 'Emoji', - image: 'Bild', - }, - switch: 'Zu Workflow-Orchestrierung wechseln', - switchTipStart: 'Eine neue App-Kopie wird für Sie erstellt, und die neue Kopie wird zur Workflow-Orchestrierung wechseln. Die neue Kopie wird ', - switchTip: 'nicht erlauben', - switchTipEnd: ' zur Basis-Orchestrierung zurückzuwechseln.', - switchLabel: 'Die zu erstellende App-Kopie', - removeOriginal: 'Ursprüngliche App löschen', - switchStart: 'Wechsel starten', - typeSelector: { - all: 'ALLE Typen', - chatbot: 'Chatbot', - agent: 'Agent', - workflow: 'Workflow', - completion: 'Vervollständigung', - }, - tracing: { - title: 'Anwendungsleistung nachverfolgen', - description: 'Konfiguration eines Drittanbieter-LLMOps-Anbieters und Nachverfolgung der Anwendungsleistung.', - config: 'Konfigurieren', - collapse: 'Einklappen', - expand: 'Ausklappen', - tracing: 'Nachverfolgung', - disabled: 'Deaktiviert', - disabledTip: 'Bitte zuerst den Anbieter konfigurieren', - enabled: 'In Betrieb', - tracingDescription: 'Erfassung des vollständigen Kontexts der Anwendungsausführung, einschließlich LLM-Aufrufe, Kontext, Prompts, HTTP-Anfragen und mehr, auf einer Nachverfolgungsplattform von Drittanbietern.', - configProviderTitle: { - configured: 'Konfiguriert', - notConfigured: 'Anbieter konfigurieren, um Nachverfolgung zu aktivieren', - moreProvider: 'Weitere Anbieter', - }, - langsmith: { - title: 'LangSmith', - description: 'Eine All-in-One-Entwicklerplattform für jeden Schritt des LLM-gesteuerten Anwendungslebenszyklus.', - }, - langfuse: { - title: 'Langfuse', - description: 'Traces, Bewertungen, Prompt-Management und Metriken zum Debuggen und Verbessern Ihrer LLM-Anwendung.', - }, - inUse: 'In Verwendung', - configProvider: { - title: 'Konfigurieren ', - placeholder: 'Geben Sie Ihren {{key}} ein', - project: 'Projekt', - publicKey: 'Öffentlicher Schlüssel', - secretKey: 'Geheimer Schlüssel', - viewDocsLink: '{{key}}-Dokumentation ansehen', - removeConfirmTitle: '{{key}}-Konfiguration entfernen?', - removeConfirmContent: 'Die aktuelle Konfiguration wird verwendet. Das Entfernen wird die Nachverfolgungsfunktion ausschalten.', - }, - }, -} - -export default translation +const translation = { + createApp: 'Neue App erstellen', + types: { + all: 'Alle', + assistant: 'Assistent', + completion: 'Vervollständigung', + workflow: 'Arbeitsablauf', + agent: 'Agent', + chatbot: 'Chatbot', + }, + modes: { + completion: 'Textgenerator', + chat: 'Basisassistent', + }, + createFromConfigFile: 'App aus Konfigurationsdatei erstellen', + deleteAppConfirmTitle: 'Diese App löschen?', + deleteAppConfirmContent: + 'Das Löschen der App ist unwiderruflich. Nutzer werden keinen Zugang mehr zu Ihrer App haben, und alle Prompt-Konfigurationen und Logs werden dauerhaft gelöscht.', + appDeleted: 'App gelöscht', + appDeleteFailed: 'Löschen der App fehlgeschlagen', + join: 'Treten Sie der Gemeinschaft bei', + communityIntro: + 'Diskutieren Sie mit Teammitgliedern, Mitwirkenden und Entwicklern auf verschiedenen Kanälen.', + roadmap: 'Sehen Sie unseren Fahrplan', + appNamePlaceholder: 'Bitte geben Sie den Namen der App ein', + newApp: { + startToCreate: 'Lassen Sie uns mit Ihrer neuen App beginnen', + captionName: 'App-Symbol & Name', + captionAppType: 'Welchen Typ von App möchten Sie erstellen?', + previewDemo: 'Vorschau-Demo', + chatApp: 'Assistent', + chatAppIntro: + 'Ich möchte eine Chat-basierte Anwendung bauen. Diese App verwendet ein Frage-Antwort-Format und ermöglicht mehrere Runden kontinuierlicher Konversation.', + agentAssistant: 'Neuer Agentenassistent', + completeApp: 'Textgenerator', + completeAppIntro: + 'Ich möchte eine Anwendung erstellen, die hochwertigen Text basierend auf Aufforderungen generiert, wie z.B. das Erstellen von Artikeln, Zusammenfassungen, Übersetzungen und mehr.', + showTemplates: 'Ich möchte aus einer Vorlage wählen', + hideTemplates: 'Zurück zur Modusauswahl', + Create: 'Erstellen', + Cancel: 'Abbrechen', + nameNotEmpty: 'Name darf nicht leer sein', + appTemplateNotSelected: 'Bitte wählen Sie eine Vorlage', + appTypeRequired: 'Bitte wählen Sie einen App-Typ', + appCreated: 'App erstellt', + appCreateFailed: 'Erstellen der App fehlgeschlagen', + basic: 'Grundlegend', + chatbotType: 'Chatbot-Orchestrierungsmethode', + workflowDescription: 'Erstellen Sie eine Anwendung, die qualitativ hochwertigen Text auf der Grundlage von Workflow-Orchestrierungen mit einem hohen Maß an Anpassung generiert. Es ist für erfahrene Benutzer geeignet.', + advancedFor: 'Für Fortgeschrittene', + startFromTemplate: 'Aus Vorlage erstellen', + appNamePlaceholder: 'Geben Sie Ihrer App einen Namen', + startFromBlank: 'Aus Leer erstellen', + basicTip: 'Für Anfänger können Sie später zu Chatflow wechseln', + basicDescription: 'Basic Orchestrate ermöglicht die Orchestrierung einer Chatbot-App mit einfachen Einstellungen, ohne die Möglichkeit, integrierte Eingabeaufforderungen zu ändern. Es ist für Anfänger geeignet.', + workflowWarning: 'Derzeit in der Beta-Phase', + advancedDescription: 'Workflow Orchestrate orchestriert Chatbots in Form von Workflows und bietet ein hohes Maß an Individualisierung, einschließlich der Möglichkeit, integrierte Eingabeaufforderungen zu bearbeiten. Es ist für erfahrene Benutzer geeignet.', + basicFor: 'FÜR ANFÄNGER', + completionWarning: 'Diese Art von App wird nicht mehr unterstützt.', + chatbotDescription: 'Erstellen Sie eine chatbasierte Anwendung. Diese App verwendet ein Frage-und-Antwort-Format, das mehrere Runden kontinuierlicher Konversation ermöglicht.', + captionDescription: 'Beschreibung', + advanced: 'Chatflow', + useTemplate: 'Diese Vorlage verwenden', + agentDescription: 'Erstellen Sie einen intelligenten Agenten, der autonom Werkzeuge auswählen kann, um die Aufgaben zu erledigen', + completionDescription: 'Erstellen Sie eine Anwendung, die qualitativ hochwertigen Text auf der Grundlage von Eingabeaufforderungen generiert, z. B. zum Generieren von Artikeln, Zusammenfassungen, Übersetzungen und mehr.', + appDescriptionPlaceholder: 'Geben Sie die Beschreibung der App ein', + }, + editApp: 'App bearbeiten', + editAppTitle: 'App-Informationen bearbeiten', + editDone: 'App-Informationen wurden aktualisiert', + editFailed: 'Aktualisierung der App-Informationen fehlgeschlagen', + iconPicker: { + ok: 'OK', + cancel: 'Abbrechen', + emoji: 'Emoji', + image: 'Bild', + }, + switch: 'Zu Workflow-Orchestrierung wechseln', + switchTipStart: 'Eine neue App-Kopie wird für Sie erstellt, und die neue Kopie wird zur Workflow-Orchestrierung wechseln. Die neue Kopie wird ', + switchTip: 'nicht erlauben', + switchTipEnd: ' zur Basis-Orchestrierung zurückzuwechseln.', + switchLabel: 'Die zu erstellende App-Kopie', + removeOriginal: 'Ursprüngliche App löschen', + switchStart: 'Wechsel starten', + typeSelector: { + all: 'ALLE Typen', + chatbot: 'Chatbot', + agent: 'Agent', + workflow: 'Workflow', + completion: 'Vervollständigung', + }, + tracing: { + title: 'Anwendungsleistung nachverfolgen', + description: 'Konfiguration eines Drittanbieter-LLMOps-Anbieters und Nachverfolgung der Anwendungsleistung.', + config: 'Konfigurieren', + collapse: 'Einklappen', + expand: 'Ausklappen', + tracing: 'Nachverfolgung', + disabled: 'Deaktiviert', + disabledTip: 'Bitte zuerst den Anbieter konfigurieren', + enabled: 'In Betrieb', + tracingDescription: 'Erfassung des vollständigen Kontexts der Anwendungsausführung, einschließlich LLM-Aufrufe, Kontext, Prompts, HTTP-Anfragen und mehr, auf einer Nachverfolgungsplattform von Drittanbietern.', + configProviderTitle: { + configured: 'Konfiguriert', + notConfigured: 'Anbieter konfigurieren, um Nachverfolgung zu aktivieren', + moreProvider: 'Weitere Anbieter', + }, + langsmith: { + title: 'LangSmith', + description: 'Eine All-in-One-Entwicklerplattform für jeden Schritt des LLM-gesteuerten Anwendungslebenszyklus.', + }, + langfuse: { + title: 'Langfuse', + description: 'Traces, Bewertungen, Prompt-Management und Metriken zum Debuggen und Verbessern Ihrer LLM-Anwendung.', + }, + inUse: 'In Verwendung', + configProvider: { + title: 'Konfigurieren ', + placeholder: 'Geben Sie Ihren {{key}} ein', + project: 'Projekt', + publicKey: 'Öffentlicher Schlüssel', + secretKey: 'Geheimer Schlüssel', + viewDocsLink: '{{key}}-Dokumentation ansehen', + removeConfirmTitle: '{{key}}-Konfiguration entfernen?', + removeConfirmContent: 'Die aktuelle Konfiguration wird verwendet. Das Entfernen wird die Nachverfolgungsfunktion ausschalten.', + }, + view: 'Ansehen', + }, + answerIcon: { + descriptionInExplore: 'Gibt an, ob das WebApp-Symbol zum Ersetzen 🤖 in Explore verwendet werden soll', + title: 'Verwenden Sie das WebApp-Symbol, um es zu ersetzen 🤖', + description: 'Gibt an, ob das WebApp-Symbol zum Ersetzen 🤖 in der freigegebenen Anwendung verwendet werden soll', + }, + importFromDSLUrlPlaceholder: 'DSL-Link hier einfügen', + duplicate: 'Duplikat', + importFromDSL: 'Import von DSL', + importDSL: 'DSL-Datei importieren', + importFromDSLUrl: 'Von URL', + exportFailed: 'Fehler beim Exportieren von DSL.', + importFromDSLFile: 'Aus DSL-Datei', + export: 'DSL exportieren', + duplicateTitle: 'App duplizieren', +} + +export default translation diff --git a/web/i18n/de-DE/billing.ts b/web/i18n/de-DE/billing.ts index ccb35675c8..7eae078ad2 100644 --- a/web/i18n/de-DE/billing.ts +++ b/web/i18n/de-DE/billing.ts @@ -1,115 +1,118 @@ -const translation = { - currentPlan: 'Aktueller Tarif', - upgradeBtn: { - plain: 'Tarif Upgraden', - encourage: 'Jetzt Upgraden', - encourageShort: 'Upgraden', - }, - viewBilling: 'Abrechnung und Abonnements verwalten', - buyPermissionDeniedTip: 'Bitte kontaktieren Sie Ihren Unternehmensadministrator, um zu abonnieren', - plansCommon: { - title: 'Wählen Sie einen Tarif, der zu Ihnen passt', - yearlyTip: 'Erhalten Sie 2 Monate kostenlos durch jährliches Abonnieren!', - mostPopular: 'Am beliebtesten', - planRange: { - monthly: 'Monatlich', - yearly: 'Jährlich', - }, - month: 'Monat', - year: 'Jahr', - save: 'Sparen ', - free: 'Kostenlos', - currentPlan: 'Aktueller Tarif', - contractSales: 'Vertrieb kontaktieren', - contractOwner: 'Teammanager kontaktieren', - startForFree: 'Kostenlos starten', - getStartedWith: 'Beginnen Sie mit ', - contactSales: 'Vertrieb kontaktieren', - talkToSales: 'Mit dem Vertrieb sprechen', - modelProviders: 'Modellanbieter', - teamMembers: 'Teammitglieder', - buildApps: 'Apps bauen', - vectorSpace: 'Vektorraum', - vectorSpaceBillingTooltip: 'Jedes 1MB kann ungefähr 1,2 Millionen Zeichen an vektorisierten Daten speichern (geschätzt mit OpenAI Embeddings, variiert je nach Modell).', - vectorSpaceTooltip: 'Vektorraum ist das Langzeitspeichersystem, das erforderlich ist, damit LLMs Ihre Daten verstehen können.', - documentsUploadQuota: 'Dokumenten-Upload-Kontingent', - documentProcessingPriority: 'Priorität der Dokumentenverarbeitung', - documentProcessingPriorityTip: 'Für eine höhere Dokumentenverarbeitungspriorität, bitte Ihren Tarif upgraden.', - documentProcessingPriorityUpgrade: 'Mehr Daten mit höherer Genauigkeit bei schnelleren Geschwindigkeiten verarbeiten.', - priority: { - 'standard': 'Standard', - 'priority': 'Priorität', - 'top-priority': 'Höchste Priorität', - }, - logsHistory: 'Protokollverlauf', - customTools: 'Benutzerdefinierte Werkzeuge', - unavailable: 'Nicht verfügbar', - days: 'Tage', - unlimited: 'Unbegrenzt', - support: 'Support', - supportItems: { - communityForums: 'Community-Foren', - emailSupport: 'E-Mail-Support', - priorityEmail: 'Priorisierter E-Mail- und Chat-Support', - logoChange: 'Logo-Änderung', - SSOAuthentication: 'SSO-Authentifizierung', - personalizedSupport: 'Persönlicher Support', - dedicatedAPISupport: 'Dedizierter API-Support', - customIntegration: 'Benutzerdefinierte Integration und Support', - ragAPIRequest: 'RAG-API-Anfragen', - bulkUpload: 'Massenupload von Dokumenten', - agentMode: 'Agentenmodus', - workflow: 'Workflow', - }, - comingSoon: 'Demnächst', - member: 'Mitglied', - memberAfter: 'Mitglied', - messageRequest: { - title: 'Nachrichtenguthaben', - tooltip: 'Nachrichtenaufrufkontingente für verschiedene Tarife unter Verwendung von OpenAI-Modellen (außer gpt4).Nachrichten über dem Limit verwenden Ihren OpenAI-API-Schlüssel.', - }, - annotatedResponse: { - title: 'Kontingentgrenzen für Annotationen', - tooltip: 'Manuelle Bearbeitung und Annotation von Antworten bieten anpassbare, hochwertige Frage-Antwort-Fähigkeiten für Apps. (Nur anwendbar in Chat-Apps)', - }, - ragAPIRequestTooltip: 'Bezieht sich auf die Anzahl der API-Aufrufe, die nur die Wissensdatenbankverarbeitungsfähigkeiten von Dify aufrufen.', - receiptInfo: 'Nur der Teaminhaber und der Teamadministrator können abonnieren und Abrechnungsinformationen einsehen', - }, - plans: { - sandbox: { - name: 'Sandbox', - description: '200 mal GPT kostenlos testen', - includesTitle: 'Beinhaltet:', - }, - professional: { - name: 'Professionell', - description: 'Für Einzelpersonen und kleine Teams, um mehr Leistung erschwinglich freizuschalten.', - includesTitle: 'Alles im kostenlosen Tarif, plus:', - }, - team: { - name: 'Team', - description: 'Zusammenarbeiten ohne Grenzen und Top-Leistung genießen.', - includesTitle: 'Alles im Professionell-Tarif, plus:', - }, - enterprise: { - name: 'Unternehmen', - description: 'Erhalten Sie volle Fähigkeiten und Unterstützung für großangelegte, missionskritische Systeme.', - includesTitle: 'Alles im Team-Tarif, plus:', - }, - }, - vectorSpace: { - fullTip: 'Vektorraum ist voll.', - fullSolution: 'Upgraden Sie Ihren Tarif, um mehr Speicherplatz zu erhalten.', - }, - apps: { - fullTipLine1: 'Upgraden Sie Ihren Tarif, um', - fullTipLine2: 'mehr Apps zu bauen.', - }, - annotatedResponse: { - fullTipLine1: 'Upgraden Sie Ihren Tarif, um', - fullTipLine2: 'mehr Konversationen zu annotieren.', - quotaTitle: 'Kontingent für Annotation-Antworten', - }, -} - -export default translation +const translation = { + currentPlan: 'Aktueller Tarif', + upgradeBtn: { + plain: 'Tarif Upgraden', + encourage: 'Jetzt Upgraden', + encourageShort: 'Upgraden', + }, + viewBilling: 'Abrechnung und Abonnements verwalten', + buyPermissionDeniedTip: 'Bitte kontaktieren Sie Ihren Unternehmensadministrator, um zu abonnieren', + plansCommon: { + title: 'Wählen Sie einen Tarif, der zu Ihnen passt', + yearlyTip: 'Erhalten Sie 2 Monate kostenlos durch jährliches Abonnieren!', + mostPopular: 'Am beliebtesten', + planRange: { + monthly: 'Monatlich', + yearly: 'Jährlich', + }, + month: 'Monat', + year: 'Jahr', + save: 'Sparen ', + free: 'Kostenlos', + currentPlan: 'Aktueller Tarif', + contractSales: 'Vertrieb kontaktieren', + contractOwner: 'Teammanager kontaktieren', + startForFree: 'Kostenlos starten', + getStartedWith: 'Beginnen Sie mit ', + contactSales: 'Vertrieb kontaktieren', + talkToSales: 'Mit dem Vertrieb sprechen', + modelProviders: 'Modellanbieter', + teamMembers: 'Teammitglieder', + buildApps: 'Apps bauen', + vectorSpace: 'Vektorraum', + vectorSpaceBillingTooltip: 'Jedes 1MB kann ungefähr 1,2 Millionen Zeichen an vektorisierten Daten speichern (geschätzt mit OpenAI Embeddings, variiert je nach Modell).', + vectorSpaceTooltip: 'Vektorraum ist das Langzeitspeichersystem, das erforderlich ist, damit LLMs Ihre Daten verstehen können.', + documentsUploadQuota: 'Dokumenten-Upload-Kontingent', + documentProcessingPriority: 'Priorität der Dokumentenverarbeitung', + documentProcessingPriorityTip: 'Für eine höhere Dokumentenverarbeitungspriorität, bitte Ihren Tarif upgraden.', + documentProcessingPriorityUpgrade: 'Mehr Daten mit höherer Genauigkeit bei schnelleren Geschwindigkeiten verarbeiten.', + priority: { + 'standard': 'Standard', + 'priority': 'Priorität', + 'top-priority': 'Höchste Priorität', + }, + logsHistory: 'Protokollverlauf', + customTools: 'Benutzerdefinierte Werkzeuge', + unavailable: 'Nicht verfügbar', + days: 'Tage', + unlimited: 'Unbegrenzt', + support: 'Support', + supportItems: { + communityForums: 'Community-Foren', + emailSupport: 'E-Mail-Support', + priorityEmail: 'Priorisierter E-Mail- und Chat-Support', + logoChange: 'Logo-Änderung', + SSOAuthentication: 'SSO-Authentifizierung', + personalizedSupport: 'Persönlicher Support', + dedicatedAPISupport: 'Dedizierter API-Support', + customIntegration: 'Benutzerdefinierte Integration und Support', + ragAPIRequest: 'RAG-API-Anfragen', + bulkUpload: 'Massenupload von Dokumenten', + agentMode: 'Agentenmodus', + workflow: 'Workflow', + llmLoadingBalancing: 'LLM-Lastausgleich', + llmLoadingBalancingTooltip: 'Fügen Sie Modellen mehrere API-Schlüssel hinzu, um die API-Ratenlimits effektiv zu umgehen.', + }, + comingSoon: 'Demnächst', + member: 'Mitglied', + memberAfter: 'Mitglied', + messageRequest: { + title: 'Nachrichtenguthaben', + tooltip: 'Nachrichtenaufrufkontingente für verschiedene Tarife unter Verwendung von OpenAI-Modellen (außer gpt4).Nachrichten über dem Limit verwenden Ihren OpenAI-API-Schlüssel.', + }, + annotatedResponse: { + title: 'Kontingentgrenzen für Annotationen', + tooltip: 'Manuelle Bearbeitung und Annotation von Antworten bieten anpassbare, hochwertige Frage-Antwort-Fähigkeiten für Apps. (Nur anwendbar in Chat-Apps)', + }, + ragAPIRequestTooltip: 'Bezieht sich auf die Anzahl der API-Aufrufe, die nur die Wissensdatenbankverarbeitungsfähigkeiten von Dify aufrufen.', + receiptInfo: 'Nur der Teaminhaber und der Teamadministrator können abonnieren und Abrechnungsinformationen einsehen', + annotationQuota: 'Kontingent für Anmerkungen', + }, + plans: { + sandbox: { + name: 'Sandbox', + description: '200 mal GPT kostenlos testen', + includesTitle: 'Beinhaltet:', + }, + professional: { + name: 'Professionell', + description: 'Für Einzelpersonen und kleine Teams, um mehr Leistung erschwinglich freizuschalten.', + includesTitle: 'Alles im kostenlosen Tarif, plus:', + }, + team: { + name: 'Team', + description: 'Zusammenarbeiten ohne Grenzen und Top-Leistung genießen.', + includesTitle: 'Alles im Professionell-Tarif, plus:', + }, + enterprise: { + name: 'Unternehmen', + description: 'Erhalten Sie volle Fähigkeiten und Unterstützung für großangelegte, missionskritische Systeme.', + includesTitle: 'Alles im Team-Tarif, plus:', + }, + }, + vectorSpace: { + fullTip: 'Vektorraum ist voll.', + fullSolution: 'Upgraden Sie Ihren Tarif, um mehr Speicherplatz zu erhalten.', + }, + apps: { + fullTipLine1: 'Upgraden Sie Ihren Tarif, um', + fullTipLine2: 'mehr Apps zu bauen.', + }, + annotatedResponse: { + fullTipLine1: 'Upgraden Sie Ihren Tarif, um', + fullTipLine2: 'mehr Konversationen zu annotieren.', + quotaTitle: 'Kontingent für Annotation-Antworten', + }, +} + +export default translation diff --git a/web/i18n/de-DE/common.ts b/web/i18n/de-DE/common.ts index 9a66d5b175..6ea06bc8b1 100644 --- a/web/i18n/de-DE/common.ts +++ b/web/i18n/de-DE/common.ts @@ -37,6 +37,7 @@ const translation = { params: 'Parameter', duplicate: 'Duplikat', rename: 'Umbenennen', + audioSourceUnavailable: 'AudioSource ist nicht verfügbar', }, placeholder: { input: 'Bitte eingeben', @@ -128,7 +129,8 @@ const translation = { workspace: 'Arbeitsbereich', createWorkspace: 'Arbeitsbereich erstellen', helpCenter: 'Hilfe', - roadmapAndFeedback: 'Feedback', + communityFeedback: 'Rückmeldung', + roadmap: 'Fahrplan', community: 'Gemeinschaft', about: 'Über', logout: 'Abmelden', @@ -190,16 +192,21 @@ const translation = { invitationSent: 'Einladung gesendet', invitationSentTip: 'Einladung gesendet, und sie können sich bei Dify anmelden, um auf Ihre Teamdaten zuzugreifen.', invitationLink: 'Einladungslink', - failedinvitationEmails: 'Die folgenden Benutzer wurden nicht erfolgreich eingeladen', + failedInvitationEmails: 'Die folgenden Benutzer wurden nicht erfolgreich eingeladen', ok: 'OK', removeFromTeam: 'Vom Team entfernen', removeFromTeamTip: 'Wird den Teamzugang entfernen', setAdmin: 'Als Administrator einstellen', setMember: 'Als normales Mitglied einstellen', setEditor: 'Als Editor einstellen', - disinvite: 'Einladung widerrufen', + disInvite: 'Einladung widerrufen', deleteMember: 'Mitglied löschen', you: '(Du)', + setBuilder: 'Als Builder festlegen', + datasetOperator: 'Wissensadministrator', + datasetOperatorTip: 'Kann die Wissensdatenbank nur verwalten', + builder: 'Bauherr', + builderTip: 'Kann eigene Apps erstellen und bearbeiten', }, integrations: { connected: 'Verbunden', @@ -346,6 +353,22 @@ const translation = { quotaTip: 'Verbleibende verfügbare kostenlose Token', loadPresets: 'Voreinstellungen laden', parameters: 'PARAMETER', + loadBalancingHeadline: 'Lastenausgleich', + apiKey: 'API-SCHLÜSSEL', + editConfig: 'Konfiguration bearbeiten', + loadBalancing: 'Lastenausgleich', + addConfig: 'Konfiguration hinzufügen', + configLoadBalancing: 'Lastenausgleich für die Konfiguration', + providerManagedDescription: 'Verwenden Sie den einzelnen Satz von Anmeldeinformationen, der vom Modellanbieter bereitgestellt wird.', + loadBalancingDescription: 'Reduzieren Sie den Druck mit mehreren Sätzen von Anmeldeinformationen.', + modelHasBeenDeprecated: 'Dieses Modell ist veraltet', + loadBalancingLeastKeyWarning: 'Um den Lastausgleich zu aktivieren, müssen mindestens 2 Schlüssel aktiviert sein.', + providerManaged: 'Vom Anbieter verwaltet', + apiKeyStatusNormal: 'APIKey-Status ist normal', + upgradeForLoadBalancing: 'Aktualisieren Sie Ihren Plan, um den Lastenausgleich zu aktivieren.', + defaultConfig: 'Standardkonfiguration', + apiKeyRateLimit: 'Ratenlimit wurde erreicht, verfügbar nach {{seconds}}s', + loadBalancingInfo: 'Standardmäßig wird für den Lastenausgleich die Round-Robin-Strategie verwendet. Wenn die Ratenbegrenzung ausgelöst wird, wird eine Abklingzeit von 1 Minute angewendet.', }, dataSource: { add: 'Eine Datenquelle hinzufügen', @@ -369,6 +392,15 @@ const translation = { preview: 'VORSCHAU', }, }, + website: { + inactive: 'Inaktiv', + description: 'Importieren Sie Inhalte von Websites mit dem Webcrawler.', + title: 'Website', + configuredCrawlers: 'Konfigurierte Crawler', + active: 'Aktiv', + with: 'Mit', + }, + configure: 'Konfigurieren', }, plugin: { serpapi: { @@ -417,6 +449,7 @@ const translation = { promptEng: 'Orchestrieren', apiAccess: 'API-Zugriff', logAndAnn: 'Protokolle & Ank.', + logs: 'Baumstämme', }, environment: { testing: 'TESTEN', @@ -498,6 +531,10 @@ const translation = { add: 'Neue Variable', addTool: 'Neues Werkzeug', }, + outputToolDisabledItem: { + desc: 'Variablen einfügen', + title: 'Variablen', + }, }, query: { item: { @@ -532,6 +569,10 @@ const translation = { created: 'Tag erfolgreich erstellt', failed: 'Tag-Erstellung fehlgeschlagen', }, + errorMsg: { + fieldRequired: '{{field}} ist erforderlich', + urlError: 'Die URL sollte mit http:// oder https:// beginnen', + }, } export default translation diff --git a/web/i18n/de-DE/custom.ts b/web/i18n/de-DE/custom.ts index c24ebca696..2f4cabd67d 100644 --- a/web/i18n/de-DE/custom.ts +++ b/web/i18n/de-DE/custom.ts @@ -1,30 +1,30 @@ -const translation = { - custom: 'Anpassung', - upgradeTip: { - prefix: 'Erweitere deinen Plan auf', - suffix: 'um deine Marke anzupassen.', - }, - webapp: { - title: 'WebApp Marke anpassen', - removeBrand: 'Entferne Powered by Dify', - changeLogo: 'Ändere Powered by Markenbild', - changeLogoTip: 'SVG oder PNG Format mit einer Mindestgröße von 40x40px', - }, - app: { - title: 'App Kopfzeilen Marke anpassen', - changeLogoTip: 'SVG oder PNG Format mit einer Mindestgröße von 80x80px', - }, - upload: 'Hochladen', - uploading: 'Lade hoch', - uploadedFail: 'Bild-Upload fehlgeschlagen, bitte erneut hochladen.', - change: 'Ändern', - apply: 'Anwenden', - restore: 'Standardeinstellungen wiederherstellen', - customize: { - contactUs: ' kontaktiere uns ', - prefix: 'Um das Markenlogo innerhalb der App anzupassen, bitte', - suffix: 'um auf die Enterprise-Edition zu upgraden.', - }, -} - -export default translation +const translation = { + custom: 'Anpassung', + upgradeTip: { + prefix: 'Erweitere deinen Plan auf', + suffix: 'um deine Marke anzupassen.', + }, + webapp: { + title: 'WebApp Marke anpassen', + removeBrand: 'Entferne Powered by Dify', + changeLogo: 'Ändere Powered by Markenbild', + changeLogoTip: 'SVG oder PNG Format mit einer Mindestgröße von 40x40px', + }, + app: { + title: 'App Kopfzeilen Marke anpassen', + changeLogoTip: 'SVG oder PNG Format mit einer Mindestgröße von 80x80px', + }, + upload: 'Hochladen', + uploading: 'Lade hoch', + uploadedFail: 'Bild-Upload fehlgeschlagen, bitte erneut hochladen.', + change: 'Ändern', + apply: 'Anwenden', + restore: 'Standardeinstellungen wiederherstellen', + customize: { + contactUs: ' kontaktiere uns ', + prefix: 'Um das Markenlogo innerhalb der App anzupassen, bitte', + suffix: 'um auf die Enterprise-Edition zu upgraden.', + }, +} + +export default translation diff --git a/web/i18n/de-DE/dataset-creation.ts b/web/i18n/de-DE/dataset-creation.ts index 8727457759..8b27395049 100644 --- a/web/i18n/de-DE/dataset-creation.ts +++ b/web/i18n/de-DE/dataset-creation.ts @@ -1,130 +1,161 @@ -const translation = { - steps: { - header: { - creation: 'Wissen erstellen', - update: 'Daten hinzufügen', - }, - one: 'Datenquelle wählen', - two: 'Textvorverarbeitung und Bereinigung', - three: 'Ausführen und beenden', - }, - error: { - unavailable: 'Dieses Wissen ist nicht verfügbar', - }, - stepOne: { - filePreview: 'Dateivorschau', - pagePreview: 'Seitenvorschau', - dataSourceType: { - file: 'Import aus Textdatei', - notion: 'Synchronisation aus Notion', - web: 'Synchronisation von Webseite', - }, - uploader: { - title: 'Textdatei hochladen', - button: 'Datei hierher ziehen oder', - browse: 'Durchsuchen', - tip: 'Unterstützt {{supportTypes}}. Maximal {{size}}MB pro Datei.', - validation: { - typeError: 'Dateityp nicht unterstützt', - size: 'Datei zu groß. Maximum ist {{size}}MB', - count: 'Mehrere Dateien nicht unterstützt', - filesNumber: 'Sie haben das Limit für die Stapelverarbeitung von {{filesNumber}} erreicht.', - }, - cancel: 'Abbrechen', - change: 'Ändern', - failed: 'Hochladen fehlgeschlagen', - }, - notionSyncTitle: 'Notion ist nicht verbunden', - notionSyncTip: 'Um mit Notion zu synchronisieren, muss zuerst eine Verbindung zu Notion hergestellt werden.', - connect: 'Verbinden gehen', - button: 'weiter', - emptyDatasetCreation: 'Ich möchte ein leeres Wissen erstellen', - modal: { - title: 'Ein leeres Wissen erstellen', - tip: 'Ein leeres Wissen enthält keine Dokumente, und Sie können jederzeit Dokumente hochladen.', - input: 'Wissensname', - placeholder: 'Bitte eingeben', - nameNotEmpty: 'Name darf nicht leer sein', - nameLengthInvaild: 'Name muss zwischen 1 bis 40 Zeichen lang sein', - cancelButton: 'Abbrechen', - confirmButton: 'Erstellen', - failed: 'Erstellung fehlgeschlagen', - }, - }, - stepTwo: { - segmentation: 'Chunk-Einstellungen', - auto: 'Automatisch', - autoDescription: 'Stellt Chunk- und Vorverarbeitungsregeln automatisch ein. Unbekannten Benutzern wird dies empfohlen.', - custom: 'Benutzerdefiniert', - customDescription: 'Chunk-Regeln, Chunk-Länge und Vorverarbeitungsregeln usw. anpassen.', - separator: 'Segmentidentifikator', - separatorPlaceholder: 'Zum Beispiel Neuer Absatz (\\\\n) oder spezieller Separator (wie "***")', - maxLength: 'Maximale Chunk-Länge', - overlap: 'Chunk-Überlappung', - overlapTip: 'Die Einstellung der Chunk-Überlappung kann die semantische Relevanz zwischen ihnen aufrechterhalten und so die Abrufeffekt verbessern. Es wird empfohlen, 10%-25% der maximalen Chunk-Größe einzustellen.', - overlapCheck: 'Chunk-Überlappung sollte nicht größer als maximale Chunk-Länge sein', - rules: 'Textvorverarbeitungsregeln', - removeExtraSpaces: 'Mehrfache Leerzeichen, Zeilenumbrüche und Tabulatoren ersetzen', - removeUrlEmails: 'Alle URLs und E-Mail-Adressen löschen', - removeStopwords: 'Stopwörter wie "ein", "eine", "der" entfernen', - preview: 'Bestätigen & Vorschau', - reset: 'Zurücksetzen', - indexMode: 'Indexmodus', - qualified: 'Hohe Qualität', - recommend: 'Empfehlen', - qualifiedTip: 'Ruft standardmäßige Systemeinbettungsschnittstelle für die Verarbeitung auf, um höhere Genauigkeit bei Benutzerabfragen zu bieten.', - warning: 'Bitte zuerst den API-Schlüssel des Modellanbieters einrichten.', - click: 'Zu den Einstellungen gehen', - economical: 'Ökonomisch', - economicalTip: 'Verwendet Offline-Vektor-Engines, Schlagwortindizes usw., um die Genauigkeit ohne Tokenverbrauch zu reduzieren', - QATitle: 'Segmentierung im Frage-und-Antwort-Format', - QATip: 'Diese Option zu aktivieren, wird mehr Tokens verbrauchen', - QALanguage: 'Segmentierung verwenden', - emstimateCost: 'Schätzung', - emstimateSegment: 'Geschätzte Chunks', - segmentCount: 'Chunks', - calculating: 'Berechnung...', - fileSource: 'Dokumente vorverarbeiten', - notionSource: 'Seiten vorverarbeiten', - other: 'und weitere ', - fileUnit: ' Dateien', - notionUnit: ' Seiten', - previousStep: 'Vorheriger Schritt', - nextStep: 'Speichern & Verarbeiten', - save: 'Speichern & Verarbeiten', - cancel: 'Abbrechen', - sideTipTitle: 'Warum segmentieren und vorverarbeiten?', - sideTipP1: 'Bei der Verarbeitung von Textdaten sind Segmentierung und Bereinigung zwei wichtige Vorverarbeitungsschritte.', - sideTipP2: 'Segmentierung teilt langen Text in Absätze, damit Modelle ihn besser verstehen können. Dies verbessert die Qualität und Relevanz der Modellergebnisse.', - sideTipP3: 'Bereinigung entfernt unnötige Zeichen und Formate, macht das Wissen sauberer und leichter zu parsen.', - sideTipP4: 'Richtige Segmentierung und Bereinigung verbessern die Modellleistung und liefern genauere und wertvollere Ergebnisse.', - previewTitle: 'Vorschau', - previewTitleButton: 'Vorschau', - previewButton: 'Umschalten zum Frage-und-Antwort-Format', - previewSwitchTipStart: 'Die aktuelle Chunk-Vorschau ist im Textformat, ein Wechsel zur Vorschau im Frage-und-Antwort-Format wird', - previewSwitchTipEnd: ' zusätzliche Tokens verbrauchen', - characters: 'Zeichen', - indexSettedTip: 'Um die Indexmethode zu ändern, bitte gehen Sie zu den ', - retrivalSettedTip: 'Um die Indexmethode zu ändern, bitte gehen Sie zu den ', - datasetSettingLink: 'Wissenseinstellungen.', - }, - stepThree: { - creationTitle: '🎉 Wissen erstellt', - creationContent: 'Wir haben das Wissen automatisch benannt, Sie können es jederzeit ändern', - label: 'Wissensname', - additionTitle: '🎉 Dokument hochgeladen', - additionP1: 'Das Dokument wurde zum Wissen hinzugefügt', - additionP2: ', Sie können es in der Dokumentenliste des Wissens finden.', - stop: 'Verarbeitung stoppen', - resume: 'Verarbeitung fortsetzen', - navTo: 'Zum Dokument gehen', - sideTipTitle: 'Was kommt als Nächstes', - sideTipContent: 'Nachdem das Dokument indiziert wurde, kann das Wissen in die Anwendung als Kontext integriert werden, Sie finden die Kontexteinstellung auf der Seite zur Eingabeaufforderungen-Orchestrierung. Sie können es auch als unabhängiges ChatGPT-Indexierungsplugin zur Veröffentlichung erstellen.', - modelTitle: 'Sind Sie sicher, dass Sie die Einbettung stoppen möchten?', - modelContent: 'Wenn Sie die Verarbeitung später fortsetzen möchten, werden Sie dort weitermachen, wo Sie aufgehört haben.', - modelButtonConfirm: 'Bestätigen', - modelButtonCancel: 'Abbrechen', - }, -} - -export default translation +const translation = { + steps: { + header: { + creation: 'Wissen erstellen', + update: 'Daten hinzufügen', + }, + one: 'Datenquelle wählen', + two: 'Textvorverarbeitung und Bereinigung', + three: 'Ausführen und beenden', + }, + error: { + unavailable: 'Dieses Wissen ist nicht verfügbar', + }, + stepOne: { + filePreview: 'Dateivorschau', + pagePreview: 'Seitenvorschau', + dataSourceType: { + file: 'Import aus Textdatei', + notion: 'Synchronisation aus Notion', + web: 'Synchronisation von Webseite', + }, + uploader: { + title: 'Textdatei hochladen', + button: 'Datei hierher ziehen oder', + browse: 'Durchsuchen', + tip: 'Unterstützt {{supportTypes}}. Maximal {{size}}MB pro Datei.', + validation: { + typeError: 'Dateityp nicht unterstützt', + size: 'Datei zu groß. Maximum ist {{size}}MB', + count: 'Mehrere Dateien nicht unterstützt', + filesNumber: 'Sie haben das Limit für die Stapelverarbeitung von {{filesNumber}} erreicht.', + }, + cancel: 'Abbrechen', + change: 'Ändern', + failed: 'Hochladen fehlgeschlagen', + }, + notionSyncTitle: 'Notion ist nicht verbunden', + notionSyncTip: 'Um mit Notion zu synchronisieren, muss zuerst eine Verbindung zu Notion hergestellt werden.', + connect: 'Verbinden gehen', + button: 'weiter', + emptyDatasetCreation: 'Ich möchte ein leeres Wissen erstellen', + modal: { + title: 'Ein leeres Wissen erstellen', + tip: 'Ein leeres Wissen enthält keine Dokumente, und Sie können jederzeit Dokumente hochladen.', + input: 'Wissensname', + placeholder: 'Bitte eingeben', + nameNotEmpty: 'Name darf nicht leer sein', + nameLengthInvalid: 'Name muss zwischen 1 bis 40 Zeichen lang sein', + cancelButton: 'Abbrechen', + confirmButton: 'Erstellen', + failed: 'Erstellung fehlgeschlagen', + }, + website: { + preview: 'Vorschau', + totalPageScraped: 'Gesamtzahl der gescrapten Seiten:', + fireCrawlNotConfigured: 'Firecrawl ist nicht konfiguriert', + options: 'Optionen', + excludePaths: 'Pfade ausschließen', + limit: 'Grenze', + exceptionErrorTitle: 'Beim Ausführen des Firecrawl-Auftrags ist eine Ausnahme aufgetreten:', + selectAll: 'Alles auswählen', + includeOnlyPaths: 'Nur Pfade einschließen', + run: 'Laufen', + firecrawlDoc: 'Firecrawl-Dokumente', + configure: 'Konfigurieren', + fireCrawlNotConfiguredDescription: 'Konfigurieren Sie Firecrawl mit dem API-Schlüssel, um es zu verwenden.', + maxDepth: 'Maximale Tiefe', + unknownError: 'Unbekannter Fehler', + resetAll: 'Alles zurücksetzen', + extractOnlyMainContent: 'Extrahieren Sie nur den Hauptinhalt (keine Kopf-, Navigations- und Fußzeilen usw.)', + firecrawlDocLink: 'https://docs.dify.ai/guides/knowledge-base/sync-from-website', + firecrawlTitle: 'Extrahieren von Webinhalten mit 🔥Firecrawl', + maxDepthTooltip: 'Maximale Tiefe für das Crawlen relativ zur eingegebenen URL. Tiefe 0 kratzt nur die Seite der eingegebenen URL, Tiefe 1 kratzt die URL und alles nach der eingegebenen URL + ein / und so weiter.', + crawlSubPage: 'Unterseiten crawlen', + scrapTimeInfo: 'Insgesamt {{{total}} Seiten innerhalb von {{time}}s gescrapt', + }, + }, + stepTwo: { + segmentation: 'Chunk-Einstellungen', + auto: 'Automatisch', + autoDescription: 'Stellt Chunk- und Vorverarbeitungsregeln automatisch ein. Unbekannten Benutzern wird dies empfohlen.', + custom: 'Benutzerdefiniert', + customDescription: 'Chunk-Regeln, Chunk-Länge und Vorverarbeitungsregeln usw. anpassen.', + separator: 'Segmentidentifikator', + separatorPlaceholder: 'Zum Beispiel Neuer Absatz (\\\\n) oder spezieller Separator (wie "***")', + maxLength: 'Maximale Chunk-Länge', + overlap: 'Chunk-Überlappung', + overlapTip: 'Die Einstellung der Chunk-Überlappung kann die semantische Relevanz zwischen ihnen aufrechterhalten und so die Abrufeffekt verbessern. Es wird empfohlen, 10%-25% der maximalen Chunk-Größe einzustellen.', + overlapCheck: 'Chunk-Überlappung sollte nicht größer als maximale Chunk-Länge sein', + rules: 'Textvorverarbeitungsregeln', + removeExtraSpaces: 'Mehrfache Leerzeichen, Zeilenumbrüche und Tabulatoren ersetzen', + removeUrlEmails: 'Alle URLs und E-Mail-Adressen löschen', + removeStopwords: 'Stopwörter wie "ein", "eine", "der" entfernen', + preview: 'Bestätigen & Vorschau', + reset: 'Zurücksetzen', + indexMode: 'Indexmodus', + qualified: 'Hohe Qualität', + recommend: 'Empfehlen', + qualifiedTip: 'Ruft standardmäßige Systemeinbettungsschnittstelle für die Verarbeitung auf, um höhere Genauigkeit bei Benutzerabfragen zu bieten.', + warning: 'Bitte zuerst den API-Schlüssel des Modellanbieters einrichten.', + click: 'Zu den Einstellungen gehen', + economical: 'Ökonomisch', + economicalTip: 'Verwendet Offline-Vektor-Engines, Schlagwortindizes usw., um die Genauigkeit ohne Tokenverbrauch zu reduzieren', + QATitle: 'Segmentierung im Frage-und-Antwort-Format', + QATip: 'Diese Option zu aktivieren, wird mehr Tokens verbrauchen', + QALanguage: 'Segmentierung verwenden', + estimateCost: 'Schätzung', + estimateSegment: 'Geschätzte Chunks', + segmentCount: 'Chunks', + calculating: 'Berechnung...', + fileSource: 'Dokumente vorverarbeiten', + notionSource: 'Seiten vorverarbeiten', + other: 'und weitere ', + fileUnit: ' Dateien', + notionUnit: ' Seiten', + previousStep: 'Vorheriger Schritt', + nextStep: 'Speichern & Verarbeiten', + save: 'Speichern & Verarbeiten', + cancel: 'Abbrechen', + sideTipTitle: 'Warum segmentieren und vorverarbeiten?', + sideTipP1: 'Bei der Verarbeitung von Textdaten sind Segmentierung und Bereinigung zwei wichtige Vorverarbeitungsschritte.', + sideTipP2: 'Segmentierung teilt langen Text in Absätze, damit Modelle ihn besser verstehen können. Dies verbessert die Qualität und Relevanz der Modellergebnisse.', + sideTipP3: 'Bereinigung entfernt unnötige Zeichen und Formate, macht das Wissen sauberer und leichter zu parsen.', + sideTipP4: 'Richtige Segmentierung und Bereinigung verbessern die Modellleistung und liefern genauere und wertvollere Ergebnisse.', + previewTitle: 'Vorschau', + previewTitleButton: 'Vorschau', + previewButton: 'Umschalten zum Frage-und-Antwort-Format', + previewSwitchTipStart: 'Die aktuelle Chunk-Vorschau ist im Textformat, ein Wechsel zur Vorschau im Frage-und-Antwort-Format wird', + previewSwitchTipEnd: ' zusätzliche Tokens verbrauchen', + characters: 'Zeichen', + indexSettingTip: 'Um die Indexmethode zu ändern, bitte gehen Sie zu den ', + retrievalSettingTip: 'Um die Indexmethode zu ändern, bitte gehen Sie zu den ', + datasetSettingLink: 'Wissenseinstellungen.', + websiteSource: 'Preprocess-Website', + webpageUnit: 'Seiten', + }, + stepThree: { + creationTitle: '🎉 Wissen erstellt', + creationContent: 'Wir haben das Wissen automatisch benannt, Sie können es jederzeit ändern', + label: 'Wissensname', + additionTitle: '🎉 Dokument hochgeladen', + additionP1: 'Das Dokument wurde zum Wissen hinzugefügt', + additionP2: ', Sie können es in der Dokumentenliste des Wissens finden.', + stop: 'Verarbeitung stoppen', + resume: 'Verarbeitung fortsetzen', + navTo: 'Zum Dokument gehen', + sideTipTitle: 'Was kommt als Nächstes', + sideTipContent: 'Nachdem das Dokument indiziert wurde, kann das Wissen in die Anwendung als Kontext integriert werden, Sie finden die Kontexteinstellung auf der Seite zur Eingabeaufforderungen-Orchestrierung. Sie können es auch als unabhängiges ChatGPT-Indexierungsplugin zur Veröffentlichung erstellen.', + modelTitle: 'Sind Sie sicher, dass Sie die Einbettung stoppen möchten?', + modelContent: 'Wenn Sie die Verarbeitung später fortsetzen möchten, werden Sie dort weitermachen, wo Sie aufgehört haben.', + modelButtonConfirm: 'Bestätigen', + modelButtonCancel: 'Abbrechen', + }, + firecrawl: { + apiKeyPlaceholder: 'API-Schlüssel von firecrawl.dev', + configFirecrawl: 'Konfigurieren von 🔥Firecrawl', + getApiKeyLinkText: 'Holen Sie sich Ihren API-Schlüssel von firecrawl.dev', + }, +} + +export default translation diff --git a/web/i18n/de-DE/dataset-documents.ts b/web/i18n/de-DE/dataset-documents.ts index e72f808392..114a73544d 100644 --- a/web/i18n/de-DE/dataset-documents.ts +++ b/web/i18n/de-DE/dataset-documents.ts @@ -1,349 +1,352 @@ -const translation = { - list: { - title: 'Dokumente', - desc: 'Alle Dateien des Wissens werden hier angezeigt, und das gesamte Wissen kann mit Dify-Zitaten verknüpft oder über das Chat-Plugin indiziert werden.', - addFile: 'Datei hinzufügen', - addPages: 'Seiten hinzufügen', - table: { - header: { - fileName: 'DATEINAME', - words: 'WÖRTER', - hitCount: 'SUCHANFRAGEN', - uploadTime: 'HOCHLADEZEIT', - status: 'STATUS', - action: 'AKTION', - }, - }, - action: { - uploadFile: 'Neue Datei hochladen', - settings: 'Segment-Einstellungen', - addButton: 'Chunk hinzufügen', - add: 'Einen Chunk hinzufügen', - batchAdd: 'Batch hinzufügen', - archive: 'Archivieren', - unarchive: 'Archivierung aufheben', - delete: 'Löschen', - enableWarning: 'Archivierte Datei kann nicht aktiviert werden', - sync: 'Synchronisieren', - }, - index: { - enable: 'Aktivieren', - disable: 'Deaktivieren', - all: 'Alle', - enableTip: 'Die Datei kann indiziert werden', - disableTip: 'Die Datei kann nicht indiziert werden', - }, - status: { - queuing: 'In Warteschlange', - indexing: 'Indizierung', - paused: 'Pausiert', - error: 'Fehler', - available: 'Verfügbar', - enabled: 'Aktiviert', - disabled: 'Deaktiviert', - archived: 'Archiviert', - }, - empty: { - title: 'Es gibt noch keine Dokumentation', - upload: { - tip: 'Sie können Dateien hochladen, von der Website oder von Web-Apps wie Notion, GitHub usw. synchronisieren.', - }, - sync: { - tip: 'Dify wird periodisch Dateien von Ihrem Notion herunterladen und die Verarbeitung abschließen.', - }, - }, - delete: { - title: 'Sind Sie sicher, dass Sie löschen möchten?', - content: 'Wenn Sie die Verarbeitung später fortsetzen müssen, werden Sie dort weitermachen, wo Sie aufgehört haben', - }, - batchModal: { - title: 'Chunks in Batch hinzufügen', - csvUploadTitle: 'Ziehen Sie Ihre CSV-Datei hierher oder ', - browse: 'durchsuchen', - tip: 'Die CSV-Datei muss der folgenden Struktur entsprechen:', - question: 'Frage', - answer: 'Antwort', - contentTitle: 'Chunk-Inhalt', - content: 'Inhalt', - template: 'Laden Sie die Vorlage hier herunter', - cancel: 'Abbrechen', - run: 'Batch ausführen', - runError: 'Batch-Ausführung fehlgeschlagen', - processing: 'In Batch-Verarbeitung', - completed: 'Import abgeschlossen', - error: 'Importfehler', - ok: 'OK', - }, - }, - metadata: { - title: 'Metadaten', - desc: 'Das Kennzeichnen von Metadaten für Dokumente ermöglicht es der KI, sie rechtzeitig zu erreichen und die Quelle der Referenzen für die Benutzer offenzulegen.', - dateTimeFormat: 'MMMM D, YYYY hh:mm A', - docTypeSelectTitle: 'Bitte wählen Sie einen Dokumenttyp', - docTypeChangeTitle: 'Dokumenttyp ändern', - docTypeSelectWarning: - 'Wenn der Dokumenttyp geändert wird, werden die jetzt ausgefüllten Metadaten nicht mehr erhalten bleiben', - firstMetaAction: 'Los geht\'s', - placeholder: { - add: 'Hinzufügen ', - select: 'Auswählen ', - }, - source: { - upload_file: 'Datei hochladen', - notion: 'Von Notion synchronisieren', - github: 'Von Github synchronisieren', - }, - type: { - book: 'Buch', - webPage: 'Webseite', - paper: 'Aufsatz', - socialMediaPost: 'Social Media Beitrag', - personalDocument: 'Persönliches Dokument', - businessDocument: 'Geschäftsdokument', - IMChat: 'IM Chat', - wikipediaEntry: 'Wikipedia-Eintrag', - notion: 'Von Notion synchronisieren', - github: 'Von Github synchronisieren', - technicalParameters: 'Technische Parameter', - }, - field: { - processRule: { - processDoc: 'Dokument verarbeiten', - segmentRule: 'Chunk-Regel', - segmentLength: 'Chunk-Länge', - processClean: 'Textverarbeitung bereinigen', - }, - book: { - title: 'Titel', - language: 'Sprache', - author: 'Autor', - publisher: 'Verlag', - publicationDate: 'Veröffentlichungsdatum', - ISBN: 'ISBN', - category: 'Kategorie', - }, - webPage: { - title: 'Titel', - url: 'URL', - language: 'Sprache', - authorPublisher: 'Autor/Verlag', - publishDate: 'Veröffentlichungsdatum', - topicsKeywords: 'Themen/Schlüsselwörter', - description: 'Beschreibung', - }, - paper: { - title: 'Titel', - language: 'Sprache', - author: 'Autor', - publishDate: 'Veröffentlichungsdatum', - journalConferenceName: 'Zeitschrift/Konferenzname', - volumeIssuePage: 'Band/Ausgabe/Seite', - DOI: 'DOI', - topicsKeywords: 'Themen/Schlüsselwörter', - abstract: 'Zusammenfassung', - }, - socialMediaPost: { - platform: 'Plattform', - authorUsername: 'Autor/Benutzername', - publishDate: 'Veröffentlichungsdatum', - postURL: 'Beitrags-URL', - topicsTags: 'Themen/Tags', - }, - personalDocument: { - title: 'Titel', - author: 'Autor', - creationDate: 'Erstellungsdatum', - lastModifiedDate: 'Letztes Änderungsdatum', - documentType: 'Dokumenttyp', - tagsCategory: 'Tags/Kategorie', - }, - businessDocument: { - title: 'Titel', - author: 'Autor', - creationDate: 'Erstellungsdatum', - lastModifiedDate: 'Letztes Änderungsdatum', - documentType: 'Dokumenttyp', - departmentTeam: 'Abteilung/Team', - }, - IMChat: { - chatPlatform: 'Chat-Plattform', - chatPartiesGroupName: 'Chat-Parteien/Gruppenname', - participants: 'Teilnehmer', - startDate: 'Startdatum', - endDate: 'Enddatum', - topicsKeywords: 'Themen/Schlüsselwörter', - fileType: 'Dateityp', - }, - wikipediaEntry: { - title: 'Titel', - language: 'Sprache', - webpageURL: 'Webseiten-URL', - editorContributor: 'Editor/Beitragender', - lastEditDate: 'Letztes Bearbeitungsdatum', - summaryIntroduction: 'Zusammenfassung/Einführung', - }, - notion: { - title: 'Titel', - language: 'Sprache', - author: 'Autor', - createdTime: 'Erstellungszeit', - lastModifiedTime: 'Letzte Änderungszeit', - url: 'URL', - tag: 'Tag', - description: 'Beschreibung', - }, - github: { - repoName: 'Repository-Name', - repoDesc: 'Repository-Beschreibung', - repoOwner: 'Repository-Eigentümer', - fileName: 'Dateiname', - filePath: 'Dateipfad', - programmingLang: 'Programmiersprache', - url: 'URL', - license: 'Lizenz', - lastCommitTime: 'Letzte Commit-Zeit', - lastCommitAuthor: 'Letzter Commit-Autor', - }, - originInfo: { - originalFilename: 'Originaldateiname', - originalFileSize: 'Originaldateigröße', - uploadDate: 'Hochladedatum', - lastUpdateDate: 'Letztes Änderungsdatum', - source: 'Quelle', - }, - technicalParameters: { - segmentSpecification: 'Chunk-Spezifikation', - segmentLength: 'Chunk-Länge', - avgParagraphLength: 'Durchschn. Absatzlänge', - paragraphs: 'Absätze', - hitCount: 'Abrufanzahl', - embeddingTime: 'Einbettungszeit', - embeddedSpend: 'Einbettungsausgaben', - }, - }, - languageMap: { - zh: 'Chinesisch', - en: 'Englisch', - es: 'Spanisch', - fr: 'Französisch', - de: 'Deutsch', - ja: 'Japanisch', - ko: 'Koreanisch', - ru: 'Russisch', - ar: 'Arabisch', - pt: 'Portugiesisch', - it: 'Italienisch', - nl: 'Niederländisch', - pl: 'Polnisch', - sv: 'Schwedisch', - tr: 'Türkisch', - he: 'Hebräisch', - hi: 'Hindi', - da: 'Dänisch', - fi: 'Finnisch', - no: 'Norwegisch', - hu: 'Ungarisch', - el: 'Griechisch', - cs: 'Tschechisch', - th: 'Thai', - id: 'Indonesisch', - }, - categoryMap: { - book: { - fiction: 'Fiktion', - biography: 'Biografie', - history: 'Geschichte', - science: 'Wissenschaft', - technology: 'Technologie', - education: 'Bildung', - philosophy: 'Philosophie', - religion: 'Religion', - socialSciences: 'Sozialwissenschaften', - art: 'Kunst', - travel: 'Reisen', - health: 'Gesundheit', - selfHelp: 'Selbsthilfe', - businessEconomics: 'Wirtschaft', - cooking: 'Kochen', - childrenYoungAdults: 'Kinder & Jugendliche', - comicsGraphicNovels: 'Comics & Grafische Romane', - poetry: 'Poesie', - drama: 'Drama', - other: 'Andere', - }, - personalDoc: { - notes: 'Notizen', - blogDraft: 'Blog-Entwurf', - diary: 'Tagebuch', - researchReport: 'Forschungsbericht', - bookExcerpt: 'Buchauszug', - schedule: 'Zeitplan', - list: 'Liste', - projectOverview: 'Projektübersicht', - photoCollection: 'Fotosammlung', - creativeWriting: 'Kreatives Schreiben', - codeSnippet: 'Code-Snippet', - designDraft: 'Design-Entwurf', - personalResume: 'Persönlicher Lebenslauf', - other: 'Andere', - }, - businessDoc: { - meetingMinutes: 'Protokolle', - researchReport: 'Forschungsbericht', - proposal: 'Vorschlag', - employeeHandbook: 'Mitarbeiterhandbuch', - trainingMaterials: 'Schulungsmaterialien', - requirementsDocument: 'Anforderungsdokumentation', - designDocument: 'Design-Dokument', - productSpecification: 'Produktspezifikation', - financialReport: 'Finanzbericht', - marketAnalysis: 'Marktanalyse', - projectPlan: 'Projektplan', - teamStructure: 'Teamstruktur', - policiesProcedures: 'Richtlinien & Verfahren', - contractsAgreements: 'Verträge & Vereinbarungen', - emailCorrespondence: 'E-Mail-Korrespondenz', - other: 'Andere', - }, - }, - }, - embedding: { - processing: 'Einbettungsverarbeitung...', - paused: 'Einbettung pausiert', - completed: 'Einbettung abgeschlossen', - error: 'Einbettungsfehler', - docName: 'Dokument vorbereiten', - mode: 'Segmentierungsregel', - segmentLength: 'Chunk-Länge', - textCleaning: 'Textvordefinition und -bereinigung', - segments: 'Absätze', - highQuality: 'Hochwertiger Modus', - economy: 'Wirtschaftlicher Modus', - estimate: 'Geschätzter Verbrauch', - stop: 'Verarbeitung stoppen', - resume: 'Verarbeitung fortsetzen', - automatic: 'Automatisch', - custom: 'Benutzerdefiniert', - previewTip: 'Absatzvorschau ist nach Abschluss der Einbettung verfügbar', - }, - segment: { - paragraphs: 'Absätze', - keywords: 'Schlüsselwörter', - addKeyWord: 'Schlüsselwort hinzufügen', - keywordError: 'Die maximale Länge des Schlüsselworts beträgt 20', - characters: 'Zeichen', - hitCount: 'Abrufanzahl', - vectorHash: 'Vektor-Hash: ', - questionPlaceholder: 'Frage hier hinzufügen', - questionEmpty: 'Frage darf nicht leer sein', - answerPlaceholder: 'Antwort hier hinzufügen', - answerEmpty: 'Antwort darf nicht leer sein', - contentPlaceholder: 'Inhalt hier hinzufügen', - contentEmpty: 'Inhalt darf nicht leer sein', - newTextSegment: 'Neues Textsegment', - newQaSegment: 'Neues Q&A-Segment', - delete: 'Diesen Chunk löschen?', - }, -} - -export default translation +const translation = { + list: { + title: 'Dokumente', + desc: 'Alle Dateien des Wissens werden hier angezeigt, und das gesamte Wissen kann mit Dify-Zitaten verknüpft oder über das Chat-Plugin indiziert werden.', + addFile: 'Datei hinzufügen', + addPages: 'Seiten hinzufügen', + table: { + header: { + fileName: 'DATEINAME', + words: 'WÖRTER', + hitCount: 'SUCHANFRAGEN', + uploadTime: 'HOCHLADEZEIT', + status: 'STATUS', + action: 'AKTION', + }, + name: 'Name', + rename: 'Umbenennen', + }, + action: { + uploadFile: 'Neue Datei hochladen', + settings: 'Segment-Einstellungen', + addButton: 'Chunk hinzufügen', + add: 'Einen Chunk hinzufügen', + batchAdd: 'Batch hinzufügen', + archive: 'Archivieren', + unarchive: 'Archivierung aufheben', + delete: 'Löschen', + enableWarning: 'Archivierte Datei kann nicht aktiviert werden', + sync: 'Synchronisieren', + }, + index: { + enable: 'Aktivieren', + disable: 'Deaktivieren', + all: 'Alle', + enableTip: 'Die Datei kann indiziert werden', + disableTip: 'Die Datei kann nicht indiziert werden', + }, + status: { + queuing: 'In Warteschlange', + indexing: 'Indizierung', + paused: 'Pausiert', + error: 'Fehler', + available: 'Verfügbar', + enabled: 'Aktiviert', + disabled: 'Deaktiviert', + archived: 'Archiviert', + }, + empty: { + title: 'Es gibt noch keine Dokumentation', + upload: { + tip: 'Sie können Dateien hochladen, von der Website oder von Web-Apps wie Notion, GitHub usw. synchronisieren.', + }, + sync: { + tip: 'Dify wird periodisch Dateien von Ihrem Notion herunterladen und die Verarbeitung abschließen.', + }, + }, + delete: { + title: 'Sind Sie sicher, dass Sie löschen möchten?', + content: 'Wenn Sie die Verarbeitung später fortsetzen müssen, werden Sie dort weitermachen, wo Sie aufgehört haben', + }, + batchModal: { + title: 'Chunks in Batch hinzufügen', + csvUploadTitle: 'Ziehen Sie Ihre CSV-Datei hierher oder ', + browse: 'durchsuchen', + tip: 'Die CSV-Datei muss der folgenden Struktur entsprechen:', + question: 'Frage', + answer: 'Antwort', + contentTitle: 'Chunk-Inhalt', + content: 'Inhalt', + template: 'Laden Sie die Vorlage hier herunter', + cancel: 'Abbrechen', + run: 'Batch ausführen', + runError: 'Batch-Ausführung fehlgeschlagen', + processing: 'In Batch-Verarbeitung', + completed: 'Import abgeschlossen', + error: 'Importfehler', + ok: 'OK', + }, + addUrl: 'URL hinzufügen', + }, + metadata: { + title: 'Metadaten', + desc: 'Das Kennzeichnen von Metadaten für Dokumente ermöglicht es der KI, sie rechtzeitig zu erreichen und die Quelle der Referenzen für die Benutzer offenzulegen.', + dateTimeFormat: 'MMMM D, YYYY hh:mm A', + docTypeSelectTitle: 'Bitte wählen Sie einen Dokumenttyp', + docTypeChangeTitle: 'Dokumenttyp ändern', + docTypeSelectWarning: + 'Wenn der Dokumenttyp geändert wird, werden die jetzt ausgefüllten Metadaten nicht mehr erhalten bleiben', + firstMetaAction: 'Los geht\'s', + placeholder: { + add: 'Hinzufügen ', + select: 'Auswählen ', + }, + source: { + upload_file: 'Datei hochladen', + notion: 'Von Notion synchronisieren', + github: 'Von Github synchronisieren', + }, + type: { + book: 'Buch', + webPage: 'Webseite', + paper: 'Aufsatz', + socialMediaPost: 'Social Media Beitrag', + personalDocument: 'Persönliches Dokument', + businessDocument: 'Geschäftsdokument', + IMChat: 'IM Chat', + wikipediaEntry: 'Wikipedia-Eintrag', + notion: 'Von Notion synchronisieren', + github: 'Von Github synchronisieren', + technicalParameters: 'Technische Parameter', + }, + field: { + processRule: { + processDoc: 'Dokument verarbeiten', + segmentRule: 'Chunk-Regel', + segmentLength: 'Chunk-Länge', + processClean: 'Textverarbeitung bereinigen', + }, + book: { + title: 'Titel', + language: 'Sprache', + author: 'Autor', + publisher: 'Verlag', + publicationDate: 'Veröffentlichungsdatum', + ISBN: 'ISBN', + category: 'Kategorie', + }, + webPage: { + title: 'Titel', + url: 'URL', + language: 'Sprache', + authorPublisher: 'Autor/Verlag', + publishDate: 'Veröffentlichungsdatum', + topicsKeywords: 'Themen/Schlüsselwörter', + description: 'Beschreibung', + }, + paper: { + title: 'Titel', + language: 'Sprache', + author: 'Autor', + publishDate: 'Veröffentlichungsdatum', + journalConferenceName: 'Zeitschrift/Konferenzname', + volumeIssuePage: 'Band/Ausgabe/Seite', + DOI: 'DOI', + topicsKeywords: 'Themen/Schlüsselwörter', + abstract: 'Zusammenfassung', + }, + socialMediaPost: { + platform: 'Plattform', + authorUsername: 'Autor/Benutzername', + publishDate: 'Veröffentlichungsdatum', + postURL: 'Beitrags-URL', + topicsTags: 'Themen/Tags', + }, + personalDocument: { + title: 'Titel', + author: 'Autor', + creationDate: 'Erstellungsdatum', + lastModifiedDate: 'Letztes Änderungsdatum', + documentType: 'Dokumenttyp', + tagsCategory: 'Tags/Kategorie', + }, + businessDocument: { + title: 'Titel', + author: 'Autor', + creationDate: 'Erstellungsdatum', + lastModifiedDate: 'Letztes Änderungsdatum', + documentType: 'Dokumenttyp', + departmentTeam: 'Abteilung/Team', + }, + IMChat: { + chatPlatform: 'Chat-Plattform', + chatPartiesGroupName: 'Chat-Parteien/Gruppenname', + participants: 'Teilnehmer', + startDate: 'Startdatum', + endDate: 'Enddatum', + topicsKeywords: 'Themen/Schlüsselwörter', + fileType: 'Dateityp', + }, + wikipediaEntry: { + title: 'Titel', + language: 'Sprache', + webpageURL: 'Webseiten-URL', + editorContributor: 'Editor/Beitragender', + lastEditDate: 'Letztes Bearbeitungsdatum', + summaryIntroduction: 'Zusammenfassung/Einführung', + }, + notion: { + title: 'Titel', + language: 'Sprache', + author: 'Autor', + createdTime: 'Erstellungszeit', + lastModifiedTime: 'Letzte Änderungszeit', + url: 'URL', + tag: 'Tag', + description: 'Beschreibung', + }, + github: { + repoName: 'Repository-Name', + repoDesc: 'Repository-Beschreibung', + repoOwner: 'Repository-Eigentümer', + fileName: 'Dateiname', + filePath: 'Dateipfad', + programmingLang: 'Programmiersprache', + url: 'URL', + license: 'Lizenz', + lastCommitTime: 'Letzte Commit-Zeit', + lastCommitAuthor: 'Letzter Commit-Autor', + }, + originInfo: { + originalFilename: 'Originaldateiname', + originalFileSize: 'Originaldateigröße', + uploadDate: 'Hochladedatum', + lastUpdateDate: 'Letztes Änderungsdatum', + source: 'Quelle', + }, + technicalParameters: { + segmentSpecification: 'Chunk-Spezifikation', + segmentLength: 'Chunk-Länge', + avgParagraphLength: 'Durchschn. Absatzlänge', + paragraphs: 'Absätze', + hitCount: 'Abrufanzahl', + embeddingTime: 'Einbettungszeit', + embeddedSpend: 'Einbettungsausgaben', + }, + }, + languageMap: { + zh: 'Chinesisch', + en: 'Englisch', + es: 'Spanisch', + fr: 'Französisch', + de: 'Deutsch', + ja: 'Japanisch', + ko: 'Koreanisch', + ru: 'Russisch', + ar: 'Arabisch', + pt: 'Portugiesisch', + it: 'Italienisch', + nl: 'Niederländisch', + pl: 'Polnisch', + sv: 'Schwedisch', + tr: 'Türkisch', + he: 'Hebräisch', + hi: 'Hindi', + da: 'Dänisch', + fi: 'Finnisch', + no: 'Norwegisch', + hu: 'Ungarisch', + el: 'Griechisch', + cs: 'Tschechisch', + th: 'Thai', + id: 'Indonesisch', + }, + categoryMap: { + book: { + fiction: 'Fiktion', + biography: 'Biografie', + history: 'Geschichte', + science: 'Wissenschaft', + technology: 'Technologie', + education: 'Bildung', + philosophy: 'Philosophie', + religion: 'Religion', + socialSciences: 'Sozialwissenschaften', + art: 'Kunst', + travel: 'Reisen', + health: 'Gesundheit', + selfHelp: 'Selbsthilfe', + businessEconomics: 'Wirtschaft', + cooking: 'Kochen', + childrenYoungAdults: 'Kinder & Jugendliche', + comicsGraphicNovels: 'Comics & Grafische Romane', + poetry: 'Poesie', + drama: 'Drama', + other: 'Andere', + }, + personalDoc: { + notes: 'Notizen', + blogDraft: 'Blog-Entwurf', + diary: 'Tagebuch', + researchReport: 'Forschungsbericht', + bookExcerpt: 'Buchauszug', + schedule: 'Zeitplan', + list: 'Liste', + projectOverview: 'Projektübersicht', + photoCollection: 'Fotosammlung', + creativeWriting: 'Kreatives Schreiben', + codeSnippet: 'Code-Snippet', + designDraft: 'Design-Entwurf', + personalResume: 'Persönlicher Lebenslauf', + other: 'Andere', + }, + businessDoc: { + meetingMinutes: 'Protokolle', + researchReport: 'Forschungsbericht', + proposal: 'Vorschlag', + employeeHandbook: 'Mitarbeiterhandbuch', + trainingMaterials: 'Schulungsmaterialien', + requirementsDocument: 'Anforderungsdokumentation', + designDocument: 'Design-Dokument', + productSpecification: 'Produktspezifikation', + financialReport: 'Finanzbericht', + marketAnalysis: 'Marktanalyse', + projectPlan: 'Projektplan', + teamStructure: 'Teamstruktur', + policiesProcedures: 'Richtlinien & Verfahren', + contractsAgreements: 'Verträge & Vereinbarungen', + emailCorrespondence: 'E-Mail-Korrespondenz', + other: 'Andere', + }, + }, + }, + embedding: { + processing: 'Einbettungsverarbeitung...', + paused: 'Einbettung pausiert', + completed: 'Einbettung abgeschlossen', + error: 'Einbettungsfehler', + docName: 'Dokument vorbereiten', + mode: 'Segmentierungsregel', + segmentLength: 'Chunk-Länge', + textCleaning: 'Textvordefinition und -bereinigung', + segments: 'Absätze', + highQuality: 'Hochwertiger Modus', + economy: 'Wirtschaftlicher Modus', + estimate: 'Geschätzter Verbrauch', + stop: 'Verarbeitung stoppen', + resume: 'Verarbeitung fortsetzen', + automatic: 'Automatisch', + custom: 'Benutzerdefiniert', + previewTip: 'Absatzvorschau ist nach Abschluss der Einbettung verfügbar', + }, + segment: { + paragraphs: 'Absätze', + keywords: 'Schlüsselwörter', + addKeyWord: 'Schlüsselwort hinzufügen', + keywordError: 'Die maximale Länge des Schlüsselworts beträgt 20', + characters: 'Zeichen', + hitCount: 'Abrufanzahl', + vectorHash: 'Vektor-Hash: ', + questionPlaceholder: 'Frage hier hinzufügen', + questionEmpty: 'Frage darf nicht leer sein', + answerPlaceholder: 'Antwort hier hinzufügen', + answerEmpty: 'Antwort darf nicht leer sein', + contentPlaceholder: 'Inhalt hier hinzufügen', + contentEmpty: 'Inhalt darf nicht leer sein', + newTextSegment: 'Neues Textsegment', + newQaSegment: 'Neues Q&A-Segment', + delete: 'Diesen Chunk löschen?', + }, +} + +export default translation diff --git a/web/i18n/de-DE/dataset-hit-testing.ts b/web/i18n/de-DE/dataset-hit-testing.ts index 89f90a57a7..baf88016b3 100644 --- a/web/i18n/de-DE/dataset-hit-testing.ts +++ b/web/i18n/de-DE/dataset-hit-testing.ts @@ -1,28 +1,28 @@ -const translation = { - title: 'Abruf-Test', - desc: 'Testen Sie die Treffereffektivität des Wissens anhand des gegebenen Abfragetextes.', - dateTimeFormat: 'MM/TT/JJJJ hh:mm A', - recents: 'Kürzlich', - table: { - header: { - source: 'Quelle', - text: 'Text', - time: 'Zeit', - }, - }, - input: { - title: 'Quelltext', - placeholder: 'Bitte geben Sie einen Text ein, ein kurzer aussagekräftiger Satz wird empfohlen.', - countWarning: 'Bis zu 200 Zeichen.', - indexWarning: 'Nur Wissen hoher Qualität.', - testing: 'Testen', - }, - hit: { - title: 'ABRUFPARAGRAFEN', - emptyTip: 'Ergebnisse des Abruf-Tests werden hier angezeigt', - }, - noRecentTip: 'Keine kürzlichen Abfrageergebnisse hier', - viewChart: 'VEKTORDIAGRAMM ansehen', -} - -export default translation +const translation = { + title: 'Abruf-Test', + desc: 'Testen Sie die Treffereffektivität des Wissens anhand des gegebenen Abfragetextes.', + dateTimeFormat: 'MM/TT/JJJJ hh:mm A', + recents: 'Kürzlich', + table: { + header: { + source: 'Quelle', + text: 'Text', + time: 'Zeit', + }, + }, + input: { + title: 'Quelltext', + placeholder: 'Bitte geben Sie einen Text ein, ein kurzer aussagekräftiger Satz wird empfohlen.', + countWarning: 'Bis zu 200 Zeichen.', + indexWarning: 'Nur Wissen hoher Qualität.', + testing: 'Testen', + }, + hit: { + title: 'ABRUFPARAGRAFEN', + emptyTip: 'Ergebnisse des Abruf-Tests werden hier angezeigt', + }, + noRecentTip: 'Keine kürzlichen Abfrageergebnisse hier', + viewChart: 'VEKTORDIAGRAMM ansehen', +} + +export default translation diff --git a/web/i18n/de-DE/dataset-settings.ts b/web/i18n/de-DE/dataset-settings.ts index c986334f15..b29e778075 100644 --- a/web/i18n/de-DE/dataset-settings.ts +++ b/web/i18n/de-DE/dataset-settings.ts @@ -1,33 +1,35 @@ -const translation = { - title: 'Wissenseinstellungen', - desc: 'Hier können Sie die Eigenschaften und Arbeitsweisen des Wissens anpassen.', - form: { - name: 'Wissensname', - namePlaceholder: 'Bitte geben Sie den Namen des Wissens ein', - nameError: 'Name darf nicht leer sein', - desc: 'Wissensbeschreibung', - descInfo: 'Bitte schreiben Sie eine klare textuelle Beschreibung, um den Inhalt des Wissens zu umreißen. Diese Beschreibung wird als Grundlage für die Auswahl aus mehreren Wissensdatenbanken zur Inferenz verwendet.', - descPlaceholder: 'Beschreiben Sie, was in diesem Wissen enthalten ist. Eine detaillierte Beschreibung ermöglicht es der KI, zeitnah auf den Inhalt des Wissens zuzugreifen. Wenn leer, verwendet Dify die Standard-Treffstrategie.', - descWrite: 'Erfahren Sie, wie man eine gute Wissensbeschreibung schreibt.', - permissions: 'Berechtigungen', - permissionsOnlyMe: 'Nur ich', - permissionsAllMember: 'Alle Teammitglieder', - indexMethod: 'Indexierungsmethode', - indexMethodHighQuality: 'Hohe Qualität', - indexMethodHighQualityTip: 'Den Embedding-Modell zur Verarbeitung aufrufen, um bei Benutzeranfragen eine höhere Genauigkeit zu bieten.', - indexMethodEconomy: 'Ökonomisch', - indexMethodEconomyTip: 'Verwendet Offline-Vektor-Engines, Schlagwortindizes usw., um die Genauigkeit ohne Tokenverbrauch zu reduzieren', - embeddingModel: 'Einbettungsmodell', - embeddingModelTip: 'Ändern Sie das eingebettete Modell, bitte gehen Sie zu ', - embeddingModelTipLink: 'Einstellungen', - retrievalSetting: { - title: 'Abrufeinstellung', - learnMore: 'Mehr erfahren', - description: ' über die Abrufmethode.', - longDescription: ' über die Abrufmethode, dies kann jederzeit in den Wissenseinstellungen geändert werden.', - }, - save: 'Speichern', - }, -} - -export default translation +const translation = { + title: 'Wissenseinstellungen', + desc: 'Hier können Sie die Eigenschaften und Arbeitsweisen des Wissens anpassen.', + form: { + name: 'Wissensname', + namePlaceholder: 'Bitte geben Sie den Namen des Wissens ein', + nameError: 'Name darf nicht leer sein', + desc: 'Wissensbeschreibung', + descInfo: 'Bitte schreiben Sie eine klare textuelle Beschreibung, um den Inhalt des Wissens zu umreißen. Diese Beschreibung wird als Grundlage für die Auswahl aus mehreren Wissensdatenbanken zur Inferenz verwendet.', + descPlaceholder: 'Beschreiben Sie, was in diesem Wissen enthalten ist. Eine detaillierte Beschreibung ermöglicht es der KI, zeitnah auf den Inhalt des Wissens zuzugreifen. Wenn leer, verwendet Dify die Standard-Treffstrategie.', + descWrite: 'Erfahren Sie, wie man eine gute Wissensbeschreibung schreibt.', + permissions: 'Berechtigungen', + permissionsOnlyMe: 'Nur ich', + permissionsAllMember: 'Alle Teammitglieder', + indexMethod: 'Indexierungsmethode', + indexMethodHighQuality: 'Hohe Qualität', + indexMethodHighQualityTip: 'Den Embedding-Modell zur Verarbeitung aufrufen, um bei Benutzeranfragen eine höhere Genauigkeit zu bieten.', + indexMethodEconomy: 'Ökonomisch', + indexMethodEconomyTip: 'Verwendet Offline-Vektor-Engines, Schlagwortindizes usw., um die Genauigkeit ohne Tokenverbrauch zu reduzieren', + embeddingModel: 'Einbettungsmodell', + embeddingModelTip: 'Ändern Sie das eingebettete Modell, bitte gehen Sie zu ', + embeddingModelTipLink: 'Einstellungen', + retrievalSetting: { + title: 'Abrufeinstellung', + learnMore: 'Mehr erfahren', + description: ' über die Abrufmethode.', + longDescription: ' über die Abrufmethode, dies kann jederzeit in den Wissenseinstellungen geändert werden.', + }, + save: 'Speichern', + permissionsInvitedMembers: 'Teilweise Teammitglieder', + me: '(Sie)', + }, +} + +export default translation diff --git a/web/i18n/de-DE/dataset.ts b/web/i18n/de-DE/dataset.ts index c6586ceee8..6462e56ded 100644 --- a/web/i18n/de-DE/dataset.ts +++ b/web/i18n/de-DE/dataset.ts @@ -1,76 +1,77 @@ -const translation = { - knowledge: 'Wissen', - documentCount: ' Dokumente', - wordCount: ' k Wörter', - appCount: ' verknüpfte Apps', - createDataset: 'Wissen erstellen', - createDatasetIntro: 'Importiere deine eigenen Textdaten oder schreibe Daten in Echtzeit über Webhook für die LLM-Kontextverbesserung.', - deleteDatasetConfirmTitle: 'Dieses Wissen löschen?', - deleteDatasetConfirmContent: - 'Das Löschen des Wissens ist unwiderruflich. Benutzer werden nicht mehr auf Ihr Wissen zugreifen können und alle Eingabeaufforderungen, Konfigurationen und Protokolle werden dauerhaft gelöscht.', - datasetUsedByApp: 'Das Wissen wird von einigen Apps verwendet. Apps werden dieses Wissen nicht mehr nutzen können, und alle Prompt-Konfigurationen und Protokolle werden dauerhaft gelöscht.', - datasetDeleted: 'Wissen gelöscht', - datasetDeleteFailed: 'Löschen des Wissens fehlgeschlagen', - didYouKnow: 'Wusstest du schon?', - intro1: 'Das Wissen kann in die Dify-Anwendung ', - intro2: 'als Kontext', - intro3: ',', - intro4: 'oder es ', - intro5: 'kann erstellt werden', - intro6: ' als ein eigenständiges ChatGPT-Index-Plugin zum Veröffentlichen', - unavailable: 'Nicht verfügbar', - unavailableTip: 'Einbettungsmodell ist nicht verfügbar, das Standard-Einbettungsmodell muss konfiguriert werden', - datasets: 'WISSEN', - datasetsApi: 'API', - retrieval: { - semantic_search: { - title: 'Vektorsuche', - description: 'Erzeuge Abfrage-Einbettungen und suche nach dem Textstück, das seiner Vektorrepräsentation am ähnlichsten ist.', - }, - full_text_search: { - title: 'Volltextsuche', - description: 'Indiziere alle Begriffe im Dokument, sodass Benutzer jeden Begriff suchen und den relevanten Textabschnitt finden können, der diese Begriffe enthält.', - }, - hybrid_search: { - title: 'Hybridsuche', - description: 'Führe Volltextsuche und Vektorsuchen gleichzeitig aus, ordne neu, um die beste Übereinstimmung für die Abfrage des Benutzers auszuwählen. Konfiguration des Rerank-Modell-APIs ist notwendig.', - recommend: 'Empfehlen', - }, - invertedIndex: { - title: 'Invertierter Index', - description: 'Ein invertierter Index ist eine Struktur, die für effiziente Abfragen verwendet wird. Organisiert nach Begriffen, zeigt jeder Begriff auf Dokumente oder Webseiten, die ihn enthalten.', - }, - change: 'Ändern', - changeRetrievalMethod: 'Abfragemethode ändern', - }, - docsFailedNotice: 'Dokumente konnten nicht indiziert werden', - retry: 'Wiederholen', - indexingTechnique: { - high_quality: 'HQ', - economy: 'ECO', - }, - indexingMethod: { - semantic_search: 'VEKTOR', - full_text_search: 'VOLLTEXT', - hybrid_search: 'HYBRID', - invertedIndex: 'INVERTIERT', - }, - mixtureHighQualityAndEconomicTip: 'Für die Mischung von hochwertigen und wirtschaftlichen Wissensbasen ist das Rerank-Modell erforderlich.', - inconsistentEmbeddingModelTip: 'Das Rerank-Modell ist erforderlich, wenn die Embedding-Modelle der ausgewählten Wissensbasen inkonsistent sind.', - retrievalSettings: 'Abrufeinstellungen', - rerankSettings: 'Rerank-Einstellungen', - weightedScore: { - title: 'Gewichtete Bewertung', - description: 'Durch Anpassung der zugewiesenen Gewichte bestimmt diese Rerank-Strategie, ob semantische oder Schlüsselwort-Übereinstimmung priorisiert werden soll.', - semanticFirst: 'Semantik zuerst', - keywordFirst: 'Schlüsselwort zuerst', - customized: 'Angepasst', - semantic: 'Semantisch', - keyword: 'Schlüsselwort', - }, - nTo1RetrievalLegacy: 'N-zu-1-Abruf wird ab September offiziell eingestellt. Es wird empfohlen, den neuesten Multi-Pfad-Abruf zu verwenden, um bessere Ergebnisse zu erzielen.', - nTo1RetrievalLegacyLink: 'Mehr erfahren', - nTo1RetrievalLegacyLinkText: 'N-zu-1-Abruf wird im September offiziell eingestellt.', -} - -export default translation +const translation = { + knowledge: 'Wissen', + documentCount: ' Dokumente', + wordCount: ' k Wörter', + appCount: ' verknüpfte Apps', + createDataset: 'Wissen erstellen', + createDatasetIntro: 'Importiere deine eigenen Textdaten oder schreibe Daten in Echtzeit über Webhook für die LLM-Kontextverbesserung.', + deleteDatasetConfirmTitle: 'Dieses Wissen löschen?', + deleteDatasetConfirmContent: + 'Das Löschen des Wissens ist unwiderruflich. Benutzer werden nicht mehr auf Ihr Wissen zugreifen können und alle Eingabeaufforderungen, Konfigurationen und Protokolle werden dauerhaft gelöscht.', + datasetUsedByApp: 'Das Wissen wird von einigen Apps verwendet. Apps werden dieses Wissen nicht mehr nutzen können, und alle Prompt-Konfigurationen und Protokolle werden dauerhaft gelöscht.', + datasetDeleted: 'Wissen gelöscht', + datasetDeleteFailed: 'Löschen des Wissens fehlgeschlagen', + didYouKnow: 'Wusstest du schon?', + intro1: 'Das Wissen kann in die Dify-Anwendung ', + intro2: 'als Kontext', + intro3: ',', + intro4: 'oder es ', + intro5: 'kann erstellt werden', + intro6: ' als ein eigenständiges ChatGPT-Index-Plugin zum Veröffentlichen', + unavailable: 'Nicht verfügbar', + unavailableTip: 'Einbettungsmodell ist nicht verfügbar, das Standard-Einbettungsmodell muss konfiguriert werden', + datasets: 'WISSEN', + datasetsApi: 'API', + retrieval: { + semantic_search: { + title: 'Vektorsuche', + description: 'Erzeuge Abfrage-Einbettungen und suche nach dem Textstück, das seiner Vektorrepräsentation am ähnlichsten ist.', + }, + full_text_search: { + title: 'Volltextsuche', + description: 'Indiziere alle Begriffe im Dokument, sodass Benutzer jeden Begriff suchen und den relevanten Textabschnitt finden können, der diese Begriffe enthält.', + }, + hybrid_search: { + title: 'Hybridsuche', + description: 'Führe Volltextsuche und Vektorsuchen gleichzeitig aus, ordne neu, um die beste Übereinstimmung für die Abfrage des Benutzers auszuwählen. Konfiguration des Rerank-Modell-APIs ist notwendig.', + recommend: 'Empfehlen', + }, + invertedIndex: { + title: 'Invertierter Index', + description: 'Ein invertierter Index ist eine Struktur, die für effiziente Abfragen verwendet wird. Organisiert nach Begriffen, zeigt jeder Begriff auf Dokumente oder Webseiten, die ihn enthalten.', + }, + change: 'Ändern', + changeRetrievalMethod: 'Abfragemethode ändern', + }, + docsFailedNotice: 'Dokumente konnten nicht indiziert werden', + retry: 'Wiederholen', + indexingTechnique: { + high_quality: 'HQ', + economy: 'ECO', + }, + indexingMethod: { + semantic_search: 'VEKTOR', + full_text_search: 'VOLLTEXT', + hybrid_search: 'HYBRID', + invertedIndex: 'INVERTIERT', + }, + mixtureHighQualityAndEconomicTip: 'Für die Mischung von hochwertigen und wirtschaftlichen Wissensbasen ist das Rerank-Modell erforderlich.', + inconsistentEmbeddingModelTip: 'Das Rerank-Modell ist erforderlich, wenn die Embedding-Modelle der ausgewählten Wissensbasen inkonsistent sind.', + retrievalSettings: 'Abrufeinstellungen', + rerankSettings: 'Rerank-Einstellungen', + weightedScore: { + title: 'Gewichtete Bewertung', + description: 'Durch Anpassung der zugewiesenen Gewichte bestimmt diese Rerank-Strategie, ob semantische oder Schlüsselwort-Übereinstimmung priorisiert werden soll.', + semanticFirst: 'Semantik zuerst', + keywordFirst: 'Schlüsselwort zuerst', + customized: 'Angepasst', + semantic: 'Semantisch', + keyword: 'Schlüsselwort', + }, + nTo1RetrievalLegacy: 'N-zu-1-Abruf wird ab September offiziell eingestellt. Es wird empfohlen, den neuesten Multi-Pfad-Abruf zu verwenden, um bessere Ergebnisse zu erzielen.', + nTo1RetrievalLegacyLink: 'Mehr erfahren', + nTo1RetrievalLegacyLinkText: 'N-zu-1-Abruf wird im September offiziell eingestellt.', + defaultRetrievalTip: 'Standardmäßig wird der Multi-Path-Abruf verwendet. Das Wissen wird aus mehreren Wissensdatenbanken abgerufen und dann neu eingestuft.', +} + +export default translation diff --git a/web/i18n/de-DE/explore.ts b/web/i18n/de-DE/explore.ts index 02fc90f89c..1e6f8c80d7 100644 --- a/web/i18n/de-DE/explore.ts +++ b/web/i18n/de-DE/explore.ts @@ -1,41 +1,41 @@ -const translation = { - title: 'Entdecken', - sidebar: { - discovery: 'Entdeckung', - chat: 'Chat', - workspace: 'Arbeitsbereich', - action: { - pin: 'Anheften', - unpin: 'Lösen', - rename: 'Umbenennen', - delete: 'Löschen', - }, - delete: { - title: 'App löschen', - content: 'Sind Sie sicher, dass Sie diese App löschen möchten?', - }, - }, - apps: { - title: 'Apps von Dify erkunden', - description: 'Nutzen Sie diese Vorlagen-Apps sofort oder passen Sie Ihre eigenen Apps basierend auf den Vorlagen an.', - allCategories: 'Alle Kategorien', - }, - appCard: { - addToWorkspace: 'Zum Arbeitsbereich hinzufügen', - customize: 'Anpassen', - }, - appCustomize: { - title: 'App aus {{name}} erstellen', - subTitle: 'App-Symbol & Name', - nameRequired: 'App-Name ist erforderlich', - }, - category: { - Assistant: 'Assistent', - Writing: 'Schreiben', - Translate: 'Übersetzen', - Programming: 'Programmieren', - HR: 'Personalwesen', - }, -} - -export default translation +const translation = { + title: 'Entdecken', + sidebar: { + discovery: 'Entdeckung', + chat: 'Chat', + workspace: 'Arbeitsbereich', + action: { + pin: 'Anheften', + unpin: 'Lösen', + rename: 'Umbenennen', + delete: 'Löschen', + }, + delete: { + title: 'App löschen', + content: 'Sind Sie sicher, dass Sie diese App löschen möchten?', + }, + }, + apps: { + title: 'Apps von Dify erkunden', + description: 'Nutzen Sie diese Vorlagen-Apps sofort oder passen Sie Ihre eigenen Apps basierend auf den Vorlagen an.', + allCategories: 'Alle Kategorien', + }, + appCard: { + addToWorkspace: 'Zum Arbeitsbereich hinzufügen', + customize: 'Anpassen', + }, + appCustomize: { + title: 'App aus {{name}} erstellen', + subTitle: 'App-Symbol & Name', + nameRequired: 'App-Name ist erforderlich', + }, + category: { + Assistant: 'Assistent', + Writing: 'Schreiben', + Translate: 'Übersetzen', + Programming: 'Programmieren', + HR: 'Personalwesen', + }, +} + +export default translation diff --git a/web/i18n/de-DE/login.ts b/web/i18n/de-DE/login.ts index f932f92976..4597228463 100644 --- a/web/i18n/de-DE/login.ts +++ b/web/i18n/de-DE/login.ts @@ -31,7 +31,7 @@ const translation = { pp: 'Datenschutzbestimmungen', tosDesc: 'Mit der Anmeldung stimmst du unseren', goToInit: 'Wenn du das Konto noch nicht initialisiert hast, gehe bitte zur Initialisierungsseite', - donthave: 'Hast du nicht?', + dontHave: 'Hast du nicht?', invalidInvitationCode: 'Ungültiger Einladungscode', accountAlreadyInited: 'Konto bereits initialisiert', forgotPassword: 'Passwort vergessen?', @@ -53,6 +53,7 @@ const translation = { nameEmpty: 'Name wird benötigt', passwordEmpty: 'Passwort wird benötigt', passwordInvalid: 'Das Passwort muss Buchstaben und Zahlen enthalten und länger als 8 Zeichen sein', + passwordLengthInValid: 'Das Passwort muss mindestens 8 Zeichen lang sein', }, license: { tip: 'Bevor du mit Dify Community Edition beginnst, lies die', @@ -68,6 +69,7 @@ const translation = { activated: 'Jetzt anmelden', adminInitPassword: 'Admin-Initialpasswort', validate: 'Validieren', + sso: 'Mit SSO fortfahren', } export default translation diff --git a/web/i18n/de-DE/run-log.ts b/web/i18n/de-DE/run-log.ts index 7c0257b513..5f7610c68d 100644 --- a/web/i18n/de-DE/run-log.ts +++ b/web/i18n/de-DE/run-log.ts @@ -23,6 +23,7 @@ const translation = { tipLeft: 'Bitte gehen Sie zum ', Link: 'Detailpanel', tipRight: 'ansehen.', + link: 'Gruppe Detail', }, } diff --git a/web/i18n/de-DE/share-app.ts b/web/i18n/de-DE/share-app.ts index 6a35b959a5..5ea67dd08f 100644 --- a/web/i18n/de-DE/share-app.ts +++ b/web/i18n/de-DE/share-app.ts @@ -1,74 +1,74 @@ -const translation = { - common: { - welcome: '', - appUnavailable: 'App ist nicht verfügbar', - appUnkonwError: 'App ist nicht verfügbar', - }, - chat: { - newChat: 'Neuer Chat', - pinnedTitle: 'Angeheftet', - unpinnedTitle: 'Chats', - newChatDefaultName: 'Neues Gespräch', - resetChat: 'Gespräch zurücksetzen', - powerBy: 'Bereitgestellt von', - prompt: 'Aufforderung', - privatePromptConfigTitle: 'Konversationseinstellungen', - publicPromptConfigTitle: 'Anfängliche Aufforderung', - configStatusDes: 'Vor dem Start können Sie die Konversationseinstellungen ändern', - configDisabled: - 'Voreinstellungen der vorherigen Sitzung wurden für diese Sitzung verwendet.', - startChat: 'Chat starten', - privacyPolicyLeft: - 'Bitte lesen Sie die ', - privacyPolicyMiddle: - 'Datenschutzrichtlinien', - privacyPolicyRight: - ', die vom App-Entwickler bereitgestellt wurden.', - deleteConversation: { - title: 'Konversation löschen', - content: 'Sind Sie sicher, dass Sie diese Konversation löschen möchten?', - }, - tryToSolve: 'Versuchen zu lösen', - temporarySystemIssue: 'Entschuldigung, vorübergehendes Systemproblem.', - }, - generation: { - tabs: { - create: 'Einmal ausführen', - batch: 'Stapelverarbeitung', - saved: 'Gespeichert', - }, - savedNoData: { - title: 'Sie haben noch kein Ergebnis gespeichert!', - description: 'Beginnen Sie mit der Inhaltserstellung und finden Sie hier Ihre gespeicherten Ergebnisse.', - startCreateContent: 'Beginnen Sie mit der Inhaltserstellung', - }, - title: 'KI-Vervollständigung', - queryTitle: 'Abfrageinhalt', - completionResult: 'Vervollständigungsergebnis', - queryPlaceholder: 'Schreiben Sie Ihren Abfrageinhalt...', - run: 'Ausführen', - copy: 'Kopieren', - resultTitle: 'KI-Vervollständigung', - noData: 'KI wird Ihnen hier geben, was Sie möchten.', - csvUploadTitle: 'Ziehen Sie Ihre CSV-Datei hierher oder ', - browse: 'durchsuchen', - csvStructureTitle: 'Die CSV-Datei muss der folgenden Struktur entsprechen:', - downloadTemplate: 'Laden Sie die Vorlage hier herunter', - field: 'Feld', - batchFailed: { - info: '{{num}} fehlgeschlagene Ausführungen', - retry: 'Wiederholen', - outputPlaceholder: 'Kein Ausgabeanhalt', - }, - errorMsg: { - empty: 'Bitte geben Sie Inhalte in die hochgeladene Datei ein.', - fileStructNotMatch: 'Die hochgeladene CSV-Datei entspricht nicht der Struktur.', - emptyLine: 'Zeile {{rowIndex}} ist leer', - invalidLine: 'Zeile {{rowIndex}}: {{varName}} Wert darf nicht leer sein', - moreThanMaxLengthLine: 'Zeile {{rowIndex}}: {{varName}} Wert darf nicht mehr als {{maxLength}} Zeichen sein', - atLeastOne: 'Bitte geben Sie mindestens eine Zeile in die hochgeladene Datei ein.', - }, - }, -} - -export default translation +const translation = { + common: { + welcome: '', + appUnavailable: 'App ist nicht verfügbar', + appUnknownError: 'App ist nicht verfügbar', + }, + chat: { + newChat: 'Neuer Chat', + pinnedTitle: 'Angeheftet', + unpinnedTitle: 'Chats', + newChatDefaultName: 'Neues Gespräch', + resetChat: 'Gespräch zurücksetzen', + poweredBy: 'Bereitgestellt von', + prompt: 'Aufforderung', + privatePromptConfigTitle: 'Konversationseinstellungen', + publicPromptConfigTitle: 'Anfängliche Aufforderung', + configStatusDes: 'Vor dem Start können Sie die Konversationseinstellungen ändern', + configDisabled: + 'Voreinstellungen der vorherigen Sitzung wurden für diese Sitzung verwendet.', + startChat: 'Chat starten', + privacyPolicyLeft: + 'Bitte lesen Sie die ', + privacyPolicyMiddle: + 'Datenschutzrichtlinien', + privacyPolicyRight: + ', die vom App-Entwickler bereitgestellt wurden.', + deleteConversation: { + title: 'Konversation löschen', + content: 'Sind Sie sicher, dass Sie diese Konversation löschen möchten?', + }, + tryToSolve: 'Versuchen zu lösen', + temporarySystemIssue: 'Entschuldigung, vorübergehendes Systemproblem.', + }, + generation: { + tabs: { + create: 'Einmal ausführen', + batch: 'Stapelverarbeitung', + saved: 'Gespeichert', + }, + savedNoData: { + title: 'Sie haben noch kein Ergebnis gespeichert!', + description: 'Beginnen Sie mit der Inhaltserstellung und finden Sie hier Ihre gespeicherten Ergebnisse.', + startCreateContent: 'Beginnen Sie mit der Inhaltserstellung', + }, + title: 'KI-Vervollständigung', + queryTitle: 'Abfrageinhalt', + completionResult: 'Vervollständigungsergebnis', + queryPlaceholder: 'Schreiben Sie Ihren Abfrageinhalt...', + run: 'Ausführen', + copy: 'Kopieren', + resultTitle: 'KI-Vervollständigung', + noData: 'KI wird Ihnen hier geben, was Sie möchten.', + csvUploadTitle: 'Ziehen Sie Ihre CSV-Datei hierher oder ', + browse: 'durchsuchen', + csvStructureTitle: 'Die CSV-Datei muss der folgenden Struktur entsprechen:', + downloadTemplate: 'Laden Sie die Vorlage hier herunter', + field: 'Feld', + batchFailed: { + info: '{{num}} fehlgeschlagene Ausführungen', + retry: 'Wiederholen', + outputPlaceholder: 'Kein Ausgabeanhalt', + }, + errorMsg: { + empty: 'Bitte geben Sie Inhalte in die hochgeladene Datei ein.', + fileStructNotMatch: 'Die hochgeladene CSV-Datei entspricht nicht der Struktur.', + emptyLine: 'Zeile {{rowIndex}} ist leer', + invalidLine: 'Zeile {{rowIndex}}: {{varName}} Wert darf nicht leer sein', + moreThanMaxLengthLine: 'Zeile {{rowIndex}}: {{varName}} Wert darf nicht mehr als {{maxLength}} Zeichen sein', + atLeastOne: 'Bitte geben Sie mindestens eine Zeile in die hochgeladene Datei ein.', + }, + }, +} + +export default translation diff --git a/web/i18n/de-DE/tools.ts b/web/i18n/de-DE/tools.ts index a45d0da1b1..3be01b8350 100644 --- a/web/i18n/de-DE/tools.ts +++ b/web/i18n/de-DE/tools.ts @@ -1,119 +1,153 @@ -const translation = { - title: 'Werkzeuge', - createCustomTool: 'Eigenes Werkzeug erstellen', - type: { - all: 'Alle', - builtIn: 'Integriert', - custom: 'Benutzerdefiniert', - }, - contribute: { - line1: 'Ich interessiere mich dafür, ', - line2: 'Werkzeuge zu Dify beizutragen.', - viewGuide: 'Leitfaden anzeigen', - }, - author: 'Von', - auth: { - unauthorized: 'Zur Autorisierung', - authorized: 'Autorisiert', - setup: 'Autorisierung einrichten, um zu nutzen', - setupModalTitle: 'Autorisierung einrichten', - setupModalTitleDescription: 'Nach der Konfiguration der Anmeldeinformationen können alle Mitglieder im Arbeitsbereich dieses Werkzeug beim Orchestrieren von Anwendungen nutzen.', - }, - includeToolNum: '{{num}} Werkzeuge inkludiert', - addTool: 'Werkzeug hinzufügen', - createTool: { - title: 'Eigenes Werkzeug erstellen', - editAction: 'Konfigurieren', - editTitle: 'Eigenes Werkzeug bearbeiten', - name: 'Name', - toolNamePlaceHolder: 'Geben Sie den Werkzeugnamen ein', - schema: 'Schema', - schemaPlaceHolder: 'Geben Sie hier Ihr OpenAPI-Schema ein', - viewSchemaSpec: 'Die OpenAPI-Swagger-Spezifikation anzeigen', - importFromUrl: 'Von URL importieren', - importFromUrlPlaceHolder: 'https://...', - urlError: 'Bitte geben Sie eine gültige URL ein', - examples: 'Beispiele', - exampleOptions: { - json: 'Wetter(JSON)', - yaml: 'Pet Store(YAML)', - blankTemplate: 'Leere Vorlage', - }, - availableTools: { - title: 'Verfügbare Werkzeuge', - name: 'Name', - description: 'Beschreibung', - method: 'Methode', - path: 'Pfad', - action: 'Aktionen', - test: 'Test', - }, - authMethod: { - title: 'Autorisierungsmethode', - type: 'Autorisierungstyp', - keyTooltip: 'Http Header Key, Sie können es bei "Authorization" belassen, wenn Sie nicht wissen, was es ist, oder auf einen benutzerdefinierten Wert setzen', - types: { - none: 'Keine', - api_key: 'API-Key', - apiKeyPlaceholder: 'HTTP-Headername für API-Key', - apiValuePlaceholder: 'API-Key eingeben', - }, - key: 'Schlüssel', - value: 'Wert', - }, - authHeaderPrefix: { - title: 'Auth-Typ', - types: { - basic: 'Basic', - bearer: 'Bearer', - custom: 'Benutzerdefiniert', - }, - }, - privacyPolicy: 'Datenschutzrichtlinie', - privacyPolicyPlaceholder: 'Bitte Datenschutzrichtlinie eingeben', - customDisclaimer: 'Benutzer Haftungsausschluss', - customDisclaimerPlaceholder: 'Bitte benutzerdefinierten Haftungsausschluss eingeben', - deleteToolConfirmTitle: 'Löschen Sie dieses Werkzeug?', - deleteToolConfirmContent: 'Das Löschen des Werkzeugs ist irreversibel. Benutzer können Ihr Werkzeug nicht mehr verwenden.', - }, - test: { - title: 'Test', - parametersValue: 'Parameter & Wert', - parameters: 'Parameter', - value: 'Wert', - testResult: 'Testergebnisse', - testResultPlaceholder: 'Testergebnis wird hier angezeigt', - }, - thought: { - using: 'Nutzung', - used: 'Genutzt', - requestTitle: 'Anfrage an', - responseTitle: 'Antwort von', - }, - setBuiltInTools: { - info: 'Info', - setting: 'Einstellung', - toolDescription: 'Werkzeugbeschreibung', - parameters: 'Parameter', - string: 'Zeichenkette', - number: 'Nummer', - required: 'Erforderlich', - infoAndSetting: 'Info & Einstellungen', - }, - noCustomTool: { - title: 'Keine benutzerdefinierten Werkzeuge!', - content: 'Fügen Sie hier Ihre benutzerdefinierten Werkzeuge hinzu und verwalten Sie sie, um KI-Apps zu erstellen.', - createTool: 'Werkzeug erstellen', - }, - noSearchRes: { - title: 'Leider keine Ergebnisse!', - content: 'Wir konnten keine Werkzeuge finden, die Ihrer Suche entsprechen.', - reset: 'Suche zurücksetzen', - }, - builtInPromptTitle: 'Aufforderung', - toolRemoved: 'Werkzeug entfernt', - notAuthorized: 'Werkzeug nicht autorisiert', - howToGet: 'Wie erhält man', -} - -export default translation +const translation = { + title: 'Werkzeuge', + createCustomTool: 'Eigenes Werkzeug erstellen', + type: { + all: 'Alle', + builtIn: 'Integriert', + custom: 'Benutzerdefiniert', + workflow: 'Arbeitsablauf', + }, + contribute: { + line1: 'Ich interessiere mich dafür, ', + line2: 'Werkzeuge zu Dify beizutragen.', + viewGuide: 'Leitfaden anzeigen', + }, + author: 'Von', + auth: { + unauthorized: 'Zur Autorisierung', + authorized: 'Autorisiert', + setup: 'Autorisierung einrichten, um zu nutzen', + setupModalTitle: 'Autorisierung einrichten', + setupModalTitleDescription: 'Nach der Konfiguration der Anmeldeinformationen können alle Mitglieder im Arbeitsbereich dieses Werkzeug beim Orchestrieren von Anwendungen nutzen.', + }, + includeToolNum: '{{num}} Werkzeuge inkludiert', + addTool: 'Werkzeug hinzufügen', + createTool: { + title: 'Eigenes Werkzeug erstellen', + editAction: 'Konfigurieren', + editTitle: 'Eigenes Werkzeug bearbeiten', + name: 'Name', + toolNamePlaceHolder: 'Geben Sie den Werkzeugnamen ein', + schema: 'Schema', + schemaPlaceHolder: 'Geben Sie hier Ihr OpenAPI-Schema ein', + viewSchemaSpec: 'Die OpenAPI-Swagger-Spezifikation anzeigen', + importFromUrl: 'Von URL importieren', + importFromUrlPlaceHolder: 'https://...', + urlError: 'Bitte geben Sie eine gültige URL ein', + examples: 'Beispiele', + exampleOptions: { + json: 'Wetter(JSON)', + yaml: 'Pet Store(YAML)', + blankTemplate: 'Leere Vorlage', + }, + availableTools: { + title: 'Verfügbare Werkzeuge', + name: 'Name', + description: 'Beschreibung', + method: 'Methode', + path: 'Pfad', + action: 'Aktionen', + test: 'Test', + }, + authMethod: { + title: 'Autorisierungsmethode', + type: 'Autorisierungstyp', + keyTooltip: 'Http Header Key, Sie können es bei "Authorization" belassen, wenn Sie nicht wissen, was es ist, oder auf einen benutzerdefinierten Wert setzen', + types: { + none: 'Keine', + api_key: 'API-Key', + apiKeyPlaceholder: 'HTTP-Headername für API-Key', + apiValuePlaceholder: 'API-Key eingeben', + }, + key: 'Schlüssel', + value: 'Wert', + }, + authHeaderPrefix: { + title: 'Auth-Typ', + types: { + basic: 'Basic', + bearer: 'Bearer', + custom: 'Benutzerdefiniert', + }, + }, + privacyPolicy: 'Datenschutzrichtlinie', + privacyPolicyPlaceholder: 'Bitte Datenschutzrichtlinie eingeben', + customDisclaimer: 'Benutzer Haftungsausschluss', + customDisclaimerPlaceholder: 'Bitte benutzerdefinierten Haftungsausschluss eingeben', + deleteToolConfirmTitle: 'Löschen Sie dieses Werkzeug?', + deleteToolConfirmContent: 'Das Löschen des Werkzeugs ist irreversibel. Benutzer können Ihr Werkzeug nicht mehr verwenden.', + toolInput: { + description: 'Beschreibung', + methodParameterTip: 'LLM-Füllungen während der Inferenz', + method: 'Methode', + methodParameter: 'Parameter', + label: 'Schilder', + required: 'Erforderlich', + methodSetting: 'Einstellung', + name: 'Name', + title: 'Werkzeug-Eingabe', + methodSettingTip: 'Der Benutzer füllt die Werkzeugkonfiguration aus', + descriptionPlaceholder: 'Beschreibung der Bedeutung des Parameters', + labelPlaceholder: 'Tags auswählen(optional)', + }, + description: 'Beschreibung', + confirmTip: 'Apps, die dieses Tool verwenden, sind davon betroffen', + nameForToolCallTip: 'Unterstützt nur Zahlen, Buchstaben und Unterstriche.', + nameForToolCall: 'Name des Werkzeugaufrufs', + confirmTitle: 'Bestätigen, um zu speichern?', + nameForToolCallPlaceHolder: 'Wird für die Maschinenerkennung verwendet, z. B. getCurrentWeather, list_pets', + descriptionPlaceholder: 'Kurze Beschreibung des Zwecks des Werkzeugs, z. B. um die Temperatur für einen bestimmten Ort zu ermitteln.', + }, + test: { + title: 'Test', + parametersValue: 'Parameter & Wert', + parameters: 'Parameter', + value: 'Wert', + testResult: 'Testergebnisse', + testResultPlaceholder: 'Testergebnis wird hier angezeigt', + }, + thought: { + using: 'Nutzung', + used: 'Genutzt', + requestTitle: 'Anfrage an', + responseTitle: 'Antwort von', + }, + setBuiltInTools: { + info: 'Info', + setting: 'Einstellung', + toolDescription: 'Werkzeugbeschreibung', + parameters: 'Parameter', + string: 'Zeichenkette', + number: 'Nummer', + required: 'Erforderlich', + infoAndSetting: 'Info & Einstellungen', + }, + noCustomTool: { + title: 'Keine benutzerdefinierten Werkzeuge!', + content: 'Fügen Sie hier Ihre benutzerdefinierten Werkzeuge hinzu und verwalten Sie sie, um KI-Apps zu erstellen.', + createTool: 'Werkzeug erstellen', + }, + noSearchRes: { + title: 'Leider keine Ergebnisse!', + content: 'Wir konnten keine Werkzeuge finden, die Ihrer Suche entsprechen.', + reset: 'Suche zurücksetzen', + }, + builtInPromptTitle: 'Aufforderung', + toolRemoved: 'Werkzeug entfernt', + notAuthorized: 'Werkzeug nicht autorisiert', + howToGet: 'Wie erhält man', + addToolModal: { + added: 'zugefügt', + manageInTools: 'Verwalten in Tools', + add: 'hinzufügen', + category: 'Kategorie', + emptyTitle: 'Kein Workflow-Tool verfügbar', + type: 'Art', + emptyTip: 'Gehen Sie zu "Workflow -> Als Tool veröffentlichen"', + }, + toolNameUsageTip: 'Name des Tool-Aufrufs für die Argumentation und Aufforderung des Agenten', + customToolTip: 'Erfahren Sie mehr über benutzerdefinierte Dify-Tools', + openInStudio: 'In Studio öffnen', +} + +export default translation diff --git a/web/i18n/de-DE/workflow.ts b/web/i18n/de-DE/workflow.ts index c1ef1d408e..5e154aeca5 100644 --- a/web/i18n/de-DE/workflow.ts +++ b/web/i18n/de-DE/workflow.ts @@ -36,7 +36,7 @@ const translation = { searchVar: 'Variable suchen', variableNamePlaceholder: 'Variablenname', setVarValuePlaceholder: 'Variable setzen', - needConnecttip: 'Dieser Schritt ist mit nichts verbunden', + needConnectTip: 'Dieser Schritt ist mit nichts verbunden', maxTreeDepth: 'Maximales Limit von {{depth}} Knoten pro Ast', needEndNode: 'Der Endblock muss hinzugefügt werden', needAnswerNode: 'Der Antwortblock muss hinzugefügt werden', @@ -69,6 +69,30 @@ const translation = { manageInTools: 'In den Tools verwalten', workflowAsToolTip: 'Nach dem Workflow-Update ist eine Neukonfiguration des Tools erforderlich.', viewDetailInTracingPanel: 'Details anzeigen', + importDSL: 'DSL importieren', + importFailure: 'Fehler beim Import', + syncingData: 'Synchronisieren von Daten, nur wenige Sekunden.', + chooseDSL: 'Wählen Sie eine DSL(yml)-Datei', + importSuccess: 'Erfolg beim Import', + importDSLTip: 'Der aktuelle Entwurf wird überschrieben. Exportieren Sie den Workflow vor dem Import als Backup.', + overwriteAndImport: 'Überschreiben und Importieren', + backupCurrentDraft: 'Aktuellen Entwurf sichern', + parallelTip: { + click: { + title: 'Klicken', + desc: 'hinzuzufügen', + }, + drag: { + title: 'Ziehen', + desc: 'um eine Verbindung herzustellen', + }, + limit: 'Die Parallelität ist auf {{num}} Zweige beschränkt.', + depthLimit: 'Begrenzung der parallelen Verschachtelungsschicht von {{num}} Schichten', + }, + parallelRun: 'Paralleler Lauf', + disconnect: 'Trennen', + jumpToNode: 'Zu diesem Knoten springen', + addParallelNode: 'Parallelen Knoten hinzufügen', }, env: { envPanelTitle: 'Umgebungsvariablen', @@ -178,6 +202,7 @@ const translation = { 'transform': 'Transformieren', 'utilities': 'Dienstprogramme', 'noResult': 'Kein Ergebnis gefunden', + 'searchTool': 'Suchwerkzeug', }, blocks: { 'start': 'Start', @@ -403,10 +428,12 @@ const translation = { 'not empty': 'ist nicht leer', 'null': 'ist null', 'not null': 'ist nicht null', + 'regex match': 'Regex-Übereinstimmung', }, enterValue: 'Wert eingeben', addCondition: 'Bedingung hinzufügen', conditionNotSetup: 'Bedingung NICHT eingerichtet', + selectVariable: 'Variable auswählen...', }, variableAssigner: { title: 'Variablen zuweisen', @@ -502,6 +529,25 @@ const translation = { iteration_other: '{{count}} Iterationen', currentIteration: 'Aktuelle Iteration', }, + note: { + editor: { + strikethrough: 'Durchgestrichen', + large: 'Groß', + bulletList: 'Aufzählung', + italic: 'Kursiv', + small: 'Klein', + bold: 'Kühn', + placeholder: 'Schreiben Sie Ihre Notiz...', + openLink: 'Offen', + showAuthor: 'Autor anzeigen', + medium: 'Mittel', + unlink: 'Trennen', + link: 'Verbinden', + enterUrl: 'URL eingeben...', + invalidUrl: 'Ungültige URL', + }, + addNote: 'Notiz hinzufügen', + }, }, tracing: { stopBy: 'Gestoppt von {{user}}', diff --git a/web/i18n/en-US/app-api.ts b/web/i18n/en-US/app-api.ts index f36708c1d0..631faeee9a 100644 --- a/web/i18n/en-US/app-api.ts +++ b/web/i18n/en-US/app-api.ts @@ -10,7 +10,7 @@ const translation = { pause: 'Pause', playing: 'Playing', loading: 'Loading', - merMaind: { + merMaid: { rerender: 'Redo Rerender', }, never: 'Never', diff --git a/web/i18n/en-US/app-debug.ts b/web/i18n/en-US/app-debug.ts index 86c5f720c3..b1f3f33cd8 100644 --- a/web/i18n/en-US/app-debug.ts +++ b/web/i18n/en-US/app-debug.ts @@ -301,7 +301,7 @@ const translation = { historyNoBeEmpty: 'Conversation history must be set in the prompt', queryNoBeEmpty: 'Query must be set in the prompt', }, - variableConig: { + variableConfig: { 'addModalTitle': 'Add Input Field', 'editModalTitle': 'Edit Input Field', 'description': 'Setting for variable {{varName}}', diff --git a/web/i18n/en-US/app.ts b/web/i18n/en-US/app.ts index 90724098de..3377a9b2f3 100644 --- a/web/i18n/en-US/app.ts +++ b/web/i18n/en-US/app.ts @@ -77,13 +77,18 @@ const translation = { emoji: 'Emoji', image: 'Image', }, + answerIcon: { + title: 'Use WebApp icon to replace 🤖', + description: 'Whether to use the WebApp icon to replace 🤖 in the shared application', + descriptionInExplore: 'Whether to use the WebApp icon to replace 🤖 in Explore', + }, switch: 'Switch to Workflow Orchestrate', switchTipStart: 'A new app copy will be created for you, and the new copy will switch to Workflow Orchestrate. The new copy will ', switchTip: 'not allow', switchTipEnd: ' switching back to Basic Orchestrate.', switchLabel: 'The app copy to be created', removeOriginal: 'Delete the original app', - switchStart: 'Start swtich', + switchStart: 'Start switch', typeSelector: { all: 'ALL Types', chatbot: 'Chatbot', diff --git a/web/i18n/en-US/common.ts b/web/i18n/en-US/common.ts index 87dab5cb71..23e301485a 100644 --- a/web/i18n/en-US/common.ts +++ b/web/i18n/en-US/common.ts @@ -37,6 +37,7 @@ const translation = { params: 'Params', duplicate: 'Duplicate', rename: 'Rename', + audioSourceUnavailable: 'AudioSource is unavailable', }, errorMsg: { fieldRequired: '{{field}} is required', @@ -132,7 +133,8 @@ const translation = { workspace: 'Workspace', createWorkspace: 'Create Workspace', helpCenter: 'Help', - roadmapAndFeedback: 'Feedback', + communityFeedback: 'Feedback', + roadmap: 'Roadmap', community: 'Community', about: 'About', logout: 'Log out', @@ -198,7 +200,7 @@ const translation = { invitationSent: 'Invitation sent', invitationSentTip: 'Invitation sent, and they can sign in to Dify to access your team data.', invitationLink: 'Invitation Link', - failedinvitationEmails: 'Below users were not invited successfully', + failedInvitationEmails: 'Below users were not invited successfully', ok: 'OK', removeFromTeam: 'Remove from team', removeFromTeamTip: 'Will remove team access', @@ -206,7 +208,7 @@ const translation = { setMember: 'Set to ordinary member', setBuilder: 'Set as builder', setEditor: 'Set as editor', - disinvite: 'Cancel the invitation', + disInvite: 'Cancel the invitation', deleteMember: 'Delete Member', you: '(You)', }, @@ -390,7 +392,7 @@ const translation = { selector: { pageSelected: 'Pages Selected', searchPages: 'Search pages...', - noSearchResult: 'No search resluts', + noSearchResult: 'No search results', addPages: 'Add pages', preview: 'PREVIEW', }, diff --git a/web/i18n/en-US/dataset-creation.ts b/web/i18n/en-US/dataset-creation.ts index 1ead19f7e6..40463593f9 100644 --- a/web/i18n/en-US/dataset-creation.ts +++ b/web/i18n/en-US/dataset-creation.ts @@ -50,7 +50,7 @@ const translation = { input: 'Knowledge name', placeholder: 'Please input', nameNotEmpty: 'Name cannot be empty', - nameLengthInvaild: 'Name must be between 1 to 40 characters', + nameLengthInvalid: 'Name must be between 1 to 40 characters', cancelButton: 'Cancel', confirmButton: 'Create', failed: 'Creation failed', @@ -86,12 +86,12 @@ const translation = { autoDescription: 'Automatically set chunk and preprocessing rules. Unfamiliar users are recommended to select this.', custom: 'Custom', customDescription: 'Customize chunks rules, chunks length, and preprocessing rules, etc.', - separator: 'Segment identifier', + separator: 'Delimiter', separatorPlaceholder: 'For example, newline (\\\\n) or special separator (such as "***")', maxLength: 'Maximum chunk length', overlap: 'Chunk overlap', overlapTip: 'Setting the chunk overlap can maintain the semantic relevance between them, enhancing the retrieve effect. It is recommended to set 10%-25% of the maximum chunk size.', - overlapCheck: 'chunk overlap should not bigger than maximun chunk length', + overlapCheck: 'chunk overlap should not bigger than maximum chunk length', rules: 'Text preprocessing rules', removeExtraSpaces: 'Replace consecutive spaces, newlines and tabs', removeUrlEmails: 'Delete all URLs and email addresses', @@ -109,8 +109,8 @@ const translation = { QATitle: 'Segmenting in Question & Answer format', QATip: 'Enable this option will consume more tokens', QALanguage: 'Segment using', - emstimateCost: 'Estimation', - emstimateSegment: 'Estimated chunks', + estimateCost: 'Estimation', + estimateSegment: 'Estimated chunks', segmentCount: 'chunks', calculating: 'Calculating...', fileSource: 'Preprocess documents', @@ -135,8 +135,8 @@ const translation = { previewSwitchTipStart: 'The current chunk preview is in text format, switching to a question-and-answer format preview will', previewSwitchTipEnd: ' consume additional tokens', characters: 'characters', - indexSettedTip: 'To change the index method, please go to the ', - retrivalSettedTip: 'To change the index method, please go to the ', + indexSettingTip: 'To change the index method & embedding model, please go to the ', + retrievalSettingTip: 'To change the retrieval setting, please go to the ', datasetSettingLink: 'Knowledge settings.', }, stepThree: { diff --git a/web/i18n/en-US/dataset.ts b/web/i18n/en-US/dataset.ts index e6914b4a00..a15efe5dc0 100644 --- a/web/i18n/en-US/dataset.ts +++ b/web/i18n/en-US/dataset.ts @@ -55,6 +55,7 @@ const translation = { hybrid_search: 'HYBRID', invertedIndex: 'INVERTED', }, + defaultRetrievalTip: 'Multi-path retrieval is used by default. Knowledge is retrieved from multiple knowledge bases and then re-ranked.', mixtureHighQualityAndEconomicTip: 'The Rerank model is required for mixture of high quality and economical knowledge bases.', inconsistentEmbeddingModelTip: 'The Rerank model is required if the Embedding models of the selected knowledge bases are inconsistent.', retrievalSettings: 'Retrieval Setting', diff --git a/web/i18n/en-US/login.ts b/web/i18n/en-US/login.ts index 2cb6ecb785..03b0d27ed5 100644 --- a/web/i18n/en-US/login.ts +++ b/web/i18n/en-US/login.ts @@ -32,7 +32,7 @@ const translation = { pp: 'Privacy Policy', tosDesc: 'By signing up, you agree to our', goToInit: 'If you have not initialized the account, please go to the initialization page', - donthave: 'Don\'t have?', + dontHave: 'Don\'t have?', invalidInvitationCode: 'Invalid invitation code', accountAlreadyInited: 'Account already initialized', forgotPassword: 'Forgot your password?', diff --git a/web/i18n/en-US/share-app.ts b/web/i18n/en-US/share-app.ts index f66e923561..b5a219c998 100644 --- a/web/i18n/en-US/share-app.ts +++ b/web/i18n/en-US/share-app.ts @@ -2,7 +2,7 @@ const translation = { common: { welcome: '', appUnavailable: 'App is unavailable', - appUnkonwError: 'App is unavailable', + appUnknownError: 'App is unavailable', }, chat: { newChat: 'New chat', @@ -10,7 +10,7 @@ const translation = { unpinnedTitle: 'Chats', newChatDefaultName: 'New conversation', resetChat: 'Reset conversation', - powerBy: 'Powered by', + poweredBy: 'Powered by', prompt: 'Prompt', privatePromptConfigTitle: 'Conversation settings', publicPromptConfigTitle: 'Initial Prompt', diff --git a/web/i18n/en-US/workflow.ts b/web/i18n/en-US/workflow.ts index e0613a110f..b83d213cb8 100644 --- a/web/i18n/en-US/workflow.ts +++ b/web/i18n/en-US/workflow.ts @@ -36,7 +36,7 @@ const translation = { searchVar: 'Search variable', variableNamePlaceholder: 'Variable name', setVarValuePlaceholder: 'Set variable', - needConnecttip: 'This step is not connected to anything', + needConnectTip: 'This step is not connected to anything', maxTreeDepth: 'Maximum limit of {{depth}} nodes per branch', needEndNode: 'The End block must be added', needAnswerNode: 'The Answer block must be added', @@ -77,6 +77,22 @@ const translation = { overwriteAndImport: 'Overwrite and Import', importFailure: 'Import failure', importSuccess: 'Import success', + parallelRun: 'Parallel Run', + parallelTip: { + click: { + title: 'Click', + desc: ' to add', + }, + drag: { + title: 'Drag', + desc: ' to connect', + }, + limit: 'Parallelism is limited to {{num}} branches.', + depthLimit: 'Parallel nesting layer limit of {{num}} layers', + }, + disconnect: 'Disconnect', + jumpToNode: 'Jump to this node', + addParallelNode: 'Add Parallel Node', }, env: { envPanelTitle: 'Environment Variables', diff --git a/web/i18n/es-ES/app-api.ts b/web/i18n/es-ES/app-api.ts index 5d2bc078e3..1474e5fb1d 100644 --- a/web/i18n/es-ES/app-api.ts +++ b/web/i18n/es-ES/app-api.ts @@ -10,7 +10,7 @@ const translation = { pause: 'Pausa', playing: 'Reproduciendo', loading: 'Cargando', - merMaind: { + merMaid: { rerender: 'Rehacer Rerender', }, never: 'Nunca', diff --git a/web/i18n/es-ES/app-debug.ts b/web/i18n/es-ES/app-debug.ts index 68088c26a6..ab5b82e7d1 100644 --- a/web/i18n/es-ES/app-debug.ts +++ b/web/i18n/es-ES/app-debug.ts @@ -259,7 +259,7 @@ const translation = { historyNoBeEmpty: 'El historial de conversaciones debe establecerse en la indicación', queryNoBeEmpty: 'La consulta debe establecerse en la indicación', }, - variableConig: { + variableConfig: { 'addModalTitle': 'Agregar Campo de Entrada', 'editModalTitle': 'Editar Campo de Entrada', 'description': 'Configuración para la variable {{varName}}', diff --git a/web/i18n/es-ES/app-overview.ts b/web/i18n/es-ES/app-overview.ts index f3aaf1f737..8beb7545dc 100644 --- a/web/i18n/es-ES/app-overview.ts +++ b/web/i18n/es-ES/app-overview.ts @@ -48,6 +48,8 @@ const translation = { title: 'Pasos del flujo de trabajo', show: 'Mostrar', hide: 'Ocultar', + subTitle: 'Detalles del flujo de trabajo', + showDesc: 'Mostrar u ocultar detalles del flujo de trabajo en WebApp', }, chatColorTheme: 'Tema de color del chat', chatColorThemeDesc: 'Establece el tema de color del chatbot', @@ -64,6 +66,12 @@ const translation = { customDisclaimerPlaceholder: 'Ingresa el texto de descargo de responsabilidad personalizado', customDisclaimerTip: 'El texto de descargo de responsabilidad personalizado se mostrará en el lado del cliente, proporcionando información adicional sobre la aplicación', }, + sso: { + description: 'Todos los usuarios deben iniciar sesión con SSO antes de usar WebApp', + tooltip: 'Póngase en contacto con el administrador para habilitar el inicio de sesión único de WebApp', + label: 'Autenticación SSO', + title: 'WebApp SSO', + }, }, embedded: { entry: 'Incrustado', @@ -119,7 +127,11 @@ const translation = { tokenPS: 'Token/s', totalMessages: { title: 'Mensajes totales', - explanation: 'Recuento diario de interacciones de IA; excluye la ingeniería/depuración de prompts.', + explanation: 'Recuento diario de interacciones con IA.', + }, + totalConversations: { + title: 'Conversaciones totales', + explanation: 'Recuento diario de conversaciones con IA; ingeniería/depuración de prompts excluida.', }, activeUsers: { title: 'Usuarios activos', diff --git a/web/i18n/es-ES/app.ts b/web/i18n/es-ES/app.ts index 739439ff58..b29f8d36e4 100644 --- a/web/i18n/es-ES/app.ts +++ b/web/i18n/es-ES/app.ts @@ -122,7 +122,17 @@ const translation = { removeConfirmTitle: '¿Eliminar la configuración de {{key}}?', removeConfirmContent: 'La configuración actual está en uso, eliminarla desactivará la función de rastreo.', }, + view: 'Vista', }, + answerIcon: { + title: 'Usar el icono de la aplicación web para reemplazar 🤖', + descriptionInExplore: 'Si se debe usar el icono de la aplicación web para reemplazarlo 🤖 en Explore', + description: 'Si se va a usar el icono de la aplicación web para reemplazarlo 🤖 en la aplicación compartida', + }, + importFromDSLUrl: 'URL de origen', + importFromDSLUrlPlaceholder: 'Pegar enlace DSL aquí', + importFromDSL: 'Importar desde DSL', + importFromDSLFile: 'Desde el archivo DSL', } export default translation diff --git a/web/i18n/es-ES/common.ts b/web/i18n/es-ES/common.ts index fc37775263..2ba907361f 100644 --- a/web/i18n/es-ES/common.ts +++ b/web/i18n/es-ES/common.ts @@ -37,6 +37,7 @@ const translation = { params: 'Parámetros', duplicate: 'Duplicar', rename: 'Renombrar', + audioSourceUnavailable: 'AudioSource no está disponible', }, errorMsg: { fieldRequired: '{{field}} es requerido', @@ -132,7 +133,8 @@ const translation = { workspace: 'Espacio de trabajo', createWorkspace: 'Crear espacio de trabajo', helpCenter: 'Ayuda', - roadmapAndFeedback: 'Comentarios', + communityFeedback: 'Comentarios', + roadmap: 'Hoja de ruta', community: 'Comunidad', about: 'Acerca de', logout: 'Cerrar sesión', @@ -198,7 +200,7 @@ const translation = { invitationSent: 'Invitación enviada', invitationSentTip: 'Invitación enviada, y pueden iniciar sesión en Dify para acceder a tus datos del equipo.', invitationLink: 'Enlace de invitación', - failedinvitationEmails: 'Los siguientes usuarios no fueron invitados exitosamente', + failedInvitationEmails: 'Los siguientes usuarios no fueron invitados exitosamente', ok: 'OK', removeFromTeam: 'Eliminar del equipo', removeFromTeamTip: 'Se eliminará el acceso al equipo', @@ -206,7 +208,7 @@ const translation = { setMember: 'Establecer como miembro ordinario', setBuilder: 'Establecer como constructor', setEditor: 'Establecer como editor', - disinvite: 'Cancelar la invitación', + disInvite: 'Cancelar la invitación', deleteMember: 'Eliminar miembro', you: '(Tú)', }, diff --git a/web/i18n/es-ES/dataset-creation.ts b/web/i18n/es-ES/dataset-creation.ts index 66b8e9b302..132c9cbb9b 100644 --- a/web/i18n/es-ES/dataset-creation.ts +++ b/web/i18n/es-ES/dataset-creation.ts @@ -50,7 +50,7 @@ const translation = { input: 'Nombre del conocimiento', placeholder: 'Por favor ingresa', nameNotEmpty: 'El nombre no puede estar vacío', - nameLengthInvaild: 'El nombre debe tener entre 1 y 40 caracteres', + nameLengthInvalid: 'El nombre debe tener entre 1 y 40 caracteres', cancelButton: 'Cancelar', confirmButton: 'Crear', failed: 'Error al crear', @@ -109,8 +109,8 @@ const translation = { QATitle: 'Segmentación en formato de pregunta y respuesta', QATip: 'Habilitar esta opción consumirá más tokens', QALanguage: 'Segmentar usando', - emstimateCost: 'Estimación', - emstimateSegment: 'Fragmentos estimados', + estimateCost: 'Estimación', + estimateSegment: 'Fragmentos estimados', segmentCount: 'fragmentos', calculating: 'Calculando...', fileSource: 'Preprocesar documentos', @@ -135,8 +135,8 @@ const translation = { previewSwitchTipStart: 'La vista previa actual del fragmento está en formato de texto, cambiar a una vista previa en formato de pregunta y respuesta', previewSwitchTipEnd: ' consumirá tokens adicionales', characters: 'caracteres', - indexSettedTip: 'Para cambiar el método de índice, por favor ve a la ', - retrivalSettedTip: 'Para cambiar el método de índice, por favor ve a la ', + indexSettingTip: 'Para cambiar el método de índice, por favor ve a la ', + retrievalSettingTip: 'Para cambiar el método de índice, por favor ve a la ', datasetSettingLink: 'configuración del conocimiento.', }, stepThree: { diff --git a/web/i18n/es-ES/dataset.ts b/web/i18n/es-ES/dataset.ts index e4fc362efa..4eefb621d2 100644 --- a/web/i18n/es-ES/dataset.ts +++ b/web/i18n/es-ES/dataset.ts @@ -71,6 +71,7 @@ const translation = { nTo1RetrievalLegacy: 'La recuperación N-a-1 será oficialmente obsoleta a partir de septiembre. Se recomienda utilizar la última recuperación de múltiples rutas para obtener mejores resultados.', nTo1RetrievalLegacyLink: 'Más información', nTo1RetrievalLegacyLinkText: 'La recuperación N-a-1 será oficialmente obsoleta en septiembre.', + defaultRetrievalTip: 'De forma predeterminada, se utiliza la recuperación de varias rutas. El conocimiento se recupera de múltiples bases de conocimiento y luego se vuelve a clasificar.', } export default translation diff --git a/web/i18n/es-ES/login.ts b/web/i18n/es-ES/login.ts index dc12cfc32f..e56161895e 100644 --- a/web/i18n/es-ES/login.ts +++ b/web/i18n/es-ES/login.ts @@ -32,7 +32,7 @@ const translation = { pp: 'Política de privacidad', tosDesc: 'Al registrarte, aceptas nuestros', goToInit: 'Si no has inicializado la cuenta, por favor ve a la página de inicialización', - donthave: '¿No tienes?', + dontHave: '¿No tienes?', invalidInvitationCode: 'Código de invitación inválido', accountAlreadyInited: 'La cuenta ya está inicializada', forgotPassword: '¿Olvidaste tu contraseña?', diff --git a/web/i18n/es-ES/share-app.ts b/web/i18n/es-ES/share-app.ts index 2e436c4327..b1ac171389 100644 --- a/web/i18n/es-ES/share-app.ts +++ b/web/i18n/es-ES/share-app.ts @@ -2,7 +2,7 @@ const translation = { common: { welcome: '', appUnavailable: 'La aplicación no está disponible', - appUnkonwError: 'La aplicación no está disponible', + appUnknownError: 'La aplicación no está disponible', }, chat: { newChat: 'Nuevo chat', @@ -10,7 +10,7 @@ const translation = { unpinnedTitle: 'Chats', newChatDefaultName: 'Nueva conversación', resetChat: 'Reiniciar conversación', - powerBy: 'Desarrollado por', + poweredBy: 'Desarrollado por', prompt: 'Indicación', privatePromptConfigTitle: 'Configuración de la conversación', publicPromptConfigTitle: 'Indicación inicial', diff --git a/web/i18n/es-ES/workflow.ts b/web/i18n/es-ES/workflow.ts index 38f2ad68b1..0efc996f91 100644 --- a/web/i18n/es-ES/workflow.ts +++ b/web/i18n/es-ES/workflow.ts @@ -36,7 +36,7 @@ const translation = { searchVar: 'Buscar variable', variableNamePlaceholder: 'Nombre de la variable', setVarValuePlaceholder: 'Establecer variable', - needConnecttip: 'Este paso no está conectado a nada', + needConnectTip: 'Este paso no está conectado a nada', maxTreeDepth: 'Límite máximo de {{depth}} nodos por rama', needEndNode: 'Debe agregarse el bloque de Fin', needAnswerNode: 'Debe agregarse el bloque de Respuesta', @@ -77,6 +77,22 @@ const translation = { overwriteAndImport: 'Sobrescribir e importar', importFailure: 'Error al importar', importSuccess: 'Importación exitosa', + parallelTip: { + click: { + title: 'Clic', + desc: 'Para agregar', + }, + drag: { + title: 'Arrastrar', + desc: 'Para conectarse', + }, + limit: 'El paralelismo se limita a {{num}} ramas.', + depthLimit: 'Límite de capa de anidamiento paralelo de capas {{num}}', + }, + parallelRun: 'Ejecución paralela', + disconnect: 'Desconectar', + jumpToNode: 'Saltar a este nodo', + addParallelNode: 'Agregar nodo paralelo', }, env: { envPanelTitle: 'Variables de Entorno', @@ -186,6 +202,7 @@ const translation = { 'transform': 'Transformar', 'utilities': 'Utilidades', 'noResult': 'No se encontraron coincidencias', + 'searchTool': 'Herramienta de búsqueda', }, blocks: { 'start': 'Inicio', @@ -411,10 +428,12 @@ const translation = { 'not empty': 'no está vacío', 'null': 'es nulo', 'not null': 'no es nulo', + 'regex match': 'Coincidencia de expresiones regulares', }, enterValue: 'Ingresa un valor', addCondition: 'Agregar condición', conditionNotSetup: 'Condición NO configurada', + selectVariable: 'Seleccionar variable...', }, variableAssigner: { title: 'Asignar variables', @@ -533,6 +552,9 @@ const translation = { stopBy: 'Detenido por {{user}}', }, }, + tracing: { + stopBy: 'Pásate por {{usuario}}', + }, } export default translation diff --git a/web/i18n/fa-IR/app-api.ts b/web/i18n/fa-IR/app-api.ts index 0548ef2a2b..7f65481fcf 100644 --- a/web/i18n/fa-IR/app-api.ts +++ b/web/i18n/fa-IR/app-api.ts @@ -10,7 +10,7 @@ const translation = { pause: 'مکث', playing: 'در حال پخش', loading: 'در حال بارگذاری', - merMaind: { + merMaid: { rerender: 'بازسازی مجدد', }, never: 'هرگز', diff --git a/web/i18n/fa-IR/app-debug.ts b/web/i18n/fa-IR/app-debug.ts index 1ce222581d..00891f3b17 100644 --- a/web/i18n/fa-IR/app-debug.ts +++ b/web/i18n/fa-IR/app-debug.ts @@ -294,7 +294,7 @@ const translation = { historyNoBeEmpty: 'تاریخچه مکالمه باید در پرس و جو تنظیم شود', queryNoBeEmpty: 'پرس و جو باید در پرس و جو تنظیم شود', }, - variableConig: { + variableConfig: { 'addModalTitle': 'افزودن فیلد ورودی', 'editModalTitle': 'ویرایش فیلد ورودی', 'description': 'تنظیم برای متغیر {{varName}}', diff --git a/web/i18n/fa-IR/app-overview.ts b/web/i18n/fa-IR/app-overview.ts index 1bbd7a0283..8a0057fb82 100644 --- a/web/i18n/fa-IR/app-overview.ts +++ b/web/i18n/fa-IR/app-overview.ts @@ -48,6 +48,8 @@ const translation = { title: 'مراحل کاری', show: 'نمایش', hide: 'مخفی کردن', + showDesc: 'نمایش یا پنهان کردن جزئیات گردش کار در WebApp', + subTitle: 'جزئیات گردش کار', }, chatColorTheme: 'تم رنگی چت', chatColorThemeDesc: 'تم رنگی چت‌بات را تنظیم کنید', @@ -64,6 +66,12 @@ const translation = { customDisclaimerPlaceholder: 'متن سلب مسئولیت سفارشی را وارد کنید', customDisclaimerTip: 'متن سلب مسئولیت سفارشی در سمت مشتری نمایش داده می‌شود و اطلاعات بیشتری درباره برنامه ارائه می‌دهد', }, + sso: { + title: 'WebApp SSO', + label: 'احراز هویت SSO', + description: 'همه کاربران باید قبل از استفاده از WebApp با SSO وارد شوند', + tooltip: 'برای فعال کردن WebApp SSO با سرپرست تماس بگیرید', + }, }, embedded: { entry: 'جاسازی شده', @@ -119,7 +127,11 @@ const translation = { tokenPS: 'توکن/ثانیه', totalMessages: { title: 'کل پیام‌ها', - explanation: 'تعداد تعاملات روزانه با AI؛ مهندسی/اشکال‌زدایی دستورات مستثنی هستند.', + explanation: 'تعداد تعاملات روزانه با هوش مصنوعی.', + }, + totalConversations: { + title: 'کل مکالمات', + explanation: 'تعداد مکالمات روزانه با هوش مصنوعی؛ مهندسی/اشکال‌زدایی پرامپت مستثنی است.', }, activeUsers: { title: 'کاربران فعال', diff --git a/web/i18n/fa-IR/app.ts b/web/i18n/fa-IR/app.ts index b9dd179809..9283b04287 100644 --- a/web/i18n/fa-IR/app.ts +++ b/web/i18n/fa-IR/app.ts @@ -126,6 +126,12 @@ const translation = { removeConfirmTitle: 'حذف پیکربندی {{key}}؟', removeConfirmContent: 'پیکربندی فعلی در حال استفاده است، حذف آن ویژگی ردیابی را غیرفعال خواهد کرد.', }, + view: 'مشاهده', + }, + answerIcon: { + descriptionInExplore: 'آیا از نماد WebApp برای جایگزینی 🤖 در Explore استفاده کنیم یا خیر', + description: 'آیا از نماد WebApp برای جایگزینی 🤖 در برنامه مشترک استفاده کنیم یا خیر', + title: 'از نماد WebApp برای جایگزینی 🤖 استفاده کنید', }, } diff --git a/web/i18n/fa-IR/common.ts b/web/i18n/fa-IR/common.ts index e4417bcbcc..c75ab11a63 100644 --- a/web/i18n/fa-IR/common.ts +++ b/web/i18n/fa-IR/common.ts @@ -37,6 +37,7 @@ const translation = { params: 'پارامترها', duplicate: 'تکرار', rename: 'تغییر نام', + audioSourceUnavailable: 'منبع صوتی در دسترس نیست', }, errorMsg: { fieldRequired: '{{field}} الزامی است', @@ -132,7 +133,8 @@ const translation = { workspace: 'فضای کاری', createWorkspace: 'ایجاد فضای کاری', helpCenter: 'راهنما', - roadmapAndFeedback: 'بازخورد', + communityFeedback: 'بازخورد', + roadmap: 'نقشه راه', community: 'انجمن', about: 'درباره', logout: 'خروج', @@ -198,7 +200,7 @@ const translation = { invitationSent: 'دعوت‌نامه ارسال شد', invitationSentTip: 'دعوت‌نامه ارسال شد و آنها می‌توانند وارد Dify شوند تا به داده‌های تیم شما دسترسی پیدا کنند.', invitationLink: 'لینک دعوت', - failedinvitationEmails: 'کاربران زیر با موفقیت دعوت نشدند', + failedInvitationEmails: 'کاربران زیر با موفقیت دعوت نشدند', ok: 'تایید', removeFromTeam: 'حذف از تیم', removeFromTeamTip: 'دسترسی تیم را حذف می‌کند', @@ -206,7 +208,7 @@ const translation = { setMember: 'تنظیم به عنوان عضو عادی', setBuilder: 'تنظیم به عنوان سازنده', setEditor: 'تنظیم به عنوان ویرایشگر', - disinvite: 'لغو دعوت', + disInvite: 'لغو دعوت', deleteMember: 'حذف عضو', you: '(شما)', }, diff --git a/web/i18n/fa-IR/dataset-creation.ts b/web/i18n/fa-IR/dataset-creation.ts index f8483af140..e6e6ad5bfb 100644 --- a/web/i18n/fa-IR/dataset-creation.ts +++ b/web/i18n/fa-IR/dataset-creation.ts @@ -50,7 +50,7 @@ const translation = { input: 'نام دانش', placeholder: 'لطفاً وارد کنید', nameNotEmpty: 'نام نمیتواند خالی باشد', - nameLengthInvaild: 'نام باید بین 1 تا 40 کاراکتر باشد', + nameLengthInvalid: 'نام باید بین 1 تا 40 کاراکتر باشد', cancelButton: 'لغو', confirmButton: 'ایجاد', failed: 'ایجاد ناموفق بود', @@ -109,8 +109,8 @@ const translation = { QATitle: 'بخشبندی در قالب پرسش و پاسخ', QATip: 'فعال کردن این گزینه توکنهای بیشتری مصرف خواهد کرد', QALanguage: 'بخشبندی با استفاده از', - emstimateCost: 'برآورد', - emstimateSegment: 'بخشهای برآورد شده', + estimateCost: 'برآورد', + estimateSegment: 'بخشهای برآورد شده', segmentCount: 'بخشها', calculating: 'در حال محاسبه...', fileSource: 'پیشپردازش اسناد', @@ -135,8 +135,8 @@ const translation = { previewSwitchTipStart: 'پیشنمایش بخش فعلی در قالب متن است، تغییر به پیشنمایش قالب پرسش و پاسخ', previewSwitchTipEnd: ' توکنهای اضافی مصرف خواهد کرد', characters: 'کاراکترها', - indexSettedTip: 'برای تغییر روش شاخص، لطفاً به', - retrivalSettedTip: 'برای تغییر روش شاخص، لطفاً به', + indexSettingTip: 'برای تغییر روش شاخص، لطفاً به', + retrievalSettingTip: 'برای تغییر روش شاخص، لطفاً به', datasetSettingLink: 'تنظیمات دانش بروید.', }, stepThree: { diff --git a/web/i18n/fa-IR/dataset.ts b/web/i18n/fa-IR/dataset.ts index 30036dc68f..e3b9d70e07 100644 --- a/web/i18n/fa-IR/dataset.ts +++ b/web/i18n/fa-IR/dataset.ts @@ -71,6 +71,7 @@ const translation = { nTo1RetrievalLegacy: 'بازیابی N-to-1 از سپتامبر به طور رسمی منسوخ خواهد شد. توصیه می‌شود از بازیابی چند مسیر جدید استفاده کنید تا نتایج بهتری بدست آورید.', nTo1RetrievalLegacyLink: 'بیشتر بدانید', nTo1RetrievalLegacyLinkText: ' بازیابی N-to-1 از سپتامبر به طور رسمی منسوخ خواهد شد.', + defaultRetrievalTip: 'بازیابی چند مسیره به طور پیش فرض استفاده می شود. دانش از چندین پایگاه دانش بازیابی می شود و سپس دوباره رتبه بندی می شود.', } export default translation diff --git a/web/i18n/fa-IR/login.ts b/web/i18n/fa-IR/login.ts index 8912561efe..4ac06c866d 100644 --- a/web/i18n/fa-IR/login.ts +++ b/web/i18n/fa-IR/login.ts @@ -32,7 +32,7 @@ const translation = { pp: 'سیاست حفظ حریم خصوصی', tosDesc: 'با ثبت نام، شما با شرایط ما موافقت می‌کنید', goToInit: 'اگر حساب را اولیه نکرده‌اید، لطفاً به صفحه اولیه‌سازی بروید', - donthave: 'ندارید؟', + dontHave: 'ندارید؟', invalidInvitationCode: 'کد دعوت نامعتبر است', accountAlreadyInited: 'حساب قبلاً اولیه شده است', forgotPassword: 'رمز عبور خود را فراموش کرده‌اید؟', diff --git a/web/i18n/fa-IR/share-app.ts b/web/i18n/fa-IR/share-app.ts index b74c893e6e..f3f1360a92 100644 --- a/web/i18n/fa-IR/share-app.ts +++ b/web/i18n/fa-IR/share-app.ts @@ -2,7 +2,7 @@ const translation = { common: { welcome: '', appUnavailable: 'اپ در دسترس نیست', - appUnkonwError: 'اپ در دسترس نیست', + appUnknownError: 'اپ در دسترس نیست', }, chat: { newChat: 'چت جدید', @@ -10,7 +10,7 @@ const translation = { unpinnedTitle: 'چت‌ها', newChatDefaultName: 'مکالمه جدید', resetChat: 'بازنشانی مکالمه', - powerBy: 'قدرت‌گرفته از', + poweredBy: 'قدرت‌گرفته از', prompt: 'پیشنهاد', privatePromptConfigTitle: 'تنظیمات مکالمه', publicPromptConfigTitle: 'پیشنهاد اولیه', diff --git a/web/i18n/fa-IR/workflow.ts b/web/i18n/fa-IR/workflow.ts index 6dc326e829..67020f5025 100644 --- a/web/i18n/fa-IR/workflow.ts +++ b/web/i18n/fa-IR/workflow.ts @@ -36,7 +36,7 @@ const translation = { searchVar: 'جستجوی متغیر', variableNamePlaceholder: 'نام متغیر', setVarValuePlaceholder: 'تنظیم متغیر', - needConnecttip: 'این مرحله به هیچ چیزی متصل نیست', + needConnectTip: 'این مرحله به هیچ چیزی متصل نیست', maxTreeDepth: 'حداکثر عمق {{depth}} نود در هر شاخه', needEndNode: 'بلوک پایان باید اضافه شود', needAnswerNode: 'بلوک پاسخ باید اضافه شود', @@ -77,6 +77,22 @@ const translation = { overwriteAndImport: 'بازنویسی و وارد کردن', importFailure: 'خطا در وارد کردن', importSuccess: 'وارد کردن موفقیت‌آمیز', + parallelTip: { + click: { + title: 'کلیک کنید', + desc: 'اضافه کردن', + }, + drag: { + desc: 'برای اتصال', + title: 'کشیدن', + }, + depthLimit: 'حد لایه تودرتو موازی لایه های {{num}}', + limit: 'موازی سازی به شاخه های {{num}} محدود می شود.', + }, + disconnect: 'قطع', + jumpToNode: 'پرش به این گره', + parallelRun: 'اجرای موازی', + addParallelNode: 'افزودن گره موازی', }, env: { envPanelTitle: 'متغیرهای محیطی', @@ -186,6 +202,7 @@ const translation = { 'transform': 'تبدیل', 'utilities': 'ابزارهای کاربردی', 'noResult': 'نتیجه‌ای پیدا نشد', + 'searchTool': 'ابزار جستجو', }, blocks: { 'start': 'شروع', @@ -411,6 +428,7 @@ const translation = { 'not empty': 'خالی نیست', 'null': 'خالی', 'not null': 'خالی نیست', + 'regex match': 'مسابقه regex', }, enterValue: 'مقدار را وارد کنید', addCondition: 'افزودن شرط', diff --git a/web/i18n/fr-FR/app-api.ts b/web/i18n/fr-FR/app-api.ts index c214e0a9c9..af3f752f98 100644 --- a/web/i18n/fr-FR/app-api.ts +++ b/web/i18n/fr-FR/app-api.ts @@ -9,7 +9,7 @@ const translation = { play: 'Jouer', pause: 'Pause', playing: 'Jouant', - merMaind: { + merMaid: { rerender: 'Refaire Rerendu', }, never: 'Jamais', @@ -77,6 +77,7 @@ const translation = { pathParams: 'Params de chemin', query: 'Requête', }, + loading: 'Chargement', } export default translation diff --git a/web/i18n/fr-FR/app-debug.ts b/web/i18n/fr-FR/app-debug.ts index b71d251956..2fd863742b 100644 --- a/web/i18n/fr-FR/app-debug.ts +++ b/web/i18n/fr-FR/app-debug.ts @@ -248,7 +248,7 @@ const translation = { historyNoBeEmpty: 'L\'historique de la conversation doit être défini dans le prompt', queryNoBeEmpty: 'La requête doit être définie dans le prompt', }, - variableConig: { + variableConfig: { 'addModalTitle': 'Add Input Field', 'editModalTitle': 'Edit Input Field', 'description': 'Setting for variable {{varName}}', diff --git a/web/i18n/fr-FR/app-overview.ts b/web/i18n/fr-FR/app-overview.ts index 23032f9897..26d538e903 100644 --- a/web/i18n/fr-FR/app-overview.ts +++ b/web/i18n/fr-FR/app-overview.ts @@ -48,6 +48,8 @@ const translation = { title: 'Étapes du workflow', show: 'Afficher', hide: 'Masquer', + showDesc: 'Afficher ou masquer les détails du flux de travail dans WebApp', + subTitle: 'Détails du flux de travail', }, chatColorTheme: 'Thème de couleur du chatbot', chatColorThemeDesc: 'Définir le thème de couleur du chatbot', @@ -64,6 +66,12 @@ const translation = { customDisclaimerPlaceholder: 'Entrez le texte de la clause de non-responsabilité personnalisée', customDisclaimerTip: 'Le texte de la clause de non-responsabilité personnalisée sera affiché côté client, fournissant des informations supplémentaires sur l\'application', }, + sso: { + label: 'Authentification SSO', + title: 'WebApp SSO', + tooltip: 'Contactez l’administrateur pour activer l’authentification unique WebApp', + description: 'Tous les utilisateurs doivent se connecter avec l’authentification unique avant d’utiliser WebApp', + }, }, embedded: { entry: 'Intégré', @@ -119,7 +127,11 @@ const translation = { tokenPS: 'Token/s', totalMessages: { title: 'Total des messages', - explanation: 'Nombre d\'interactions quotidiennes avec l\'IA ; l\'ingénierie/le débogage des prompts sont exclus.', + explanation: 'Nombre d\'interactions quotidiennes avec l\'IA.', + }, + totalConversations: { + title: 'Conversations totales', + explanation: 'Nombre de conversations quotidiennes avec l\'IA ; ingénierie/débogage des prompts exclus.', }, activeUsers: { title: 'Utilisateurs actifs', diff --git a/web/i18n/fr-FR/app.ts b/web/i18n/fr-FR/app.ts index fab9b98f99..55966f6966 100644 --- a/web/i18n/fr-FR/app.ts +++ b/web/i18n/fr-FR/app.ts @@ -122,7 +122,17 @@ const translation = { removeConfirmTitle: 'Supprimer la configuration de {{key}} ?', removeConfirmContent: 'La configuration actuelle est en cours d\'utilisation, sa suppression désactivera la fonction de Traçage.', }, + view: 'Vue', }, + answerIcon: { + description: 'S’il faut utiliser l’icône WebApp pour remplacer 🤖 dans l’application partagée', + title: 'Utiliser l’icône WebApp pour remplacer 🤖', + descriptionInExplore: 'Utilisation de l’icône WebApp pour remplacer 🤖 dans Explore', + }, + importFromDSLUrlPlaceholder: 'Collez le lien DSL ici', + importFromDSL: 'Importation à partir d’une DSL', + importFromDSLUrl: 'À partir de l’URL', + importFromDSLFile: 'À partir d’un fichier DSL', } export default translation diff --git a/web/i18n/fr-FR/billing.ts b/web/i18n/fr-FR/billing.ts index 09c8ca43a8..2bcdfd5b23 100644 --- a/web/i18n/fr-FR/billing.ts +++ b/web/i18n/fr-FR/billing.ts @@ -60,6 +60,8 @@ const translation = { bulkUpload: 'Téléchargement en masse de documents', agentMode: 'Mode Agent', workflow: 'Flux de travail', + llmLoadingBalancingTooltip: 'Ajoutez plusieurs clés API aux modèles, en contournant efficacement les limites de débit de l’API.', + llmLoadingBalancing: 'Équilibrage de charge LLM', }, comingSoon: 'Bientôt disponible', member: 'Membre', @@ -74,6 +76,7 @@ const translation = { }, ragAPIRequestTooltip: 'Fait référence au nombre d\'appels API invoquant uniquement les capacités de traitement de la base de connaissances de Dify.', receiptInfo: 'Seuls le propriétaire de l\'équipe et l\'administrateur de l\'équipe peuvent s\'abonner et consulter les informations de facturation', + annotationQuota: 'Quota d’annotation', }, plans: { sandbox: { diff --git a/web/i18n/fr-FR/common.ts b/web/i18n/fr-FR/common.ts index 2ae0731006..c4fed4405d 100644 --- a/web/i18n/fr-FR/common.ts +++ b/web/i18n/fr-FR/common.ts @@ -37,6 +37,7 @@ const translation = { params: 'Paramètres', duplicate: 'Dupliquer', rename: 'Renommer', + audioSourceUnavailable: 'AudioSource n’est pas disponible', }, placeholder: { input: 'Veuillez entrer', @@ -128,7 +129,8 @@ const translation = { workspace: 'Espace de travail', createWorkspace: 'Créer un Espace de Travail', helpCenter: 'Aide', - roadmapAndFeedback: 'Retour d\'information', + communityFeedback: 'Retour d\'information', + roadmap: 'Feuille de route', community: 'Communauté', about: 'À propos', logout: 'Se déconnecter', @@ -190,16 +192,21 @@ const translation = { invitationSent: 'Invitation envoyée', invitationSentTip: 'Invitation envoyée, et ils peuvent se connecter à Dify pour accéder aux données de votre équipe.', invitationLink: 'Lien d\'invitation', - failedinvitationEmails: 'Les utilisateurs ci-dessous n\'ont pas été invités avec succès', + failedInvitationEmails: 'Les utilisateurs ci-dessous n\'ont pas été invités avec succès', ok: 'D\'accord', removeFromTeam: 'Retirer de l\'équipe', removeFromTeamTip: 'Supprimera l\'accès de l\'équipe', setAdmin: 'Définir comme administrateur', setMember: 'Définir en tant que membre ordinaire', setEditor: 'Définir en tant qu\'éditeur', - disinvite: 'Annuler l\'invitation', + disInvite: 'Annuler l\'invitation', deleteMember: 'Supprimer Membre', you: '(Vous)', + builder: 'Constructeur', + datasetOperatorTip: 'Seul peut gérer la base de connaissances', + datasetOperator: 'Administrateur des connaissances', + setBuilder: 'Définir en tant que constructeur', + builderTip: 'Peut créer et modifier ses propres applications', }, integrations: { connected: 'Connecté', @@ -346,6 +353,22 @@ const translation = { quotaTip: 'Tokens gratuits restants disponibles', loadPresets: 'Charger les Présents', parameters: 'PARAMÈTRES', + modelHasBeenDeprecated: 'Ce modèle est obsolète', + providerManagedDescription: 'Utilisez l’ensemble unique d’informations d’identification fourni par le fournisseur de modèle.', + loadBalancingHeadline: 'Équilibrage', + loadBalancing: 'Équilibrage', + loadBalancingLeastKeyWarning: 'Pour activer l’équilibrage de charge, au moins 2 clés doivent être activées.', + apiKey: 'API-KEY', + apiKeyStatusNormal: 'L’état de l’APIKey est normal', + configLoadBalancing: 'Équilibrage de charge de configuration', + loadBalancingInfo: 'Par défaut, l’équilibrage de charge utilise la stratégie Round-robin. Si la limitation de vitesse est déclenchée, une période de recharge de 1 minute sera appliquée.', + editConfig: 'Modifier la configuration', + addConfig: 'Ajouter une configuration', + apiKeyRateLimit: 'La limite de débit a été atteinte, disponible après {{secondes}}s', + defaultConfig: 'Configuration par défaut', + loadBalancingDescription: 'Réduisez la pression grâce à plusieurs ensembles d’informations d’identification.', + providerManaged: 'Géré par le fournisseur', + upgradeForLoadBalancing: 'Mettez à niveau votre plan pour activer l’équilibrage de charge.', }, dataSource: { add: 'Ajouter une source de données', @@ -369,6 +392,15 @@ const translation = { preview: 'APERÇU', }, }, + website: { + configuredCrawlers: 'Robots d’exploration configurés', + with: 'Avec', + inactive: 'Inactif', + active: 'Actif', + title: 'Site internet', + description: 'Importez du contenu à partir de sites Web à l’aide du robot d’indexation.', + }, + configure: 'Configurer', }, plugin: { serpapi: { @@ -537,6 +569,10 @@ const translation = { created: 'Tag créé avec succès', failed: 'La création de la balise a échoué', }, + errorMsg: { + fieldRequired: '{{field}} est obligatoire', + urlError: 'L’URL doit commencer par http:// ou https://', + }, } export default translation diff --git a/web/i18n/fr-FR/dataset-creation.ts b/web/i18n/fr-FR/dataset-creation.ts index da3ac8d476..c08a3e5731 100644 --- a/web/i18n/fr-FR/dataset-creation.ts +++ b/web/i18n/fr-FR/dataset-creation.ts @@ -45,11 +45,35 @@ const translation = { input: 'Nom de la connaissance', placeholder: 'Veuillez entrer', nameNotEmpty: 'Le nom ne peut pas être vide', - nameLengthInvaild: 'Le nom doit comporter entre 1 et 40 caractères.', + nameLengthInvalid: 'Le nom doit comporter entre 1 et 40 caractères.', cancelButton: 'Annuler', confirmButton: 'Créer', failed: 'Création échouée', }, + website: { + limit: 'Limite', + fireCrawlNotConfiguredDescription: 'Configurez Firecrawl avec la clé API pour l’utiliser.', + selectAll: 'Tout sélectionner', + unknownError: 'Erreur inconnue', + firecrawlDoc: 'Docs Firecrawl', + totalPageScraped: 'Nombre total de pages extraites :', + preview: 'Aperçu', + crawlSubPage: 'Explorer les sous-pages', + configure: 'Configurer', + firecrawlDocLink: 'https://docs.dify.ai/guides/knowledge-base/sync-from-website', + maxDepth: 'Profondeur maximale', + fireCrawlNotConfigured: 'Firecrawl n’est pas configuré', + firecrawlTitle: 'Extraire du contenu web avec 🔥Firecrawl', + scrapTimeInfo: 'Pages récupérées au total dans un délai de {{time}}s', + options: 'Options', + exceptionErrorTitle: 'Une exception s’est produite lors de l’exécution de la tâche Firecrawl :', + includeOnlyPaths: 'Inclure uniquement les chemins d’accès', + resetAll: 'Tout réinitialiser', + run: 'Courir', + extractOnlyMainContent: 'Extraire uniquement le contenu principal (pas d’en-têtes, de navigations, de pieds de page, etc.)', + excludePaths: 'Exclure les chemins d’accès', + maxDepthTooltip: 'Profondeur maximale à explorer par rapport à l’URL saisie. La profondeur 0 gratte simplement la page de l’URL saisie, la profondeur 1 récupère l’URL et tout ce qui suit l’URL saisie + un /, et ainsi de suite.', + }, }, stepTwo: { segmentation: 'Paramètres de bloc', @@ -80,8 +104,8 @@ const translation = { QATitle: 'Segmentation en format Question & Réponse', QATip: 'Activer cette option consommera plus de jetons', QALanguage: 'Segmenter en utilisant', - emstimateCost: 'Estimation', - emstimateSegment: 'Morceaux estimés', + estimateCost: 'Estimation', + estimateSegment: 'Morceaux estimés', segmentCount: 'morceaux', calculating: 'En calcul...', fileSource: 'Prétraiter les documents', @@ -104,9 +128,11 @@ const translation = { previewSwitchTipStart: 'L\'aperçu actuel du morceau est en format texte, passer à un aperçu en format de questions-réponses va', previewSwitchTipEnd: 'consommer des tokens supplémentaires', characters: 'personnages', - indexSettedTip: 'Pour changer la méthode d\'index, veuillez aller à la', - retrivalSettedTip: 'Pour changer la méthode d\'index, veuillez aller à la', + indexSettingTip: 'Pour changer la méthode d\'index, veuillez aller à la', + retrievalSettingTip: 'Pour changer la méthode d\'index, veuillez aller à la', datasetSettingLink: 'Paramètres de connaissance.', + webpageUnit: 'Pages', + websiteSource: 'Site web de prétraitement', }, stepThree: { creationTitle: '🎉 Connaissance créée', @@ -125,6 +151,11 @@ const translation = { modelButtonConfirm: 'Confirmer', modelButtonCancel: 'Annuler', }, + firecrawl: { + apiKeyPlaceholder: 'Clé API de firecrawl.dev', + configFirecrawl: 'Configurer 🔥Firecrawl', + getApiKeyLinkText: 'Obtenez votre clé API auprès de firecrawl.dev', + }, } export default translation diff --git a/web/i18n/fr-FR/dataset-documents.ts b/web/i18n/fr-FR/dataset-documents.ts index c6b0fca1df..1aad7870a8 100644 --- a/web/i18n/fr-FR/dataset-documents.ts +++ b/web/i18n/fr-FR/dataset-documents.ts @@ -13,6 +13,8 @@ const translation = { status: 'STATUT', action: 'ACTION', }, + rename: 'Renommer', + name: 'Nom', }, action: { uploadFile: 'Télécharger un nouveau fichier', @@ -74,6 +76,7 @@ const translation = { error: 'Erreur d\'Importation', ok: 'D\'accord', }, + addUrl: 'Ajouter une URL', }, metadata: { title: 'Métadonnées', diff --git a/web/i18n/fr-FR/dataset-settings.ts b/web/i18n/fr-FR/dataset-settings.ts index 84d1692dff..9b1f44f54f 100644 --- a/web/i18n/fr-FR/dataset-settings.ts +++ b/web/i18n/fr-FR/dataset-settings.ts @@ -27,6 +27,8 @@ const translation = { longDescription: 'À propos de la méthode de récupération, vous pouvez la modifier à tout moment dans les paramètres de Connaissance.', }, save: 'Enregistrer', + me: '(Vous)', + permissionsInvitedMembers: 'Membres partiels de l’équipe', }, } diff --git a/web/i18n/fr-FR/dataset.ts b/web/i18n/fr-FR/dataset.ts index 014168e006..9c8df9f79d 100644 --- a/web/i18n/fr-FR/dataset.ts +++ b/web/i18n/fr-FR/dataset.ts @@ -71,6 +71,7 @@ const translation = { nTo1RetrievalLegacy: 'La récupération N-à-1 sera officiellement obsolète à partir de septembre. Il est recommandé d\'utiliser la dernière récupération multi-chemins pour obtenir de meilleurs résultats.', nTo1RetrievalLegacyLink: 'En savoir plus', nTo1RetrievalLegacyLinkText: 'La récupération N-à-1 sera officiellement obsolète en septembre.', + defaultRetrievalTip: 'La récupération à chemins multiples est utilisée par défaut. Les connaissances sont extraites de plusieurs bases de connaissances, puis reclassées.', } export default translation diff --git a/web/i18n/fr-FR/login.ts b/web/i18n/fr-FR/login.ts index c905320b22..cee09cf0e7 100644 --- a/web/i18n/fr-FR/login.ts +++ b/web/i18n/fr-FR/login.ts @@ -31,7 +31,7 @@ const translation = { pp: 'Politique de Confidentialité', tosDesc: 'En vous inscrivant, vous acceptez nos', goToInit: 'Si vous n\'avez pas initialisé le compte, veuillez vous rendre sur la page d\'initialisation', - donthave: 'Vous n\'avez pas ?', + dontHave: 'Vous n\'avez pas ?', invalidInvitationCode: 'Code d\'invitation invalide', accountAlreadyInited: 'Compte déjà initialisé', forgotPassword: 'Mot de passe oublié?', @@ -53,6 +53,7 @@ const translation = { nameEmpty: 'Le nom est requis', passwordEmpty: 'Un mot de passe est requis', passwordInvalid: 'Le mot de passe doit contenir des lettres et des chiffres, et la longueur doit être supérieure à 8.', + passwordLengthInValid: 'Le mot de passe doit comporter au moins 8 caractères.', }, license: { tip: 'Avant de commencer Dify Community Edition, lisez le GitHub', @@ -68,6 +69,7 @@ const translation = { activated: 'Connectez-vous maintenant', adminInitPassword: 'Mot de passe d\'initialisation de l\'administrateur', validate: 'Valider', + sso: 'Poursuivre avec l’authentification unique', } export default translation diff --git a/web/i18n/fr-FR/share-app.ts b/web/i18n/fr-FR/share-app.ts index 8f9e04e941..44d03b1e35 100644 --- a/web/i18n/fr-FR/share-app.ts +++ b/web/i18n/fr-FR/share-app.ts @@ -2,7 +2,7 @@ const translation = { common: { welcome: '', appUnavailable: 'L\'application n\'est pas disponible', - appUnkonwError: 'L\'application n\'est pas disponible', + appUnknownError: 'L\'application n\'est pas disponible', }, chat: { newChat: 'Nouveau chat', @@ -10,7 +10,7 @@ const translation = { unpinnedTitle: 'Discussions', newChatDefaultName: 'Nouvelle conversation', resetChat: 'Réinitialiser la conversation', - powerBy: 'Propulsé par', + poweredBy: 'Propulsé par', prompt: 'Prompt', privatePromptConfigTitle: 'Paramètres de conversation', publicPromptConfigTitle: 'Prompt Initial', diff --git a/web/i18n/fr-FR/tools.ts b/web/i18n/fr-FR/tools.ts index 5e2c770fea..34c71e7764 100644 --- a/web/i18n/fr-FR/tools.ts +++ b/web/i18n/fr-FR/tools.ts @@ -5,6 +5,7 @@ const translation = { all: 'Tout', builtIn: 'Intégré', custom: 'Personnalisé', + workflow: 'Flux de travail', }, contribute: { line1: 'Je suis intéressé par', @@ -75,6 +76,27 @@ const translation = { customDisclaimerPlaceholder: 'Entrez le texte de la clause de non-responsabilité personnalisée', deleteToolConfirmTitle: 'Supprimer cet outil ?', deleteToolConfirmContent: 'La suppression de l\'outil est irréversible. Les utilisateurs ne pourront plus accéder à votre outil.', + toolInput: { + required: 'Obligatoire', + name: 'Nom', + label: 'Étiquettes', + title: 'Entrée d’outil', + methodSetting: 'Réglage', + labelPlaceholder: 'Choisir des balises(facultatif)', + descriptionPlaceholder: 'Description de la signification du paramètre', + method: 'Méthode', + methodParameter: 'Paramètre', + methodSettingTip: 'L’utilisateur renseigne la configuration de l’outil', + methodParameterTip: 'Remplissages LLM pendant l’inférence', + description: 'Description', + }, + nameForToolCallTip: 'Ne prend en charge que les chiffres, les lettres et les traits de soulignement.', + confirmTitle: 'Confirmer pour enregistrer ?', + nameForToolCall: 'Nom de l’appel de l’outil', + confirmTip: 'Les applications utilisant cet outil seront affectées', + description: 'Description', + nameForToolCallPlaceHolder: 'Utilisé pour la reconnaissance automatique, tels que getCurrentWeather, list_pets', + descriptionPlaceholder: 'Brève description de l’objectif de l’outil, par exemple, obtenir la température d’un endroit spécifique.', }, test: { title: 'Test', @@ -114,6 +136,18 @@ const translation = { toolRemoved: 'Outil supprimé', notAuthorized: 'Outil non autorisé', howToGet: 'Comment obtenir', + addToolModal: { + type: 'type', + emptyTitle: 'Aucun outil de flux de travail disponible', + added: 'supplémentaire', + add: 'ajouter', + category: 'catégorie', + manageInTools: 'Gérer dans Outils', + emptyTip: 'Allez dans « Flux de travail -> Publier en tant qu’outil »', + }, + openInStudio: 'Ouvrir dans Studio', + customToolTip: 'En savoir plus sur les outils personnalisés Dify', + toolNameUsageTip: 'Nom de l’appel de l’outil pour le raisonnement et l’invite de l’agent', } export default translation diff --git a/web/i18n/fr-FR/workflow.ts b/web/i18n/fr-FR/workflow.ts index d3518d5742..3e56246e0c 100644 --- a/web/i18n/fr-FR/workflow.ts +++ b/web/i18n/fr-FR/workflow.ts @@ -36,7 +36,7 @@ const translation = { searchVar: 'Rechercher une variable', variableNamePlaceholder: 'Nom de la variable', setVarValuePlaceholder: 'Définir la valeur de la variable', - needConnecttip: 'Cette étape n\'est connectée à rien', + needConnectTip: 'Cette étape n\'est connectée à rien', maxTreeDepth: 'Limite maximale de {{depth}} nœuds par branche', needEndNode: 'Le bloc de fin doit être ajouté', needAnswerNode: 'Le bloc de réponse doit être ajouté', @@ -77,6 +77,22 @@ const translation = { overwriteAndImport: 'Écraser et importer', importFailure: 'Echec de l\'importation', importSuccess: 'Import avec succès', + parallelTip: { + click: { + title: 'Cliquer', + desc: 'à ajouter', + }, + drag: { + title: 'Traîner', + desc: 'pour se connecter', + }, + limit: 'Le parallélisme est limité aux branches {{num}}.', + depthLimit: 'Limite de couches d’imbrication parallèle de {{num}} couches', + }, + parallelRun: 'Exécution parallèle', + disconnect: 'Déconnecter', + jumpToNode: 'Aller à ce nœud', + addParallelNode: 'Ajouter un nœud parallèle', }, env: { envPanelTitle: 'Variables d\'Environnement', @@ -186,6 +202,7 @@ const translation = { 'transform': 'Transformer', 'utilities': 'Utilitaires', 'noResult': 'Aucun résultat trouvé', + 'searchTool': 'Outil de recherche', }, blocks: { 'start': 'Début', @@ -411,10 +428,12 @@ const translation = { 'not empty': 'n\'est pas vide', 'null': 'est nul', 'not null': 'n\'est pas nul', + 'regex match': 'correspondance regex', }, enterValue: 'Entrez la valeur', addCondition: 'Ajouter une condition', conditionNotSetup: 'Condition NON configurée', + selectVariable: 'Sélectionner une variable...', }, variableAssigner: { title: 'Attribuer des variables', diff --git a/web/i18n/hi-IN/app-api.ts b/web/i18n/hi-IN/app-api.ts index 3c9ceadd29..b983fd0ccb 100644 --- a/web/i18n/hi-IN/app-api.ts +++ b/web/i18n/hi-IN/app-api.ts @@ -10,7 +10,7 @@ const translation = { pause: 'विराम', playing: 'चल रहा है', loading: 'लोड हो रहा है', - merMaind: { + merMaid: { rerender: 'पुनः रीरेंडर करें', }, never: 'कभी नहीं', diff --git a/web/i18n/hi-IN/app-debug.ts b/web/i18n/hi-IN/app-debug.ts index 29944c8d84..1b0633ef32 100644 --- a/web/i18n/hi-IN/app-debug.ts +++ b/web/i18n/hi-IN/app-debug.ts @@ -290,7 +290,7 @@ const translation = { historyNoBeEmpty: 'संवाद इतिहास प्रॉम्प्ट में सेट होना चाहिए', queryNoBeEmpty: 'प्रश्न प्रॉम्प्ट में सेट होना चाहिए', }, - variableConig: { + variableConfig: { 'addModalTitle': 'इनपुट फ़ील्ड जोड़ें', 'editModalTitle': 'इनपुट फ़ील्ड संपादित करें', 'description': 'वेरिएबल {{varName}} के लिए सेटिंग', diff --git a/web/i18n/hi-IN/app-log.ts b/web/i18n/hi-IN/app-log.ts index 7ed8718668..668ae12d65 100644 --- a/web/i18n/hi-IN/app-log.ts +++ b/web/i18n/hi-IN/app-log.ts @@ -93,6 +93,10 @@ const translation = { promptTemplate: 'प्रॉम्प्ट टेम्पलेट', promptInput: 'प्रॉम्प्ट इनपुट', response: 'प्रतिक्रिया', + iterations: 'पुनरूक्तियाँ', + toolUsed: 'प्रयुक्त उपकरण', + finalProcessing: 'अंतिम प्रसंस्करण', + iteration: 'चलना', }, } diff --git a/web/i18n/hi-IN/app-overview.ts b/web/i18n/hi-IN/app-overview.ts index b75206e032..b930fd413a 100644 --- a/web/i18n/hi-IN/app-overview.ts +++ b/web/i18n/hi-IN/app-overview.ts @@ -52,6 +52,8 @@ const translation = { title: 'वर्कफ़्लो स्टेप्स', show: 'दिखाएं', hide: 'छुपाएं', + subTitle: 'कार्यप्रवाह विवरण', + showDesc: 'WebApp में वर्कफ़्लो विवरण दिखाएँ या छुपाएँ', }, chatColorTheme: 'चैटबॉट का रंग थीम', chatColorThemeDesc: 'चैटबॉट का रंग थीम निर्धारित करें', @@ -70,6 +72,12 @@ const translation = { customDisclaimerTip: 'कस्टम अस्वीकरण टेक्स्ट क्लाइंट साइड पर प्रदर्शित होगा, जो एप्लिकेशन के बारे में अतिरिक्त जानकारी प्रदान करेगा', }, + sso: { + title: 'वेबएप एसएसओ', + label: 'SSO प्रमाणीकरण', + description: 'WebApp का उपयोग करने से पहले सभी उपयोगकर्ताओं को SSO के साथ लॉगिन करना आवश्यक है', + tooltip: 'WebApp SSO को सक्षम करने के लिए व्यवस्थापक से संपर्क करें', + }, }, embedded: { entry: 'एम्बेडेड', @@ -130,8 +138,11 @@ const translation = { tokenPS: 'टोकन/से.', totalMessages: { title: 'कुल संदेश', - explanation: - 'दैनिक एआई इंटरैक्शन की गिनती; प्रॉम्प्ट इंजीनियरिंग/डीबगिंग को शामिल नहीं किया गया।', + explanation: 'दैनिक AI इंटरैक्शन की गिनती।', + }, + totalConversations: { + title: 'कुल वार्तालाप', + explanation: 'दैनिक AI वार्तालाप की गिनती; प्रॉम्प्ट इंजीनियरिंग/डीबगिंग शामिल नहीं।', }, activeUsers: { title: 'सक्रिय उपयोगकर्ता', diff --git a/web/i18n/hi-IN/app.ts b/web/i18n/hi-IN/app.ts index 3f22b6701b..06c342cf61 100644 --- a/web/i18n/hi-IN/app.ts +++ b/web/i18n/hi-IN/app.ts @@ -122,7 +122,17 @@ const translation = { removeConfirmTitle: '{{key}} कॉन्फ़िगरेशन हटाएं?', removeConfirmContent: 'वर्तमान कॉन्फ़िगरेशन उपयोग में है, इसे हटाने से ट्रेसिंग सुविधा बंद हो जाएगी।', }, + view: 'देखना', }, + answerIcon: { + title: 'बदलने 🤖 के लिए WebApp चिह्न का उपयोग करें', + descriptionInExplore: 'एक्सप्लोर में बदलने 🤖 के लिए वेबऐप आइकन का उपयोग करना है या नहीं', + description: 'साझा अनुप्रयोग में प्रतिस्थापित 🤖 करने के लिए WebApp चिह्न का उपयोग करना है या नहीं', + }, + importFromDSLFile: 'डीएसएल फ़ाइल से', + importFromDSLUrl: 'यूआरएल से', + importFromDSL: 'DSL से आयात करें', + importFromDSLUrlPlaceholder: 'डीएसएल लिंक यहां पेस्ट करें', } export default translation diff --git a/web/i18n/hi-IN/common.ts b/web/i18n/hi-IN/common.ts index 0a210072e1..256cb9d426 100644 --- a/web/i18n/hi-IN/common.ts +++ b/web/i18n/hi-IN/common.ts @@ -37,6 +37,7 @@ const translation = { params: 'पैरामीटर', duplicate: 'डुप्लिकेट', rename: 'नाम बदलें', + audioSourceUnavailable: 'ऑडियो स्रोत अनुपलब्ध है', }, errorMsg: { fieldRequired: '{{field}} आवश्यक है', @@ -137,7 +138,8 @@ const translation = { workspace: 'वर्कस्पेस', createWorkspace: 'वर्कस्पेस बनाएं', helpCenter: 'सहायता', - roadmapAndFeedback: 'प्रतिक्रिया', + communityFeedback: 'प्रतिक्रिया', + roadmap: 'रोडमैप', community: 'समुदाय', about: 'के बारे में', logout: 'लॉग आउट', @@ -172,6 +174,9 @@ const translation = { langGeniusAccountTip: 'आपका Dify खाता और संबंधित उपयोगकर्ता डेटा।', editName: 'नाम संपादित करें', showAppLength: '{{length}} ऐप्स दिखाएं', + deleteConfirmTip: 'पुष्टि करने के लिए, कृपया अपने पंजीकृत ईमेल से निम्नलिखित भेजें', + delete: 'खाता हटाएं', + deleteTip: 'अपना खाता हटाने से आपका सारा डेटा स्थायी रूप से मिट जाएगा और इसे पुनर्प्राप्त नहीं किया जा सकता है।', }, members: { team: 'टीम', @@ -201,7 +206,7 @@ const translation = { invitationSentTip: 'आमंत्रण भेजा गया, और वे साइन इन करके आपकी टीम डेटा तक पहुंच सकते हैं।', invitationLink: 'आमंत्रण लिंक', - failedinvitationEmails: + failedInvitationEmails: 'नीचे दिए गए उपयोगकर्ताओं को सफलतापूर्वक आमंत्रित नहीं किया गया', ok: 'ठीक है', removeFromTeam: 'टीम से हटाएं', @@ -210,9 +215,11 @@ const translation = { setMember: 'सामान्य सदस्य के रूप में सेट करें', setBuilder: 'निर्माता के रूप में सेट करें', setEditor: 'संपादक के रूप में सेट करें', - disinvite: 'आमंत्रण रद्द करें', + disInvite: 'आमंत्रण रद्द करें', deleteMember: 'सदस्य को हटाएं', you: '(आप)', + datasetOperator: 'ज्ञान व्यवस्थापक', + datasetOperatorTip: 'केवल नॉलेज बेस प्रबंधित कर सकते हैं', }, integrations: { connected: 'कनेक्टेड', diff --git a/web/i18n/hi-IN/dataset-creation.ts b/web/i18n/hi-IN/dataset-creation.ts index 59913c71ea..0fa71acf4a 100644 --- a/web/i18n/hi-IN/dataset-creation.ts +++ b/web/i18n/hi-IN/dataset-creation.ts @@ -51,7 +51,7 @@ const translation = { input: 'ज्ञान का नाम', placeholder: 'कृपया दर्ज करें', nameNotEmpty: 'नाम खाली नहीं हो सकता', - nameLengthInvaild: 'नाम 1 से 40 वर्णों के बीच होना चाहिए', + nameLengthInvalid: 'नाम 1 से 40 वर्णों के बीच होना चाहिए', cancelButton: 'रद्द करें', confirmButton: 'बनाएं', failed: 'बनाना विफल रहा', @@ -121,8 +121,8 @@ const translation = { QATitle: 'प्रश्न और उत्तर प्रारूप में खंड करना', QATip: 'इस विकल्प को सक्षम करने से अधिक टोकन खर्च होंगे', QALanguage: 'का उपयोग करके खंड करना', - emstimateCost: 'अनुमानित लागत', - emstimateSegment: 'अनुमानित खंड', + estimateCost: 'अनुमानित लागत', + estimateSegment: 'अनुमानित खंड', segmentCount: 'खंड', calculating: 'गणना कर रहा है...', fileSource: 'दस्तावेज़ों को पूर्व-प्रसंस्करण करें', @@ -152,8 +152,8 @@ const translation = { 'वर्तमान खंड पूर्वावलोकन पाठ प्रारूप में है, प्रश्न-उत्तर प्रारूप में स्विच करने से', previewSwitchTipEnd: ' अतिरिक्त टोकन खर्च होंगे', characters: 'वर्ण', - indexSettedTip: 'इंडेक्स विधि बदलने के लिए, कृपया जाएं ', - retrivalSettedTip: 'इंडेक्स विधि बदलने के लिए, कृपया जाएं ', + indexSettingTip: 'इंडेक्स विधि बदलने के लिए, कृपया जाएं ', + retrievalSettingTip: 'इंडेक्स विधि बदलने के लिए, कृपया जाएं ', datasetSettingLink: 'ज्ञान सेटिंग्स।', }, stepThree: { diff --git a/web/i18n/hi-IN/dataset-settings.ts b/web/i18n/hi-IN/dataset-settings.ts index cd798e4fab..129643dd85 100644 --- a/web/i18n/hi-IN/dataset-settings.ts +++ b/web/i18n/hi-IN/dataset-settings.ts @@ -32,6 +32,8 @@ const translation = { 'प्राप्ति पद्धति के बारे में, आप इसे किसी भी समय ज्ञान सेटिंग्ज में बदल सकते हैं।', }, save: 'सेवना', + me: '(आप)', + permissionsInvitedMembers: 'आंशिक टीम के सदस्य', }, } diff --git a/web/i18n/hi-IN/dataset.ts b/web/i18n/hi-IN/dataset.ts index de33113d2b..afd8cd277c 100644 --- a/web/i18n/hi-IN/dataset.ts +++ b/web/i18n/hi-IN/dataset.ts @@ -78,6 +78,7 @@ const translation = { nTo1RetrievalLegacy: 'N-से-1 पुनर्प्राप्ति सितंबर से आधिकारिक तौर पर बंद कर दी जाएगी। बेहतर परिणाम प्राप्त करने के लिए नवीनतम बहु-मार्ग पुनर्प्राप्ति का उपयोग करने की सिफारिश की जाती है।', nTo1RetrievalLegacyLink: 'और जानें', nTo1RetrievalLegacyLinkText: 'N-से-1 पुनर्प्राप्ति सितंबर में आधिकारिक तौर पर बंद कर दी जाएगी।', + defaultRetrievalTip: 'मल्टी-पाथ रिट्रीवल का उपयोग डिफ़ॉल्ट रूप से किया जाता है। ज्ञान को कई ज्ञान आधारों से पुनर्प्राप्त किया जाता है और फिर फिर से रैंक किया जाता है।', } export default translation diff --git a/web/i18n/hi-IN/login.ts b/web/i18n/hi-IN/login.ts index 3ecba9a186..b3ca0b1a52 100644 --- a/web/i18n/hi-IN/login.ts +++ b/web/i18n/hi-IN/login.ts @@ -36,7 +36,7 @@ const translation = { tosDesc: 'साइन अप करके, आप हमारी सहमति देते हैं', goToInit: 'यदि आपने खाता प्रारंभ नहीं किया है, तो कृपया प्रारंभिक पृष्ठ पर जाएं', - donthave: 'नहीं है?', + dontHave: 'नहीं है?', invalidInvitationCode: 'अवैध निमंत्रण कोड', accountAlreadyInited: 'खाता पहले से प्रारंभ किया गया है', forgotPassword: 'क्या आपने अपना पासवर्ड भूल गए हैं?', diff --git a/web/i18n/hi-IN/share-app.ts b/web/i18n/hi-IN/share-app.ts index a3884be706..a5c7816fe2 100644 --- a/web/i18n/hi-IN/share-app.ts +++ b/web/i18n/hi-IN/share-app.ts @@ -3,6 +3,7 @@ const translation = { welcome: 'आपका स्वागत है', appUnavailable: 'ऐप उपलब्ध नहीं है', appUnknownError: 'अज्ञात त्रुटि, कृपया पुनः प्रयास करें', + appUnknownError: 'ऐप अनुपलब्ध है', }, chat: { newChat: 'नया चैट', @@ -10,7 +11,7 @@ const translation = { unpinnedTitle: 'चैट', newChatDefaultName: 'नया संवाद', resetChat: 'संवाद रीसेट करें', - powerBy: 'संचालित है', + poweredBy: 'संचालित है', prompt: 'प्रॉम्प्ट', privatePromptConfigTitle: 'संवाद सेटिंग्स', publicPromptConfigTitle: 'प्रारंभिक प्रॉम्प्ट', diff --git a/web/i18n/hi-IN/tools.ts b/web/i18n/hi-IN/tools.ts index ea8e915ea3..6b0cccebad 100644 --- a/web/i18n/hi-IN/tools.ts +++ b/web/i18n/hi-IN/tools.ts @@ -103,6 +103,7 @@ const translation = { label: 'टैग', labelPlaceholder: 'टैग चुनें(वैकल्पिक)', description: 'पैरामीटर के अर्थ का विवरण', + descriptionPlaceholder: 'पैरामीटर के अर्थ का विवरण', }, customDisclaimer: 'कस्टम अस्वीकरण', customDisclaimerPlaceholder: 'कस्टम अस्वीकरण दर्ज करें', diff --git a/web/i18n/hi-IN/workflow.ts b/web/i18n/hi-IN/workflow.ts index 8891fa8bc1..072e4874e3 100644 --- a/web/i18n/hi-IN/workflow.ts +++ b/web/i18n/hi-IN/workflow.ts @@ -36,7 +36,7 @@ const translation = { searchVar: 'वेरिएबल खोजें', variableNamePlaceholder: 'वेरिएबल नाम', setVarValuePlaceholder: 'वेरिएबल सेट करें', - needConnecttip: 'यह चरण किसी से जुड़ा नहीं है', + needConnectTip: 'यह चरण किसी से जुड़ा नहीं है', maxTreeDepth: 'प्रति शाखा अधिकतम {{depth}} नोड्स की सीमा', needEndNode: 'अंत ब्लॉक जोड़ा जाना चाहिए', needAnswerNode: 'उत्तर ब्लॉक जोड़ा जाना चाहिए', @@ -73,6 +73,29 @@ const translation = { 'कार्यप्रवाह अपडेट के बाद टूल पुनः कॉन्फ़िगरेशन आवश्यक है।', viewDetailInTracingPanel: 'विवरण देखें', syncingData: 'डेटा सिंक हो रहा है, बस कुछ सेकंड।', + overwriteAndImport: 'अधिलेखित और आयात', + importSuccess: 'सफलता आयात करें', + chooseDSL: 'डीएसएल (वाईएमएल) फ़ाइल चुनें', + importDSL: 'DSL आयात करें', + backupCurrentDraft: 'बैकअप वर्तमान ड्राफ्ट', + importFailure: 'आयात विफलता', + importDSLTip: 'वर्तमान ड्राफ्ट ओवरराइट हो जाएगा। आयात करने से पहले वर्कफ़्लो को बैकअप के रूप में निर्यात करें.', + parallelTip: { + click: { + title: 'क्लिक करना', + desc: 'जोड़ने के लिए', + }, + drag: { + title: 'खींचना', + desc: 'कनेक्ट करने के लिए', + }, + limit: 'समांतरता {{num}} शाखाओं तक सीमित है।', + depthLimit: '{{num}} परतों की समानांतर नेस्टिंग परत सीमा', + }, + disconnect: 'अलग करना', + parallelRun: 'समानांतर रन', + jumpToNode: 'इस नोड पर जाएं', + addParallelNode: 'समानांतर नोड जोड़ें', }, env: { envPanelTitle: 'पर्यावरण चर', @@ -182,6 +205,7 @@ const translation = { 'transform': 'परिवर्तन', 'utilities': 'उपयोगिताएं', 'noResult': 'कोई मिलान नहीं मिला', + 'searchTool': 'खोज उपकरण', }, blocks: { 'start': 'प्रारंभ', @@ -419,10 +443,12 @@ const translation = { 'not empty': 'खाली नहीं है', 'null': 'शून्य है', 'not null': 'शून्य नहीं है', + 'regex match': 'रेगेक्स मैच', }, enterValue: 'मान दर्ज करें', addCondition: 'शर्त जोड़ें', conditionNotSetup: 'शर्त सेटअप नहीं है', + selectVariable: 'चर का चयन करें...', }, variableAssigner: { title: 'वेरिएबल्स असाइन करें', diff --git a/web/i18n/it-IT/app-api.ts b/web/i18n/it-IT/app-api.ts index a31a771a57..48aceb5c48 100644 --- a/web/i18n/it-IT/app-api.ts +++ b/web/i18n/it-IT/app-api.ts @@ -10,7 +10,7 @@ const translation = { pause: 'Pausa', playing: 'In Riproduzione', loading: 'Caricamento', - merMaind: { + merMaid: { rerender: 'Rifare il rendering', }, never: 'Mai', diff --git a/web/i18n/it-IT/app-debug.ts b/web/i18n/it-IT/app-debug.ts index a4cf7bba2d..e4555b973a 100644 --- a/web/i18n/it-IT/app-debug.ts +++ b/web/i18n/it-IT/app-debug.ts @@ -293,7 +293,7 @@ const translation = { 'La cronologia delle conversazioni deve essere impostata nel prompt', queryNoBeEmpty: 'La query deve essere impostata nel prompt', }, - variableConig: { + variableConfig: { 'addModalTitle': 'Aggiungi Campo Input', 'editModalTitle': 'Modifica Campo Input', 'description': 'Impostazione per la variabile {{varName}}', diff --git a/web/i18n/it-IT/app-overview.ts b/web/i18n/it-IT/app-overview.ts index 380f7e46ad..cd545df6c6 100644 --- a/web/i18n/it-IT/app-overview.ts +++ b/web/i18n/it-IT/app-overview.ts @@ -52,6 +52,8 @@ const translation = { title: 'Fasi del Workflow', show: 'Mostra', hide: 'Nascondi', + subTitle: 'Dettagli del flusso di lavoro', + showDesc: 'Mostrare o nascondere i dettagli del flusso di lavoro in WebApp', }, chatColorTheme: 'Tema colore chat', chatColorThemeDesc: 'Imposta il tema colore del chatbot', @@ -72,6 +74,12 @@ const translation = { customDisclaimerTip: 'Il testo del disclaimer personalizzato verrà visualizzato sul lato client, fornendo informazioni aggiuntive sull\'applicazione', }, + sso: { + label: 'Autenticazione SSO', + title: 'WebApp SSO', + description: 'Tutti gli utenti devono effettuare l\'accesso con SSO prima di utilizzare WebApp', + tooltip: 'Contattare l\'amministratore per abilitare l\'SSO di WebApp', + }, }, embedded: { entry: 'Incorporato', @@ -132,8 +140,11 @@ const translation = { tokenPS: 'Token/s', totalMessages: { title: 'Totale Messaggi', - explanation: - 'Conteggio delle interazioni giornaliere con l\'AI; ingegneria dei prompt/debug esclusi.', + explanation: 'Conteggio delle interazioni giornaliere con l\'IA.', + }, + totalConversations: { + title: 'Conversazioni totali', + explanation: 'Conteggio delle conversazioni giornaliere con l\'IA; ingegneria/debug dei prompt esclusi.', }, activeUsers: { title: 'Utenti Attivi', diff --git a/web/i18n/it-IT/app.ts b/web/i18n/it-IT/app.ts index 265cb58ec4..f28b000b58 100644 --- a/web/i18n/it-IT/app.ts +++ b/web/i18n/it-IT/app.ts @@ -134,7 +134,17 @@ const translation = { removeConfirmContent: 'La configurazione attuale è in uso, rimuovendola disattiverà la funzione di Tracciamento.', }, + view: 'Vista', }, + answerIcon: { + description: 'Se utilizzare l\'icona WebApp per la sostituzione 🤖 nell\'applicazione condivisa', + title: 'Usa l\'icona WebApp per sostituire 🤖', + descriptionInExplore: 'Se utilizzare l\'icona WebApp per sostituirla 🤖 in Esplora', + }, + importFromDSLUrl: 'Dall\'URL', + importFromDSLFile: 'Da file DSL', + importFromDSL: 'Importazione da DSL', + importFromDSLUrlPlaceholder: 'Incolla qui il link DSL', } export default translation diff --git a/web/i18n/it-IT/common.ts b/web/i18n/it-IT/common.ts index 595a5075eb..aa675bb471 100644 --- a/web/i18n/it-IT/common.ts +++ b/web/i18n/it-IT/common.ts @@ -37,6 +37,7 @@ const translation = { params: 'Parametri', duplicate: 'Duplica', rename: 'Rinomina', + audioSourceUnavailable: 'AudioSource non è disponibile', }, errorMsg: { fieldRequired: '{{field}} è obbligatorio', @@ -137,7 +138,8 @@ const translation = { workspace: 'Workspace', createWorkspace: 'Crea Workspace', helpCenter: 'Aiuto', - roadmapAndFeedback: 'Feedback', + communityFeedback: 'Feedback', + roadmap: 'Tabella di marcia', community: 'Comunità', about: 'Informazioni', logout: 'Esci', @@ -208,7 +210,7 @@ const translation = { invitationSentTip: 'Invito inviato, e possono accedere a Dify per accedere ai dati del tuo team.', invitationLink: 'Link di Invito', - failedinvitationEmails: + failedInvitationEmails: 'Gli utenti seguenti non sono stati invitati con successo', ok: 'OK', removeFromTeam: 'Rimuovi dal team', @@ -217,7 +219,7 @@ const translation = { setMember: 'Imposta come membro ordinario', setBuilder: 'Imposta come builder', setEditor: 'Imposta come editor', - disinvite: 'Annulla l\'invito', + disInvite: 'Annulla l\'invito', deleteMember: 'Elimina Membro', you: '(Tu)', }, diff --git a/web/i18n/it-IT/dataset-creation.ts b/web/i18n/it-IT/dataset-creation.ts index 553c8218c4..1629776bf3 100644 --- a/web/i18n/it-IT/dataset-creation.ts +++ b/web/i18n/it-IT/dataset-creation.ts @@ -52,7 +52,7 @@ const translation = { input: 'Nome della Conoscenza', placeholder: 'Per favore inserisci', nameNotEmpty: 'Il nome non può essere vuoto', - nameLengthInvaild: 'Il nome deve essere tra 1 e 40 caratteri', + nameLengthInvalid: 'Il nome deve essere tra 1 e 40 caratteri', cancelButton: 'Annulla', confirmButton: 'Crea', failed: 'Creazione fallita', @@ -124,8 +124,8 @@ const translation = { QATitle: 'Segmentazione in formato Domanda & Risposta', QATip: 'Abilitare questa opzione consumerà più token', QALanguage: 'Segmenta usando', - emstimateCost: 'Stima', - emstimateSegment: 'Blocchi stimati', + estimateCost: 'Stima', + estimateSegment: 'Blocchi stimati', segmentCount: 'blocchi', calculating: 'Calcolo in corso...', fileSource: 'Preprocessa documenti', @@ -155,8 +155,8 @@ const translation = { 'L\'anteprima del blocco corrente è in formato testo, il passaggio a un\'anteprima in formato domanda e risposta', previewSwitchTipEnd: ' consumerà token aggiuntivi', characters: 'caratteri', - indexSettedTip: 'Per cambiare il metodo di indicizzazione, vai alle ', - retrivalSettedTip: 'Per cambiare il metodo di indicizzazione, vai alle ', + indexSettingTip: 'Per cambiare il metodo di indicizzazione, vai alle ', + retrievalSettingTip: 'Per cambiare il metodo di indicizzazione, vai alle ', datasetSettingLink: 'impostazioni della Conoscenza.', }, stepThree: { diff --git a/web/i18n/it-IT/dataset.ts b/web/i18n/it-IT/dataset.ts index 9223a3a96d..fc79c9b4c7 100644 --- a/web/i18n/it-IT/dataset.ts +++ b/web/i18n/it-IT/dataset.ts @@ -78,6 +78,7 @@ const translation = { nTo1RetrievalLegacy: 'Il recupero N-a-1 sarà ufficialmente deprecato da settembre. Si consiglia di utilizzare il più recente recupero multi-percorso per ottenere risultati migliori.', nTo1RetrievalLegacyLink: 'Scopri di più', nTo1RetrievalLegacyLinkText: 'Il recupero N-a-1 sarà ufficialmente deprecato a settembre.', + defaultRetrievalTip: 'Per impostazione predefinita, il recupero a percorsi multipli viene utilizzato. Le informazioni vengono recuperate da più knowledge base e quindi riclassificate.', } export default translation diff --git a/web/i18n/it-IT/login.ts b/web/i18n/it-IT/login.ts index 018f9dca46..b46960aa45 100644 --- a/web/i18n/it-IT/login.ts +++ b/web/i18n/it-IT/login.ts @@ -38,7 +38,7 @@ const translation = { tosDesc: 'Iscrivendoti, accetti i nostri', goToInit: 'Se non hai inizializzato l\'account, vai alla pagina di inizializzazione', - donthave: 'Non hai?', + dontHave: 'Non hai?', invalidInvitationCode: 'Codice di invito non valido', accountAlreadyInited: 'Account già inizializzato', forgotPassword: 'Hai dimenticato la password?', diff --git a/web/i18n/it-IT/share-app.ts b/web/i18n/it-IT/share-app.ts index b1f99d0ba1..772a6e902d 100644 --- a/web/i18n/it-IT/share-app.ts +++ b/web/i18n/it-IT/share-app.ts @@ -2,7 +2,7 @@ const translation = { common: { welcome: '', appUnavailable: 'L\'app non è disponibile', - appUnkonwError: 'L\'app non è disponibile', + appUnknownError: 'L\'app non è disponibile', }, chat: { newChat: 'Nuova chat', @@ -10,7 +10,7 @@ const translation = { unpinnedTitle: 'Chat', newChatDefaultName: 'Nuova conversazione', resetChat: 'Reimposta conversazione', - powerBy: 'Powered by', + poweredBy: 'Powered by', prompt: 'Prompt', privatePromptConfigTitle: 'Impostazioni conversazione', publicPromptConfigTitle: 'Prompt iniziale', diff --git a/web/i18n/it-IT/workflow.ts b/web/i18n/it-IT/workflow.ts index 5936679e13..f5d6fc8bf5 100644 --- a/web/i18n/it-IT/workflow.ts +++ b/web/i18n/it-IT/workflow.ts @@ -36,7 +36,7 @@ const translation = { searchVar: 'Cerca variabile', variableNamePlaceholder: 'Nome variabile', setVarValuePlaceholder: 'Imposta variabile', - needConnecttip: 'Questo passaggio non è collegato a nulla', + needConnectTip: 'Questo passaggio non è collegato a nulla', maxTreeDepth: 'Limite massimo di {{depth}} nodi per ramo', needEndNode: 'Deve essere aggiunto il blocco di Fine', needAnswerNode: 'Deve essere aggiunto il blocco di Risposta', @@ -81,6 +81,22 @@ const translation = { overwriteAndImport: 'Sovrascrivi e Importa', importFailure: 'Importazione fallita', importSuccess: 'Importazione riuscita', + parallelTip: { + click: { + title: 'Clic', + desc: 'per aggiungere', + }, + drag: { + title: 'Trascinare', + desc: 'per collegare', + }, + depthLimit: 'Limite di livelli di annidamento parallelo di {{num}} livelli', + limit: 'Il parallelismo è limitato ai rami {{num}}.', + }, + parallelRun: 'Corsa parallela', + disconnect: 'Disconnettere', + jumpToNode: 'Vai a questo nodo', + addParallelNode: 'Aggiungi nodo parallelo', }, env: { envPanelTitle: 'Variabili d\'Ambiente', @@ -191,6 +207,7 @@ const translation = { 'transform': 'Trasforma', 'utilities': 'Utility', 'noResult': 'Nessuna corrispondenza trovata', + 'searchTool': 'Strumento di ricerca', }, blocks: { 'start': 'Inizio', @@ -430,6 +447,7 @@ const translation = { 'not empty': 'non è vuoto', 'null': 'è nullo', 'not null': 'non è nullo', + 'regex match': 'Corrispondenza regex', }, enterValue: 'Inserisci valore', addCondition: 'Aggiungi Condizione', diff --git a/web/i18n/ja-JP/app-api.ts b/web/i18n/ja-JP/app-api.ts index 63beaeb45c..721504f9f3 100644 --- a/web/i18n/ja-JP/app-api.ts +++ b/web/i18n/ja-JP/app-api.ts @@ -10,7 +10,7 @@ const translation = { pause: '一時停止', playing: '再生中', loading: '読み込み中', - merMaind: { + merMaid: { rerender: '再レンダリング', }, never: 'なし', diff --git a/web/i18n/ja-JP/app-debug.ts b/web/i18n/ja-JP/app-debug.ts index 6049be2406..e9ee765435 100644 --- a/web/i18n/ja-JP/app-debug.ts +++ b/web/i18n/ja-JP/app-debug.ts @@ -295,7 +295,7 @@ const translation = { historyNoBeEmpty: 'プロンプトには会話履歴を設定する必要があります', queryNoBeEmpty: 'プロンプトにクエリを設定する必要があります', }, - variableConig: { + variableConfig: { 'addModalTitle': '入力フィールドを追加', 'editModalTitle': '入力フィールドを編集', 'description': '{{varName}} の変数設定', diff --git a/web/i18n/ja-JP/app-overview.ts b/web/i18n/ja-JP/app-overview.ts index 1eb6fd1795..1fcd9d21c4 100644 --- a/web/i18n/ja-JP/app-overview.ts +++ b/web/i18n/ja-JP/app-overview.ts @@ -30,24 +30,26 @@ const translation = { overview: { title: '概要', appInfo: { - explanation: '使いやすいAI WebApp', + explanation: '使いやすいAI Webアプリ', accessibleAddress: '公開URL', preview: 'プレビュー', regenerate: '再生成', regenerateNotice: '公開URLを再生成しますか?', - preUseReminder: '続行する前にWebAppを有効にしてください。', + preUseReminder: '続行する前にWebアプリを有効にしてください。', settings: { entry: '設定', - title: 'WebApp設定', - webName: 'WebApp名', - webDesc: 'WebAppの説明', + title: 'Webアプリの設定', + webName: 'Webアプリの名前', + webDesc: 'Webアプリの説明', webDescTip: 'このテキストはクライアント側に表示され、アプリケーションの使用方法の基本的なガイダンスを提供します。', - webDescPlaceholder: 'WebAppの説明を入力してください', + webDescPlaceholder: 'Webアプリの説明を入力してください', language: '言語', workflow: { title: 'ワークフローステップ', show: '表示', hide: '非表示', + subTitle: 'ワークフローの詳細', + showDesc: 'Webアプリでワークフローの詳細を表示または非表示にする', }, chatColorTheme: 'チャットボットのカラーテーマ', chatColorThemeDesc: 'チャットボットのカラーテーマを設定します', @@ -64,6 +66,12 @@ const translation = { customDisclaimerPlaceholder: '免責事項を入力してください', customDisclaimerTip: 'アプリケーションの使用に関する免責事項を提供します。', }, + sso: { + title: 'WebアプリのSSO', + tooltip: '管理者に問い合わせて、WebアプリのSSOを有効にします', + label: 'SSO認証', + description: 'すべてのユーザーは、Webアプリを使用する前にSSOでログインする必要があります', + }, }, embedded: { entry: '埋め込み', @@ -83,8 +91,8 @@ const translation = { customize: { way: '方法', entry: 'カスタマイズ', - title: 'AI WebAppのカスタマイズ', - explanation: 'シナリオとスタイルのニーズに合わせてWeb Appのフロントエンドをカスタマイズできます。', + title: 'AI Webアプリのカスタマイズ', + explanation: 'シナリオとスタイルのニーズに合わせてWebアプリのフロントエンドをカスタマイズできます。', way1: { name: 'クライアントコードをフォークして修正し、Vercelにデプロイします(推奨)', step1: 'クライアントコードをフォークして修正します', @@ -119,7 +127,11 @@ const translation = { tokenPS: 'トークン/秒', totalMessages: { title: 'トータルメッセージ数', - explanation: '日次AIインタラクション数;工学的/デバッグ目的のプロンプトは除外されます。', + explanation: '日次AIインタラクション数。', + }, + totalConversations: { + title: '総会話数', + explanation: '日次AI会話数;プロンプトエンジニアリング/デバッグは除外。', }, activeUsers: { title: 'アクティブユーザー数', diff --git a/web/i18n/ja-JP/app.ts b/web/i18n/ja-JP/app.ts index 55f641f4c3..76c7d1c4f4 100644 --- a/web/i18n/ja-JP/app.ts +++ b/web/i18n/ja-JP/app.ts @@ -127,6 +127,12 @@ const translation = { removeConfirmTitle: '{{key}}の設定を削除しますか?', removeConfirmContent: '現在の設定は使用中です。これを削除すると、トレース機能が無効になります。', }, + view: '見る', + }, + answerIcon: { + title: 'Webアプリアイコンを使用して🤖を置き換える', + description: '共有アプリケーションの中で Webアプリアイコンを使用して🤖を置き換えるかどうか', + descriptionInExplore: 'ExploreでWebアプリアイコンを使用して🤖を置き換えるかどうか', }, } diff --git a/web/i18n/ja-JP/common.ts b/web/i18n/ja-JP/common.ts index fc61141bd3..e2517a619d 100644 --- a/web/i18n/ja-JP/common.ts +++ b/web/i18n/ja-JP/common.ts @@ -37,6 +37,7 @@ const translation = { params: 'パラメータ', duplicate: '重複', rename: '名前の変更', + audioSourceUnavailable: 'AudioSource が利用できません', }, errorMsg: { fieldRequired: '{{field}}は必要です', @@ -132,7 +133,8 @@ const translation = { workspace: 'ワークスペース', createWorkspace: 'ワークスペースを作成', helpCenter: 'ヘルプ', - roadmapAndFeedback: 'フィードバック', + communityFeedback: 'フィードバック', + roadmap: 'ロードマップ', community: 'コミュニティ', about: 'Difyについて', logout: 'ログアウト', @@ -198,7 +200,7 @@ const translation = { invitationSent: '招待が送信されました', invitationSentTip: '招待が送信され、彼らはDifyにサインインしてあなた様のチームデータにアクセスできます。', invitationLink: '招待リンク', - failedinvitationEmails: '以下のユーザーは正常に招待されませんでした', + failedInvitationEmails: '以下のユーザーは正常に招待されませんでした', ok: 'OK', removeFromTeam: 'チームから削除', removeFromTeamTip: 'チームへのアクセスが削除されます', @@ -206,7 +208,7 @@ const translation = { setMember: '通常のメンバーに設定', setBuilder: 'ビルダーに設定', setEditor: 'エディターに設定', - disinvite: '招待をキャンセル', + disInvite: '招待をキャンセル', deleteMember: 'メンバーを削除', you: '(あなた様)', }, diff --git a/web/i18n/ja-JP/dataset-creation.ts b/web/i18n/ja-JP/dataset-creation.ts index 448e208222..e6d204840a 100644 --- a/web/i18n/ja-JP/dataset-creation.ts +++ b/web/i18n/ja-JP/dataset-creation.ts @@ -50,7 +50,7 @@ const translation = { input: 'ナレッジ名', placeholder: '入力してください', nameNotEmpty: '名前は空にできません', - nameLengthInvaild: '名前は1〜40文字である必要があります', + nameLengthInvalid: '名前は1〜40文字である必要があります', cancelButton: 'キャンセル', confirmButton: '作成', failed: '作成に失敗しました', @@ -109,8 +109,8 @@ const translation = { QATitle: '質問と回答形式でセグメント化', QATip: 'このオプションを有効にすると、追加のトークンが消費されます', QALanguage: '使用言語', - emstimateCost: '見積もり', - emstimateSegment: '推定チャンク数', + estimateCost: '見積もり', + estimateSegment: '推定チャンク数', segmentCount: 'チャンク', calculating: '計算中...', fileSource: 'ドキュメントの前処理', @@ -135,8 +135,8 @@ const translation = { previewSwitchTipStart: '現在のチャンクプレビューはテキスト形式です。質問と回答形式のプレビューに切り替えると、', previewSwitchTipEnd: ' 追加のトークンが消費されます', characters: '文字', - indexSettedTip: 'インデックス方法を変更するには、', - retrivalSettedTip: 'インデックス方法を変更するには、', + indexSettingTip: 'インデックス方法を変更するには、', + retrievalSettingTip: '検索方法を変更するには、', datasetSettingLink: 'ナレッジ設定', }, stepThree: { diff --git a/web/i18n/ja-JP/dataset.ts b/web/i18n/ja-JP/dataset.ts index d2eaf05276..a765473b7e 100644 --- a/web/i18n/ja-JP/dataset.ts +++ b/web/i18n/ja-JP/dataset.ts @@ -37,7 +37,7 @@ const translation = { recommend: 'おすすめ', }, invertedIndex: { - title: '逆インデックス', + title: '転置インデックス', description: '効率的な検索に使用される構造です。各用語が含まれるドキュメントまたはWebページを指すように、用語ごとに整理されています。', }, change: '変更', @@ -53,6 +53,7 @@ const translation = { semantic_search: 'ベクトル検索', full_text_search: 'フルテキスト検索', hybrid_search: 'ハイブリッド検索', + invertedIndex: '逆さま', }, mixtureHighQualityAndEconomicTip: '高品質なナレッジベースと経済的なナレッジベースを混在させるには、Rerankモデルを構成する必要がある。', inconsistentEmbeddingModelTip: '選択されたナレッジベースが一貫性のない埋め込みモデルで構成されている場合、Rerankモデルの構成が必要です。', @@ -70,6 +71,7 @@ const translation = { nTo1RetrievalLegacy: '製品計画によると、N-to-1 Retrievalは9月に正式に廃止される予定です。それまでは通常通り使用できます。', nTo1RetrievalLegacyLink: '詳細を見る', nTo1RetrievalLegacyLinkText: ' N-to-1 retrievalは9月に正式に廃止されます。', + defaultRetrievalTip: 'デフォルトでは、マルチパス取得が使用されます。ナレッジは複数のナレッジ ベースから取得され、再ランク付けされます。', } export default translation diff --git a/web/i18n/ja-JP/login.ts b/web/i18n/ja-JP/login.ts index dd01e3b96a..72a26df3db 100644 --- a/web/i18n/ja-JP/login.ts +++ b/web/i18n/ja-JP/login.ts @@ -32,7 +32,7 @@ const translation = { pp: 'プライバシーポリシー', tosDesc: 'サインアップすることで、以下に同意するものとします', goToInit: 'アカウントを初期化していない場合は、初期化ページに移動してください', - donthave: 'お持ちでない場合', + dontHave: 'お持ちでない場合', invalidInvitationCode: '無効な招待コード', accountAlreadyInited: 'アカウントは既に初期化されています', forgotPassword: 'パスワードを忘れましたか?', diff --git a/web/i18n/ja-JP/share-app.ts b/web/i18n/ja-JP/share-app.ts index 503972dc48..6b7615c408 100644 --- a/web/i18n/ja-JP/share-app.ts +++ b/web/i18n/ja-JP/share-app.ts @@ -2,7 +2,7 @@ const translation = { common: { welcome: '', appUnavailable: 'アプリが利用できません', - appUnkonwError: 'アプリが利用できません', + appUnknownError: 'アプリが利用できません', }, chat: { newChat: '新しいチャット', @@ -10,7 +10,7 @@ const translation = { unpinnedTitle: 'チャット', newChatDefaultName: '新しい会話', resetChat: '会話をリセット', - powerBy: 'Powered by', + poweredBy: 'Powered by', prompt: 'プロンプト', privatePromptConfigTitle: '会話の設定', publicPromptConfigTitle: '初期プロンプト', diff --git a/web/i18n/ja-JP/workflow.ts b/web/i18n/ja-JP/workflow.ts index 8f506bcb46..755061e8f6 100644 --- a/web/i18n/ja-JP/workflow.ts +++ b/web/i18n/ja-JP/workflow.ts @@ -36,7 +36,7 @@ const translation = { searchVar: '変数を検索', variableNamePlaceholder: '変数名', setVarValuePlaceholder: '変数を設定', - needConnecttip: 'このステップは何にも接続されていません', + needConnectTip: 'このステップは何にも接続されていません', maxTreeDepth: 'ブランチごとの最大制限は{{depth}}ノードです', needEndNode: '終了ブロックを追加する必要があります', needAnswerNode: '回答ブロックを追加する必要があります', @@ -77,6 +77,22 @@ const translation = { overwriteAndImport: 'オーバライトとインポート', importFailure: 'インポート失敗', importSuccess: 'インポート成功', + parallelTip: { + click: { + title: 'クリック', + desc: '追加する', + }, + drag: { + title: 'ドラッグ', + desc: '接続するには', + }, + limit: '並列処理は {{num}} ブランチに制限されています。', + depthLimit: '{{num}}レイヤーの平行ネストレイヤーの制限', + }, + parallelRun: 'パラレルラン', + disconnect: '切る', + jumpToNode: 'このノードにジャンプします', + addParallelNode: '並列ノードを追加', }, env: { envPanelTitle: '環境変数', @@ -412,10 +428,12 @@ const translation = { 'not empty': '空でない', 'null': 'null', 'not null': 'nullでない', + 'regex match': '正規表現マッチ', }, enterValue: '値を入力', addCondition: '条件を追加', conditionNotSetup: '条件が設定されていません', + selectVariable: '変数を選択...', }, variableAssigner: { title: '変数を代入する', diff --git a/web/i18n/ko-KR/app-api.ts b/web/i18n/ko-KR/app-api.ts index fc978cddf4..4708fac7dd 100644 --- a/web/i18n/ko-KR/app-api.ts +++ b/web/i18n/ko-KR/app-api.ts @@ -10,7 +10,7 @@ const translation = { pause: '일시 정지', playing: '실행 중', loading: '로드 중', - merMaind: { + merMaid: { rerender: '다시 렌더링', }, never: '없음', diff --git a/web/i18n/ko-KR/app-debug.ts b/web/i18n/ko-KR/app-debug.ts index 0a2488b64c..bafe0bf8d8 100644 --- a/web/i18n/ko-KR/app-debug.ts +++ b/web/i18n/ko-KR/app-debug.ts @@ -259,7 +259,7 @@ const translation = { historyNoBeEmpty: '프롬프트에 대화 기록을 설정해야 합니다', queryNoBeEmpty: '프롬프트에 쿼리를 설정해야 합니다', }, - variableConig: { + variableConfig: { 'addModalTitle': '입력 필드 추가', 'editModalTitle': '입력 필드 편집', 'description': '{{varName}} 변수 설정', diff --git a/web/i18n/ko-KR/app-overview.ts b/web/i18n/ko-KR/app-overview.ts index 47342984b3..b06e84587b 100644 --- a/web/i18n/ko-KR/app-overview.ts +++ b/web/i18n/ko-KR/app-overview.ts @@ -48,6 +48,8 @@ const translation = { title: '워크플로 단계', show: '표시', hide: '숨기기', + showDesc: 'WebApp에서 워크플로 세부 정보 표시 또는 숨기기', + subTitle: '워크플로우 세부 정보', }, chatColorTheme: '챗봇 색상 테마', chatColorThemeDesc: '챗봇의 색상 테마를 설정하세요', @@ -60,6 +62,15 @@ const translation = { privacyPolicy: '개인정보 처리방침', privacyPolicyPlaceholder: '개인정보 처리방침 링크를 입력하세요', privacyPolicyTip: '방문자가 애플리케이션이 수집하는 데이터를 이해하고, Dify의 개인정보 처리방침을 참조할 수 있도록 합니다.', + customDisclaimer: '사용자 지정 면책 조항', + customDisclaimerPlaceholder: '사용자 지정 면책 조항 텍스트를 입력합니다.', + customDisclaimerTip: '사용자 지정 고지 사항 텍스트는 클라이언트 쪽에 표시되어 응용 프로그램에 대한 추가 정보를 제공합니다', + }, + sso: { + label: 'SSO 인증', + title: '웹앱 SSO', + tooltip: '관리자에게 문의하여 WebApp SSO를 사용하도록 설정합니다.', + description: '모든 사용자는 WebApp을 사용하기 전에 SSO로 로그인해야 합니다.', }, }, embedded: { @@ -116,7 +127,11 @@ const translation = { tokenPS: '토큰/초', totalMessages: { title: '총 메시지 수', - explanation: '일일 AI 상호작용 수; 엔지니어링/디버깅 목적의 프롬프트는 제외됩니다.', + explanation: '일일 AI 상호작용 수.', + }, + totalConversations: { + title: '총 대화 수', + explanation: '일일 AI 대화 수; 프롬프트 엔지니어링/디버깅 제외.', }, activeUsers: { title: '활성 사용자 수', diff --git a/web/i18n/ko-KR/app.ts b/web/i18n/ko-KR/app.ts index a1fde30e4d..3f3abd3325 100644 --- a/web/i18n/ko-KR/app.ts +++ b/web/i18n/ko-KR/app.ts @@ -118,7 +118,17 @@ const translation = { removeConfirmTitle: '{{key}} 구성을 제거하시겠습니까?', removeConfirmContent: '현재 구성이 사용 중입니다. 제거하면 추적 기능이 꺼집니다.', }, + view: '보기', }, + answerIcon: { + description: 'WebApp 아이콘을 사용하여 공유 응용 프로그램에서 바꿀🤖지 여부', + title: 'WebApp 아이콘을 사용하여 🤖', + descriptionInExplore: 'Explore에서 WebApp 아이콘을 사용하여 바꿀🤖지 여부', + }, + importFromDSL: 'DSL에서 가져오기', + importFromDSLFile: 'DSL 파일에서', + importFromDSLUrl: 'URL에서', + importFromDSLUrlPlaceholder: '여기에 DSL 링크 붙여 넣기', } export default translation diff --git a/web/i18n/ko-KR/billing.ts b/web/i18n/ko-KR/billing.ts index ca6a361e06..94d557fd4b 100644 --- a/web/i18n/ko-KR/billing.ts +++ b/web/i18n/ko-KR/billing.ts @@ -58,6 +58,9 @@ const translation = { ragAPIRequest: 'RAG API 요청', agentMode: '에이전트 모드', workflow: '워크플로우', + llmLoadingBalancing: 'LLM 로드 밸런싱', + bulkUpload: '문서 대량 업로드', + llmLoadingBalancingTooltip: '모델에 여러 API 키를 추가하여 API 속도 제한을 효과적으로 우회할 수 있습니다.', }, comingSoon: '곧 출시 예정', member: '멤버', @@ -72,6 +75,8 @@ const translation = { }, ragAPIRequestTooltip: 'Dify의 지식베이스 처리 기능을 호출하는 API 호출 수를 나타냅니다.', receiptInfo: '팀 소유자 및 팀 관리자만 구독 및 청구 정보를 볼 수 있습니다', + annotationQuota: 'Annotation Quota(주석 할당량)', + documentsUploadQuota: '문서 업로드 할당량', }, plans: { sandbox: { diff --git a/web/i18n/ko-KR/common.ts b/web/i18n/ko-KR/common.ts index edd0295b89..8ef55da3f7 100644 --- a/web/i18n/ko-KR/common.ts +++ b/web/i18n/ko-KR/common.ts @@ -37,6 +37,7 @@ const translation = { params: '매개변수', duplicate: '중복', rename: '이름 바꾸기', + audioSourceUnavailable: '오디오 소스를 사용할 수 없습니다.', }, placeholder: { input: '입력해주세요', @@ -124,7 +125,8 @@ const translation = { workspace: '작업 공간', createWorkspace: '작업 공간 만들기', helpCenter: '도움말 센터', - roadmapAndFeedback: '로드맵 및 피드백', + communityFeedback: '로드맵 및 피드백', + roadmap: '로드맵', community: '커뮤니티', about: 'Dify 소개', logout: '로그아웃', @@ -186,16 +188,21 @@ const translation = { invitationSent: '초대가 전송되었습니다', invitationSentTip: '초대가 전송되었으며, 그들은 Dify에 로그인하여 당신의 팀 데이터에 액세스할 수 있습니다.', invitationLink: '초대 링크', - failedinvitationEmails: '다음 사용자들은 성공적으로 초대되지 않았습니다', + failedInvitationEmails: '다음 사용자들은 성공적으로 초대되지 않았습니다', ok: '확인', removeFromTeam: '팀에서 제거', removeFromTeamTip: '팀 액세스가 제거됩니다', setAdmin: '관리자 설정', setMember: '일반 멤버 설정', setEditor: '편집자 설정', - disinvite: '초대 취소', + disInvite: '초대 취소', deleteMember: '멤버 삭제', you: '(나)', + datasetOperator: '지식 관리자', + setBuilder: '빌더로 설정', + builder: '건설자', + builderTip: '자신의 앱을 구축 및 편집할 수 있습니다.', + datasetOperatorTip: '기술 자료만 관리할 수 있습니다.', }, integrations: { connected: '연결됨', @@ -342,6 +349,22 @@ const translation = { quotaTip: '남은 무료 토큰 사용 가능', loadPresets: '프리셋 로드', parameters: '매개변수', + apiKey: 'API 키', + defaultConfig: '기본 구성', + providerManaged: '제공자 관리', + loadBalancing: '부하 분산Load balancing', + addConfig: '구성 추가', + apiKeyStatusNormal: 'APIKey 상태는 정상입니다.', + configLoadBalancing: 'Config 로드 밸런싱', + editConfig: '구성 편집', + loadBalancingHeadline: '로드 밸런싱', + modelHasBeenDeprecated: '이 모델은 더 이상 사용되지 않습니다', + loadBalancingDescription: '여러 자격 증명 세트로 부담을 줄입니다.', + upgradeForLoadBalancing: '로드 밸런싱을 사용하도록 계획을 업그레이드합니다.', + apiKeyRateLimit: '속도 제한에 도달했으며, {{seconds}}s 후에 사용할 수 있습니다.', + loadBalancingInfo: '기본적으로 부하 분산은 라운드 로빈 전략을 사용합니다. 속도 제한이 트리거되면 1분의 휴지 기간이 적용됩니다.', + loadBalancingLeastKeyWarning: '로드 밸런싱을 사용하려면 최소 2개의 키를 사용하도록 설정해야 합니다.', + providerManagedDescription: '모델 공급자가 제공하는 단일 자격 증명 집합을 사용합니다.', }, dataSource: { add: '데이터 소스 추가하기', @@ -365,6 +388,15 @@ const translation = { preview: '미리보기', }, }, + website: { + inactive: '게으른', + title: '웹 사이트', + configuredCrawlers: '구성된 크롤러', + with: '와', + active: '활동적인', + description: '웹 크롤러를 사용하여 웹 사이트에서 콘텐츠를 가져옵니다.', + }, + configure: '구성', }, plugin: { serpapi: { @@ -533,6 +565,10 @@ const translation = { created: '태그가 성공적으로 생성되었습니다', failed: '태그 생성에 실패했습니다', }, + errorMsg: { + urlError: 'URL은 http:// 또는 https:// 로 시작해야 합니다.', + fieldRequired: '{{field}}는 필수입니다.', + }, } export default translation diff --git a/web/i18n/ko-KR/dataset-creation.ts b/web/i18n/ko-KR/dataset-creation.ts index 3039f69d6d..e8851acd2f 100644 --- a/web/i18n/ko-KR/dataset-creation.ts +++ b/web/i18n/ko-KR/dataset-creation.ts @@ -45,11 +45,35 @@ const translation = { input: '지식 이름', placeholder: '입력하세요', nameNotEmpty: '이름은 비워둘 수 없습니다', - nameLengthInvaild: '이름은 1~40자여야 합니다', + nameLengthInvalid: '이름은 1~40자여야 합니다', cancelButton: '취소', confirmButton: '생성', failed: '생성에 실패했습니다', }, + website: { + firecrawlDocLink: 'https://docs.dify.ai/guides/knowledge-base/sync-from-website', + limit: '한계', + options: '옵션', + firecrawlDoc: 'Firecrawl 문서', + selectAll: '모두 선택', + maxDepth: '최대 수심', + includeOnlyPaths: '경로만 포함', + excludePaths: '경로 제외', + preview: '미리 보기', + run: '달리다', + fireCrawlNotConfigured: 'Firecrawl이 구성되지 않았습니다.', + firecrawlTitle: 'Firecrawl로 🔥웹 콘텐츠 추출', + configure: '구성', + resetAll: '모두 재설정', + crawlSubPage: '하위 페이지 크롤링', + exceptionErrorTitle: 'Firecrawl 작업을 실행하는 동안 예외가 발생했습니다.', + scrapTimeInfo: '{{time}}s 내에 총 {{total}} 페이지를 스크랩했습니다.', + unknownError: '알 수 없는 오류', + totalPageScraped: '스크랩한 총 페이지 수:', + fireCrawlNotConfiguredDescription: 'API 키로 Firecrawl을 구성하여 사용합니다.', + extractOnlyMainContent: '기본 콘텐츠만 추출합니다(머리글, 탐색, 바닥글 등 없음).', + maxDepthTooltip: '입력한 URL을 기준으로 크롤링할 최대 수준입니다. 깊이 0은 입력 된 url의 페이지를 긁어 내고, 깊이 1은 url과 enteredURL + one / 이후의 모든 것을 긁어 모으는 식입니다.', + }, }, stepTwo: { segmentation: '청크 설정', @@ -80,8 +104,8 @@ const translation = { QATitle: '질문과 답변 형식으로 세그먼트화', QATip: '이 옵션을 활성화하면 추가 토큰이 소비됩니다', QALanguage: '사용 언어', - emstimateCost: '예상 비용', - emstimateSegment: '예상 청크 수', + estimateCost: '예상 비용', + estimateSegment: '예상 청크 수', segmentCount: '청크', calculating: '계산 중...', fileSource: '문서 전처리', @@ -104,9 +128,11 @@ const translation = { previewSwitchTipStart: '현재 청크 미리보기는 텍스트 형식입니다. 질문과 답변 형식 미리보기로 전환하면', previewSwitchTipEnd: ' 추가 토큰이 소비됩니다', characters: '문자', - indexSettedTip: '인덱스 방식을 변경하려면,', - retrivalSettedTip: '인덱스 방식을 변경하려면,', + indexSettingTip: '인덱스 방식을 변경하려면,', + retrievalSettingTip: '인덱스 방식을 변경하려면,', datasetSettingLink: '지식 설정', + webpageUnit: '페이지', + websiteSource: '웹 사이트 전처리', }, stepThree: { creationTitle: '🎉 지식이 생성되었습니다', @@ -126,6 +152,11 @@ const translation = { modelButtonConfirm: '확인', modelButtonCancel: '취소', }, + firecrawl: { + getApiKeyLinkText: 'firecrawl.dev 에서 API 키 가져오기', + apiKeyPlaceholder: 'firecrawl.dev 의 API 키', + configFirecrawl: 'Firecrawl 구성 🔥', + }, } export default translation diff --git a/web/i18n/ko-KR/dataset-documents.ts b/web/i18n/ko-KR/dataset-documents.ts index 8e7db58a6d..22c0330134 100644 --- a/web/i18n/ko-KR/dataset-documents.ts +++ b/web/i18n/ko-KR/dataset-documents.ts @@ -13,6 +13,8 @@ const translation = { status: '상태', action: '동작', }, + name: '이름', + rename: '이름 바꾸기', }, action: { uploadFile: '새 파일 업로드', @@ -74,6 +76,7 @@ const translation = { error: '가져오기 오류', ok: '확인', }, + addUrl: 'URL 추가', }, metadata: { title: '메타데이터', diff --git a/web/i18n/ko-KR/dataset-settings.ts b/web/i18n/ko-KR/dataset-settings.ts index 7dac64fce5..ef451ee866 100644 --- a/web/i18n/ko-KR/dataset-settings.ts +++ b/web/i18n/ko-KR/dataset-settings.ts @@ -27,6 +27,8 @@ const translation = { longDescription: ' 검색 방법에 대한 자세한 내용은 언제든지 지식 설정에서 변경할 수 있습니다.', }, save: '저장', + permissionsInvitedMembers: '부분 팀 구성원', + me: '(당신)', }, } diff --git a/web/i18n/ko-KR/dataset.ts b/web/i18n/ko-KR/dataset.ts index 907a1f21b6..9fd0dd16a7 100644 --- a/web/i18n/ko-KR/dataset.ts +++ b/web/i18n/ko-KR/dataset.ts @@ -70,6 +70,7 @@ const translation = { nTo1RetrievalLegacy: 'N-대-1 검색은 9월부터 공식적으로 더 이상 사용되지 않습니다. 더 나은 결과를 얻으려면 최신 다중 경로 검색을 사용하는 것이 좋습니다.', nTo1RetrievalLegacyLink: '자세히 알아보기', nTo1RetrievalLegacyLinkText: 'N-대-1 검색은 9월에 공식적으로 더 이상 사용되지 않습니다.', + defaultRetrievalTip: '다중 경로 검색이 기본적으로 사용됩니다. 지식은 여러 기술 자료에서 검색된 다음 순위가 다시 매겨집니다.', } export default translation diff --git a/web/i18n/ko-KR/login.ts b/web/i18n/ko-KR/login.ts index 01d1f538fe..ceddeb7b9b 100644 --- a/web/i18n/ko-KR/login.ts +++ b/web/i18n/ko-KR/login.ts @@ -31,7 +31,7 @@ const translation = { pp: '개인정보 처리 방침', tosDesc: '가입함으로써 다음 내용에 동의하게 됩니다.', goToInit: '계정이 초기화되지 않았다면 초기화 페이지로 이동하세요.', - donthave: '계정이 없으신가요?', + dontHave: '계정이 없으신가요?', invalidInvitationCode: '유효하지 않은 초대 코드입니다.', accountAlreadyInited: '계정은 이미 초기화되었습니다.', forgotPassword: '비밀번호를 잊으셨나요?', @@ -53,6 +53,7 @@ const translation = { nameEmpty: '사용자 이름을 입력하세요.', passwordEmpty: '비밀번호를 입력하세요.', passwordInvalid: '비밀번호는 문자와 숫자를 포함하고 8자 이상이어야 합니다.', + passwordLengthInValid: '비밀번호는 8자 이상이어야 합니다.', }, license: { tip: 'Dify Community Edition을 시작하기 전에 GitHub의', @@ -68,6 +69,7 @@ const translation = { activated: '지금 로그인하세요', adminInitPassword: '관리자 초기화 비밀번호', validate: '확인', + sso: 'SSO로 계속하기', } export default translation diff --git a/web/i18n/ko-KR/share-app.ts b/web/i18n/ko-KR/share-app.ts index 9c0738b3d7..be2e34a5fc 100644 --- a/web/i18n/ko-KR/share-app.ts +++ b/web/i18n/ko-KR/share-app.ts @@ -2,7 +2,7 @@ const translation = { common: { welcome: '', appUnavailable: '앱을 사용할 수 없습니다', - appUnkonwError: '앱을 사용할 수 없습니다', + appUnknownError: '앱을 사용할 수 없습니다', }, chat: { newChat: '새 채팅', @@ -10,7 +10,7 @@ const translation = { unpinnedTitle: '채팅', newChatDefaultName: '새 대화', resetChat: '대화 재설정', - powerBy: 'Powered by', + poweredBy: 'Powered by', prompt: '프롬프트', privatePromptConfigTitle: '채팅 설정', publicPromptConfigTitle: '초기 프롬프트', diff --git a/web/i18n/ko-KR/workflow.ts b/web/i18n/ko-KR/workflow.ts index d225d97141..8fed0e0417 100644 --- a/web/i18n/ko-KR/workflow.ts +++ b/web/i18n/ko-KR/workflow.ts @@ -36,7 +36,7 @@ const translation = { searchVar: '변수 검색', variableNamePlaceholder: '변수 이름', setVarValuePlaceholder: '변수 값 설정', - needConnecttip: '이 단계는 아무것도 연결되어 있지 않습니다', + needConnectTip: '이 단계는 아무것도 연결되어 있지 않습니다', maxTreeDepth: '분기당 최대 {{depth}} 노드 제한', needEndNode: '종료 블록을 추가해야 합니다', needAnswerNode: '답변 블록을 추가해야 합니다', @@ -69,6 +69,30 @@ const translation = { manageInTools: '도구에서 관리', workflowAsToolTip: '워크플로우 업데이트 후 도구 재구성이 필요합니다.', viewDetailInTracingPanel: '세부 정보 보기', + importDSL: 'DSL 가져오기', + importFailure: '가져오기 실패', + chooseDSL: 'DSL(yml) 파일 선택', + backupCurrentDraft: '현재 초안 백업', + overwriteAndImport: '덮어쓰기 및 가져오기', + importSuccess: '가져오기 성공', + syncingData: '단 몇 초 만에 데이터를 동기화할 수 있습니다.', + importDSLTip: '현재 초안을 덮어씁니다. 가져오기 전에 워크플로를 백업으로 내보냅니다.', + parallelTip: { + click: { + title: '클릭', + desc: '추가', + }, + drag: { + title: '드래그', + desc: '연결 방법', + }, + depthLimit: '평행 중첩 레이어 {{num}}개 레이어의 제한', + limit: '병렬 처리는 {{num}}개의 분기로 제한됩니다.', + }, + parallelRun: '병렬 실행', + disconnect: '분리하다', + jumpToNode: '이 노드로 이동', + addParallelNode: '병렬 노드 추가', }, env: { envPanelTitle: '환경 변수', @@ -178,6 +202,7 @@ const translation = { 'transform': '변환', 'utilities': '유틸리티', 'noResult': '일치하는 결과 없음', + 'searchTool': '검색 도구', }, blocks: { 'start': '시작', @@ -403,10 +428,12 @@ const translation = { 'not empty': '비어 있지 않음', 'null': 'null임', 'not null': 'null이 아님', + 'regex match': '정규식 일치', }, enterValue: '값 입력', addCondition: '조건 추가', conditionNotSetup: '조건이 설정되지 않음', + selectVariable: '변수 선택...', }, variableAssigner: { title: '변수 할당', @@ -502,6 +529,25 @@ const translation = { iteration_other: '{{count}} 반복', currentIteration: '현재 반복', }, + note: { + editor: { + medium: '보통', + showAuthor: '작성자 표시', + link: '링크', + unlink: '해제', + small: '작다', + large: '큰', + placeholder: '메모 쓰기...', + bold: '대담한', + enterUrl: 'URL 입력...', + openLink: '열다', + italic: '이탤릭체', + invalidUrl: '잘못된 URL', + strikethrough: '취소선', + bulletList: '글머리 기호 목록', + }, + addNote: '메모 추가', + }, }, tracing: { stopBy: '{{user}}에 의해 중지됨', diff --git a/web/i18n/language.ts b/web/i18n/language.ts index e65d34d0ff..fde69328bd 100644 --- a/web/i18n/language.ts +++ b/web/i18n/language.ts @@ -49,6 +49,7 @@ export const NOTICE_I18N = { ko_KR: '중요 공지', pl_PL: 'Ważne ogłoszenie', uk_UA: 'Важливе повідомлення', + ru_RU: 'Важное Уведомление', vi_VN: 'Thông báo quan trọng', it_IT: 'Avviso Importante', fa_IR: 'هشدار مهم', @@ -74,6 +75,8 @@ export const NOTICE_I18N = { 'Nasz system będzie niedostępny od 19:00 do 24:00 UTC 28 sierpnia w celu aktualizacji. W przypadku pytań prosimy o kontakt z naszym zespołem wsparcia (support@dify.ai). Doceniamy Twoją cierpliwość.', uk_UA: 'Наша система буде недоступна з 19:00 до 24:00 UTC 28 серпня для оновлення. Якщо у вас виникнуть запитання, будь ласка, зв’яжіться з нашою службою підтримки (support@dify.ai). Дякуємо за терпіння.', + ru_RU: + 'Наша система будет недоступна с 19:00 до 24:00 UTC 28 августа для обновления. По вопросам, пожалуйста, обращайтесь в нашу службу поддержки (support@dify.ai). Спасибо за ваше терпение', vi_VN: 'Hệ thống của chúng tôi sẽ ngừng hoạt động từ 19:00 đến 24:00 UTC vào ngày 28 tháng 8 để nâng cấp. Nếu có thắc mắc, vui lòng liên hệ với nhóm hỗ trợ của chúng tôi (support@dify.ai). Chúng tôi đánh giá cao sự kiên nhẫn của bạn.', tr_TR: diff --git a/web/i18n/languages.json b/web/i18n/languages.json index d819e49089..a70963b067 100644 --- a/web/i18n/languages.json +++ b/web/i18n/languages.json @@ -68,7 +68,7 @@ "name": "Русский (Россия)", "prompt_name": "Russian", "example": " Привет, Dify!", - "supported": false + "supported": true }, { "value": "it-IT", diff --git a/web/i18n/pl-PL/app-api.ts b/web/i18n/pl-PL/app-api.ts index 46f9cbb454..05cad56fe2 100644 --- a/web/i18n/pl-PL/app-api.ts +++ b/web/i18n/pl-PL/app-api.ts @@ -10,7 +10,7 @@ const translation = { pause: 'Pauza', playing: 'Gra', loading: 'Ładowanie', - merMaind: { + merMaid: { rerender: 'Przerób Renderowanie', }, never: 'Nigdy', diff --git a/web/i18n/pl-PL/app-debug.ts b/web/i18n/pl-PL/app-debug.ts index afb412f264..7cf6c77cb4 100644 --- a/web/i18n/pl-PL/app-debug.ts +++ b/web/i18n/pl-PL/app-debug.ts @@ -289,7 +289,7 @@ const translation = { historyNoBeEmpty: 'Historia konwersacji musi być ustawiona w monicie', queryNoBeEmpty: 'Zapytanie musi być ustawione w monicie', }, - variableConig: { + variableConfig: { 'addModalTitle': 'Dodaj Pole Wejściowe', 'editModalTitle': 'Edytuj Pole Wejściowe', 'description': 'Ustawienia dla zmiennej {{varName}}', diff --git a/web/i18n/pl-PL/app-overview.ts b/web/i18n/pl-PL/app-overview.ts index 95e8aadb70..9f8bcd34dc 100644 --- a/web/i18n/pl-PL/app-overview.ts +++ b/web/i18n/pl-PL/app-overview.ts @@ -52,6 +52,8 @@ const translation = { title: 'Kroki przepływu pracy', show: 'Pokaż', hide: 'Ukryj', + subTitle: 'Szczegóły przepływu pracy', + showDesc: 'Pokazywanie lub ukrywanie szczegółów przepływu pracy w aplikacji internetowej', }, chatColorTheme: 'Motyw kolorystyczny czatu', chatColorThemeDesc: 'Ustaw motyw kolorystyczny czatu', @@ -69,6 +71,12 @@ const translation = { customDisclaimerPlaceholder: 'Wprowadź oświadczenie o ochronie danych', customDisclaimerTip: 'Niestandardowy tekst oświadczenia będzie wyświetlany po stronie klienta, dostarczając dodatkowych informacji o aplikacji.', }, + sso: { + tooltip: 'Skontaktuj się z administratorem, aby włączyć logowanie jednokrotne w aplikacji internetowej', + title: 'Logowanie jednokrotne w aplikacji internetowej', + label: 'Uwierzytelnianie logowania jednokrotnego', + description: 'Wszyscy użytkownicy muszą zalogować się za pomocą logowania jednokrotnego przed użyciem aplikacji internetowej', + }, }, embedded: { entry: 'Osadzone', @@ -130,8 +138,11 @@ const translation = { tokenPS: 'Tokeny/s', totalMessages: { title: 'Łączna liczba wiadomości', - explanation: - 'Dzienna liczba interakcji z AI; inżynieria i debugowanie monitów wykluczone.', + explanation: 'Liczba dziennych interakcji z AI.', + }, + totalConversations: { + title: 'Całkowita liczba rozmów', + explanation: 'Liczba dziennych rozmów z AI; inżynieria/debugowanie promptów wykluczone.', }, activeUsers: { title: 'Aktywni użytkownicy', diff --git a/web/i18n/pl-PL/app.ts b/web/i18n/pl-PL/app.ts index 6a47d43798..e672b7cd4f 100644 --- a/web/i18n/pl-PL/app.ts +++ b/web/i18n/pl-PL/app.ts @@ -129,7 +129,17 @@ const translation = { removeConfirmTitle: 'Usunąć konfigurację {{key}}?', removeConfirmContent: 'Obecna konfiguracja jest w użyciu, jej usunięcie wyłączy funkcję Śledzenia.', }, + view: 'Widok', }, + answerIcon: { + description: 'Czy w aplikacji udostępnionej ma być używana ikona aplikacji internetowej do zamiany 🤖.', + title: 'Użyj ikony WebApp, aby zastąpić 🤖', + descriptionInExplore: 'Czy używać ikony aplikacji internetowej do zastępowania 🤖 w Eksploruj', + }, + importFromDSL: 'Importowanie z DSL', + importFromDSLUrl: 'Z adresu URL', + importFromDSLFile: 'Z pliku DSL', + importFromDSLUrlPlaceholder: 'Wklej tutaj link DSL', } export default translation diff --git a/web/i18n/pl-PL/billing.ts b/web/i18n/pl-PL/billing.ts index 40ddc1f732..cff567e162 100644 --- a/web/i18n/pl-PL/billing.ts +++ b/web/i18n/pl-PL/billing.ts @@ -65,6 +65,8 @@ const translation = { bulkUpload: 'Masowe przesyłanie dokumentów', agentMode: 'Tryb agenta', workflow: 'Przepływ pracy', + llmLoadingBalancing: 'Równoważenie obciążenia LLM', + llmLoadingBalancingTooltip: 'Dodaj wiele kluczy API do modeli, skutecznie omijając limity szybkości interfejsu API.', }, comingSoon: 'Wkrótce dostępne', member: 'Członek', @@ -83,6 +85,7 @@ const translation = { 'Odnosi się do liczby wywołań API wykorzystujących tylko zdolności przetwarzania bazy wiedzy Dify.', receiptInfo: 'Tylko właściciel zespołu i administrator zespołu mogą subskrybować i przeglądać informacje o rozliczeniach', + annotationQuota: 'Przydział adnotacji', }, plans: { sandbox: { diff --git a/web/i18n/pl-PL/common.ts b/web/i18n/pl-PL/common.ts index 1f41abe154..91f5fb2899 100644 --- a/web/i18n/pl-PL/common.ts +++ b/web/i18n/pl-PL/common.ts @@ -37,6 +37,7 @@ const translation = { params: 'Parametry', duplicate: 'Duplikuj', rename: 'Zmień nazwę', + audioSourceUnavailable: 'AudioSource jest niedostępny', }, placeholder: { input: 'Proszę wprowadzić', @@ -133,7 +134,8 @@ const translation = { workspace: 'Przestrzeń robocza', createWorkspace: 'Utwórz przestrzeń roboczą', helpCenter: 'Pomoc', - roadmapAndFeedback: 'Opinie', + communityFeedback: 'Opinie', + roadmap: 'Plan działania', community: 'Społeczność', about: 'O', logout: 'Wyloguj się', @@ -198,16 +200,21 @@ const translation = { invitationSentTip: 'Zaproszenie zostało wysłane, a oni mogą zalogować się do Dify, aby uzyskać dostęp do danych Twojego zespołu.', invitationLink: 'Link zaproszenia', - failedinvitationEmails: 'Poniższe osoby nie zostały pomyślnie zaproszone', + failedInvitationEmails: 'Poniższe osoby nie zostały pomyślnie zaproszone', ok: 'OK', removeFromTeam: 'Usuń z zespołu', removeFromTeamTip: 'Usunie dostęp do zespołu', setAdmin: 'Ustaw jako administratora', setMember: 'Ustaw jako zwykłego członka', setEditor: 'Ustaw jako edytora', - disinvite: 'Anuluj zaproszenie', + disInvite: 'Anuluj zaproszenie', deleteMember: 'Usuń członka', you: '(Ty)', + datasetOperatorTip: 'Może zarządzać tylko bazą wiedzy', + setBuilder: 'Ustaw jako budowniczego', + builder: 'Budowniczy', + builderTip: 'Może tworzyć i edytować własne aplikacje', + datasetOperator: 'Wiedza Admin', }, integrations: { connected: 'Połączony', @@ -359,6 +366,22 @@ const translation = { quotaTip: 'Pozostałe dostępne darmowe tokeny', loadPresets: 'Załaduj ustawienia wstępne', parameters: 'PARAMETRY', + apiKey: 'KLUCZ-API', + loadBalancing: 'Równoważenie obciążenia', + defaultConfig: 'Domyślna konfiguracja', + providerManagedDescription: 'Użyj pojedynczego zestawu poświadczeń dostarczonych przez dostawcę modelu.', + loadBalancingHeadline: 'Równoważenie obciążenia', + modelHasBeenDeprecated: 'Ten model jest przestarzały', + loadBalancingDescription: 'Zmniejsz presję dzięki wielu zestawom poświadczeń.', + providerManaged: 'Zarządzany przez dostawcę', + upgradeForLoadBalancing: 'Uaktualnij swój plan, aby włączyć równoważenie obciążenia.', + apiKeyStatusNormal: 'Stan APIKey jest normalny', + loadBalancingLeastKeyWarning: 'Aby włączyć równoważenie obciążenia, muszą być włączone co najmniej 2 klucze.', + loadBalancingInfo: 'Domyślnie równoważenie obciążenia używa strategii działania okrężnego. Jeśli zostanie uruchomione ograniczenie szybkości, zostanie zastosowany 1-minutowy okres odnowienia.', + configLoadBalancing: 'Równoważenie obciążenia konfiguracji', + editConfig: 'Edytuj konfigurację', + addConfig: 'Dodaj konfigurację', + apiKeyRateLimit: 'Osiągnięto limit szybkości, dostępny po {{sekund}}s', }, dataSource: { add: 'Dodaj źródło danych', @@ -382,6 +405,15 @@ const translation = { preview: 'PODGLĄD', }, }, + website: { + active: 'Aktywny', + with: 'Z', + title: 'Strona internetowa', + description: 'Importuj zawartość ze stron internetowych za pomocą robota indeksującego.', + configuredCrawlers: 'Skonfigurowane roboty indeksujące', + inactive: 'Nieaktywny', + }, + configure: 'Konfigurować', }, plugin: { serpapi: { @@ -555,6 +587,10 @@ const translation = { created: 'Tag został pomyślnie utworzony', failed: 'Nie udało się utworzyć tagu', }, + errorMsg: { + fieldRequired: '{{field}} jest wymagane', + urlError: 'Adres URL powinien zaczynać się od http:// lub https://', + }, } export default translation diff --git a/web/i18n/pl-PL/dataset-creation.ts b/web/i18n/pl-PL/dataset-creation.ts index 1b12e51b05..64e50c6b33 100644 --- a/web/i18n/pl-PL/dataset-creation.ts +++ b/web/i18n/pl-PL/dataset-creation.ts @@ -46,11 +46,35 @@ const translation = { input: 'Nazwa Wiedzy', placeholder: 'Proszę wpisz', nameNotEmpty: 'Nazwa nie może być pusta', - nameLengthInvaild: 'Nazwa musi zawierać od 1 do 40 znaków', + nameLengthInvalid: 'Nazwa musi zawierać od 1 do 40 znaków', cancelButton: 'Anuluj', confirmButton: 'Utwórz', failed: 'Utworzenie nie powiodło się', }, + website: { + limit: 'Ograniczać', + firecrawlDocLink: 'https://docs.dify.ai/guides/knowledge-base/sync-from-website', + firecrawlDoc: 'Dokumentacja Firecrawl', + unknownError: 'Nieznany błąd', + fireCrawlNotConfiguredDescription: 'Skonfiguruj Firecrawl z kluczem API, aby z niego korzystać.', + run: 'Biegać', + configure: 'Konfigurować', + resetAll: 'Zresetuj wszystko', + preview: 'Prapremiera', + exceptionErrorTitle: 'Wystąpił wyjątek podczas uruchamiania zadania Firecrawl:', + maxDepth: 'Maksymalna głębokość', + crawlSubPage: 'Przeszukiwanie podstron', + options: 'Opcje', + scrapTimeInfo: 'Zeskrobano {{total}} stron w sumie w ciągu {{time}}s', + totalPageScraped: 'Łączna liczba zeskrobanych stron:', + extractOnlyMainContent: 'Wyodrębnij tylko główną zawartość (bez nagłówków, nawigacji, stopek itp.)', + excludePaths: 'Wykluczanie ścieżek', + includeOnlyPaths: 'Uwzględnij tylko ścieżki', + selectAll: 'Zaznacz wszystko', + firecrawlTitle: 'Wyodrębnij zawartość internetową za pomocą 🔥Firecrawl', + fireCrawlNotConfigured: 'Firecrawl nie jest skonfigurowany', + maxDepthTooltip: 'Maksymalna głębokość przeszukiwania względem wprowadzonego adresu URL. Głębokość 0 po prostu zeskrobuje stronę z wprowadzonego adresu URL, głębokość 1 zeskrobuje adres URL i wszystko po wprowadzeniuURL+ jeden / i tak dalej.', + }, }, stepTwo: { segmentation: 'Ustawienia bloków tekstu', @@ -88,8 +112,8 @@ const translation = { QATitle: 'Segmentacja w formacie pytania i odpowiedzi', QATip: 'Włączenie tej opcji spowoduje zużycie większej liczby tokenów', QALanguage: 'Segmentacja przy użyciu', - emstimateCost: 'Oszacowanie', - emstimateSegment: 'Oszacowane bloki', + estimateCost: 'Oszacowanie', + estimateSegment: 'Oszacowane bloki', segmentCount: 'bloki', calculating: 'Obliczanie...', fileSource: 'Przetwarzaj dokumenty', @@ -117,9 +141,11 @@ const translation = { 'Aktulany podgląd bloku jest w formacie tekstu, przełączenie na podgląd w formacie pytania i odpowiedzi spowoduje', previewSwitchTipEnd: ' dodatkowe zużycie tokenów', characters: 'znaki', - indexSettedTip: 'Aby zmienić metodę indeksowania, przejdź do ', - retrivalSettedTip: 'Aby zmienić metodę indeksowania, przejdź do ', + indexSettingTip: 'Aby zmienić metodę indeksowania, przejdź do ', + retrievalSettingTip: 'Aby zmienić metodę indeksowania, przejdź do ', datasetSettingLink: 'ustawień Wiedzy.', + webpageUnit: 'Stron', + websiteSource: 'Witryna internetowa przetwarzania wstępnego', }, stepThree: { creationTitle: '🎉 Utworzono Wiedzę', @@ -141,6 +167,11 @@ const translation = { modelButtonConfirm: 'Potwierdź', modelButtonCancel: 'Anuluj', }, + firecrawl: { + apiKeyPlaceholder: 'Klucz API od firecrawl.dev', + configFirecrawl: 'Konfiguracja 🔥Firecrawla', + getApiKeyLinkText: 'Pobierz klucz API z firecrawl.dev', + }, } export default translation diff --git a/web/i18n/pl-PL/dataset-documents.ts b/web/i18n/pl-PL/dataset-documents.ts index f8617a29cf..7152c3e9d6 100644 --- a/web/i18n/pl-PL/dataset-documents.ts +++ b/web/i18n/pl-PL/dataset-documents.ts @@ -13,6 +13,8 @@ const translation = { status: 'STATUS', action: 'AKCJA', }, + name: 'Nazwa', + rename: 'Przemianować', }, action: { uploadFile: 'Wgraj nowy plik', @@ -75,6 +77,7 @@ const translation = { error: 'Błąd importu', ok: 'OK', }, + addUrl: 'Dodaj adres URL', }, metadata: { title: 'Metadane', diff --git a/web/i18n/pl-PL/dataset-settings.ts b/web/i18n/pl-PL/dataset-settings.ts index b7d1738c54..3f069ab9b0 100644 --- a/web/i18n/pl-PL/dataset-settings.ts +++ b/web/i18n/pl-PL/dataset-settings.ts @@ -32,6 +32,8 @@ const translation = { ' dotyczące metody doboru, możesz to zmienić w dowolnym momencie w ustawieniach wiedzy.', }, save: 'Zapisz', + permissionsInvitedMembers: 'Częściowi członkowie zespołu', + me: '(Ty)', }, } diff --git a/web/i18n/pl-PL/dataset.ts b/web/i18n/pl-PL/dataset.ts index 14de4eaf40..efd8b75c95 100644 --- a/web/i18n/pl-PL/dataset.ts +++ b/web/i18n/pl-PL/dataset.ts @@ -77,6 +77,7 @@ const translation = { nTo1RetrievalLegacy: 'Wyszukiwanie N-do-1 zostanie oficjalnie wycofane od września. Zaleca się korzystanie z najnowszego wyszukiwania wielościeżkowego, aby uzyskać lepsze wyniki.', nTo1RetrievalLegacyLink: 'Dowiedz się więcej', nTo1RetrievalLegacyLinkText: 'Wyszukiwanie N-do-1 zostanie oficjalnie wycofane we wrześniu.', + defaultRetrievalTip: 'Pobieranie wielu ścieżek jest używane domyślnie. Wiedza jest pobierana z wielu baz wiedzy, a następnie ponownie klasyfikowana.', } export default translation diff --git a/web/i18n/pl-PL/login.ts b/web/i18n/pl-PL/login.ts index 075b79b913..be9e74f37d 100644 --- a/web/i18n/pl-PL/login.ts +++ b/web/i18n/pl-PL/login.ts @@ -36,7 +36,7 @@ const translation = { pp: 'Polityka prywatności', tosDesc: 'Założeniem konta zgadzasz się z naszymi', goToInit: 'Jeśli nie zainicjowałeś konta, przejdź do strony inicjalizacji', - donthave: 'Nie masz?', + dontHave: 'Nie masz?', invalidInvitationCode: 'Niewłaściwy kod zaproszenia', accountAlreadyInited: 'Konto już zainicjowane', forgotPassword: 'Zapomniałeś hasła?', @@ -59,6 +59,7 @@ const translation = { passwordEmpty: 'Hasło jest wymagane', passwordInvalid: 'Hasło musi zawierać litery i cyfry, a jego długość musi być większa niż 8', + passwordLengthInValid: 'Hasło musi składać się z co najmniej 8 znaków', }, license: { tip: 'Przed rozpoczęciem wersji społecznościowej Dify, przeczytaj GitHub', diff --git a/web/i18n/pl-PL/share-app.ts b/web/i18n/pl-PL/share-app.ts index eb5573c1a1..90b6ca1929 100644 --- a/web/i18n/pl-PL/share-app.ts +++ b/web/i18n/pl-PL/share-app.ts @@ -2,7 +2,7 @@ const translation = { common: { welcome: '', appUnavailable: 'Aplikacja jest niedostępna', - appUnkonwError: 'Aplikacja jest niedostępna', + appUnknownError: 'Aplikacja jest niedostępna', }, chat: { newChat: 'Nowy czat', @@ -10,7 +10,7 @@ const translation = { unpinnedTitle: 'Czaty', newChatDefaultName: 'Nowa rozmowa', resetChat: 'Resetuj rozmowę', - powerBy: 'Działany przez', + poweredBy: 'Działany przez', prompt: 'Podpowiedź', privatePromptConfigTitle: 'Ustawienia rozmowy', publicPromptConfigTitle: 'Początkowa podpowiedź', diff --git a/web/i18n/pl-PL/tools.ts b/web/i18n/pl-PL/tools.ts index f5d3226e8c..f34825b049 100644 --- a/web/i18n/pl-PL/tools.ts +++ b/web/i18n/pl-PL/tools.ts @@ -5,6 +5,7 @@ const translation = { all: 'Wszystkie', builtIn: 'Wbudowane', custom: 'Niestandardowe', + workflow: 'Przepływ pracy', }, contribute: { line1: 'Interesuje mnie ', @@ -77,6 +78,27 @@ const translation = { customDisclaimerPlaceholder: 'Proszę wprowadzić oświadczenie niestandardowe', deleteToolConfirmTitle: 'Skasuj ten przyrząd?', deleteToolConfirmContent: 'Usunięcie narzędzia jest nieodwracalne. Użytkownicy nie będą mieli już dostępu do Twojego narzędzia.', + toolInput: { + name: 'Nazwa', + required: 'Wymagane', + descriptionPlaceholder: 'Opis znaczenia parametru', + methodParameter: 'Parametr', + label: 'Tagi', + methodSetting: 'Ustawienie', + description: 'Opis', + method: 'Metoda', + methodParameterTip: 'LLM wypełnia się podczas wnioskowania', + labelPlaceholder: 'Wybierz tagi (opcjonalnie)', + methodSettingTip: 'Użytkownik wypełnia konfigurację narzędzia', + title: 'Wprowadzanie narzędzi', + }, + nameForToolCall: 'Nazwa wywołania narzędzia', + description: 'Opis', + descriptionPlaceholder: 'Krótki opis przeznaczenia narzędzia, np. zmierz temperaturę dla konkretnej lokalizacji.', + nameForToolCallTip: 'Obsługuje tylko cyfry, litery i podkreślenia.', + nameForToolCallPlaceHolder: 'Służy do rozpoznawania maszyn, takich jak getCurrentWeather, list_pets', + confirmTip: 'Będzie to miało wpływ na aplikacje korzystające z tego narzędzia', + confirmTitle: 'Potwierdź, aby zapisać ?', }, test: { title: 'Test', @@ -118,6 +140,18 @@ const translation = { toolRemoved: 'Narzędzie usunięte', notAuthorized: 'Narzędzie nieautoryzowane', howToGet: 'Jak uzyskać', + addToolModal: { + manageInTools: 'Zarządzanie w Narzędziach', + added: 'Dodane', + type: 'typ', + category: 'kategoria', + add: 'dodawać', + emptyTitle: 'Brak dostępnego narzędzia do przepływu pracy', + emptyTip: 'Przejdź do "Przepływ pracy -> Opublikuj jako narzędzie"', + }, + openInStudio: 'Otwieranie w Studio', + customToolTip: 'Dowiedz się więcej o niestandardowych narzędziach Dify', + toolNameUsageTip: 'Nazwa wywołania narzędzia do wnioskowania i podpowiadania agentowi', } export default translation diff --git a/web/i18n/pl-PL/workflow.ts b/web/i18n/pl-PL/workflow.ts index 62defed019..de05ee7169 100644 --- a/web/i18n/pl-PL/workflow.ts +++ b/web/i18n/pl-PL/workflow.ts @@ -36,7 +36,7 @@ const translation = { searchVar: 'Szukaj zmiennej', variableNamePlaceholder: 'Nazwa zmiennej', setVarValuePlaceholder: 'Ustaw zmienną', - needConnecttip: 'Ten krok nie jest połączony z niczym', + needConnectTip: 'Ten krok nie jest połączony z niczym', maxTreeDepth: 'Maksymalny limit {{depth}} węzłów na gałąź', needEndNode: 'Należy dodać blok końcowy', needAnswerNode: 'Należy dodać blok odpowiedzi', @@ -69,6 +69,30 @@ const translation = { manageInTools: 'Zarządzaj w narzędziach', workflowAsToolTip: 'Wymagana rekonfiguracja narzędzia po aktualizacji przepływu pracy.', viewDetailInTracingPanel: 'Zobacz szczegóły', + importDSLTip: 'Bieżąca wersja robocza zostanie nadpisana. Eksportuj przepływ pracy jako kopię zapasową przed zaimportowaniem.', + syncingData: 'Synchronizacja danych w zaledwie kilka sekund.', + importSuccess: 'Import powodzenie', + importDSL: 'Importowanie DSL', + overwriteAndImport: 'Nadpisywanie i importowanie', + chooseDSL: 'Wybierz plik DSL(yml)', + backupCurrentDraft: 'Utwórz kopię zapasową bieżącej wersji roboczej', + importFailure: 'Niepowodzenie importu', + parallelTip: { + click: { + title: 'Klikać', + desc: ', aby dodać', + }, + drag: { + title: 'Przeciągnąć', + desc: 'aby się połączyć', + }, + limit: 'Równoległość jest ograniczona do gałęzi {{num}}.', + depthLimit: 'Limit warstw zagnieżdżania równoległego dla warstw {{num}}', + }, + parallelRun: 'Bieg równoległy', + jumpToNode: 'Przejdź do tego węzła', + disconnect: 'Odłączyć', + addParallelNode: 'Dodaj węzeł równoległy', }, env: { envPanelTitle: 'Zmienne Środowiskowe', @@ -178,6 +202,7 @@ const translation = { 'transform': 'Transformacja', 'utilities': 'Narzędzia pomocnicze', 'noResult': 'Nie znaleziono dopasowań', + 'searchTool': 'Wyszukiwarka', }, blocks: { 'start': 'Start', @@ -403,10 +428,12 @@ const translation = { 'not empty': 'nie jest pusty', 'null': 'jest null', 'not null': 'nie jest null', + 'regex match': 'Dopasowanie wyrażenia regularnego', }, enterValue: 'Wpisz wartość', addCondition: 'Dodaj warunek', conditionNotSetup: 'Warunek NIE został ustawiony', + selectVariable: 'Wybierz zmienną...', }, variableAssigner: { title: 'Przypisz zmienne', @@ -502,6 +529,25 @@ const translation = { iteration_other: '{{count}} Iteracje', currentIteration: 'Bieżąca iteracja', }, + note: { + editor: { + link: 'Łącze', + medium: 'Średni', + small: 'Mały', + italic: 'Kursywa', + enterUrl: 'Wpisz adres URL...', + showAuthor: 'Pokaż autora', + bold: 'Śmiały', + unlink: 'Odłączyć', + bulletList: 'Lista punktowana', + large: 'Duży', + openLink: 'Otwierać', + strikethrough: 'Przekreślenie', + invalidUrl: 'Nieprawidłowy adres URL', + placeholder: 'Napisz swoją notatkę...', + }, + addNote: 'Dodaj notatkę', + }, }, tracing: { stopBy: 'Zatrzymane przez {{user}}', diff --git a/web/i18n/pt-BR/app-api.ts b/web/i18n/pt-BR/app-api.ts index 95a9c84a3e..7bbd25695d 100644 --- a/web/i18n/pt-BR/app-api.ts +++ b/web/i18n/pt-BR/app-api.ts @@ -6,7 +6,7 @@ const translation = { ok: 'Em Serviço', copy: 'Copiar', copied: 'Copiado', - merMaind: { + merMaid: { rerender: 'Refazer Rerender', }, never: 'Nunca', @@ -74,6 +74,10 @@ const translation = { pathParams: 'Parâmetros de Caminho', query: 'Consulta', }, + play: 'Brincar', + loading: 'Carregamento', + pause: 'Pausa', + playing: 'Jogar', } export default translation diff --git a/web/i18n/pt-BR/app-debug.ts b/web/i18n/pt-BR/app-debug.ts index 9605bd5d95..df4312f887 100644 --- a/web/i18n/pt-BR/app-debug.ts +++ b/web/i18n/pt-BR/app-debug.ts @@ -265,7 +265,7 @@ const translation = { historyNoBeEmpty: 'O histórico da conversa deve ser definido na solicitação', queryNoBeEmpty: 'A consulta deve ser definida na solicitação', }, - variableConig: { + variableConfig: { 'addModalTitle': 'Adicionar Campo de Entrada', 'editModalTitle': 'Editar Campo de Entrada', 'description': 'Configuração para a variável {{varName}}', diff --git a/web/i18n/pt-BR/app-log.ts b/web/i18n/pt-BR/app-log.ts index a61d4204d4..96e604c49e 100644 --- a/web/i18n/pt-BR/app-log.ts +++ b/web/i18n/pt-BR/app-log.ts @@ -90,6 +90,13 @@ const translation = { iteração: 'Iteração', finalProcessing: 'Processamento Final', }, + agentLogDetail: { + iterations: 'Iterações', + agentMode: 'Modo Agente', + finalProcessing: 'Processamento final', + iteration: 'Iteração', + toolUsed: 'Ferramenta usada', + }, } export default translation diff --git a/web/i18n/pt-BR/app-overview.ts b/web/i18n/pt-BR/app-overview.ts index d288e331b3..a717ca259c 100644 --- a/web/i18n/pt-BR/app-overview.ts +++ b/web/i18n/pt-BR/app-overview.ts @@ -48,6 +48,8 @@ const translation = { title: 'Etapas do fluxo de trabalho', show: 'Mostrar', hide: 'Ocultar', + subTitle: 'Detalhes do fluxo de trabalho', + showDesc: 'Mostrar ou ocultar detalhes do fluxo de trabalho no WebApp', }, chatColorTheme: 'Tema de cor do chatbot', chatColorThemeDesc: 'Defina o tema de cor do chatbot', @@ -64,6 +66,12 @@ const translation = { customDisclaimerPlaceholder: 'Insira o texto do aviso legal', customDisclaimerTip: 'O texto do aviso legal personalizado será exibido no lado do cliente, fornecendo informações adicionais sobre o aplicativo', }, + sso: { + tooltip: 'Entre em contato com o administrador para habilitar o SSO do WebApp', + label: 'Autenticação SSO', + title: 'WebApp SSO', + description: 'Todos os usuários devem fazer login com SSO antes de usar o WebApp', + }, }, embedded: { entry: 'Embutido', @@ -119,7 +127,11 @@ const translation = { tokenPS: 'Token/s', totalMessages: { title: 'Total de Mensagens', - explanation: 'Contagem diária de interações AI; engenharia/de depuração excluída.', + explanation: 'Contagem diária de interações com IA.', + }, + totalConversations: { + title: 'Total de Conversas', + explanation: 'Contagem diária de conversas com IA; engenharia/depuração de prompts excluída.', }, activeUsers: { title: 'Usuários Ativos', diff --git a/web/i18n/pt-BR/app.ts b/web/i18n/pt-BR/app.ts index ef9122b86c..cf0c987eb2 100644 --- a/web/i18n/pt-BR/app.ts +++ b/web/i18n/pt-BR/app.ts @@ -122,7 +122,17 @@ const translation = { removeConfirmTitle: 'Remover configuração de {{key}}?', removeConfirmContent: 'A configuração atual está em uso, removê-la desligará o recurso de Rastreamento.', }, + view: 'Vista', }, + answerIcon: { + descriptionInExplore: 'Se o ícone do WebApp deve ser usado para substituir 🤖 no Explore', + description: 'Se o ícone WebApp deve ser usado para substituir 🤖 no aplicativo compartilhado', + title: 'Use o ícone do WebApp para substituir 🤖', + }, + importFromDSLUrlPlaceholder: 'Cole o link DSL aqui', + importFromDSLUrl: 'Do URL', + importFromDSLFile: 'Do arquivo DSL', + importFromDSL: 'Importar de DSL', } export default translation diff --git a/web/i18n/pt-BR/billing.ts b/web/i18n/pt-BR/billing.ts index 1ac082ec38..0a7a964376 100644 --- a/web/i18n/pt-BR/billing.ts +++ b/web/i18n/pt-BR/billing.ts @@ -1,108 +1,118 @@ const translation = { - currentPlan: 'Current Plan', + currentPlan: 'Plano Atual', upgradeBtn: { - plain: 'Upgrade Plan', - encourage: 'Upgrade Now', + plain: 'Fazer Upgrade do Plano', + encourage: 'Fazer Upgrade Agora', encourageShort: 'Upgrade', }, - viewBilling: 'View billing information', - buyPermissionDeniedTip: 'Please contact your enterprise administrator to subscribe', + viewBilling: 'Ver informações de cobrança', + buyPermissionDeniedTip: 'Por favor, entre em contato com o administrador da sua empresa para assinar', plansCommon: { - title: 'Choose a plan that’s right for you', - yearlyTip: 'Get 2 months for free by subscribing yearly!', - mostPopular: 'Most Popular', + title: 'Escolha o plano que melhor atende você', + yearlyTip: 'Receba 2 meses grátis assinando anualmente!', + mostPopular: 'Mais Popular', planRange: { - monthly: 'Monthly', - yearly: 'Yearly', + monthly: 'Mensalmente', + yearly: 'Anualmente', }, - month: 'month', - year: 'year', - save: 'Save ', - free: 'Free', - currentPlan: 'Current Plan', - contractOwner: 'Contact team manager', - startForFree: 'Start for free', - getStartedWith: 'Get started with ', - contactSales: 'Contact Sales', - talkToSales: 'Talk to Sales', - modelProviders: 'Model Providers', - teamMembers: 'Team Members', - buildApps: 'Build Apps', - vectorSpace: 'Vector Space', - vectorSpaceBillingTooltip: 'Each 1MB can store about 1.2million characters of vectorized data(estimated using OpenAI Embeddings, varies across models).', - vectorSpaceTooltip: 'Vector Space is the long-term memory system required for LLMs to comprehend your data.', - documentProcessingPriority: 'Document Processing Priority', - documentProcessingPriorityTip: 'For higher document processing priority, please upgrade your plan.', - documentProcessingPriorityUpgrade: 'Process more data with higher accuracy at faster speeds.', + month: 'mês', + year: 'ano', + save: 'Economize ', + free: 'Grátis', + currentPlan: 'Plano Atual', + contractOwner: 'Entre em contato com o gerente da equipe', + startForFree: 'Comece de graça', + getStartedWith: 'Comece com', + contactSales: 'Fale com a equipe de Vendas', + talkToSales: 'Fale com a equipe de Vendas', + modelProviders: 'Fornecedores de Modelos', + teamMembers: 'Membros da Equipe', + buildApps: 'Construir Aplicações', + vectorSpace: 'Espaço Vetorial', + vectorSpaceBillingTooltip: 'Cada 1MB pode armazenar cerca de 1,2 milhão de caracteres de dados vetorizados (estimado usando OpenAI Embeddings, varia entre os modelos).', + vectorSpaceTooltip: 'O Espaço Vetorial é o sistema de memória de longo prazo necessário para que LLMs compreendam seus dados.', + documentProcessingPriority: 'Prioridade no Processamento de Documentos', + documentProcessingPriorityTip: 'Para maior prioridade no processamento de documentos, faça o upgrade do seu plano.', + documentProcessingPriorityUpgrade: 'Processe mais dados com maior precisão e velocidade.', priority: { - 'standard': 'Standard', - 'priority': 'Priority', - 'top-priority': 'Top Priority', + 'standard': 'Padrão', + 'priority': 'Prioridade', + 'top-priority': 'Prioridade Máxima', }, - logsHistory: 'Logs history', - days: 'days', - unlimited: 'Unlimited', - support: 'Support', + logsHistory: 'Histórico de Logs', + days: 'dias', + unlimited: 'Ilimitado', + support: 'Suporte', supportItems: { - communityForums: 'Community forums', - emailSupport: 'Email support', - priorityEmail: 'Priority email & chat support', - logoChange: 'Logo change', - SSOAuthentication: 'SSO authentication', - personalizedSupport: 'Personalized support', - dedicatedAPISupport: 'Dedicated API support', - customIntegration: 'Custom integration and support', - ragAPIRequest: 'RAG API Requests', - agentModel: 'Agent Model', + communityForums: 'Fóruns da Comunidade', + emailSupport: 'Suporte por E-mail', + priorityEmail: 'Suporte prioritário por e-mail e chat', + logoChange: 'Mudança de logo', + SSOAuthentication: 'Autenticação SSO', + personalizedSupport: 'Suporte personalizado', + dedicatedAPISupport: 'Suporte dedicado à API', + customIntegration: 'Integração e suporte personalizados', + ragAPIRequest: 'Solicitações API RAG', + agentModel: 'Modelo de Agente', + workflow: 'Fluxo de trabalho', + llmLoadingBalancing: 'Balanceamento de carga LLM', + bulkUpload: 'Upload em massa de documentos', + llmLoadingBalancingTooltip: 'Adicione várias chaves de API aos modelos, efetivamente ignorando os limites de taxa da API. ', + agentMode: 'Modo Agente', }, - comingSoon: 'Coming soon', - member: 'Member', - memberAfter: 'Member', + comingSoon: 'Em breve', + member: 'Membro', + memberAfter: 'Membro', messageRequest: { - title: 'Message Credits', - tooltip: 'Message invocation quotas for various plans using OpenAI models (except gpt4).Messages over the limit will use your OpenAI API Key.', + title: 'Créditos de Mensagem', + tooltip: 'Cotas de invocação de mensagens para vários planos usando modelos da OpenAI (exceto gpt4). Mensagens além do limite usarão sua Chave de API da OpenAI.', }, annotatedResponse: { - title: 'Annotation Quota Limits', - tooltip: 'Manual editing and annotation of responses provides customizable high-quality question-answering abilities for apps. (Applicable only in chat apps)', + title: 'Limites de Cota de Anotação', + tooltip: 'A edição manual e anotação de respostas oferece habilidades personalizadas de perguntas e respostas de alta qualidade para aplicativos. (Aplicável apenas em aplicativos de chat)', }, - ragAPIRequestTooltip: 'Refers to the number of API calls invoking only the knowledge base processing capabilities of Dify.', + ragAPIRequestTooltip: 'Refere-se ao número de chamadas de API que invocam apenas as capacidades de processamento da base de conhecimento do Dify.', receiptInfo: 'Somente proprietários e administradores de equipe podem se inscrever e visualizar informações de cobrança', + customTools: 'Ferramentas personalizadas', + documentsUploadQuota: 'Cota de upload de documentos', + annotationQuota: 'Cota de anotação', + contractSales: 'Entre em contato com a equipe de vendas', + unavailable: 'Indisponível', }, plans: { sandbox: { name: 'Sandbox', - description: '200 times GPT free trial', - includesTitle: 'Includes:', + description: '200 vezes GPT de teste gratuito', + includesTitle: 'Inclui:', }, professional: { - name: 'Professional', - description: 'For individuals and small teams to unlock more power affordably.', - includesTitle: 'Everything in free plan, plus:', + name: 'Profissional', + description: 'Para indivíduos e pequenas equipes desbloquearem mais poder de forma acessível.', + includesTitle: 'Tudo no plano gratuito, além de:', }, team: { - name: 'Team', - description: 'Collaborate without limits and enjoy top-tier performance.', - includesTitle: 'Everything in Professional plan, plus:', + name: 'Equipe', + description: 'Colabore sem limites e aproveite o desempenho de primeira linha.', + includesTitle: 'Tudo no plano Profissional, além de:', }, enterprise: { - name: 'Enterprise', - description: 'Get full capabilities and support for large-scale mission-critical systems.', - includesTitle: 'Everything in Team plan, plus:', + name: 'Empresa', + description: 'Obtenha capacidades completas e suporte para sistemas críticos em larga escala.', + includesTitle: 'Tudo no plano Equipe, além de:', }, }, vectorSpace: { - fullTip: 'Vector Space is full.', - fullSolution: 'Upgrade your plan to get more space.', + fullTip: 'O Espaço Vetorial está cheio.', + fullSolution: 'Faça o upgrade do seu plano para obter mais espaço.', }, apps: { - fullTipLine1: 'Upgrade your plan to', - fullTipLine2: 'build more apps.', + fullTipLine1: 'Faça o upgrade do seu plano para', + fullTipLine2: 'construir mais aplicativos.', }, annotatedResponse: { - fullTipLine1: 'Upgrade your plan to', - fullTipLine2: 'annotate more conversations.', - quotaTitle: 'Annotation Reply Quota', + fullTipLine1: 'Faça o upgrade do seu plano para', + fullTipLine2: 'anotar mais conversas.', + quotaTitle: 'Cota de Respostas Anotadas', }, } diff --git a/web/i18n/pt-BR/common.ts b/web/i18n/pt-BR/common.ts index f93979404b..f9e9eb7888 100644 --- a/web/i18n/pt-BR/common.ts +++ b/web/i18n/pt-BR/common.ts @@ -37,6 +37,7 @@ const translation = { params: 'Parâmetros', duplicate: 'Duplicada', rename: 'Renomear', + audioSourceUnavailable: 'AudioSource não está disponível', }, placeholder: { input: 'Por favor, insira', @@ -128,7 +129,8 @@ const translation = { workspace: 'Espaço de trabalho', createWorkspace: 'Criar Espaço de Trabalho', helpCenter: 'Ajuda', - roadmapAndFeedback: 'Feedback', + communityFeedback: 'Feedback', + roadmap: 'Roteiro', community: 'Comunidade', about: 'Sobre', logout: 'Sair', @@ -190,16 +192,21 @@ const translation = { invitationSent: 'Convite enviado', invitationSentTip: 'Convite enviado e eles podem fazer login no Dify para acessar os dados da sua equipe.', invitationLink: 'Link do Convite', - failedinvitationEmails: 'Os seguintes usuários não foram convidados com sucesso', + failedInvitationEmails: 'Os seguintes usuários não foram convidados com sucesso', ok: 'OK', removeFromTeam: 'Remover da equipe', removeFromTeamTip: 'Removerá o acesso da equipe', setAdmin: 'Definir como administrador', setMember: 'Definir como membro comum', setEditor: 'Definir como editor', - disinvite: 'Cancelar o convite', + disInvite: 'Cancelar o convite', deleteMember: 'Excluir Membro', you: '(Você)', + datasetOperatorTip: 'Só pode gerenciar a base de dados de conhecimento', + builder: 'Construtor', + setBuilder: 'Definir como construtor', + builderTip: 'Pode criar e editar seus próprios aplicativos', + datasetOperator: 'Administrador de conhecimento', }, integrations: { connected: 'Conectado', @@ -346,6 +353,22 @@ const translation = { quotaTip: 'Tokens gratuitos disponíveis restantes', loadPresets: 'Carregar Predefinições', parameters: 'PARÂMETROS', + loadBalancingDescription: 'Reduza a pressão com vários conjuntos de credenciais.', + configLoadBalancing: 'Balanceamento de carga de configuração', + upgradeForLoadBalancing: 'Atualize seu plano para habilitar o balanceamento de carga.', + providerManaged: 'Gerenciado pelo provedor', + apiKeyStatusNormal: 'O status do APIKey é normal', + loadBalancing: 'Balanceamento de carga', + addConfig: 'Adicionar configuração', + providerManagedDescription: 'Use o único conjunto de credenciais fornecido pelo provedor de modelo.', + apiKey: 'CHAVE DE API', + loadBalancingLeastKeyWarning: 'Para habilitar o balanceamento de carga, pelo menos 2 chaves devem estar habilitadas.', + editConfig: 'Editar configuração', + defaultConfig: 'Configuração padrão', + modelHasBeenDeprecated: 'Este modelo foi preterido', + loadBalancingInfo: 'Por padrão, o balanceamento de carga usa a estratégia Round-robin. Se a limitação de taxa for acionada, um período de espera de 1 minuto será aplicado.', + apiKeyRateLimit: 'O limite de taxa foi atingido, disponível após {{seconds}}s', + loadBalancingHeadline: 'Balanceamento de carga', }, dataSource: { add: 'Adicionar uma fonte de dados', @@ -369,6 +392,15 @@ const translation = { preview: 'PRÉ-VISUALIZAÇÃO', }, }, + website: { + inactive: 'Inativo', + active: 'Ativo', + title: 'Local na rede Internet', + with: 'Com', + configuredCrawlers: 'Rastreadores configurados', + description: 'Importe conteúdo de sites usando o rastreador da Web.', + }, + configure: 'Configurar', }, plugin: { serpapi: { @@ -537,6 +569,10 @@ const translation = { created: 'Tag criada com sucesso', failed: 'Falha na criação da tag', }, + errorMsg: { + fieldRequired: '{{field}} é obrigatório', + urlError: 'URL deve começar com http:// ou https://', + }, } export default translation diff --git a/web/i18n/pt-BR/dataset-creation.ts b/web/i18n/pt-BR/dataset-creation.ts index b721f2177b..4ab78a50c7 100644 --- a/web/i18n/pt-BR/dataset-creation.ts +++ b/web/i18n/pt-BR/dataset-creation.ts @@ -45,11 +45,35 @@ const translation = { input: 'Nome do Conhecimento', placeholder: 'Por favor, insira', nameNotEmpty: 'O nome não pode estar vazio', - nameLengthInvaild: 'O nome deve ter entre 1 e 40 caracteres', + nameLengthInvalid: 'O nome deve ter entre 1 e 40 caracteres', cancelButton: 'Cancelar', confirmButton: 'Criar', failed: 'Falha na criação', }, + website: { + fireCrawlNotConfiguredDescription: 'Configure o Firecrawl com a chave de API para usá-lo.', + run: 'Correr', + unknownError: 'Erro desconhecido', + crawlSubPage: 'Rastrear subpáginas', + selectAll: 'Selecionar tudo', + resetAll: 'Redefinir tudo', + firecrawlDocLink: 'https://docs.dify.ai/guides/knowledge-base/sync-from-website', + includeOnlyPaths: 'Incluir apenas caminhos', + configure: 'Configurar', + limit: 'Limite', + firecrawlDoc: 'Documentos do Firecrawl', + preview: 'Visualizar', + options: 'Opções', + scrapTimeInfo: 'Páginas {{total}} raspadas no total dentro de {{time}}s', + exceptionErrorTitle: 'Ocorreu uma exceção durante a execução do trabalho Firecrawl:', + fireCrawlNotConfigured: 'O Firecrawl não está configurado', + maxDepthTooltip: 'Profundidade máxima para rastrear em relação ao URL inserido. A profundidade 0 apenas raspa a página do url inserido, a profundidade 1 raspa o url e tudo depois de inseridoURL + um / e assim por diante.', + firecrawlTitle: 'Extraia conteúdo da web com 🔥Firecrawl', + maxDepth: 'Profundidade máxima', + totalPageScraped: 'Total de páginas raspadas:', + excludePaths: 'Excluir caminhos', + extractOnlyMainContent: 'Extraia apenas o conteúdo principal (sem cabeçalhos, navs, rodapés, etc.)', + }, }, stepTwo: { segmentation: 'Configurações de fragmentação', @@ -80,8 +104,8 @@ const translation = { QATitle: 'Fragmentação no formato de Perguntas e Respostas', QATip: 'Habilitar esta opção consumirá mais tokens', QALanguage: 'Fragmentar usando', - emstimateCost: 'Estimativa', - emstimateSegment: 'Fragmentos estimados', + estimateCost: 'Estimativa', + estimateSegment: 'Fragmentos estimados', segmentCount: 'fragmentos', calculating: 'Calculando...', fileSource: 'Pré-processar documentos', @@ -104,9 +128,11 @@ const translation = { previewSwitchTipStart: 'A visualização atual do fragmento está no formato de texto, alternar para uma visualização no formato de Perguntas e Respostas irá', previewSwitchTipEnd: ' consumir tokens adicionais', characters: 'caracteres', - indexSettedTip: 'Para alterar o método de índice, por favor vá para as ', - retrivalSettedTip: 'Para alterar o método de índice, por favor vá para as ', + indexSettingTip: 'Para alterar o método de índice, por favor vá para as ', + retrievalSettingTip: 'Para alterar o método de índice, por favor vá para as ', datasetSettingLink: 'configurações do Conhecimento.', + websiteSource: 'Site de pré-processamento', + webpageUnit: 'Páginas', }, stepThree: { creationTitle: '🎉 Conhecimento criado', @@ -125,6 +151,11 @@ const translation = { modelButtonConfirm: 'Confirmar', modelButtonCancel: 'Cancelar', }, + firecrawl: { + apiKeyPlaceholder: 'Chave de API do firecrawl.dev', + configFirecrawl: 'Configurar 🔥o Firecrawl', + getApiKeyLinkText: 'Obtenha sua chave de API do firecrawl.dev', + }, } export default translation diff --git a/web/i18n/pt-BR/dataset-documents.ts b/web/i18n/pt-BR/dataset-documents.ts index a7265a9cff..ded46c8a14 100644 --- a/web/i18n/pt-BR/dataset-documents.ts +++ b/web/i18n/pt-BR/dataset-documents.ts @@ -13,6 +13,8 @@ const translation = { status: 'STATUS', action: 'AÇÃO', }, + name: 'Nome', + rename: 'Renomear', }, action: { uploadFile: 'Enviar novo arquivo', @@ -74,6 +76,7 @@ const translation = { error: 'Erro na importação', ok: 'OK', }, + addUrl: 'Adicionar URL', }, metadata: { title: 'Metadados', diff --git a/web/i18n/pt-BR/dataset-settings.ts b/web/i18n/pt-BR/dataset-settings.ts index fff9d9c7d5..cfedbff337 100644 --- a/web/i18n/pt-BR/dataset-settings.ts +++ b/web/i18n/pt-BR/dataset-settings.ts @@ -27,6 +27,8 @@ const translation = { longDescription: ' sobre o método de recuperação, você pode alterar isso a qualquer momento nas configurações do conhecimento.', }, save: 'Salvar', + permissionsInvitedMembers: 'Membros parciais da equipe', + me: '(Você)', }, } diff --git a/web/i18n/pt-BR/dataset.ts b/web/i18n/pt-BR/dataset.ts index 8710879149..3feeb99a8b 100644 --- a/web/i18n/pt-BR/dataset.ts +++ b/web/i18n/pt-BR/dataset.ts @@ -70,6 +70,8 @@ const translation = { nTo1RetrievalLegacy: 'A recuperação N-para-1 será oficialmente descontinuada a partir de setembro. Recomenda-se usar a recuperação de múltiplos caminhos mais recente para obter melhores resultados.', nTo1RetrievalLegacyLink: 'Saiba mais', nTo1RetrievalLegacyLinkText: 'A recuperação N-para-1 será oficialmente descontinuada em setembro.', + intro6: 'como um plug-in de índice ChatGPT autônomo para publicar', + defaultRetrievalTip: 'A recuperação de vários caminhos é usada por padrão. O conhecimento é recuperado de várias bases de dados de conhecimento e, em seguida, reclassificado.', } export default translation diff --git a/web/i18n/pt-BR/login.ts b/web/i18n/pt-BR/login.ts index 88312778c3..a1f278dba8 100644 --- a/web/i18n/pt-BR/login.ts +++ b/web/i18n/pt-BR/login.ts @@ -31,7 +31,7 @@ const translation = { pp: 'Política de Privacidade', tosDesc: 'Ao se inscrever, você concorda com nossos', goToInit: 'Se você não inicializou a conta, vá para a página de inicialização', - donthave: 'Não tem?', + dontHave: 'Não tem?', invalidInvitationCode: 'Código de convite inválido', accountAlreadyInited: 'Conta já iniciada', forgotPassword: 'Esqueceu sua senha?', @@ -53,6 +53,7 @@ const translation = { nameEmpty: 'O nome é obrigatório', passwordEmpty: 'A senha é obrigatória', passwordInvalid: 'A senha deve conter letras e números e ter um comprimento maior que 8', + passwordLengthInValid: 'A senha deve ter pelo menos 8 caracteres', }, license: { tip: 'Antes de começar a usar a Edição Comunitária do Dify, leia a', @@ -66,6 +67,9 @@ const translation = { activatedTipStart: 'Você se juntou à equipe', activatedTipEnd: '', activated: 'Entrar agora', + adminInitPassword: 'Senha de inicialização do administrador', + validate: 'Validar', + sso: 'Continuar com SSO', } export default translation diff --git a/web/i18n/pt-BR/share-app.ts b/web/i18n/pt-BR/share-app.ts index 27baf35275..1e1861e01b 100644 --- a/web/i18n/pt-BR/share-app.ts +++ b/web/i18n/pt-BR/share-app.ts @@ -2,7 +2,7 @@ const translation = { common: { welcome: '', appUnavailable: 'O aplicativo não está disponível', - appUnkonwError: 'O aplicativo encontrou um erro desconhecido', + appUnknownError: 'O aplicativo encontrou um erro desconhecido', }, chat: { newChat: 'Nova conversa', @@ -10,7 +10,7 @@ const translation = { unpinnedTitle: 'Conversas', newChatDefaultName: 'Nova conversa', resetChat: 'Redefinir conversa', - powerBy: 'Desenvolvido por', + poweredBy: 'Desenvolvido por', prompt: 'Prompt', privatePromptConfigTitle: 'Configurações da conversa', publicPromptConfigTitle: 'Prompt inicial', diff --git a/web/i18n/pt-BR/tools.ts b/web/i18n/pt-BR/tools.ts index 28b6e82176..1b20715328 100644 --- a/web/i18n/pt-BR/tools.ts +++ b/web/i18n/pt-BR/tools.ts @@ -5,6 +5,7 @@ const translation = { all: 'Todas', builtIn: 'Integradas', custom: 'Personalizadas', + workflow: 'Fluxo de trabalho', }, contribute: { line1: 'Estou interessado em ', @@ -75,6 +76,27 @@ const translation = { customDisclaimerPlaceholder: 'Digite o aviso personalizado', deleteToolConfirmTitle: 'Excluir esta ferramenta?', deleteToolConfirmContent: 'Excluir a ferramenta é irreversível. Os usuários não poderão mais acessar sua ferramenta.', + toolInput: { + label: 'Tags', + methodSetting: 'Ambiente', + methodParameterTip: 'Preenchimentos de LLM durante a inferência', + methodSettingTip: 'O usuário preenche a configuração da ferramenta', + methodParameter: 'Parâmetro', + name: 'Nome', + description: 'Descrição', + method: 'Método', + required: 'Necessário', + title: 'Entrada de ferramenta', + labelPlaceholder: 'Escolha tags(opcional)', + descriptionPlaceholder: 'Descrição do significado do parâmetro', + }, + description: 'Descrição', + nameForToolCall: 'Nome da chamada da ferramenta', + confirmTip: 'Os aplicativos que usam essa ferramenta serão afetados', + confirmTitle: 'Confirme para salvar ?', + nameForToolCallTip: 'Suporta apenas números, letras e sublinhados.', + descriptionPlaceholder: 'Breve descrição da finalidade da ferramenta, por exemplo, obter a temperatura para um local específico.', + nameForToolCallPlaceHolder: 'Usado para reconhecimento de máquina, como getCurrentWeather, list_pets', }, test: { title: 'Testar', @@ -114,6 +136,18 @@ const translation = { toolRemoved: 'Ferramenta removida', notAuthorized: 'Ferramenta não autorizada', howToGet: 'Como obter', + addToolModal: { + category: 'categoria', + type: 'tipo', + emptyTip: 'Vá para "Fluxo de trabalho - > Publicar como ferramenta"', + add: 'adicionar', + emptyTitle: 'Nenhuma ferramenta de fluxo de trabalho disponível', + added: 'Adicionado', + manageInTools: 'Gerenciar em Ferramentas', + }, + openInStudio: 'Abrir no Studio', + customToolTip: 'Saiba mais sobre as ferramentas personalizadas da Dify', + toolNameUsageTip: 'Nome da chamada da ferramenta para raciocínio e solicitação do agente', } export default translation diff --git a/web/i18n/pt-BR/workflow.ts b/web/i18n/pt-BR/workflow.ts index 071f6e99f8..ef589c0bde 100644 --- a/web/i18n/pt-BR/workflow.ts +++ b/web/i18n/pt-BR/workflow.ts @@ -36,7 +36,7 @@ const translation = { searchVar: 'Buscar variável', variableNamePlaceholder: 'Nome da variável', setVarValuePlaceholder: 'Definir valor da variável', - needConnecttip: 'Este passo não está conectado a nada', + needConnectTip: 'Este passo não está conectado a nada', maxTreeDepth: 'Limite máximo de {{depth}} nós por ramo', needEndNode: 'O bloco de fim deve ser adicionado', needAnswerNode: 'O bloco de resposta deve ser adicionado', @@ -69,6 +69,30 @@ const translation = { manageInTools: 'Gerenciar nas ferramentas', workflowAsToolTip: 'É necessária a reconfiguração da ferramenta após a atualização do fluxo de trabalho.', viewDetailInTracingPanel: 'Ver detalhes', + importSuccess: 'Sucesso da importação', + chooseDSL: 'Escolha o arquivo DSL(yml)', + importFailure: 'Falha na importação', + syncingData: 'Sincronizando dados, apenas alguns segundos.', + overwriteAndImport: 'Substituir e importar', + importDSLTip: 'O rascunho atual será substituído. Exporte o fluxo de trabalho como backup antes de importar.', + backupCurrentDraft: 'Fazer backup do rascunho atual', + importDSL: 'Importar DSL', + parallelTip: { + click: { + title: 'Clique', + desc: 'para adicionar', + }, + drag: { + title: 'Arrastar', + desc: 'para conectar', + }, + limit: 'O paralelismo é limitado a {{num}} ramificações.', + depthLimit: 'Limite de camada de aninhamento paralelo de {{num}} camadas', + }, + parallelRun: 'Execução paralela', + disconnect: 'Desligar', + jumpToNode: 'Ir para este nó', + addParallelNode: 'Adicionar nó paralelo', }, env: { envPanelTitle: 'Variáveis de Ambiente', @@ -178,6 +202,7 @@ const translation = { 'transform': 'Transformar', 'utilities': 'Utilitários', 'noResult': 'Nenhum resultado encontrado', + 'searchTool': 'Ferramenta de pesquisa', }, blocks: { 'start': 'Iniciar', @@ -403,10 +428,12 @@ const translation = { 'not empty': 'não está vazio', 'null': 'é nulo', 'not null': 'não é nulo', + 'regex match': 'partida regex', }, enterValue: 'Digite o valor', addCondition: 'Adicionar condição', conditionNotSetup: 'Condição NÃO configurada', + selectVariable: 'Selecione a variável...', }, variableAssigner: { title: 'Atribuir variáveis', @@ -502,6 +529,25 @@ const translation = { iteration_other: '{{count}} Iterações', currentIteration: 'Iteração atual', }, + note: { + editor: { + small: 'Pequeno', + bold: 'Ousado', + openLink: 'Abrir', + strikethrough: 'Tachado', + italic: 'Itálico', + invalidUrl: 'URL inválido', + placeholder: 'Escreva sua nota...', + bulletList: 'Lista de marcadores', + link: 'Link', + enterUrl: 'Digite o URL...', + medium: 'Média', + large: 'Grande', + unlink: 'Desvincular', + showAuthor: 'Autor do programa', + }, + addNote: 'Adicionar nota', + }, }, tracing: { stopBy: 'Parado por {{user}}', diff --git a/web/i18n/ro-RO/app-api.ts b/web/i18n/ro-RO/app-api.ts index 0b86ec6976..e6a52ade42 100644 --- a/web/i18n/ro-RO/app-api.ts +++ b/web/i18n/ro-RO/app-api.ts @@ -10,7 +10,7 @@ const translation = { pause: 'Pauză', playing: 'În redare', loading: 'Se încarcă', - merMaind: { + merMaid: { rerender: 'Reprocesare', }, never: 'Niciodată', diff --git a/web/i18n/ro-RO/app-debug.ts b/web/i18n/ro-RO/app-debug.ts index 7363f2954f..bafeee8bb0 100644 --- a/web/i18n/ro-RO/app-debug.ts +++ b/web/i18n/ro-RO/app-debug.ts @@ -265,7 +265,7 @@ const translation = { historyNoBeEmpty: 'Istoricul conversației trebuie setat în prompt', queryNoBeEmpty: 'Interogația trebuie setată în prompt', }, - variableConig: { + variableConfig: { 'addModalTitle': 'Adăugați câmp de intrare', 'editModalTitle': 'Editați câmpul de intrare', 'description': 'Setare pentru variabila {{varName}}', diff --git a/web/i18n/ro-RO/app-overview.ts b/web/i18n/ro-RO/app-overview.ts index 007a76a10e..9a9c9be35b 100644 --- a/web/i18n/ro-RO/app-overview.ts +++ b/web/i18n/ro-RO/app-overview.ts @@ -48,6 +48,8 @@ const translation = { title: 'Pași flux de lucru', show: 'Afișați', hide: 'Ascundeți', + subTitle: 'Detalii despre fluxul de lucru', + showDesc: 'Afișarea sau ascunderea detaliilor fluxului de lucru în WebApp', }, chatColorTheme: 'Tema de culoare a chatului', chatColorThemeDesc: 'Setați tema de culoare a chatbotului', @@ -60,6 +62,15 @@ const translation = { privacyPolicy: 'Politica de confidențialitate', privacyPolicyPlaceholder: 'Introduceți link-ul politicii de confidențialitate', privacyPolicyTip: 'Ajută vizitatorii să înțeleagă datele pe care le colectează aplicația, consultați Politica de confidențialitate a Dify.', + customDisclaimerPlaceholder: 'Introduceți textul personalizat de declinare a responsabilității', + customDisclaimerTip: 'Textul personalizat de declinare a responsabilității va fi afișat pe partea clientului, oferind informații suplimentare despre aplicație', + customDisclaimer: 'Declinarea responsabilității personalizate', + }, + sso: { + label: 'Autentificare SSO', + title: 'WebApp SSO', + description: 'Toți utilizatorii trebuie să se conecteze cu SSO înainte de a utiliza WebApp', + tooltip: 'Contactați administratorul pentru a activa WebApp SSO', }, }, embedded: { @@ -116,7 +127,11 @@ const translation = { tokenPS: 'Token/s', totalMessages: { title: 'Mesaje totale', - explanation: 'Număr de interacțiuni AI zilnice; exclud proiectarea și depanarea promptelor.', + explanation: 'Numărul de interacțiuni zilnice cu IA.', + }, + totalConversations: { + title: 'Total Conversații', + explanation: 'Numărul de conversații zilnice cu IA; ingineria/depanarea prompturilor exclusă.', }, activeUsers: { title: 'Utilizatori activi', diff --git a/web/i18n/ro-RO/app.ts b/web/i18n/ro-RO/app.ts index 2d13dd4e66..9baaabdd07 100644 --- a/web/i18n/ro-RO/app.ts +++ b/web/i18n/ro-RO/app.ts @@ -122,7 +122,17 @@ const translation = { removeConfirmTitle: 'Eliminați configurația {{key}}?', removeConfirmContent: 'Configurația curentă este în uz, eliminarea acesteia va dezactiva funcția de Urmărire.', }, + view: 'Vedere', }, + answerIcon: { + descriptionInExplore: 'Dacă să utilizați pictograma WebApp pentru a înlocui 🤖 în Explore', + description: 'Dacă se utilizează pictograma WebApp pentru a înlocui 🤖 în aplicația partajată', + title: 'Utilizați pictograma WebApp pentru a înlocui 🤖', + }, + importFromDSL: 'Import din DSL', + importFromDSLUrl: 'De la URL', + importFromDSLUrlPlaceholder: 'Lipiți linkul DSL aici', + importFromDSLFile: 'Din fișierul DSL', } export default translation diff --git a/web/i18n/ro-RO/billing.ts b/web/i18n/ro-RO/billing.ts index 57b9986889..707d892047 100644 --- a/web/i18n/ro-RO/billing.ts +++ b/web/i18n/ro-RO/billing.ts @@ -60,6 +60,8 @@ const translation = { bulkUpload: 'Încărcare în bloc a documentelor', agentMode: 'Mod agent', workflow: 'Flux de lucru', + llmLoadingBalancing: 'Echilibrarea sarcinii LLM', + llmLoadingBalancingTooltip: 'Adăugați mai multe chei API la modele, ocolind efectiv limitele de rată API.', }, comingSoon: 'Vine în curând', member: 'Membru', @@ -74,6 +76,7 @@ const translation = { }, ragAPIRequestTooltip: 'Se referă la numărul de apeluri API care invocă doar capacitățile de procesare a bazei de cunoștințe a Dify.', receiptInfo: 'Doar proprietarul echipei și administratorul echipei pot să se aboneze și să vizualizeze informațiile de facturare', + annotationQuota: 'Cota de adnotare', }, plans: { sandbox: { diff --git a/web/i18n/ro-RO/common.ts b/web/i18n/ro-RO/common.ts index 34ca1c4671..1fd8778106 100644 --- a/web/i18n/ro-RO/common.ts +++ b/web/i18n/ro-RO/common.ts @@ -37,6 +37,7 @@ const translation = { params: 'Parametri', duplicate: 'Duplică', rename: 'Redenumește', + audioSourceUnavailable: 'Sursa audio nu este disponibilă', }, placeholder: { input: 'Vă rugăm să introduceți', @@ -63,6 +64,7 @@ const translation = { hiIN: 'Hindi', trTR: 'Turcă', faIR: 'Persană', + plPL: 'Poloneză', }, }, unit: { @@ -127,7 +129,8 @@ const translation = { workspace: 'Spațiu de lucru', createWorkspace: 'Creează Spațiu de lucru', helpCenter: 'Ajutor', - roadmapAndFeedback: 'Feedback', + communityFeedback: 'Feedback', + roadmap: 'Plan de acțiune', community: 'Comunitate', about: 'Despre', logout: 'Deconectare', @@ -189,16 +192,21 @@ const translation = { invitationSent: 'Invitație trimisă', invitationSentTip: 'Invitația a fost trimisă și pot să se autentifice în Dify pentru a accesa datele echipei dvs.', invitationLink: 'Link de invitație', - failedinvitationEmails: 'Următorii utilizatori nu au fost invitați cu succes', + failedInvitationEmails: 'Următorii utilizatori nu au fost invitați cu succes', ok: 'OK', removeFromTeam: 'Elimină din echipă', removeFromTeamTip: 'Va elimina accesul la echipă', setAdmin: 'Setează ca administrator', setMember: 'Setează ca membru obișnuit', setEditor: 'Setează ca editor', - disinvite: 'Anulează invitația', + disInvite: 'Anulează invitația', deleteMember: 'Șterge membru', you: '(Dvs.)', + datasetOperatorTip: 'Numai poate gestiona baza de cunoștințe', + builder: 'Constructor', + datasetOperator: 'Administrator de cunoștințe', + setBuilder: 'Setare ca constructor', + builderTip: 'Poate construi și edita propriile aplicații', }, integrations: { connected: 'Conectat', @@ -345,6 +353,22 @@ const translation = { quotaTip: 'Jetoane gratuite disponibile rămase', loadPresets: 'Încarcă presetări', parameters: 'PARAMETRI', + loadBalancingHeadline: 'Echilibrare', + loadBalancingInfo: 'În mod implicit, echilibrarea încărcării utilizează strategia Round-robin. Dacă se declanșează limitarea ratei, se va aplica o perioadă de reactivare de 1 minut.', + loadBalancing: 'Echilibrare', + apiKeyRateLimit: 'Limita de viteză a fost atinsă, disponibilă după {{secunde}}s', + providerManaged: 'Gestionat de furnizor', + providerManagedDescription: 'Utilizați setul unic de acreditări furnizat de furnizorul de modele.', + defaultConfig: 'Configurație implicită', + addConfig: 'Adăugați configurație', + apiKey: 'CHEIE API', + modelHasBeenDeprecated: 'Acest model a fost depreciat', + loadBalancingDescription: 'Reduceți presiunea cu mai multe seturi de acreditări.', + apiKeyStatusNormal: 'Starea APIKey este normală', + loadBalancingLeastKeyWarning: 'Pentru a activa echilibrarea încărcării trebuie activate cel puțin 2 chei.', + editConfig: 'Editați configurația', + configLoadBalancing: 'Echilibrarea încărcării de configurare', + upgradeForLoadBalancing: 'Actualizați-vă planul pentru a activa Load Balancing.', }, dataSource: { add: 'Adăugați o sursă de date', @@ -368,6 +392,15 @@ const translation = { preview: 'PREVIZUALIZARE', }, }, + website: { + inactive: 'Inactiv', + description: 'Importați conținut de pe site-uri web folosind crawlerul web.', + active: 'Activ', + with: 'Cu', + title: 'Site-ul web', + configuredCrawlers: 'Crawlere configurate', + }, + configure: 'Configura', }, plugin: { serpapi: { @@ -536,6 +569,10 @@ const translation = { created: 'Etichetă creată cu succes', failed: 'Crearea etichetei a eșuat', }, + errorMsg: { + fieldRequired: '{{câmp}} este obligatoriu', + urlError: 'URL-ul ar trebui să înceapă cu http:// sau https://', + }, } export default translation diff --git a/web/i18n/ro-RO/dataset-creation.ts b/web/i18n/ro-RO/dataset-creation.ts index 89e614e00c..efe3bb246c 100644 --- a/web/i18n/ro-RO/dataset-creation.ts +++ b/web/i18n/ro-RO/dataset-creation.ts @@ -45,11 +45,35 @@ const translation = { input: 'Numele Cunoștinței', placeholder: 'Vă rugăm să introduceți', nameNotEmpty: 'Numele nu poate fi gol', - nameLengthInvaild: 'Numele trebuie să fie între 1 și 40 de caractere', + nameLengthInvalid: 'Numele trebuie să fie între 1 și 40 de caractere', cancelButton: 'Anulează', confirmButton: 'Creează', failed: 'Crearea a eșuat', }, + website: { + crawlSubPage: 'Accesarea cu crawlere a subpaginilor', + limit: 'Limită', + selectAll: 'Selectează tot', + configure: 'Configura', + preview: 'Previzualizare', + run: 'Alerga', + maxDepth: 'Adâncime maximă', + firecrawlDoc: 'Documente Firecrawl', + options: 'Opțiuni', + exceptionErrorTitle: 'A apărut o excepție în timpul rulării lucrării Firecrawl:', + firecrawlTitle: 'Extrageți conținut web cu 🔥Firecrawl', + unknownError: 'Eroare necunoscută', + scrapTimeInfo: 'Pagini răzuite {{total}} în total în {{timp}}s', + firecrawlDocLink: 'https://docs.dify.ai/guides/knowledge-base/sync-from-website', + excludePaths: 'Excluderea căilor', + resetAll: 'Resetați toate', + extractOnlyMainContent: 'Extrageți doar conținutul principal (fără anteturi, navigări, subsoluri etc.)', + fireCrawlNotConfiguredDescription: 'Configurați Firecrawl cu cheia API pentru a-l utiliza.', + fireCrawlNotConfigured: 'Firecrawl nu este configurat', + includeOnlyPaths: 'Includeți numai căi', + totalPageScraped: 'Total pagini răzuite:', + maxDepthTooltip: 'Adâncimea maximă de accesat cu crawlere în raport cu adresa URL introdusă. Adâncimea 0 doar răzuiește pagina URL-ului introdus, adâncimea 1 răzuiește url-ul și tot ceea ce după ce a introdusURL + un / și așa mai departe.', + }, }, stepTwo: { segmentation: 'Setări de segmentare', @@ -80,8 +104,8 @@ const translation = { QATitle: 'Segmentarea în format Întrebare și Răspuns', QATip: 'Activarea acestei opțiuni va consuma mai multe jetoane', QALanguage: 'Segmentează folosind', - emstimateCost: 'Estimare', - emstimateSegment: 'Segmente estimate', + estimateCost: 'Estimare', + estimateSegment: 'Segmente estimate', segmentCount: 'segmente', calculating: 'Se calculează...', fileSource: 'Prelucrează documente', @@ -104,9 +128,11 @@ const translation = { previewSwitchTipStart: 'Previzualizarea curentă a segmentului este în format text, comutarea la o previzualizare în format întrebare și răspuns va', previewSwitchTipEnd: ' consuma jetoane suplimentare', characters: 'caractere', - indexSettedTip: 'Pentru a modifica metoda de indexare, vă rugăm să mergeți la ', - retrivalSettedTip: 'Pentru a modifica metoda de indexare, vă rugăm să mergeți la ', + indexSettingTip: 'Pentru a modifica metoda de indexare, vă rugăm să mergeți la ', + retrievalSettingTip: 'Pentru a modifica metoda de indexare, vă rugăm să mergeți la ', datasetSettingLink: 'setările Cunoștinței.', + webpageUnit: 'Pagini', + websiteSource: 'Site-ul web de preprocesare', }, stepThree: { creationTitle: '🎉 Cunoștință creată', @@ -125,6 +151,11 @@ const translation = { modelButtonConfirm: 'Confirmă', modelButtonCancel: 'Anulează', }, + firecrawl: { + configFirecrawl: 'Configurați 🔥Firecrawl', + getApiKeyLinkText: 'Obțineți cheia API de la firecrawl.dev', + apiKeyPlaceholder: 'Cheie API de la firecrawl.dev', + }, } export default translation diff --git a/web/i18n/ro-RO/dataset-documents.ts b/web/i18n/ro-RO/dataset-documents.ts index a7c0bf5d51..ed8720e35a 100644 --- a/web/i18n/ro-RO/dataset-documents.ts +++ b/web/i18n/ro-RO/dataset-documents.ts @@ -13,6 +13,8 @@ const translation = { status: 'STARE', action: 'ACȚIUNE', }, + name: 'Nume', + rename: 'Redenumire', }, action: { uploadFile: 'Încarcă un fișier nou', @@ -74,6 +76,7 @@ const translation = { error: 'Eroare la import', ok: 'OK', }, + addUrl: 'Adăugați adresa URL', }, metadata: { title: 'Metadate', diff --git a/web/i18n/ro-RO/dataset-settings.ts b/web/i18n/ro-RO/dataset-settings.ts index c0f9e76aeb..54780f9c16 100644 --- a/web/i18n/ro-RO/dataset-settings.ts +++ b/web/i18n/ro-RO/dataset-settings.ts @@ -27,6 +27,8 @@ const translation = { longDescription: ' despre metoda de recuperare, o puteți schimba în orice moment în setările cunoștințelor.', }, save: 'Salvare', + permissionsInvitedMembers: 'Membri parțiali ai echipei', + me: '(Tu)', }, } diff --git a/web/i18n/ro-RO/dataset.ts b/web/i18n/ro-RO/dataset.ts index 3e605baf92..baf54f68e6 100644 --- a/web/i18n/ro-RO/dataset.ts +++ b/web/i18n/ro-RO/dataset.ts @@ -71,6 +71,7 @@ const translation = { nTo1RetrievalLegacy: 'Recuperarea N-la-1 va fi oficial depreciată din septembrie. Se recomandă utilizarea celei mai recente recuperări cu căi multiple pentru a obține rezultate mai bune.', nTo1RetrievalLegacyLink: 'Află mai multe', nTo1RetrievalLegacyLinkText: 'Recuperarea N-la-1 va fi oficial depreciată în septembrie.', + defaultRetrievalTip: 'Recuperarea pe mai multe căi este utilizată în mod implicit. Cunoștințele sunt preluate din mai multe baze de cunoștințe și apoi reclasificate.', } export default translation diff --git a/web/i18n/ro-RO/login.ts b/web/i18n/ro-RO/login.ts index c8a0fad91c..6a8d899b33 100644 --- a/web/i18n/ro-RO/login.ts +++ b/web/i18n/ro-RO/login.ts @@ -32,7 +32,7 @@ const translation = { pp: 'Politica de confidențialitate', tosDesc: 'Prin înregistrarea, ești de acord cu', goToInit: 'Dacă nu ai inițializat încă contul, te rugăm să mergi la pagina de inițializare', - donthave: 'Nu ai?', + dontHave: 'Nu ai?', invalidInvitationCode: 'Cod de invitație invalid', accountAlreadyInited: 'Contul este deja inițializat', forgotPassword: 'Ați uitat parola?', @@ -54,6 +54,7 @@ const translation = { nameEmpty: 'Numele este obligatoriu', passwordEmpty: 'Parola este obligatorie', passwordInvalid: 'Parola trebuie să conțină litere și cifre, iar lungimea trebuie să fie mai mare de 8 caractere', + passwordLengthInValid: 'Parola trebuie să aibă cel puțin 8 caractere', }, license: { tip: 'Înainte de a începe Dify Community Edition, citește', diff --git a/web/i18n/ro-RO/share-app.ts b/web/i18n/ro-RO/share-app.ts index 06cf083a04..c9ec36ab03 100644 --- a/web/i18n/ro-RO/share-app.ts +++ b/web/i18n/ro-RO/share-app.ts @@ -2,7 +2,7 @@ const translation = { common: { welcome: '', appUnavailable: 'Aplicația nu este disponibilă', - appUnkonwError: 'Aplicația nu este disponibilă', + appUnknownError: 'Aplicația nu este disponibilă', }, chat: { newChat: 'Chat nou', @@ -10,7 +10,7 @@ const translation = { unpinnedTitle: 'Conversații', newChatDefaultName: 'Conversație nouă', resetChat: 'Resetează conversația', - powerBy: 'Furnizat de', + poweredBy: 'Furnizat de', prompt: 'Sugestie', privatePromptConfigTitle: 'Setări conversație', publicPromptConfigTitle: 'Sugestie inițială', diff --git a/web/i18n/ro-RO/tools.ts b/web/i18n/ro-RO/tools.ts index e878162426..165bdb26ed 100644 --- a/web/i18n/ro-RO/tools.ts +++ b/web/i18n/ro-RO/tools.ts @@ -5,6 +5,7 @@ const translation = { all: 'Toate', builtIn: 'Incorporat', custom: 'Personalizat', + workflow: 'Flux de lucru', }, contribute: { line1: 'Sunt interesat să ', @@ -73,6 +74,29 @@ const translation = { privacyPolicyPlaceholder: 'Vă rugăm să introduceți politica de confidențialitate', deleteToolConfirmTitle: 'Ștergeți această unealtă?', deleteToolConfirmContent: ' Ștergerea uneltă este irreversibilă. Utilizatorii nu vor mai putea accesa uneltă dvs.', + toolInput: { + methodParameter: 'Parametru', + description: 'Descriere', + methodSetting: 'Setare', + methodSettingTip: 'Utilizatorul completează configurația instrumentului', + methodParameterTip: 'Completări LLM în timpul inferenței', + name: 'Nume', + descriptionPlaceholder: 'Descrierea semnificației parametrului', + label: 'Tags', + required: 'Necesar', + method: 'Metodă', + title: 'Intrare instrument', + labelPlaceholder: 'Alegeți etichetele (opțional)', + }, + descriptionPlaceholder: 'Scurtă descriere a scopului instrumentului, de exemplu, obțineți temperatura pentru o anumită locație.', + nameForToolCall: 'Numele apelului instrumentului', + description: 'Descriere', + confirmTip: 'Aplicațiile care folosesc acest instrument vor fi afectate', + nameForToolCallPlaceHolder: 'Utilizat pentru recunoașterea mașinii, cum ar fi getCurrentWeather, list_pets', + customDisclaimer: 'Declinarea responsabilității personalizate', + confirmTitle: 'Confirmați pentru a salva?', + customDisclaimerPlaceholder: 'Vă rugăm să introduceți declinarea responsabilității personalizate', + nameForToolCallTip: 'Acceptă doar numere, litere și caractere de subliniere.', }, test: { title: 'Testează', @@ -112,6 +136,18 @@ const translation = { toolRemoved: 'Instrument eliminat', notAuthorized: 'Instrument neautorizat', howToGet: 'Cum să obții', + addToolModal: { + added: 'adăugat', + category: 'categorie', + manageInTools: 'Gestionați în Instrumente', + add: 'adăuga', + type: 'tip', + emptyTitle: 'Nu este disponibil niciun instrument de flux de lucru', + emptyTip: 'Accesați "Flux de lucru -> Publicați ca instrument"', + }, + openInStudio: 'Deschide în Studio', + customToolTip: 'Aflați mai multe despre instrumentele personalizate Dify', + toolNameUsageTip: 'Numele de apel al instrumentului pentru raționamentul și solicitarea agentului', } export default translation diff --git a/web/i18n/ro-RO/workflow.ts b/web/i18n/ro-RO/workflow.ts index 9cbd2c0d7e..689ebdead9 100644 --- a/web/i18n/ro-RO/workflow.ts +++ b/web/i18n/ro-RO/workflow.ts @@ -36,7 +36,7 @@ const translation = { searchVar: 'Caută variabilă', variableNamePlaceholder: 'Nume variabilă', setVarValuePlaceholder: 'Setează valoarea variabilei', - needConnecttip: 'Acest pas nu este conectat la nimic', + needConnectTip: 'Acest pas nu este conectat la nimic', maxTreeDepth: 'Limită maximă de {{depth}} noduri pe ramură', needEndNode: 'Trebuie adăugat blocul de sfârșit', needAnswerNode: 'Trebuie adăugat blocul de răspuns', @@ -69,6 +69,30 @@ const translation = { manageInTools: 'Gestionează în instrumente', workflowAsToolTip: 'Reconfigurarea instrumentului este necesară după actualizarea fluxului de lucru.', viewDetailInTracingPanel: 'Vezi detalii', + overwriteAndImport: 'Suprascriere și import', + chooseDSL: 'Alegeți fișierul DSL(yml)', + syncingData: 'Sincronizarea datelor, doar câteva secunde.', + importDSL: 'Importați DSL', + importFailure: 'Eșecul importului', + importSuccess: 'Succesul importului', + backupCurrentDraft: 'Backup curent draft', + importDSLTip: 'Proiectul curent va fi suprascris. Exportați fluxul de lucru ca backup înainte de import.', + parallelTip: { + click: { + title: 'Clic', + desc: 'pentru a adăuga', + }, + drag: { + title: 'Glisa', + desc: 'pentru a vă conecta', + }, + depthLimit: 'Limita straturilor de imbricare paralelă a {{num}} straturi', + limit: 'Paralelismul este limitat la {{num}} ramuri.', + }, + parallelRun: 'Rulare paralelă', + disconnect: 'Deconecta', + jumpToNode: 'Sari la acest nod', + addParallelNode: 'Adăugare nod paralel', }, env: { envPanelTitle: 'Variabile de Mediu', @@ -178,6 +202,7 @@ const translation = { 'transform': 'Transformare', 'utilities': 'Utilități', 'noResult': 'Niciun rezultat găsit', + 'searchTool': 'Instrument de căutare', }, blocks: { 'start': 'Începe', @@ -403,10 +428,12 @@ const translation = { 'not empty': 'nu este gol', 'null': 'este null', 'not null': 'nu este null', + 'regex match': 'potrivire regex', }, enterValue: 'Introduceți valoarea', addCondition: 'Adăugați condiție', conditionNotSetup: 'Condiția NU este setată', + selectVariable: 'Selectați variabila...', }, variableAssigner: { title: 'Atribuie variabile', @@ -502,6 +529,25 @@ const translation = { iteration_other: '{{count}} Iterații', currentIteration: 'Iterație curentă', }, + note: { + editor: { + small: 'Mic', + bold: 'Îndrăzneț', + unlink: 'Deconecta', + strikethrough: 'Tăiere', + invalidUrl: 'URL nevalidă', + medium: 'Medie', + openLink: 'Deschide', + large: 'Mare', + enterUrl: 'Introduceți adresa URL...', + italic: 'Cursiv', + placeholder: 'Scrie-ți notița...', + link: 'Legătură', + bulletList: 'Lista de marcatori', + showAuthor: 'Afișați autorul', + }, + addNote: 'Adăugați o notă', + }, }, tracing: { stopBy: 'Oprit de {{user}}', diff --git a/web/i18n/ru-RU/app-annotation.ts b/web/i18n/ru-RU/app-annotation.ts new file mode 100644 index 0000000000..18f2ae4a11 --- /dev/null +++ b/web/i18n/ru-RU/app-annotation.ts @@ -0,0 +1,87 @@ +const translation = { + title: 'Аннотации', + name: 'Ответить на аннотацию', + editBy: 'Ответ отредактирован {{author}}', + noData: { + title: 'Нет аннотаций', + description: 'Вы можете редактировать аннотации во время отладки приложения или импортировать их массово здесь для получения качественного ответа.', + }, + table: { + header: { + question: 'вопрос', + answer: 'ответ', + createdAt: 'создано', + hits: 'попаданий', + actions: 'действия', + addAnnotation: 'Добавить аннотацию', + bulkImport: 'Массовый импорт', + bulkExport: 'Массовый экспорт', + clearAll: 'Очистить все аннотации', + }, + }, + editModal: { + title: 'Редактировать ответ аннотации', + queryName: 'Запрос пользователя', + answerName: 'Storyteller Bot', + yourAnswer: 'Ваш ответ', + answerPlaceholder: 'Введите ваш ответ здесь', + yourQuery: 'Ваш запрос', + queryPlaceholder: 'Введите ваш запрос здесь', + removeThisCache: 'Удалить эту аннотацию', + createdAt: 'Создано', + }, + addModal: { + title: 'Добавить ответ аннотации', + queryName: 'Вопрос', + answerName: 'Ответ', + answerPlaceholder: 'Введите ответ здесь', + queryPlaceholder: 'Введите вопрос здесь', + createNext: 'Добавить еще один аннотированный ответ', + }, + batchModal: { + title: 'Массовый импорт', + csvUploadTitle: 'Перетащите сюда ваш CSV-файл или ', + browse: 'выберите файл', + tip: 'CSV-файл должен соответствовать следующей структуре:', + question: 'вопрос', + answer: 'ответ', + contentTitle: 'содержимое фрагмента', + content: 'содержимое', + template: 'Скачать шаблон здесь', + cancel: 'Отмена', + run: 'Запустить пакет', + runError: 'Ошибка запуска пакета', + processing: 'В процессе пакетной обработки', + completed: 'Импорт завершен', + error: 'Ошибка импорта', + ok: 'ОК', + }, + errorMessage: { + answerRequired: 'Ответ обязателен', + queryRequired: 'Вопрос обязателен', + }, + viewModal: { + annotatedResponse: 'Ответ аннотации', + hitHistory: 'История попаданий', + hit: 'Попадание', + hits: 'Попадания', + noHitHistory: 'Нет истории попаданий', + }, + hitHistoryTable: { + query: 'Запрос', + match: 'Совпадение', + response: 'Ответ', + source: 'Источник', + score: 'Оценка', + time: 'Время', + }, + initSetup: { + title: 'Начальная настройка ответа аннотации', + configTitle: 'Настройка ответа аннотации', + confirmBtn: 'Сохранить и включить', + configConfirmBtn: 'Сохранить', + }, + embeddingModelSwitchTip: 'Модель векторизации текста аннотаций, переключение между моделями будет осуществлено повторно, что приведет к дополнительным затратам.', +} + +export default translation diff --git a/web/i18n/ru-RU/app-api.ts b/web/i18n/ru-RU/app-api.ts new file mode 100644 index 0000000000..0c56d7a501 --- /dev/null +++ b/web/i18n/ru-RU/app-api.ts @@ -0,0 +1,83 @@ +const translation = { + apiServer: 'API Сервер', + apiKey: 'API Ключ', + status: 'Статус', + disabled: 'Отключено', + ok: 'В работе', + copy: 'Копировать', + copied: 'Скопировано', + play: 'Запустить', + pause: 'Приостановить', + playing: 'Запущено', + loading: 'Загрузка', + merMaid: { + rerender: 'Перезапустить рендеринг', + }, + never: 'Никогда', + apiKeyModal: { + apiSecretKey: 'Секретный ключ API', + apiSecretKeyTips: 'Чтобы предотвратить злоупотребление API, защитите свой API ключ. Избегайте использования его в виде plain-текста во фронтенд-коде. :)', + createNewSecretKey: 'Создать новый секретный ключ', + secretKey: 'Секретный ключ', + created: 'СОЗДАН', + lastUsed: 'ПОСЛЕДНЕЕ ИСПОЛЬЗОВАНИЕ', + generateTips: 'Храните этот ключ в безопасном и доступном месте.', + }, + actionMsg: { + deleteConfirmTitle: 'Удалить этот секретный ключ?', + deleteConfirmTips: 'Это действие необратимо.', + ok: 'ОК', + }, + completionMode: { + title: 'API приложения', + info: 'Для высококачественной генерации текста, такой как статьи, резюме и переводы, используйте API completion-messages с пользовательским вводом. Генерация текста основана на параметрах модели и шаблонах подсказок, установленных в Dify Prompt Engineering.', + createCompletionApi: 'Создать completion-message', + createCompletionApiTip: 'Создайте completion-message для поддержки режима вопросов и ответов.', + inputsTips: '(Необязательно) Укажите поля пользовательского ввода в виде пар ключ-значение, соответствующих переменным в Prompt Eng. Ключ - это имя переменной, Значение - это значение параметра. Если тип поля - Выбор, отправленное Значение должно быть одним из предустановленных вариантов.', + queryTips: 'Текстовое содержимое пользовательского ввода.', + blocking: 'Блокирующий тип, ожидает завершения выполнения и возвращает результаты. (Запросы могут быть прерваны, если процесс длительный)', + streaming: ' Ответ в рамках потока. Реализация потоковой передачи ответов на основе SSE (Server-Sent Events).', + messageFeedbackApi: 'Обратная связь по сообщению (лайк)', + messageFeedbackApiTip: 'Оцените полученные сообщения от имени конечных пользователей с помощью лайков или дизлайков. Эти данные видны на странице Журналы и аннотации и используются для будущей тонкой настройки модели.', + messageIDTip: 'Идентификатор сообщения', + ratingTip: 'лайк или дизлайк, null - отмена', + parametersApi: 'Получить информацию о параметрах приложения', + parametersApiTip: 'Получить настроенные входные параметры, включая имена переменных, имена полей, типы и значения по умолчанию. Обычно используется для отображения этих полей в форме или заполнения значений по умолчанию после загрузки клиента.', + }, + chatMode: { + title: 'API приложения чата', + info: 'Для универсальных диалоговых приложений, использующих формат вопросов и ответов, вызовите API chat-messages, чтобы начать диалог. Поддерживайте текущие разговоры, передавая возвращенный conversation_id. Параметры ответа и шаблоны зависят от настроек Dify Prompt Eng.', + createChatApi: 'Создать сообщение чата', + createChatApiTip: 'Создайте новое сообщение разговора или продолжите существующий диалог.', + inputsTips: '(Необязательно) Укажите поля пользовательского ввода в виде пар ключ-значение, соответствующих переменным в Prompt Eng. Ключ - это имя переменной, Значение - это значение параметра. Если тип поля - Выбор, отправленное Значение должно быть одним из предустановленных вариантов.', + queryTips: 'Содержимое пользовательского ввода/вопроса', + blocking: 'Блокирующий тип, ожидает завершения выполнения и возвращает результаты. (Запросы могут быть прерваны, если процесс длительный)', + streaming: 'потоковая передача возвращает. Реализация потоковой передачи возврата на основе SSE (Server-Sent Events).', + conversationIdTip: '(Необязательно) Идентификатор разговора: оставьте пустым для первого разговора; передайте conversation_id из контекста, чтобы продолжить диалог.', + messageFeedbackApi: 'Обратная связь конечного пользователя по сообщению, лайк', + messageFeedbackApiTip: 'Оцените полученные сообщения от имени конечных пользователей с помощью лайков или дизлайков. Эти данные видны на странице Журналы и аннотации и используются для будущей тонкой настройки модели.', + messageIDTip: 'Идентификатор сообщения', + ratingTip: 'лайк или дизлайк, null - отмена', + chatMsgHistoryApi: 'Получить историю сообщений чата', + chatMsgHistoryApiTip: 'Первая страница возвращает последние `limit` строк, которые находятся в обратном порядке.', + chatMsgHistoryConversationIdTip: 'Идентификатор разговора', + chatMsgHistoryFirstId: 'Идентификатор первой записи чата на текущей странице. По умолчанию - нет.', + chatMsgHistoryLimit: 'Сколько чатов возвращается за один запрос', + conversationsListApi: 'Получить список разговоров', + conversationsListApiTip: 'Получает список сеансов текущего пользователя. По умолчанию возвращаются последние 20 сеансов.', + conversationsListFirstIdTip: 'Идентификатор последней записи на текущей странице, по умолчанию - нет.', + conversationsListLimitTip: 'Сколько чатов возвращается за один запрос', + conversationRenamingApi: 'Переименование разговора', + conversationRenamingApiTip: 'Переименовать разговоры; имя отображается в многосессионных клиентских интерфейсах.', + conversationRenamingNameTip: 'Новое имя', + parametersApi: 'Получить информацию о параметрах приложения', + parametersApiTip: 'Получить настроенные входные параметры, включая имена переменных, имена полей, типы и значения по умолчанию. Обычно используется для отображения этих полей в форме или заполнения значений по умолчанию после загрузки клиента.', + }, + develop: { + requestBody: 'Тело запроса', + pathParams: 'Параметры пути', + query: 'Запрос', + }, +} + +export default translation diff --git a/web/i18n/ru-RU/app-debug.ts b/web/i18n/ru-RU/app-debug.ts new file mode 100644 index 0000000000..038165301e --- /dev/null +++ b/web/i18n/ru-RU/app-debug.ts @@ -0,0 +1,463 @@ +const translation = { + pageTitle: { + line1: 'PROMPT', + line2: 'Engineering', + }, + orchestrate: 'Оркестрация', + promptMode: { + simple: 'Переключиться в экспертный режим для редактирования всего ПРОМПТА', + advanced: 'Экспертный режим', + switchBack: 'Переключиться обратно', + advancedWarning: { + title: 'Вы переключились в экспертный режим, и после изменения ПРОМПТА вы НЕ СМОЖЕТЕ вернуться в базовый режим.', + description: 'В экспертном режиме вы можете редактировать весь ПРОМПТ.', + learnMore: 'Узнать больше', + ok: 'ОК', + }, + operation: { + addMessage: 'Добавить сообщение', + }, + contextMissing: 'Отсутствует компонент контекста, эффективность промпта может быть невысокой.', + }, + operation: { + applyConfig: 'Опубликовать', + resetConfig: 'Сбросить', + debugConfig: 'Отладка', + addFeature: 'Добавить функцию', + automatic: 'Сгенерировать', + stopResponding: 'Остановить ответ', + agree: 'лайк', + disagree: 'дизлайк', + cancelAgree: 'Отменить лайк', + cancelDisagree: 'Отменить дизлайк', + userAction: 'Пользователь ', + }, + notSetAPIKey: { + title: 'Ключ поставщика LLM не установлен', + trailFinished: 'Пробный период закончен', + description: 'Ключ поставщика LLM не установлен, его необходимо установить перед отладкой.', + settingBtn: 'Перейти к настройкам', + }, + trailUseGPT4Info: { + title: 'В настоящее время не поддерживается gpt-4', + description: 'Чтобы использовать gpt-4, пожалуйста, установите API ключ.', + }, + feature: { + groupChat: { + title: 'Улучшение чата', + description: 'Добавление настроек предварительного разговора для приложений может улучшить пользовательский опыт.', + }, + groupExperience: { + title: 'Улучшение опыта', + }, + conversationOpener: { + title: 'Начальное сообщение', + description: 'В чат-приложении первое предложение, которое ИИ активно говорит пользователю, обычно используется в качестве приветствия.', + }, + suggestedQuestionsAfterAnswer: { + title: 'Последующие вопросы', + description: 'Настройка предложения следующих вопросов может улучшить чат для пользователей.', + resDes: '3 предложения для следующего вопроса пользователя.', + tryToAsk: 'Попробуйте спросить', + }, + moreLikeThis: { + title: 'Больше похожего', + description: 'Сгенерируйте несколько текстов одновременно, а затем отредактируйте и продолжайте генерировать', + generateNumTip: 'Количество генерируемых каждый раз', + tip: 'Использование этой функции приведет к дополнительным расходам токенов', + }, + speechToText: { + title: 'Преобразование речи в текст', + description: 'После включения вы можете использовать голосовой ввод.', + resDes: 'Голосовой ввод включен', + }, + textToSpeech: { + title: 'Преобразование текста в речь', + description: 'После включения текст можно преобразовать в речь.', + resDes: 'Преобразование текста в аудио включено', + }, + citation: { + title: 'Цитаты и ссылки', + description: 'После включения отображается исходный документ и атрибутированная часть сгенерированного контента.', + resDes: 'Цитаты и ссылки включены', + }, + annotation: { + title: 'Ответ аннотации', + description: 'Вы можете вручную добавить высококачественный ответ в кэш для приоритетного сопоставления с похожими вопросами пользователей.', + resDes: 'Ответ аннотации включен', + scoreThreshold: { + title: 'Порог оценки', + description: 'Используется для установки порога сходства для ответа аннотации.', + easyMatch: 'Простое совпадение', + accurateMatch: 'Точное совпадение', + }, + matchVariable: { + title: 'Переменная соответствия', + choosePlaceholder: 'Выберите переменную соответствия', + }, + cacheManagement: 'Аннотации', + cached: 'Аннотировано', + remove: 'Удалить', + removeConfirm: 'Удалить эту аннотацию?', + add: 'Добавить аннотацию', + edit: 'Редактировать аннотацию', + }, + dataSet: { + title: 'Контекст', + noData: 'Вы можете импортировать знания в качестве контекста', + words: 'Слова', + textBlocks: 'Текстовые блоки', + selectTitle: 'Выберите справочные знания', + selected: 'Знания выбраны', + noDataSet: 'Знания не найдены', + toCreate: 'Перейти к созданию', + notSupportSelectMulti: 'В настоящее время поддерживаются только одни знания', + queryVariable: { + title: 'Переменная запроса', + tip: 'Эта переменная будет использоваться в качестве входных данных запроса для поиска контекста, получая информацию о контексте, связанную с вводом этой переменной.', + choosePlaceholder: 'Выберите переменную запроса', + noVar: 'Нет переменных', + noVarTip: 'пожалуйста, создайте переменную в разделе Переменные', + unableToQueryDataSet: 'Невозможно запросить знания', + unableToQueryDataSetTip: 'Не удалось успешно запросить знания, пожалуйста, выберите переменную запроса контекста в разделе контекста.', + ok: 'ОК', + contextVarNotEmpty: 'переменная запроса контекста не может быть пустой', + deleteContextVarTitle: 'Удалить переменную "{{varName}}"?', + deleteContextVarTip: 'Эта переменная была установлена в качестве переменной запроса контекста, и ее удаление повлияет на нормальное использование знаний. Если вам все еще нужно удалить ее, пожалуйста, выберите ее заново в разделе контекста.', + }, + }, + tools: { + title: 'Инструменты', + tips: 'Инструменты предоставляют стандартный метод вызова API, принимая пользовательский ввод или переменные в качестве параметров запроса для запроса внешних данных в качестве контекста.', + toolsInUse: '{{count}} инструментов используется', + modal: { + title: 'Инструмент', + toolType: { + title: 'Тип инструмента', + placeholder: 'Пожалуйста, выберите тип инструмента', + }, + name: { + title: 'Имя', + placeholder: 'Пожалуйста, введите имя', + }, + variableName: { + title: 'Имя переменной', + placeholder: 'Пожалуйста, введите имя переменной', + }, + }, + }, + conversationHistory: { + title: 'История разговоров', + description: 'Установить префиксы имен для ролей разговора', + tip: 'История разговоров не включена, пожалуйста, добавьте в промпт выше.', + learnMore: 'Узнать больше', + editModal: { + title: 'Редактировать имена ролей разговора', + userPrefix: 'Префикс пользователя', + assistantPrefix: 'Префикс помощника', + }, + }, + toolbox: { + title: 'НАБОР ИНСТРУМЕНТОВ', + }, + moderation: { + title: 'Модерация контента', + description: 'Обеспечьте безопасность выходных данных модели, используя API модерации или поддерживая список чувствительных слов.', + allEnabled: 'ВХОДНОЙ/ВЫХОДНОЙ контент включен', + inputEnabled: 'ВХОДНОЙ контент включен', + outputEnabled: 'ВЫХОДНОЙ контент включен', + modal: { + title: 'Настройки модерации контента', + provider: { + title: 'Поставщик', + openai: 'Модерация OpenAI', + openaiTip: { + prefix: 'Для модерации OpenAI требуется ключ API OpenAI, настроенный в ', + suffix: '.', + }, + keywords: 'Ключевые слова', + }, + keywords: { + tip: 'По одному на строку, разделенные разрывами строк. До 100 символов на строку.', + placeholder: 'По одному на строку, разделенные разрывами строк', + line: 'Строка', + }, + content: { + input: 'Модерировать ВХОДНОЙ контент', + output: 'Модерировать ВЫХОДНОЙ контент', + preset: 'Предустановленные ответы', + placeholder: 'Здесь содержимое предустановленных ответов', + condition: 'Модерация ВХОДНОГО и ВЫХОДНОГО контента включена хотя бы одна', + fromApi: 'Предустановленные ответы возвращаются API', + errorMessage: 'Предустановленные ответы не могут быть пустыми', + supportMarkdown: 'Markdown поддерживается', + }, + openaiNotConfig: { + before: 'Для модерации OpenAI требуется ключ API OpenAI, настроенный в', + after: '', + }, + }, + }, + }, + generate: { + title: 'Генератор промпта', + description: 'Генератор промпта использует настроенную модель для оптимизации промпта для повышения качества и улучшения структуры. Пожалуйста, напишите четкие и подробные инструкции.', + tryIt: 'Попробуйте', + instruction: 'Инструкции', + instructionPlaceHolder: 'Напишите четкие и конкретные инструкции.', + generate: 'Сгенерировать', + resTitle: 'Сгенерированный промпт', + noDataLine1: 'Опишите свой случай использования слева,', + noDataLine2: 'предварительный просмотр оркестрации будет показан здесь.', + apply: 'Применить', + loading: 'Оркестрация приложения для вас...', + overwriteTitle: 'Перезаписать существующую конфигурацию?', + overwriteMessage: 'Применение этого промпта перезапишет существующую конфигурацию.', + template: { + pythonDebugger: { + name: 'Отладчик Python', + instruction: 'Бот, который может генерировать и отлаживать ваш код на основе ваших инструкций', + }, + translation: { + name: 'Переводчик', + instruction: 'Переводчик, который может переводить на несколько языков', + }, + professionalAnalyst: { + name: 'Профессиональный аналитик', + instruction: 'Извлекайте информацию, выявляйте риски и извлекайте ключевую информацию из длинных отчетов в одну записку', + }, + excelFormulaExpert: { + name: 'Эксперт по формулам Excel', + instruction: 'Чат-бот, который может помочь начинающим пользователям понять, использовать и создавать формулы Excel на основе инструкций пользователя', + }, + travelPlanning: { + name: 'Планировщик путешествий', + instruction: 'Помощник по планированию путешествий - это интеллектуальный инструмент, разработанный, чтобы помочь пользователям без труда планировать свои поездки', + }, + SQLSorcerer: { + name: 'SQL-ассистент', + instruction: 'Преобразуйте повседневный язык в SQL-запросы', + }, + GitGud: { + name: 'Git gud', + instruction: 'Генерируйте соответствующие команды Git на основе описанных пользователем действий по управлению версиями', + }, + meetingTakeaways: { + name: 'Итоги совещания', + instruction: 'Извлекайте из совещаний краткие резюме, включая темы обсуждения, ключевые выводы и элементы действий', + }, + writingsPolisher: { + name: 'Редактор', + instruction: 'Используйте LLM, чтобы улучшить свои письменные работы', + }, + }, + }, + resetConfig: { + title: 'Подтвердить сброс?', + message: + 'Сброс отменяет изменения, восстанавливая последнюю опубликованную конфигурацию.', + }, + errorMessage: { + nameOfKeyRequired: 'имя ключа: {{key}} обязательно', + valueOfVarRequired: 'значение {{key}} не может быть пустым', + queryRequired: 'Требуется текст запроса.', + waitForResponse: + 'Пожалуйста, дождитесь завершения ответа на предыдущее сообщение.', + waitForBatchResponse: + 'Пожалуйста, дождитесь завершения ответа на пакетное задание.', + notSelectModel: 'Пожалуйста, выберите модель', + waitForImgUpload: 'Пожалуйста, дождитесь загрузки изображения', + }, + chatSubTitle: 'Инструкции', + completionSubTitle: 'Префикс Промпта', + promptTip: + 'Промпт направляют ответы ИИ с помощью инструкций и ограничений. Вставьте переменные, такие как {{input}}. Этот Промпт не будет видна пользователям.', + formattingChangedTitle: 'Форматирование изменено', + formattingChangedText: + 'Изменение форматирования приведет к сбросу области отладки, вы уверены?', + variableTitle: 'Переменные', + variableTip: + 'Пользователи заполняют переменные в форме, автоматически заменяя переменные в промпте.', + notSetVar: 'Переменные позволяют пользователям вводить промпты или вступительные замечания при заполнении форм. Вы можете попробовать ввести "{{input}}" в промптах.', + autoAddVar: 'В предварительной промпте упоминаются неопределенные переменные, хотите ли вы добавить их в форму пользовательского ввода?', + variableTable: { + key: 'Ключ переменной', + name: 'Имя поля пользовательского ввода', + optional: 'Необязательно', + type: 'Тип ввода', + action: 'Действия', + typeString: 'Строка', + typeSelect: 'Выбор', + }, + varKeyError: { + canNoBeEmpty: '{{key}} обязательно', + tooLong: '{{key}} слишком длинное. Не может быть длиннее 30 символов', + notValid: '{{key}} недействительно. Может содержать только буквы, цифры и подчеркивания', + notStartWithNumber: '{{key}} не может начинаться с цифры', + keyAlreadyExists: '{{key}} уже существует', + }, + otherError: { + promptNoBeEmpty: 'Промпт не может быть пустой', + historyNoBeEmpty: 'История разговоров должна быть установлена в промпте', + queryNoBeEmpty: 'Запрос должен быть установлен в промпте', + }, + variableConfig: { + 'addModalTitle': 'Добавить поле ввода', + 'editModalTitle': 'Редактировать поле ввода', + 'description': 'Настройка для переменной {{varName}}', + 'fieldType': 'Тип поля', + 'string': 'Короткий текст', + 'text-input': 'Короткий текст', + 'paragraph': 'Абзац', + 'select': 'Выбор', + 'number': 'Число', + 'notSet': 'Не задано, попробуйте ввести {{input}} в префикс промпта', + 'stringTitle': 'Параметры текстового поля формы', + 'maxLength': 'Максимальная длина', + 'options': 'Варианты', + 'addOption': 'Добавить вариант', + 'apiBasedVar': 'Переменная на основе API', + 'varName': 'Имя переменной', + 'labelName': 'Имя метки', + 'inputPlaceholder': 'Пожалуйста, введите', + 'content': 'Содержимое', + 'required': 'Обязательно', + 'errorMsg': { + labelNameRequired: 'Имя метки обязательно', + varNameCanBeRepeat: 'Имя переменной не может повторяться', + atLeastOneOption: 'Требуется хотя бы один вариант', + optionRepeat: 'Есть повторяющиеся варианты', + }, + }, + vision: { + name: 'Зрение', + description: 'Включение зрения позволит модели принимать изображения и отвечать на вопросы о них.', + settings: 'Настройки', + visionSettings: { + title: 'Настройки зрения', + resolution: 'Разрешение', + resolutionTooltip: `Низкое разрешение позволит модели получать версию изображения с низким разрешением 512 x 512 и представлять изображение с бюджетом 65 токенов. Это позволяет API возвращать ответы быстрее и потреблять меньше входных токенов для случаев использования, не требующих высокой детализации. + \n + Высокое разрешение сначала позволит модели увидеть изображение с низким разрешением, а затем создаст детальные фрагменты входных изображений в виде квадратов 512 пикселей на основе размера входного изображения. Каждый из детальных фрагментов использует вдвое больший бюджет токенов, в общей сложности 129 токенов.`, + high: 'Высокое', + low: 'Низкое', + uploadMethod: 'Метод загрузки', + both: 'Оба', + localUpload: 'Локальная загрузка', + url: 'URL', + uploadLimit: 'Лимит загрузки', + }, + }, + voice: { + name: 'Голос', + defaultDisplay: 'Голос по умолчанию', + description: 'Настройки преобразования текста в речь', + settings: 'Настройки', + voiceSettings: { + title: 'Настройки голоса', + language: 'Язык', + resolutionTooltip: 'Язык, поддерживаемый преобразованием текста в речь.', + voice: 'Голос', + autoPlay: 'Автовоспроизведение', + autoPlayEnabled: 'Включить', + autoPlayDisabled: 'Выключить', + }, + }, + openingStatement: { + title: 'Начальное сообщение', + add: 'Добавить', + writeOpener: 'Написать начальное сообщение', + placeholder: 'Напишите здесь свое начальное сообщение, вы можете использовать переменные, попробуйте ввести {{variable}}.', + openingQuestion: 'Начальные вопросы', + noDataPlaceHolder: + 'Начало разговора с пользователем может помочь ИИ установить более тесную связь с ним в диалоговых приложениях.', + varTip: 'Вы можете использовать переменные, попробуйте ввести {{variable}}', + tooShort: 'Для генерации вступительного замечания к разговору требуется не менее 20 слов начального промпта.', + notIncludeKey: 'Начальный промпт не включает переменную: {{key}}. Пожалуйста, добавьте её в начальную промпт.', + }, + modelConfig: { + model: 'Модель', + setTone: 'Установить тон ответов', + title: 'Модель и параметры', + modeType: { + chat: 'Чат', + completion: 'Завершение', + }, + }, + inputs: { + title: 'Отладка и предварительный просмотр', + noPrompt: 'Попробуйте написать промпт во входных данных предварительного промпта', + userInputField: 'Поле пользовательского ввода', + noVar: 'Заполните значение переменной, которое будет автоматически заменяться в промпте каждый раз при запуске нового сеанса.', + chatVarTip: + 'Заполните значение переменной, которое будет автоматически заменяться в промпте каждый раз при запуске нового сеанса', + completionVarTip: + 'Заполните значение переменной, которое будет автоматически заменяться в промпте каждый раз при отправке вопроса.', + previewTitle: 'Предварительный просмотр промпта', + queryTitle: 'Содержимое запроса', + queryPlaceholder: 'Пожалуйста, введите текст запроса.', + run: 'ЗАПУСТИТЬ', + }, + result: 'Выходной текст', + datasetConfig: { + settingTitle: 'Настройки поиска', + knowledgeTip: 'Нажмите кнопку "+", чтобы добавить знания', + retrieveOneWay: { + title: 'Поиск N-к-1', + description: 'На основе намерения пользователя и описаний знаний агент автономно выбирает наилучшие знания для запроса. Лучше всего подходит для приложений с различными, ограниченными знаниями.', + }, + retrieveMultiWay: { + title: 'Многопутный поиск', + description: 'На основе намерения пользователя выполняет запросы по всем знаниям, извлекает соответствующий текст из нескольких источников и выбирает наилучшие результаты, соответствующие запросу пользователя, после повторного ранжирования.', + }, + rerankModelRequired: 'Требуется rerank-модель ', + params: 'Параметры', + top_k: 'Top K', + top_kTip: 'Используется для фильтрации фрагментов, наиболее похожих на вопросы пользователей. Система также будет динамически корректировать значение Top K в зависимости от max_tokens выбранной модели.', + score_threshold: 'Порог оценки', + score_thresholdTip: 'Используется для установки порога сходства для фильтрации фрагментов.', + retrieveChangeTip: 'Изменение режима индексации и режима поиска может повлиять на приложения, связанные с этими знаниями.', + }, + debugAsSingleModel: 'Отладка как одной модели', + debugAsMultipleModel: 'Отладка как нескольких моделей', + duplicateModel: 'Дублировать', + publishAs: 'Опубликовать как', + assistantType: { + name: 'Тип помощника', + chatAssistant: { + name: 'Базовый помощник', + description: 'Создайте помощника на основе чата, используя большую языковую модель', + }, + agentAssistant: { + name: 'Агент-помощник', + description: 'Создайте интеллектуального агента, который может автономно выбирать инструменты для выполнения задач', + }, + }, + agent: { + agentMode: 'Режим агента', + agentModeDes: 'Установите тип режима вывода для агента', + agentModeType: { + ReACT: 'ReAct', + functionCall: 'Вызов функции', + }, + setting: { + name: 'Настройки агента', + description: 'Настройки агента-помощника позволяют установить режим агента и расширенные функции, такие как встроенные промпты, доступные только в типе агента.', + maximumIterations: { + name: 'Максимальное количество итераций', + description: 'Ограничьте количество итераций, которые может выполнить агент-помощник', + }, + }, + buildInPrompt: 'Встроенный промпт', + firstPrompt: 'Первый промпт', + nextIteration: 'Следующая итерация', + promptPlaceholder: 'Напишите здесь свой первый промпт', + tools: { + name: 'Инструменты', + description: 'Использование инструментов может расширить возможности LLM, такие как поиск в Интернете или выполнение научных расчетов', + enabled: 'Включено', + }, + }, +} + +export default translation diff --git a/web/i18n/ru-RU/app-log.ts b/web/i18n/ru-RU/app-log.ts new file mode 100644 index 0000000000..c6c54ef178 --- /dev/null +++ b/web/i18n/ru-RU/app-log.ts @@ -0,0 +1,95 @@ +const translation = { + title: 'Логирование', + description: 'В логах записывается состояние работы приложения, включая пользовательский ввод и ответы ИИ.', + dateTimeFormat: 'DD.MM.YYYY HH:mm', + table: { + header: { + updatedTime: 'Время обновления', + time: 'Время создания', + endUser: 'Конечный пользователь или аккаунт', + input: 'Ввод', + output: 'Вывод', + summary: 'Заголовок', + messageCount: 'Количество сообщений', + userRate: 'Оценка пользователя', + adminRate: 'Оценка оп.', + startTime: 'ВРЕМЯ НАЧАЛА', + status: 'СТАТУС', + runtime: 'ВРЕМЯ ВЫПОЛНЕНИЯ', + tokens: 'ТОКЕНЫ', + user: 'Конечный пользователь или аккаунт', + version: 'ВЕРСИЯ', + }, + pagination: { + previous: 'Предыдущий', + next: 'Следующий', + }, + empty: { + noChat: 'Еще нет чатов', + noOutput: 'Нет вывода', + element: { + title: 'Есть кто-нибудь?', + content: 'Наблюдайте и аннотируйте взаимодействия между конечными пользователями и приложениями ИИ здесь, чтобы постоянно повышать точность ИИ. Вы можете попробовать поделиться или протестировать веб-приложение самостоятельно, а затем вернуться на эту страницу.', + }, + }, + }, + detail: { + time: 'Время', + conversationId: 'Идентификатор разговора', + promptTemplate: 'Шаблон подсказки', + promptTemplateBeforeChat: 'Шаблон подсказки перед чатом · Как системное сообщение', + annotationTip: 'Улучшения, отмеченные {{user}}', + timeConsuming: '', + second: 'с', + tokenCost: 'Потрачено токенов', + loading: 'загрузка', + operation: { + like: 'лайк', + dislike: 'дизлайк', + addAnnotation: 'Добавить улучшение', + editAnnotation: 'Редактировать улучшение', + annotationPlaceholder: 'Введите ожидаемый ответ, который вы хотите получить от ИИ, который может быть использован для тонкой настройки модели и постоянного улучшения качества генерации текста в будущем.', + }, + variables: 'Переменные', + uploadImages: 'Загруженные изображения', + }, + filter: { + period: { + today: 'Сегодня', + last7days: 'Последние 7 дней', + last4weeks: 'Последние 4 недели', + last3months: 'Последние 3 месяца', + last12months: 'Последние 12 месяцев', + monthToDate: 'С начала месяца', + quarterToDate: 'С начала квартала', + yearToDate: 'С начала года', + allTime: 'Все время', + }, + annotation: { + all: 'Все', + annotated: 'Аннотированные улучшения ({{count}} элементов)', + not_annotated: 'Не аннотировано', + }, + sortBy: 'Сортировать по:', + descending: 'по убыванию', + ascending: 'по возрастанию', + }, + workflowTitle: 'Журналы рабочих процессов', + workflowSubtitle: 'Журнал записал работу Automate.', + runDetail: { + title: 'Журнал разговоров', + workflowTitle: 'Подробная информация о журнале', + }, + promptLog: 'Журнал подсказок', + agentLog: 'Журнал агента', + viewLog: 'Просмотреть журнал', + agentLogDetail: { + agentMode: 'Режим агента', + toolUsed: 'Использованный инструмент', + iterations: 'Итерации', + iteration: 'Итерация', + finalProcessing: 'Окончательная обработка', + }, +} + +export default translation diff --git a/web/i18n/ru-RU/app-overview.ts b/web/i18n/ru-RU/app-overview.ts new file mode 100644 index 0000000000..2969ddd6e7 --- /dev/null +++ b/web/i18n/ru-RU/app-overview.ts @@ -0,0 +1,168 @@ +const translation = { + welcome: { + firstStepTip: 'Чтобы начать,', + enterKeyTip: 'введите свой ключ API OpenAI ниже', + getKeyTip: 'Получите свой ключ API на панели инструментов OpenAI', + placeholder: 'Ваш ключ API OpenAI (например, sk-xxxx)', + }, + apiKeyInfo: { + cloud: { + trial: { + title: 'Вы используете пробную квоту {{providerName}}.', + description: 'Пробная квота предоставляется для тестирования. Прежде чем пробная квота будет исчерпана, пожалуйста, настройте своего собственного поставщика модели или приобретите дополнительную квоту.', + }, + exhausted: { + title: 'Ваша пробная квота была исчерпана, пожалуйста, настройте свой APIKey.', + description: 'Вы исчерпали свою пробную квоту. Пожалуйста, настройте своего собственного поставщика модели или приобретите дополнительную квоту.', + }, + }, + selfHost: { + title: { + row1: 'Чтобы начать,', + row2: 'сначала настройте своего поставщика модели.', + }, + }, + callTimes: 'Количество вызовов', + usedToken: 'Использованные токены', + setAPIBtn: 'Перейти к настройке поставщика модели', + tryCloud: 'Или попробуйте облачную версию Dify с бесплатной квотой', + }, + overview: { + title: 'Обзор', + appInfo: { + explanation: 'Готовое к использованию веб-приложение ИИ', + accessibleAddress: 'Публичный URL', + preview: 'Предварительный просмотр', + regenerate: 'Перегенерировать', + regenerateNotice: 'Вы хотите перегенерировать публичный URL?', + preUseReminder: 'Пожалуйста, включите веб-приложение перед продолжением.', + settings: { + entry: 'Настройки', + title: 'Настройки веб-приложения', + webName: 'Название веб-приложения', + webDesc: 'Описание веб-приложения', + webDescTip: 'Этот текст будет отображаться на стороне клиента, предоставляя базовые инструкции по использованию приложения', + webDescPlaceholder: 'Введите описание веб-приложения', + language: 'Язык', + workflow: { + title: 'Рабочий процесс', + subTitle: 'Подробности рабочего процесса', + show: 'Показать', + hide: 'Скрыть', + showDesc: 'Показать или скрыть подробности рабочего процесса в веб-приложении', + }, + chatColorTheme: 'Цветовая тема чата', + chatColorThemeDesc: 'Установите цветовую тему чат-бота', + chatColorThemeInverted: 'Инвертированные цвета', + invalidHexMessage: 'Неверное HEX-значение', + sso: { + label: 'SSO аутентификация', + title: 'WebApp SSO', + description: 'Все пользователи должны войти в систему с помощью SSO перед использованием WebApp', + tooltip: 'Обратитесь к администратору, чтобы включить WebApp SSO', + }, + more: { + entry: 'Показать больше настроек', + copyright: 'Авторские права', + copyRightPlaceholder: 'Введите имя автора или организации', + privacyPolicy: 'Политика конфиденциальности', + privacyPolicyPlaceholder: 'Введите ссылку на политику конфиденциальности', + privacyPolicyTip: 'Помогает посетителям понять, какие данные собирает приложение, см. Политику конфиденциальности Dify.', + customDisclaimer: 'Пользовательский отказ от ответственности', + customDisclaimerPlaceholder: 'Введите текст пользовательского отказа от ответственности', + customDisclaimerTip: 'Текст пользовательского отказа от ответственности будет отображаться на стороне клиента, предоставляя дополнительную информацию о приложении', + }, + }, + embedded: { + entry: 'Встраивание', + title: 'Встроить на веб-сайт', + explanation: 'Выберите способ встраивания чат-приложения на свой веб-сайт', + iframe: 'Чтобы добавить чат-приложение в любое место на вашем веб-сайте, добавьте этот iframe в свой HTML-код.', + scripts: 'Чтобы добавить чат-приложение в правый нижний угол вашего веб-сайта, добавьте этот код в свой HTML.', + chromePlugin: 'Установите расширение Dify Chatbot для Chrome', + copied: 'Скопировано', + copy: 'Копировать', + }, + qrcode: { + title: 'QR-код ссылки', + scan: 'Сканировать, чтобы поделиться', + download: 'Скачать QR-код', + }, + customize: { + way: 'способ', + entry: 'Настроить', + title: 'Настроить веб-приложение ИИ', + explanation: 'Вы можете настроить внешний интерфейс веб-приложения в соответствии со своими потребностями.', + way1: { + name: 'Создайте форк клиентского кода, измените его и разверните на Vercel (рекомендуется)', + step1: 'Создайте форк клиентского кода и измените его', + step1Tip: 'Нажмите здесь, чтобы создать форк исходного кода в своей учетной записи GitHub и изменить код', + step1Operation: 'Dify-WebClient', + step2: 'Развернуть на Vercel', + step2Tip: 'Нажмите здесь, чтобы импортировать репозиторий в Vercel и развернуть', + step2Operation: 'Импортировать репозиторий', + step3: 'Настроить переменные среды', + step3Tip: 'Добавьте следующие переменные среды в Vercel', + }, + way2: { + name: 'Напишите клиентский код для вызова API и разверните его на сервере', + operation: 'Документация', + }, + }, + }, + apiInfo: { + title: 'API серверной части', + explanation: 'Легко интегрируется в ваше приложение', + accessibleAddress: 'Конечная точка API сервиса', + doc: 'Справочник по API', + }, + status: { + running: 'В работе', + disable: 'Отключено', + }, + }, + analysis: { + title: 'Анализ', + ms: 'мс', + tokenPS: 'Токен/с', + totalMessages: { + title: 'Всего сообщений', + explanation: 'Ежедневное количество взаимодействий с ИИ.', + }, + totalConversations: { + title: 'Всего чатов', + explanation: 'Ежедневное количество чатов с LLM; проектирование/отладка не учитываются.', + }, + activeUsers: { + title: 'Активные пользователи', + explanation: 'Уникальные пользователи, участвующие в вопросах и ответах с LLM; проектирование/отладка не учитываются.', + }, + tokenUsage: { + title: 'Использование токенов', + explanation: 'Отражает ежедневное использование токенов языковой модели для приложения, полезно для целей контроля затрат.', + consumed: 'Потрачено', + }, + avgSessionInteractions: { + title: 'Среднее количество взаимодействий за сеанс', + explanation: 'Количество непрерывных взаимодействий пользователя с LLM; для приложений на основе чатов.', + }, + avgUserInteractions: { + title: 'Среднее количество взаимодействий пользователя', + explanation: 'Отражает ежедневную частоту использования пользователями. Эта метрика отражает активность пользователей.', + }, + userSatisfactionRate: { + title: 'Уровень удовлетворенности пользователей', + explanation: 'Количество лайков на 1000 сообщений. Это указывает на долю ответов, которыми пользователи довольны.', + }, + avgResponseTime: { + title: 'Среднее время ответа', + explanation: 'Время (мс) для обработки/ответа LLM; для текстовых приложений.', + }, + tps: { + title: 'Скорость вывода токенов', + explanation: 'Измерьте производительность LLM. Подсчитайте скорость вывода токенов LLM от начала запроса до завершения вывода.', + }, + }, +} + +export default translation diff --git a/web/i18n/ru-RU/app.ts b/web/i18n/ru-RU/app.ts new file mode 100644 index 0000000000..f5f45e65f1 --- /dev/null +++ b/web/i18n/ru-RU/app.ts @@ -0,0 +1,138 @@ +const translation = { + createApp: 'СОЗДАТЬ ПРИЛОЖЕНИЕ', + types: { + all: 'Все', + chatbot: 'Чат-бот', + agent: 'Агент', + workflow: 'Рабочий процесс', + completion: 'Завершение', + }, + duplicate: 'Дублировать', + duplicateTitle: 'Дублировать приложение', + export: 'Экспортировать DSL', + exportFailed: 'Ошибка экспорта DSL.', + importDSL: 'Импортировать файл DSL', + createFromConfigFile: 'Создать из файла DSL', + importFromDSL: 'Импортировать из DSL', + importFromDSLFile: 'Из файла DSL', + importFromDSLUrl: 'Из URL', + importFromDSLUrlPlaceholder: 'Вставьте ссылку DSL сюда', + deleteAppConfirmTitle: 'Удалить это приложение?', + deleteAppConfirmContent: + 'Удаление приложения необратимо. Пользователи больше не смогут получить доступ к вашему приложению, и все настройки подсказок и журналы будут безвозвратно удалены.', + appDeleted: 'Приложение удалено', + appDeleteFailed: 'Не удалось удалить приложение', + join: 'Присоединяйтесь к сообществу', + communityIntro: + 'Общайтесь с членами команды, участниками и разработчиками на разных каналах.', + roadmap: 'Посмотреть наш roadmap', + newApp: { + startFromBlank: 'Создать с нуля', + startFromTemplate: 'Создать из шаблона', + captionAppType: 'Какой тип приложения вы хотите создать?', + chatbotDescription: 'Создайте приложение на основе чата. Это приложение использует формат вопросов и ответов, позволяя общаться непрерывно.', + completionDescription: 'Создайте приложение, которое генерирует высококачественный текст на основе подсказок, например, генерирует статьи, резюме, переводы и многое другое.', + completionWarning: 'Этот тип приложения больше не будет поддерживаться.', + agentDescription: 'Создайте интеллектуального агента, который может автономно выбирать инструменты для выполнения задач', + workflowDescription: 'Создайте приложение, которое генерирует высококачественный текст на основе рабочего процесса, организованного с высокой степенью настройки. Подходит для опытных пользователей.', + workflowWarning: 'В настоящее время находится в бета-версии', + chatbotType: 'Метод организации чат-бота', + basic: 'Базовый', + basicTip: 'Для начинающих, можно переключиться на Chatflow позже', + basicFor: 'ДЛЯ НАЧИНАЮЩИХ', + basicDescription: 'Базовый конструктор позволяет создать приложение чат-бота с помощью простых настроек, без возможности изменять встроенные подсказки. Подходит для начинающих.', + advanced: 'Chatflow', + advancedFor: 'Для продвинутых пользователей', + advancedDescription: 'Организация рабочего процесса организует чат-ботов в виде рабочих процессов, предлагая высокую степень настройки, включая возможность редактирования встроенных подсказок. Подходит для опытных пользователей.', + captionName: 'Значок и название приложения', + appNamePlaceholder: 'Дайте вашему приложению имя', + captionDescription: 'Описание', + appDescriptionPlaceholder: 'Введите описание приложения', + useTemplate: 'Использовать этот шаблон', + previewDemo: 'Предварительный просмотр', + chatApp: 'Ассистент', + chatAppIntro: + 'Я хочу создать приложение на основе чата. Это приложение использует формат вопросов и ответов, позволяя общаться непрерывно.', + agentAssistant: 'Новый Ассистент Агента', + completeApp: 'Генератор текста', + completeAppIntro: + 'Я хочу создать приложение, которое генерирует высококачественный текст на основе подсказок, например, генерирует статьи, резюме, переводы и многое другое.', + showTemplates: 'Я хочу выбрать из шаблона', + hideTemplates: 'Вернуться к выбору режима', + Create: 'Создать', + Cancel: 'Отмена', + nameNotEmpty: 'Имя не может быть пустым', + appTemplateNotSelected: 'Пожалуйста, выберите шаблон', + appTypeRequired: 'Пожалуйста, выберите тип приложения', + appCreated: 'Приложение создано', + appCreateFailed: 'Не удалось создать приложение', + }, + editApp: 'Редактировать информацию', + editAppTitle: 'Редактировать информацию о приложении', + editDone: 'Информация о приложении обновлена', + editFailed: 'Не удалось обновить информацию о приложении', + iconPicker: { + ok: 'ОК', + cancel: 'Отмена', + emoji: 'Эмодзи', + image: 'Изображение', + }, + switch: 'Переключиться на Workflow', + switchTipStart: 'Для вас будет создана новая копия Workflow. Новая копия ', + switchTip: 'не позволит', + switchTipEnd: ' переключиться обратно на базовую организацию.', + switchLabel: 'Копия приложения, которая будет создана', + removeOriginal: 'Удалить исходное приложение', + switchStart: 'Переключиться', + typeSelector: { + all: 'ВСЕ типы', + chatbot: 'Чат-бот', + agent: 'Агент', + workflow: 'Рабочий процесс', + completion: 'Завершение', + }, + tracing: { + title: 'Отслеживание производительности приложения', + description: 'Настройка стороннего поставщика LLMOps и отслеживание производительности приложения.', + config: 'Настройка', + view: 'Просмотр', + collapse: 'Свернуть', + expand: 'Развернуть', + tracing: 'Отслеживание', + disabled: 'Отключено', + disabledTip: 'Пожалуйста, сначала настройте провайдера LLM', + enabled: 'В работе', + tracingDescription: 'Запись полного контекста выполнения приложения, включая вызовы LLM, контекст, подсказки, HTTP-запросы и многое другое, на стороннюю платформу трассировки.', + configProviderTitle: { + configured: 'Настроено', + notConfigured: 'Настройте провайдера, чтобы включить трассировку', + moreProvider: 'Больше провайдеров', + }, + langsmith: { + title: 'LangSmith', + description: 'Универсальная платформа для разработчиков для каждого этапа жизненного цикла приложения на базе LLM.', + }, + langfuse: { + title: 'Langfuse', + description: 'Трассировка, оценка, управление подсказками и метрики для отладки и улучшения вашего приложения LLM.', + }, + inUse: 'Используется', + configProvider: { + title: 'Настройка ', + placeholder: 'Введите ваш {{key}}', + project: 'Проект', + publicKey: 'Публичный ключ', + secretKey: 'Секретный ключ', + viewDocsLink: 'Посмотреть документацию {{key}}', + removeConfirmTitle: 'Удалить конфигурацию {{key}}?', + removeConfirmContent: 'Текущая конфигурация используется, ее удаление отключит функцию трассировки.', + }, + }, + answerIcon: { + title: 'Использование значка WebApp для замены 🤖', + description: 'Следует ли использовать значок WebApp для замены 🤖 в общем приложении', + descriptionInExplore: 'Следует ли использовать значок WebApp для замены 🤖 в разделе "Обзор"', + }, +} + +export default translation diff --git a/web/i18n/ru-RU/billing.ts b/web/i18n/ru-RU/billing.ts new file mode 100644 index 0000000000..e7760d9ac6 --- /dev/null +++ b/web/i18n/ru-RU/billing.ts @@ -0,0 +1,118 @@ +const translation = { + currentPlan: 'Текущий тарифный план', + upgradeBtn: { + plain: 'Обновить тарифный план', + encourage: 'Обновить сейчас', + encourageShort: 'Обновить', + }, + viewBilling: 'Управление счетами и подписками', + buyPermissionDeniedTip: 'Пожалуйста, свяжитесь с администратором вашей организации, чтобы подписаться', + plansCommon: { + title: 'Выберите тарифный план, который подходит именно вам', + yearlyTip: 'Получите 2 месяца бесплатно, подписавшись на год!', + mostPopular: 'Самый популярный', + planRange: { + monthly: 'Ежемесячно', + yearly: 'Ежегодно', + }, + month: 'месяц', + year: 'год', + save: 'Сэкономить ', + free: 'Бесплатно', + currentPlan: 'Текущий тарифный план', + contractSales: 'Связаться с отделом продаж', + contractOwner: 'Связаться с руководителем команды', + startForFree: 'Начать бесплатно', + getStartedWith: 'Начать с ', + contactSales: 'Связаться с отделом продаж', + talkToSales: 'Поговорить с отделом продаж', + modelProviders: 'Поставщики моделей', + teamMembers: 'Участники команды', + annotationQuota: 'Квота аннотаций', + buildApps: 'Создать приложения', + vectorSpace: 'Векторное пространство', + vectorSpaceBillingTooltip: 'Каждый 1 МБ может хранить около 1,2 миллиона символов векторизованных данных (оценка с использованием Embeddings OpenAI, варьируется в зависимости от модели).', + vectorSpaceTooltip: 'Векторное пространство - это система долговременной памяти, необходимая LLM для понимания ваших данных.', + documentsUploadQuota: 'Квота загрузки документов', + documentProcessingPriority: 'Приоритет обработки документов', + documentProcessingPriorityTip: 'Для более высокого приоритета обработки документов, пожалуйста, обновите свой тарифный план.', + documentProcessingPriorityUpgrade: 'Обрабатывайте больше данных с большей точностью и на более высоких скоростях.', + priority: { + 'standard': 'Стандартный', + 'priority': 'Приоритетный', + 'top-priority': 'Высокий приоритет', + }, + logsHistory: 'История журналов', + customTools: 'Пользовательские инструменты', + unavailable: 'Недоступно', + days: 'дней', + unlimited: 'Неограниченно', + support: 'Поддержка', + supportItems: { + communityForums: 'Форумы сообщества', + emailSupport: 'Поддержка по электронной почте', + priorityEmail: 'Приоритетная поддержка по электронной почте и в чате', + logoChange: 'Изменение логотипа', + SSOAuthentication: 'SSO аутентификация', + personalizedSupport: 'Персональная поддержка', + dedicatedAPISupport: 'Выделенная поддержка API', + customIntegration: 'Пользовательская интеграция и поддержка', + ragAPIRequest: 'Запросы RAG API', + bulkUpload: 'Массовая загрузка документов', + agentMode: 'Режим агента', + workflow: 'Рабочий процесс', + llmLoadingBalancing: 'Балансировка нагрузки LLM', + llmLoadingBalancingTooltip: 'Добавьте несколько ключей API к моделям, эффективно обходя ограничения скорости API.', + }, + comingSoon: 'Скоро', + member: 'Участник', + memberAfter: 'Участник', + messageRequest: { + title: 'Кредиты на сообщения', + tooltip: 'Квоты вызова сообщений для различных тарифных планов, использующих модели OpenAI (кроме gpt4). Сообщения, превышающие лимит, будут использовать ваш ключ API OpenAI.', + }, + annotatedResponse: { + title: 'Ограничения квоты аннотаций', + tooltip: 'Ручное редактирование и аннотирование ответов обеспечивает настраиваемые высококачественные возможности ответов на вопросы для приложений. (Применимо только в чат-приложениях)', + }, + ragAPIRequestTooltip: 'Относится к количеству вызовов API, вызывающих только возможности обработки базы знаний Dify.', + receiptInfo: 'Только владелец команды и администратор команды могут подписываться и просматривать информацию о выставлении счетов', + }, + plans: { + sandbox: { + name: 'Песочница', + description: '200 бесплатных пробных использований GPT', + includesTitle: 'Включает:', + }, + professional: { + name: 'Профессиональный', + description: 'Для частных лиц и небольших команд, чтобы разблокировать больше возможностей по доступной цене.', + includesTitle: 'Все в бесплатном плане, плюс:', + }, + team: { + name: 'Команда', + description: 'Сотрудничайте без ограничений и наслаждайтесь высочайшей производительностью.', + includesTitle: 'Все в профессиональном плане, плюс:', + }, + enterprise: { + name: 'Корпоративный', + description: 'Получите полный набор возможностей и поддержку для крупномасштабных критически важных систем.', + includesTitle: 'Все в командном плане, плюс:', + }, + }, + vectorSpace: { + fullTip: 'Векторное пространство заполнено.', + fullSolution: 'Обновите свой тарифный план, чтобы получить больше места.', + }, + apps: { + fullTipLine1: 'Обновите свой тарифный план, чтобы', + fullTipLine2: 'создавать больше приложений.', + }, + annotatedResponse: { + fullTipLine1: 'Обновите свой тарифный план, чтобы', + fullTipLine2: 'аннотировать больше разговоров.', + quotaTitle: 'Квота ответов аннотаций', + }, +} + +export default translation diff --git a/web/i18n/ru-RU/common.ts b/web/i18n/ru-RU/common.ts new file mode 100644 index 0000000000..82e3471e60 --- /dev/null +++ b/web/i18n/ru-RU/common.ts @@ -0,0 +1,578 @@ +const translation = { + api: { + success: 'Успешно', + actionSuccess: 'Действие выполнено успешно', + saved: 'Сохранено', + create: 'Создано', + remove: 'Удалено', + }, + operation: { + create: 'Создать', + confirm: 'Подтвердить', + cancel: 'Отмена', + clear: 'Очистить', + save: 'Сохранить', + saveAndEnable: 'Сохранить и включить', + edit: 'Редактировать', + add: 'Добавить', + added: 'Добавлено', + refresh: 'Перезапустить', + reset: 'Сбросить', + search: 'Поиск', + change: 'Изменить', + remove: 'Удалить', + send: 'Отправить', + copy: 'Копировать', + lineBreak: 'Разрыв строки', + sure: 'Я уверен', + download: 'Скачать', + delete: 'Удалить', + settings: 'Настройки', + setup: 'Настроить', + getForFree: 'Получить бесплатно', + reload: 'Перезагрузить', + ok: 'ОК', + log: 'Журнал', + learnMore: 'Узнать больше', + params: 'Параметры', + duplicate: 'Дублировать', + rename: 'Переименовать', + audioSourceUnavailable: 'AudioSource недоступен', + }, + errorMsg: { + fieldRequired: '{{field}} обязательно', + urlError: 'URL должен начинаться с http:// или https://', + }, + placeholder: { + input: 'Пожалуйста, введите', + select: 'Пожалуйста, выберите', + }, + voice: { + language: { + zhHans: 'Китайский', + zhHant: 'Традиционный китайский', + enUS: 'Английский', + deDE: 'Немецкий', + frFR: 'Французский', + esES: 'Испанский', + itIT: 'Итальянский', + thTH: 'Тайский', + idID: 'Индонезийский', + jaJP: 'Японский', + koKR: 'Корейский', + ptBR: 'Португальский', + ruRU: 'Русский', + ukUA: 'Украинский', + viVN: 'Вьетнамский', + plPL: 'Польский', + roRO: 'Румынский', + hiIN: 'Хинди', + trTR: 'Турецкий', + faIR: 'Персидский', + }, + }, + unit: { + char: 'символов', + }, + actionMsg: { + noModification: 'На данный момент нет изменений.', + modifiedSuccessfully: 'Изменено успешно', + modifiedUnsuccessfully: 'Изменено неудачно', + copySuccessfully: 'Скопировано успешно', + paySucceeded: 'Оплата прошла успешно', + payCancelled: 'Оплата отменена', + generatedSuccessfully: 'Сгенерировано успешно', + generatedUnsuccessfully: 'Сгенерировано неудачно', + }, + model: { + params: { + temperature: 'Temperature', + temperatureTip: + 'Контролирует случайность: более низкое значение приводит к менее случайным завершениям. По мере приближения температуры к нулю модель станет детерминированной и повторяющейся.', + top_p: 'Top P', + top_pTip: + 'Контролирует разнообразие с помощью ядерной выборки: 0,5 означает, что рассматривается половина всех вариантов, взвешенных по вероятности.', + presence_penalty: 'Presence penalty', + presence_penaltyTip: + 'Насколько штрафовать новые токены в зависимости от того, появляются ли они в тексте до сих пор.\nУвеличивает вероятность того, что модель будет говорить о новых темах.', + frequency_penalty: 'Frequency penalty', + frequency_penaltyTip: + 'Насколько штрафовать новые токены в зависимости от их существующей частоты в тексте до сих пор.\nУменьшает вероятность того, что модель будет повторять одну и ту же строку дословно.', + max_tokens: 'Максимальное количество токенов', + max_tokensTip: + 'Используется для ограничения максимальной длины ответа в токенах. \nБольшие значения могут ограничивать пространство, оставленное для подсказок, журналов чата и знаний. \nРекомендуется установить его ниже двух третей\ngpt-4-1106-preview, gpt-4-vision-preview max token (input 128k output 4k)', + maxTokenSettingTip: 'Ваша настройка максимального количества токенов высока, что потенциально ограничивает пространство для подсказок, запросов и данных. Подумайте о том, чтобы установить его ниже 2/3.', + setToCurrentModelMaxTokenTip: 'Максимальное количество токенов обновлено до 80% максимального количества токенов текущей модели {{maxToken}}.', + stop_sequences: 'Стоп-последовательности', + stop_sequencesTip: 'До четырех последовательностей, где API прекратит генерировать дальнейшие токены. Возвращаемый текст не будет содержать стоп-последовательность.', + stop_sequencesPlaceholder: 'Введите последовательность и нажмите Tab', + }, + tone: { + Creative: 'Творческий', + Balanced: 'Сбалансированный', + Precise: 'Точный', + Custom: 'Пользовательский', + }, + addMoreModel: 'Перейдите в настройки, чтобы добавить больше моделей', + }, + menus: { + status: 'бета', + explore: 'Исследовать', + apps: 'Студия', + plugins: 'Плагины', + pluginsTips: 'Интегрируйте сторонние плагины или создавайте совместимые с ChatGPT AI-плагины.', + datasets: 'Знания', + datasetsTips: 'СКОРО: Импортируйте свои собственные текстовые данные или записывайте данные в режиме реального времени через Webhook для улучшения контекста LLM.', + newApp: 'Новое приложение', + newDataset: 'Создать знания', + tools: 'Инструменты', + }, + userProfile: { + settings: 'Настройки', + emailSupport: 'Поддержка по электронной почте', + workspace: 'Рабочее пространство', + createWorkspace: 'Создать рабочее пространство', + helpCenter: 'Помощь', + communityFeedback: 'Обратная связь', + roadmap: 'План развития', + community: 'Сообщество', + about: 'О нас', + logout: 'Выйти', + }, + settings: { + accountGroup: 'АККАУНТ', + workplaceGroup: 'РАБОЧЕЕ ПРОСТРАНСТВО', + account: 'Моя учетная запись', + members: 'Участники', + billing: 'Оплата', + integrations: 'Интеграции', + language: 'Язык', + provider: 'Поставщик модели', + dataSource: 'Источник данных', + plugin: 'Плагины', + apiBasedExtension: 'API расширение', + }, + account: { + avatar: 'Аватар', + name: 'Имя', + email: 'Электронная почта', + password: 'Пароль', + passwordTip: 'Вы можете установить постоянный пароль, если не хотите использовать временные коды входа', + setPassword: 'Установить пароль', + resetPassword: 'Сбросить пароль', + currentPassword: 'Текущий пароль', + newPassword: 'Новый пароль', + confirmPassword: 'Подтвердите пароль', + notEqual: 'Два пароля различаются.', + langGeniusAccount: 'Учетная запись Dify', + langGeniusAccountTip: 'Ваша учетная запись Dify и связанные с ней пользовательские данные.', + editName: 'Редактировать имя', + showAppLength: 'Показать {{length}} приложений', + delete: 'Удалить учетную запись', + deleteTip: 'Удаление вашей учетной записи приведет к безвозвратному удалению всех ваших данных, и их невозможно будет восстановить.', + deleteConfirmTip: 'Для подтверждения, пожалуйста, отправьте следующее с вашего зарегистрированного адреса электронной почты на ', + }, + members: { + team: 'Команда', + invite: 'Добавить', + name: 'ИМЯ', + lastActive: 'ПОСЛЕДНЯЯ АКТИВНОСТЬ', + role: 'РОЛИ', + pending: 'Ожидание...', + owner: 'Владелец', + admin: 'Администратор', + adminTip: 'Может создавать приложения и управлять настройками команды', + normal: 'Обычный', + normalTip: 'Может только использовать приложения, не может создавать приложения', + builder: 'Разработчик', + builderTip: 'Может создавать и редактировать собственные приложения', + editor: 'Редактор', + editorTip: 'Может создавать и редактировать приложения', + datasetOperator: 'Администратор знаний', + datasetOperatorTip: 'Может управлять только базой знаний', + inviteTeamMember: 'Добавить участника команды', + inviteTeamMemberTip: 'Они могут получить доступ к данным вашей команды сразу после входа в систему.', + email: 'Электронная почта', + emailInvalid: 'Неверный формат электронной почты', + emailPlaceholder: 'Пожалуйста, введите адреса электронной почты', + sendInvite: 'Отправить приглашение', + invitedAsRole: 'Приглашен как пользователь с ролью {{role}}', + invitationSent: 'Приглашение отправлено', + invitationSentTip: 'Приглашение отправлено, и они могут войти в Dify, чтобы получить доступ к данным вашей команды.', + invitationLink: 'Ссылка для приглашения', + failedInvitationEmails: 'Следующие пользователи не были успешно приглашены', + ok: 'ОК', + removeFromTeam: 'Удалить из команды', + removeFromTeamTip: 'Удалить доступ к команде', + setAdmin: 'Назначить администратором', + setMember: 'Назначить обычным участником', + setBuilder: 'Назначить разработчиком', + setEditor: 'Назначить редактором', + disInvite: 'Отменить приглашение', + deleteMember: 'Удалить участника', + you: '(Вы)', + }, + integrations: { + connected: 'Подключено', + google: 'Google', + googleAccount: 'Войти с помощью учетной записи Google', + github: 'GitHub', + githubAccount: 'Войти с помощью учетной записи GitHub', + connect: 'Подключить', + }, + language: { + displayLanguage: 'Язык отображения', + timezone: 'Часовой пояс', + }, + provider: { + apiKey: 'Ключ API', + enterYourKey: 'Введите свой ключ API здесь', + invalidKey: 'Неверный ключ API OpenAI', + validatedError: 'Ошибка валидации: ', + validating: 'Проверка ключа...', + saveFailed: 'Ошибка сохранения ключа API', + apiKeyExceedBill: 'Этот API-ключ не имеет доступной квоты, пожалуйста, прочитайте', + addKey: 'Добавить ключ', + comingSoon: 'Скоро', + editKey: 'Редактировать', + invalidApiKey: 'Неверный ключ API', + azure: { + apiBase: 'Базовый API', + apiBasePlaceholder: 'Базовый URL-адрес API вашей конечной точки Azure OpenAI.', + apiKey: 'Ключ API', + apiKeyPlaceholder: 'Введите свой ключ API здесь', + helpTip: 'Узнать о службе Azure OpenAI', + }, + openaiHosted: { + openaiHosted: 'Размещенный OpenAI', + onTrial: 'ПРОБНАЯ ВЕРСИЯ', + exhausted: 'КВОТА ИСЧЕРПАНА', + desc: 'Хостинговая служба OpenAI, предоставляемая Dify, позволяет вам использовать такие модели, как GPT-3.5. Прежде чем ваша пробная квота будет исчерпана, вам необходимо настроить других поставщиков моделей.', + callTimes: 'Количество вызовов', + usedUp: 'Пробная квота исчерпана. Добавьте собственного поставщика модели.', + useYourModel: 'В настоящее время используется собственный поставщик модели.', + close: 'Закрыть', + }, + anthropicHosted: { + anthropicHosted: 'Anthropic Claude', + onTrial: 'ПРОБНАЯ ВЕРСИЯ', + exhausted: 'КВОТА ИСЧЕРПАНА', + desc: 'Мощная модель, которая отлично справляется с широким спектром задач, от сложных диалогов и создания творческого контента до подробных инструкций.', + callTimes: 'Количество вызовов', + usedUp: 'Пробная квота исчерпана. Добавьте собственного поставщика модели.', + useYourModel: 'В настоящее время используется собственный поставщик модели.', + close: 'Закрыть', + }, + anthropic: { + using: 'Возможность встраивания использует', + enableTip: 'Чтобы включить модель Anthropic, вам необходимо сначала привязаться к OpenAI или Azure OpenAI Service.', + notEnabled: 'Не включено', + keyFrom: 'Получите свой ключ API от Anthropic', + }, + encrypted: { + front: 'Ваш API-ключ будет зашифрован и сохранен с использованием', + back: ' технологии.', + }, + }, + modelProvider: { + notConfigured: 'Системная модель еще не полностью настроена, и некоторые функции могут быть недоступны.', + systemModelSettings: 'Настройки системной модели', + systemModelSettingsLink: 'Зачем нужно настраивать системную модель?', + selectModel: 'Выберите свою модель', + setupModelFirst: 'Пожалуйста, сначала настройте свою модель', + systemReasoningModel: { + key: 'Модель системного мышления', + tip: 'Установите модель вывода по умолчанию, которая будет использоваться для создания приложений, а также такие функции, как генерация имени диалога и предложение следующего вопроса, также будут использовать модель вывода по умолчанию.', + }, + embeddingModel: { + key: 'Модель встраивания', + tip: 'Установите модель по умолчанию для обработки встраивания документов знаний, как поиск, так и импорт знаний используют эту модель встраивания для обработки векторизации. Переключение приведет к несоответствию векторного измерения между импортированными знаниями и вопросом, что приведет к сбою поиска. Чтобы избежать сбоя поиска, пожалуйста, не переключайте эту модель по своему усмотрению.', + required: 'Модель встраивания обязательна', + }, + speechToTextModel: { + key: 'Модель преобразования речи в текст', + tip: 'Установите модель по умолчанию для ввода речи в текст в разговоре.', + }, + ttsModel: { + key: 'Модель преобразования текста в речь', + tip: 'Установите модель по умолчанию для ввода текста в речь в разговоре.', + }, + rerankModel: { + key: 'Модель повторного ранжирования', + tip: 'Модель повторного ранжирования изменит порядок списка документов-кандидатов на основе семантического соответствия запросу пользователя, улучшая результаты семантического ранжирования', + }, + apiKey: 'API-КЛЮЧ', + quota: 'Квота', + searchModel: 'Поиск модели', + noModelFound: 'Модель не найдена для {{model}}', + models: 'Модели', + showMoreModelProvider: 'Показать больше поставщиков моделей', + selector: { + tip: 'Эта модель была удалена. Пожалуйста, добавьте модель или выберите другую модель.', + emptyTip: 'Нет доступных моделей', + emptySetting: 'Пожалуйста, перейдите в настройки для настройки', + rerankTip: 'Пожалуйста, настройте модель повторного ранжирования', + }, + card: { + quota: 'КВОТА', + onTrial: 'Пробная версия', + paid: 'Платный', + quotaExhausted: 'Квота исчерпана', + callTimes: 'Количество вызовов', + tokens: 'Токены', + buyQuota: 'Купить квоту', + priorityUse: 'Приоритетное использование', + removeKey: 'Удалить API-ключ', + tip: 'Приоритет будет отдаваться платной квоте. Пробная квота будет использоваться после исчерпания платной квоты.', + }, + item: { + deleteDesc: '{{modelName}} используются в качестве моделей системного мышления. Некоторые функции будут недоступны после удаления. Пожалуйста, подтвердите.', + freeQuota: 'БЕСПЛАТНАЯ КВОТА', + }, + addApiKey: 'Добавьте свой API-ключ', + invalidApiKey: 'Неверный API-ключ', + encrypted: { + front: 'Ваш API-ключ будет зашифрован и сохранен с использованием', + back: ' технологии.', + }, + freeQuota: { + howToEarn: 'Как заработать', + }, + addMoreModelProvider: 'ДОБАВИТЬ БОЛЬШЕ ПОСТАВЩИКОВ МОДЕЛЕЙ', + addModel: 'Добавить модель', + modelsNum: '{{num}} Моделей', + showModels: 'Показать модели', + showModelsNum: 'Показать {{num}} моделей', + collapse: 'Свернуть', + config: 'Настройка', + modelAndParameters: 'Модель и параметры', + model: 'Модель', + featureSupported: '{{feature}} поддерживается', + callTimes: 'Количество вызовов', + credits: 'Кредиты на сообщения', + buyQuota: 'Купить квоту', + getFreeTokens: 'Получить бесплатные токены', + priorityUsing: 'Приоритетное использование', + deprecated: 'Устаревший', + confirmDelete: 'Подтвердить удаление?', + quotaTip: 'Оставшиеся доступные бесплатные токены', + loadPresets: 'Загрузить предустановки', + parameters: 'ПАРАМЕТРЫ', + loadBalancing: 'Балансировка нагрузки', + loadBalancingDescription: 'Снизьте нагрузку с помощью нескольких наборов учетных данных.', + loadBalancingHeadline: 'Балансировка нагрузки', + configLoadBalancing: 'Настроить балансировку нагрузки', + modelHasBeenDeprecated: 'Эта модель устарела', + providerManaged: 'Управляется поставщиком', + providerManagedDescription: 'Используйте один набор учетных данных, предоставленный поставщиком модели.', + defaultConfig: 'Настройка по умолчанию', + apiKeyStatusNormal: 'Статус APIKey в норме', + apiKeyRateLimit: 'Достигнут предел скорости, доступен через {{seconds}}s', + addConfig: 'Добавить конфигурацию', + editConfig: 'Редактировать конфигурацию', + loadBalancingLeastKeyWarning: 'Для включения балансировки нагрузки необходимо включить не менее 2 ключей.', + loadBalancingInfo: 'По умолчанию балансировка нагрузки использует стратегию Round-robin. Если срабатывает ограничение скорости, будет применен 1-минутный период охлаждения.', + upgradeForLoadBalancing: 'Обновите свой тарифный план, чтобы включить балансировку нагрузки.', + }, + dataSource: { + add: 'Добавить источник данных', + connect: 'Подключить', + configure: 'Настроить', + notion: { + title: 'Notion', + description: 'Использование Notion в качестве источника данных для знаний.', + connectedWorkspace: 'Подключенное рабочее пространство', + addWorkspace: 'Добавить рабочее пространство', + connected: 'Подключено', + disconnected: 'Отключено', + changeAuthorizedPages: 'Изменить авторизованные страницы', + pagesAuthorized: 'Авторизованные страницы', + sync: 'Синхронизировать', + remove: 'Удалить', + selector: { + pageSelected: 'Выбранные страницы', + searchPages: 'Поиск страниц...', + noSearchResult: 'Нет результатов поиска', + addPages: 'Добавить страницы', + preview: 'ПРЕДПРОСМОТР', + }, + }, + website: { + title: 'Веб-сайт', + description: 'Импортировать контент с веб-сайтов с помощью веб-краулера.', + with: 'С', + configuredCrawlers: 'Настроенные краулеры', + active: 'Активный', + inactive: 'Неактивный', + }, + }, + plugin: { + serpapi: { + apiKey: 'Ключ API', + apiKeyPlaceholder: 'Введите свой ключ API', + keyFrom: 'Получите свой ключ SerpAPI на странице учетной записи SerpAPI', + }, + }, + apiBasedExtension: { + title: 'API-расширения обеспечивают централизованное управление API, упрощая настройку для удобного использования в приложениях Dify.', + link: 'Узнайте, как разработать собственное API-расширение.', + linkUrl: 'https://docs.dify.ai/features/extension/api_based_extension', + add: 'Добавить API Extension', + selector: { + title: 'API Extension', + placeholder: 'Пожалуйста, выберите API-расширение', + manage: 'Управление API-расширением', + }, + modal: { + title: 'Добавить API-расширение', + editTitle: 'Редактировать API-расширение', + name: { + title: 'Имя', + placeholder: 'Пожалуйста, введите имя', + }, + apiEndpoint: { + title: 'API Endpoint', + placeholder: 'Пожалуйста, введите конечную точку API', + }, + apiKey: { + title: 'API-ключ', + placeholder: 'Пожалуйста, введите API-ключ', + lengthError: 'Длина API-ключа не может быть меньше 5 символов', + }, + }, + type: 'Тип', + }, + about: { + changeLog: 'Журнал изменений', + updateNow: 'Обновить сейчас', + nowAvailable: 'Dify {{version}} теперь доступен.', + latestAvailable: 'Dify {{version}} - последняя доступная версия.', + }, + appMenus: { + overview: 'Мониторинг', + promptEng: 'Оркестрация', + apiAccess: 'Доступ к API', + logAndAnn: 'Журналы и аннотации', + logs: 'Журналы', + }, + environment: { + testing: 'ТЕСТИРОВАНИЕ', + development: 'РАЗРАБОТКА', + }, + appModes: { + completionApp: 'Генератор текста', + chatApp: 'Чат-приложение', + }, + datasetMenus: { + documents: 'Документы', + hitTesting: 'Тестирование поиска', + settings: 'Настройки', + emptyTip: 'Знания не были связаны, пожалуйста, перейдите в приложение или плагин, чтобы завершить связывание.', + viewDoc: 'Просмотреть документацию', + relatedApp: 'связанные приложения', + }, + voiceInput: { + speaking: 'Говорите сейчас...', + converting: 'Преобразование в текст...', + notAllow: 'микрофон не авторизован', + }, + modelName: { + 'gpt-3.5-turbo': 'GPT-3.5-Turbo', + 'gpt-3.5-turbo-16k': 'GPT-3.5-Turbo-16K', + 'gpt-4': 'GPT-4', + 'gpt-4-32k': 'GPT-4-32K', + 'text-davinci-003': 'Text-Davinci-003', + 'text-embedding-ada-002': 'Text-Embedding-Ada-002', + 'whisper-1': 'Whisper-1', + 'claude-instant-1': 'Claude-Instant', + 'claude-2': 'Claude-2', + }, + chat: { + renameConversation: 'Переименовать разговор', + conversationName: 'Название разговора', + conversationNamePlaceholder: 'Пожалуйста, введите название разговора', + conversationNameCanNotEmpty: 'Название разговора обязательно', + citation: { + title: 'ЦИТАТЫ', + linkToDataset: 'Ссылка на знания', + characters: 'Символы:', + hitCount: 'Количество совпадений:', + vectorHash: 'Векторный хэш:', + hitScore: 'Оценка совпадения:', + }, + }, + promptEditor: { + placeholder: 'Напишите здесь свое ключевое слово подсказки, введите \'{\', чтобы вставить переменную, введите \'/\', чтобы вставить блок содержимого подсказки', + context: { + item: { + title: 'Контекст', + desc: 'Вставить шаблон контекста', + }, + modal: { + title: '{{num}} знаний в контексте', + add: 'Добавить контекст ', + footer: 'Вы можете управлять контекстами в разделе «Контекст» ниже.', + }, + }, + history: { + item: { + title: 'История разговоров', + desc: 'Вставить шаблон исторического сообщения', + }, + modal: { + title: 'ПРИМЕР', + user: 'Привет', + assistant: 'Привет! Как я могу вам помочь сегодня?', + edit: 'Редактировать имена ролей разговора', + }, + }, + variable: { + item: { + title: 'Переменные и внешние инструменты', + desc: 'Вставить переменные и внешние инструменты', + }, + outputToolDisabledItem: { + title: 'Переменные', + desc: 'Вставить переменные', + }, + modal: { + add: 'Новая переменная', + addTool: 'Новый инструмент', + }, + }, + query: { + item: { + title: 'Запрос', + desc: 'Вставить шаблон запроса пользователя', + }, + }, + existed: 'Уже существует в подсказке', + }, + imageUploader: { + uploadFromComputer: 'Загрузить с компьютера', + uploadFromComputerReadError: 'Ошибка чтения изображения, повторите попытку.', + uploadFromComputerUploadError: 'Ошибка загрузки изображения, загрузите еще раз.', + uploadFromComputerLimit: 'Загружаемые изображения не могут превышать {{size}} МБ', + pasteImageLink: 'Вставить ссылку на изображение', + pasteImageLinkInputPlaceholder: 'Вставьте ссылку на изображение здесь', + pasteImageLinkInvalid: 'Неверная ссылка на изображение', + imageUpload: 'Загрузка изображения', + }, + tag: { + placeholder: 'Все теги', + addNew: 'Добавить новый тег', + noTag: 'Нет тегов', + noTagYet: 'Еще нет тегов', + addTag: 'Добавить теги', + editTag: 'Редактировать теги', + manageTags: 'Управление тегами', + selectorPlaceholder: 'Введите для поиска или создания', + create: 'Создать', + delete: 'Удалить тег', + deleteTip: 'Тег используется, удалить его?', + created: 'Тег успешно создан', + failed: 'Ошибка создания тега', + }, +} + +export default translation diff --git a/web/i18n/ru-RU/custom.ts b/web/i18n/ru-RU/custom.ts new file mode 100644 index 0000000000..8725c83577 --- /dev/null +++ b/web/i18n/ru-RU/custom.ts @@ -0,0 +1,30 @@ +const translation = { + custom: 'Настройка', + upgradeTip: { + prefix: 'Обновите свой тарифный план, чтобы', + suffix: 'настроить свой бренд.', + }, + webapp: { + title: 'Настроить бренд веб-приложения', + removeBrand: 'Удалить Powered by Dify', + changeLogo: 'Изменить изображение бренда Powered by', + changeLogoTip: 'Формат SVG или PNG с минимальным размером 40x40px', + }, + app: { + title: 'Настроить бренд заголовка приложения', + changeLogoTip: 'Формат SVG или PNG с минимальным размером 80x80px', + }, + upload: 'Загрузить', + uploading: 'Загрузка', + uploadedFail: 'Ошибка загрузки изображения, пожалуйста изображение, загрузите еще раз.', + change: 'Изменить', + apply: 'Применить', + restore: 'Восстановить значения по умолчанию', + customize: { + contactUs: ' свяжитесь с нами ', + prefix: 'Чтобы настроить логотип бренда в приложении, пожалуйста,', + suffix: 'чтобы перейти на корпоративную версию.', + }, +} + +export default translation diff --git a/web/i18n/ru-RU/dataset-creation.ts b/web/i18n/ru-RU/dataset-creation.ts new file mode 100644 index 0000000000..c4dce774d8 --- /dev/null +++ b/web/i18n/ru-RU/dataset-creation.ts @@ -0,0 +1,161 @@ +const translation = { + steps: { + header: { + creation: 'Создать базу знаний', + update: 'Добавить данные', + }, + one: 'Выберите источник данных', + two: 'Предварительная обработка и очистка текста', + three: 'Выполнить и завершить', + }, + error: { + unavailable: 'Эта база знаний недоступна', + }, + firecrawl: { + configFirecrawl: 'Настроить 🔥Firecrawl', + apiKeyPlaceholder: 'Ключ API с firecrawl.dev', + getApiKeyLinkText: 'Получите свой ключ API с firecrawl.dev', + }, + stepOne: { + filePreview: 'Предварительный просмотр файла', + pagePreview: 'Предварительный просмотр страницы', + dataSourceType: { + file: 'Импортировать из файла', + notion: 'Синхронизировать из Notion', + web: 'Синхронизировать с веб-сайта', + }, + uploader: { + title: 'Загрузить файл', + button: 'Перетащите файл или', + browse: 'Обзор', + tip: 'Поддерживаются {{supportTypes}}. Максимум {{size}} МБ каждый.', + validation: { + typeError: 'Тип файла не поддерживается', + size: 'Файл слишком большой. Максимум {{size}} МБ', + count: 'Несколько файлов не поддерживаются', + filesNumber: 'Вы достигли лимита пакетной загрузки {{filesNumber}} файлов.', + }, + cancel: 'Отмена', + change: 'Изменить', + failed: 'Ошибка загрузки', + }, + notionSyncTitle: 'Notion не подключен', + notionSyncTip: 'Чтобы синхронизировать данные из Notion, сначала необходимо установить соединение с Notion.', + connect: 'Перейти к подключению', + button: 'Далее', + emptyDatasetCreation: 'Я хочу создать пустую базу знаний', + modal: { + title: 'Создать пустую базу знаний', + tip: 'Пустая база знаний не будет содержать документов, и вы можете загружать документы в любое время.', + input: 'Название базы знаний', + placeholder: 'Пожалуйста, введите', + nameNotEmpty: 'Название не может быть пустым', + nameLengthInvalid: 'Название должно быть от 1 до 40 символов', + cancelButton: 'Отмена', + confirmButton: 'Создать', + failed: 'Ошибка создания', + }, + website: { + fireCrawlNotConfigured: 'Firecrawl не настроен', + fireCrawlNotConfiguredDescription: 'Настройте Firecrawl с API-ключом.', + configure: 'Настроить', + run: 'Запустить', + firecrawlTitle: 'Извлечь веб-контент с помощью 🔥Firecrawl', + firecrawlDoc: 'Документация Firecrawl', + firecrawlDocLink: 'https://docs.dify.ai/guides/knowledge-base/sync-from-website', + options: 'Опции', + crawlSubPage: 'Сканировать подстраницы', + limit: 'Лимит', + maxDepth: 'Максимальная глубина', + excludePaths: 'Исключить пути', + includeOnlyPaths: 'Включить только пути', + extractOnlyMainContent: 'Извлекать только основной контент (без заголовков, навигации, футеров и т. д.)', + exceptionErrorTitle: 'Произошло исключение при запуске задания Firecrawl:', + unknownError: 'Неизвестная ошибка', + totalPageScraped: 'Всего просканировано страниц:', + selectAll: 'Выбрать все', + resetAll: 'Сбросить все', + scrapTimeInfo: 'Всего просканировано {{total}} страниц за {{time}} секунд', + preview: 'Предварительный просмотр', + maxDepthTooltip: 'Максимальная глубина сканирования относительно введенного URL. Глубина 0 сканирует только страницу введенного URL, глубина 1 сканирует URL и все, что находится после введенного URL + один /, и так далее.', + }, + }, + stepTwo: { + segmentation: 'Настройки фрагментации', + auto: 'Автоматически', + autoDescription: 'Автоматически устанавливать правила фрагментации и предварительной обработки. Пользователям, не знакомым с системой, рекомендуется выбрать этот вариант.', + custom: 'Пользовательский', + customDescription: 'Настроить правила фрагментации, длину фрагментов, правила предварительной обработки и т. д.', + separator: 'Идентификатор сегмента', + separatorPlaceholder: 'Например, новая строка (\\\\n) или специальный разделитель (например, "***")', + maxLength: 'Максимальная длина фрагмента', + overlap: 'Перекрытие фрагментов', + overlapTip: 'Установка перекрытия фрагментов может сохранить семантическую связь между ними, улучшая эффект поиска. Рекомендуется установить 10%-25% от максимального размера фрагмента.', + overlapCheck: 'перекрытие фрагментов не должно превышать максимальную длину фрагмента', + rules: 'Правила предварительной обработки текста', + removeExtraSpaces: 'Заменить последовательные пробелы, новые строки и табуляции', + removeUrlEmails: 'Удалить все URL-адреса и адреса электронной почты', + removeStopwords: 'Удалить стоп-слова, такие как "a", "an", "the"', + preview: 'Подтвердить и просмотреть', + reset: 'Сбросить', + indexMode: 'Режим индексации', + qualified: 'Высокое качество', + recommend: 'Рекомендуется', + qualifiedTip: 'Вызов интерфейса встраивания системы по умолчанию для обработки, чтобы обеспечить более высокую точность при запросах пользователей.', + warning: 'Пожалуйста, сначала настройте ключ API поставщика модели.', + click: 'Перейти к настройкам', + economical: 'Экономичный', + economicalTip: 'Используйте автономные векторные движки, индексы ключевых слов и т. д., чтобы снизить точность, не тратя токены', + QATitle: 'Сегментация в формате вопрос-ответ', + QATip: 'Включение этой опции приведет к потреблению большего количества токенов', + QALanguage: 'Сегментировать с помощью', + estimateCost: 'Оценка', + estimateSegment: 'Оценочное количество фрагментов', + segmentCount: 'фрагментов', + calculating: 'Вычисление...', + fileSource: 'Предварительная обработка документов', + notionSource: 'Предварительная обработка страниц', + websiteSource: 'Предварительная обработка веб-сайта', + other: 'и другие ', + fileUnit: ' файлов', + notionUnit: ' страниц', + webpageUnit: ' страниц', + previousStep: 'Предыдущий шаг', + nextStep: 'Сохранить и обработать', + save: 'Сохранить и обработать', + cancel: 'Отмена', + sideTipTitle: 'Зачем нужна фрагментация и предварительная обработка?', + sideTipP1: 'При обработке текстовых данных фрагментация и очистка являются двумя важными этапами предварительной обработки.', + sideTipP2: 'Сегментация разбивает длинный текст на абзацы, чтобы модели могли лучше его понимать. Это улучшает качество и релевантность результатов модели.', + sideTipP3: 'Очистка удаляет ненужные символы и форматы, делая знания более чистыми и легкими для анализа.', + sideTipP4: 'Правильная фрагментация и очистка улучшают производительность модели, обеспечивая более точные и ценные результаты.', + previewTitle: 'Предварительный просмотр', + previewTitleButton: 'Предварительный просмотр', + previewButton: 'Переключение в формат вопрос-ответ', + previewSwitchTipStart: 'Текущий предварительный просмотр фрагмента находится в текстовом формате, переключение на предварительный просмотр в формате вопрос-ответ', + previewSwitchTipEnd: ' потребляет дополнительные токены', + characters: 'символов', + indexSettingTip: 'Чтобы изменить метод индексации, пожалуйста, перейдите в ', + retrievalSettingTip: 'Чтобы изменить метод индексации, пожалуйста, перейдите в ', + datasetSettingLink: 'настройки базы знаний.', + }, + stepThree: { + creationTitle: '🎉 База знаний создана', + creationContent: 'Мы автоматически назвали базу знаний, вы можете изменить ее в любое время', + label: 'Название базы знаний', + additionTitle: '🎉 Документ загружен', + additionP1: 'Документ был загружен в базу знаний', + additionP2: ', вы можете найти его в списке документов базы знаний.', + stop: 'Остановить обработку', + resume: 'Возобновить обработку', + navTo: 'Перейти к документу', + sideTipTitle: 'Что дальше', + sideTipContent: 'После завершения индексации документа база знаний может быть интегрирована в приложение в качестве контекста, вы можете найти настройку контекста на странице prompt orchestration. Вы также можете создать-workflow приложение как отдельный как независимый плагин.', + modelTitle: 'Вы уверены, что хотите остановить встраивание?', + modelContent: 'Если вам нужно будет возобновить обработку позже, вы продолжите с того места, где остановились.', + modelButtonConfirm: 'Подтвердить', + modelButtonCancel: 'Отмена', + }, +} + +export default translation diff --git a/web/i18n/ru-RU/dataset-documents.ts b/web/i18n/ru-RU/dataset-documents.ts new file mode 100644 index 0000000000..b1870fb680 --- /dev/null +++ b/web/i18n/ru-RU/dataset-documents.ts @@ -0,0 +1,352 @@ +const translation = { + list: { + title: 'Документы', + desc: 'Здесь отображаются все файлы базы знаний, и вся база знаний может быть связана с цитатами Dify или проиндексирована с помощью чата.', + addFile: 'Добавить файл', + addPages: 'Добавить страницы', + addUrl: 'Добавить URL', + table: { + header: { + fileName: 'НАЗВАНИЕ ФАЙЛА', + words: 'СЛОВА', + hitCount: 'КОЛИЧЕСТВО ОБРАЩЕНИЙ', + uploadTime: 'ВРЕМЯ ЗАГРУЗКИ', + status: 'СТАТУС', + action: 'ДЕЙСТВИЕ', + }, + rename: 'Переименовать', + name: 'Название', + }, + action: { + uploadFile: 'Загрузить новый файл', + settings: 'Настройки сегментации', + addButton: 'Добавить фрагмент', + add: 'Добавить фрагмент', + batchAdd: 'Пакетное добавление', + archive: 'Архивировать', + unarchive: 'Разархивировать', + delete: 'Удалить', + enableWarning: 'Архивный файл не может быть включен', + sync: 'Синхронизировать', + }, + index: { + enable: 'Включить', + disable: 'Отключить', + all: 'Все', + enableTip: 'Файл может быть проиндексирован', + disableTip: 'Файл не может быть проиндексирован', + }, + status: { + queuing: 'В очереди', + indexing: 'Индексация', + paused: 'Приостановлено', + error: 'Ошибка', + available: 'Доступно', + enabled: 'Включено', + disabled: 'Отключено', + archived: 'Архивировано', + }, + empty: { + title: 'Пока нет документов', + upload: { + tip: 'Вы можете загружать файлы, синхронизировать с веб-сайта или из веб-приложений, таких как Notion, GitHub и т. д.', + }, + sync: { + tip: 'Dify будет периодически загружать файлы из вашего Notion и завершать обработку.', + }, + }, + delete: { + title: 'Вы уверены, что хотите удалить?', + content: 'Если вам нужно будет возобновить обработку позже, вы продолжите с того места, где остановились', + }, + batchModal: { + title: 'Пакетное добавление фрагментов', + csvUploadTitle: 'Перетащите сюда свой CSV-файл или ', + browse: 'обзор', + tip: 'CSV-файл должен соответствовать следующей структуре:', + question: 'вопрос', + answer: 'ответ', + contentTitle: 'содержимое фрагмента', + content: 'содержимое', + template: 'Скачать шаблон здесь', + cancel: 'Отмена', + run: 'Запустить пакет', + runError: 'Ошибка запуска пакета', + processing: 'В процессе пакетной обработки', + completed: 'Импорт завершен', + error: 'Ошибка импорта', + ok: 'ОК', + }, + }, + metadata: { + title: 'Метаданные', + desc: 'Маркировка метаданных для документов позволяет ИИ своевременно получать к ним доступ и раскрывать источник ссылок для пользователей.', + dateTimeFormat: 'D MMMM YYYY, HH:mm', + docTypeSelectTitle: 'Пожалуйста, выберите тип документа', + docTypeChangeTitle: 'Изменить тип документа', + docTypeSelectWarning: + 'Если тип документа будет изменен, заполненные сейчас метаданные больше не будут сохранены', + firstMetaAction: 'Поехали', + placeholder: { + add: 'Добавить ', + select: 'Выбрать ', + }, + source: { + upload_file: 'Загрузить файл', + notion: 'Синхронизировать из Notion', + github: 'Синхронизировать из Github', + }, + type: { + book: 'Книга', + webPage: 'Веб-страница', + paper: 'Статья', + socialMediaPost: 'Пост в социальных сетях', + personalDocument: 'Личный документ', + businessDocument: 'Деловой документ', + IMChat: 'Чат в мессенджере', + wikipediaEntry: 'Статья в Википедии', + notion: 'Синхронизировать из Notion', + github: 'Синхронизировать из Github', + technicalParameters: 'Технические параметры', + }, + field: { + processRule: { + processDoc: 'Обработка документа', + segmentRule: 'Правило фрагментации', + segmentLength: 'Длина фрагментов', + processClean: 'Очистка текста', + }, + book: { + title: 'Название', + language: 'Язык', + author: 'Автор', + publisher: 'Издатель', + publicationDate: 'Дата публикации', + ISBN: 'ISBN', + category: 'Категория', + }, + webPage: { + title: 'Название', + url: 'URL', + language: 'Язык', + authorPublisher: 'Автор/Издатель', + publishDate: 'Дата публикации', + topicsKeywords: 'Темы/Ключевые слова', + description: 'Описание', + }, + paper: { + title: 'Название', + language: 'Язык', + author: 'Автор', + publishDate: 'Дата публикации', + journalConferenceName: 'Название журнала/конференции', + volumeIssuePage: 'Том/Выпуск/Страница', + DOI: 'DOI', + topicsKeywords: 'Темы/Ключевые слова', + abstract: 'Аннотация', + }, + socialMediaPost: { + platform: 'Платформа', + authorUsername: 'Автор/Имя пользователя', + publishDate: 'Дата публикации', + postURL: 'URL поста', + topicsTags: 'Темы/Теги', + }, + personalDocument: { + title: 'Название', + author: 'Автор', + creationDate: 'Дата создания', + lastModifiedDate: 'Дата последнего изменения', + documentType: 'Тип документа', + tagsCategory: 'Теги/Категория', + }, + businessDocument: { + title: 'Название', + author: 'Автор', + creationDate: 'Дата создания', + lastModifiedDate: 'Дата последнего изменения', + documentType: 'Тип документа', + departmentTeam: 'Отдел/Команда', + }, + IMChat: { + chatPlatform: 'Платформа чата', + chatPartiesGroupName: 'Участники чата/Название группы', + participants: 'Участники', + startDate: 'Дата начала', + endDate: 'Дата окончания', + topicsKeywords: 'Темы/Ключевые слова', + fileType: 'Тип файла', + }, + wikipediaEntry: { + title: 'Название', + language: 'Язык', + webpageURL: 'URL веб-страницы', + editorContributor: 'Редактор/Автор', + lastEditDate: 'Дата последнего редактирования', + summaryIntroduction: 'Краткое содержание/Введение', + }, + notion: { + title: 'Название', + language: 'Язык', + author: 'Автор', + createdTime: 'Время создания', + lastModifiedTime: 'Время последнего изменения', + url: 'URL', + tag: 'Тег', + description: 'Описание', + }, + github: { + repoName: 'Название репозитория', + repoDesc: 'Описание репозитория', + repoOwner: 'Владелец репозитория', + fileName: 'Название файла', + filePath: 'Путь к файлу', + programmingLang: 'Язык программирования', + url: 'URL', + license: 'Лицензия', + lastCommitTime: 'Время последнего коммита', + lastCommitAuthor: 'Автор последнего коммита', + }, + originInfo: { + originalFilename: 'Исходное имя файла', + originalFileSize: 'Исходный размер файла', + uploadDate: 'Дата загрузки', + lastUpdateDate: 'Дата последнего обновления', + source: 'Источник', + }, + technicalParameters: { + segmentSpecification: 'Спецификация фрагментов', + segmentLength: 'Длина фрагментов', + avgParagraphLength: 'Средняя длина абзаца', + paragraphs: 'Абзацы', + hitCount: 'Количество обращений', + embeddingTime: 'Время встраивания', + embeddedSpend: 'Потрачено на встраивание', + }, + }, + languageMap: { + zh: 'Китайский', + en: 'Английский', + es: 'Испанский', + fr: 'Французский', + de: 'Немецкий', + ja: 'Японский', + ko: 'Корейский', + ru: 'Русский', + ar: 'Арабский', + pt: 'Португальский', + it: 'Итальянский', + nl: 'Голландский', + pl: 'Польский', + sv: 'Шведский', + tr: 'Турецкий', + he: 'Иврит', + hi: 'Хинди', + da: 'Датский', + fi: 'Финский', + no: 'Норвежский', + hu: 'Венгерский', + el: 'Греческий', + cs: 'Чешский', + th: 'Тайский', + id: 'Индонезийский', + }, + categoryMap: { + book: { + fiction: 'Художественная литература', + biography: 'Биография', + history: 'История', + science: 'Наука', + technology: 'Технологии', + education: 'Образование', + philosophy: 'Философия', + religion: 'Религия', + socialSciences: 'Социальные науки', + art: 'Искусство', + travel: 'Путешествия', + health: 'Здоровье', + selfHelp: 'Самопомощь', + businessEconomics: 'Бизнес/Экономика', + cooking: 'Кулинария', + childrenYoungAdults: 'Детская/Подростковая литература', + comicsGraphicNovels: 'Комиксы/Графические романы', + poetry: 'Поэзия', + drama: 'Драматургия', + other: 'Другое', + }, + personalDoc: { + notes: 'Заметки', + blogDraft: 'Черновик блога', + diary: 'Дневник', + researchReport: 'Научный отчет', + bookExcerpt: 'Отрывок из книги', + schedule: 'Расписание', + list: 'Список', + projectOverview: 'Обзор проекта', + photoCollection: 'Коллекция фотографий', + creativeWriting: 'Творческое письмо', + codeSnippet: 'Фрагмент кода', + designDraft: 'Черновик дизайна', + personalResume: 'Личное резюме', + other: 'Другое', + }, + businessDoc: { + meetingMinutes: 'Протокол собрания', + researchReport: 'Научный отчет', + proposal: 'Предложение', + employeeHandbook: 'Справочник сотрудника', + trainingMaterials: 'Учебные материалы', + requirementsDocument: 'Документ с требованиями', + designDocument: 'Проектный документ', + productSpecification: 'Спецификация продукта', + financialReport: 'Финансовый отчет', + marketAnalysis: 'Анализ рынка', + projectPlan: 'План проекта', + teamStructure: 'Структура команды', + policiesProcedures: 'Политики и процедуры', + contractsAgreements: 'Договоры и соглашения', + emailCorrespondence: 'Переписка по электронной почте', + other: 'Другое', + }, + }, + }, + embedding: { + processing: 'Расчет эмбеддингов...', + paused: 'Расчет эмбеддингов приостановлен', + completed: 'Встраивание завершено', + error: 'Ошибка расчета эмбеддингов', + docName: 'Предварительная обработка документа', + mode: 'Правило сегментации', + segmentLength: 'Длина фрагментов', + textCleaning: 'Предварительная очистка текста', + segments: 'Абзацы', + highQuality: 'Режим высокого качества', + economy: 'Экономичный режим', + estimate: 'Оценочное потребление', + stop: 'Остановить обработку', + resume: 'Возобновить обработку', + automatic: 'Автоматически', + custom: 'Пользовательский', + previewTip: 'Предварительный просмотр абзацев будет доступен после завершения расчета эмбеддингов', + }, + segment: { + paragraphs: 'Абзацы', + keywords: 'Ключевые слова', + addKeyWord: 'Добавить ключевое слово', + keywordError: 'Максимальная длина ключевого слова - 20', + characters: 'символов', + hitCount: 'Количество обращений', + vectorHash: 'Векторный хэш: ', + questionPlaceholder: 'добавьте вопрос здесь', + questionEmpty: 'Вопрос не может быть пустым', + answerPlaceholder: 'добавьте ответ здесь', + answerEmpty: 'Ответ не может быть пустым', + contentPlaceholder: 'добавьте содержимое здесь', + contentEmpty: 'Содержимое не может быть пустым', + newTextSegment: 'Новый текстовый сегмент', + newQaSegment: 'Новый сегмент вопрос-ответ', + delete: 'Удалить этот фрагмент?', + }, +} + +export default translation diff --git a/web/i18n/ru-RU/dataset-hit-testing.ts b/web/i18n/ru-RU/dataset-hit-testing.ts new file mode 100644 index 0000000000..0d3a14a676 --- /dev/null +++ b/web/i18n/ru-RU/dataset-hit-testing.ts @@ -0,0 +1,28 @@ +const translation = { + title: 'Тестирование поиска', + desc: 'Проверьте эффективность поиска в базе знаний на основе заданного текста запроса.', + dateTimeFormat: 'DD.MM.YYYY HH:mm', + recents: 'Недавние', + table: { + header: { + source: 'Источник', + text: 'Текст', + time: 'Время', + }, + }, + input: { + title: 'Исходный текст', + placeholder: 'Пожалуйста, введите текст, рекомендуется использовать короткое повествовательное предложение.', + countWarning: 'До 200 символов.', + indexWarning: 'Только база знаний высокого качества.', + testing: 'Тестирование', + }, + hit: { + title: 'НАЙДЕННЫЕ АБЗАЦЫ', + emptyTip: 'Результаты тестирования поиска будут отображаться здесь', + }, + noRecentTip: 'Здесь нет результатов недавних запросов', + viewChart: 'Посмотреть ВЕКТОРНУЮ ДИАГРАММУ', +} + +export default translation diff --git a/web/i18n/ru-RU/dataset-settings.ts b/web/i18n/ru-RU/dataset-settings.ts new file mode 100644 index 0000000000..f732562d23 --- /dev/null +++ b/web/i18n/ru-RU/dataset-settings.ts @@ -0,0 +1,35 @@ +const translation = { + title: 'Настройки базы знаний', + desc: 'Здесь вы можете изменить свойства и методы работы базы знаний.', + form: { + name: 'Название базы знаний', + namePlaceholder: 'Пожалуйста, введите название базы знаний', + nameError: 'Название не может быть пустым', + desc: 'Описание базы знаний', + descInfo: 'Пожалуйста, напишите четкое текстовое описание, чтобы обрисовать содержание базы знаний. Это описание будет использоваться в качестве основы для сопоставления при выборе из нескольких баз знаний для вывода.', + descPlaceholder: 'Опишите, что находится в этой базе знаний. Подробное описание позволяет ИИ своевременно получать доступ к содержимому базы знаний. Если оставить пустым, Dify будет использовать стратегию поиска по умолчанию.', + descWrite: 'Узнайте, как написать хорошее описание базы знаний.', + permissions: 'Разрешения', + permissionsOnlyMe: 'Только я', + permissionsAllMember: 'Все участники команды', + permissionsInvitedMembers: 'Отдельные участники команды', + me: '(Вы)', + indexMethod: 'Метод индексации', + indexMethodHighQuality: 'Высокое качество', + indexMethodHighQualityTip: 'Вызов модели встраивания для обработки, чтобы обеспечить более высокую точность при запросах пользователей.', + indexMethodEconomy: 'Экономичный', + indexMethodEconomyTip: 'Используйте автономные векторные движки, индексы ключевых слов и т. д., чтобы снизить точность, не тратя токены', + embeddingModel: 'Модель встраивания', + embeddingModelTip: 'Изменить встроенную модель, пожалуйста, перейдите в ', + embeddingModelTipLink: 'Настройки', + retrievalSetting: { + title: 'Настройки поиска', + learnMore: 'Узнать больше', + description: ' о методе поиска.', + longDescription: ' о методе поиска, вы можете изменить это в любое время в настройках базы знаний.', + }, + save: 'Сохранить', + }, +} + +export default translation diff --git a/web/i18n/ru-RU/dataset.ts b/web/i18n/ru-RU/dataset.ts new file mode 100644 index 0000000000..4a7590ad60 --- /dev/null +++ b/web/i18n/ru-RU/dataset.ts @@ -0,0 +1,77 @@ +const translation = { + knowledge: 'База знаний', + documentCount: ' документов', + wordCount: ' тыс. слов', + appCount: ' связанных приложений', + createDataset: 'Создать базу знаний', + createDatasetIntro: 'Импортируйте свои собственные текстовые данные или записывайте данные в режиме реального времени через Webhook для улучшения контекста LLM.', + deleteDatasetConfirmTitle: 'Удалить эту базу знаний?', + deleteDatasetConfirmContent: + 'Удаление базы знаний необратимо. Пользователи больше не смогут получить доступ к вашей базе знаний, и все настройки подсказок и журналы будут безвозвратно удалены.', + datasetUsedByApp: 'База знаний используется некоторыми приложениями. Приложения больше не смогут использовать эту базу знаний, и все настройки подсказок и журналы будут безвозвратно удалены.', + datasetDeleted: 'База знаний удалена', + datasetDeleteFailed: 'Не удалось удалить базу знаний', + didYouKnow: 'Знаете ли вы?', + intro1: 'Базу знаний можно интегрировать в приложение Dify ', + intro2: 'в качестве контекста', + intro3: ',', + intro4: 'или ее ', + intro5: 'можно создать', + intro6: ' как отдельный плагин индекса ChatGPT для публикации', + unavailable: 'Недоступно', + unavailableTip: 'Модель встраивания недоступна, необходимо настроить модель встраивания по умолчанию', + datasets: 'БАЗЫ ЗНАНИЙ', + datasetsApi: 'ДОСТУП К API', + retrieval: { + semantic_search: { + title: 'Векторный поиск', + description: 'Создайте встраивания запросов и найдите фрагмент текста, наиболее похожий на его векторное представление.', + }, + full_text_search: { + title: 'Полнотекстовый поиск', + description: 'Индексируйте все термины в документе, позволяя пользователям искать любой термин и извлекать соответствующий фрагмент текста, содержащий эти термины.', + }, + hybrid_search: { + title: 'Гибридный поиск', + description: 'Выполняйте полнотекстовый поиск и векторный поиск одновременно, переранжируйте, чтобы выбрать наилучшее соответствие запросу пользователя. Пользователи могут выбрать установку весов или настройку модели переранжирования.', + recommend: 'Рекомендуется', + }, + invertedIndex: { + title: 'Инвертированный индекс', + description: 'Инвертированный индекс - это структура, используемая для эффективного поиска. Организованный по терминам, каждый термин указывает на документы или веб-страницы, содержащие его.', + }, + change: 'Изменить', + changeRetrievalMethod: 'Изменить метод поиска', + }, + docsFailedNotice: 'документов не удалось проиндексировать', + retry: 'Повторить попытку', + indexingTechnique: { + high_quality: 'HQ', + economy: 'ECO', + }, + indexingMethod: { + semantic_search: 'ВЕКТОР', + full_text_search: 'ПОЛНЫЙ ТЕКСТ', + hybrid_search: 'ГИБРИД', + invertedIndex: 'ИНВЕРТИРОВАННЫЙ', + }, + mixtureHighQualityAndEconomicTip: 'Для смешивания высококачественных и экономичных баз знаний требуется модель переранжирования.', + inconsistentEmbeddingModelTip: 'Модель переранжирования требуется, если модели встраивания выбранных баз знаний несовместимы.', + retrievalSettings: 'Настройки поиска', + rerankSettings: 'Настройки переранжирования', + weightedScore: { + title: 'Взвешенная оценка', + description: 'Регулируя назначенные веса, эта стратегия переранжирования определяет, следует ли отдавать приоритет семантическому или ключевому соответствию.', + semanticFirst: 'Семантика в первую очередь', + keywordFirst: 'Ключевые слова в первую очередь', + customized: 'Настраиваемый', + semantic: 'Семантика', + keyword: 'Ключевые слова', + }, + nTo1RetrievalLegacy: 'Поиск N-к-1 будет официально прекращен с сентября. Рекомендуется использовать новейший многопутный поиск для получения лучших результатов.', + nTo1RetrievalLegacyLink: 'Узнать больше', + nTo1RetrievalLegacyLinkText: ' Поиск N-к-1 будет официально прекращен в сентябре.', + defaultRetrievalTip: 'По умолчанию используется многоканальная проверка. Знания извлекаются из нескольких баз знаний, а затем повторно ранжируются.', +} + +export default translation diff --git a/web/i18n/ru-RU/explore.ts b/web/i18n/ru-RU/explore.ts new file mode 100644 index 0000000000..6c0b41f7d4 --- /dev/null +++ b/web/i18n/ru-RU/explore.ts @@ -0,0 +1,41 @@ +const translation = { + title: 'Обзор', + sidebar: { + discovery: 'Открытия', + chat: 'Чат', + workspace: 'Рабочее пространство', + action: { + pin: 'Закрепить', + unpin: 'Открепить', + rename: 'Переименовать', + delete: 'Удалить', + }, + delete: { + title: 'Удалить приложение', + content: 'Вы уверены, что хотите удалить это приложение?', + }, + }, + apps: { + title: 'Обзор приложений от Dify', + description: 'Используйте эти шаблонные приложения мгновенно или настройте свои собственные приложения на основе шаблонов.', + allCategories: 'Рекомендуемые', + }, + appCard: { + addToWorkspace: 'Добавить в рабочее пространство', + customize: 'Настроить', + }, + appCustomize: { + title: 'Создать приложение из {{name}}', + subTitle: 'Значок и название приложения', + nameRequired: 'Название приложения обязательно', + }, + category: { + Assistant: 'Ассистент', + Writing: 'Написание', + Translate: 'Перевод', + Programming: 'Программирование', + HR: 'HR', + }, +} + +export default translation diff --git a/web/i18n/ru-RU/layout.ts b/web/i18n/ru-RU/layout.ts new file mode 100644 index 0000000000..928649474b --- /dev/null +++ b/web/i18n/ru-RU/layout.ts @@ -0,0 +1,4 @@ +const translation = { +} + +export default translation diff --git a/web/i18n/ru-RU/login.ts b/web/i18n/ru-RU/login.ts new file mode 100644 index 0000000000..81918745dd --- /dev/null +++ b/web/i18n/ru-RU/login.ts @@ -0,0 +1,75 @@ +const translation = { + pageTitle: 'Привет, давайте начнем!👋', + welcome: 'Добро пожаловать в Dify, пожалуйста, войдите, чтобы продолжить.', + email: 'Адрес электронной почты', + emailPlaceholder: 'Ваш адрес электронной почты', + password: 'Пароль', + passwordPlaceholder: 'Ваш пароль', + name: 'Имя пользователя', + namePlaceholder: 'Ваше имя пользователя', + forget: 'Забыли пароль?', + signBtn: 'Войти', + sso: 'Продолжить с SSO', + installBtn: 'Настроить', + setAdminAccount: 'Настройка учетной записи администратора', + setAdminAccountDesc: 'Максимальные привилегии для учетной записи администратора, которые можно использовать для создания приложений, управления поставщиками LLM и т. д.', + createAndSignIn: 'Создать и войти', + oneMoreStep: 'Еще один шаг', + createSample: 'На основе этой информации мы создадим для вас пример приложения', + invitationCode: 'Пригласительный код', + invitationCodePlaceholder: 'Ваш пригласительный код', + interfaceLanguage: 'Язык интерфейса', + timezone: 'Часовой пояс', + go: 'Перейти к Dify', + sendUsMail: 'Отправьте нам по электронной почте свое представление, и мы обработаем запрос на приглашение.', + acceptPP: 'Я прочитал и принимаю политику конфиденциальности', + reset: 'Пожалуйста, выполните следующую команду, чтобы сбросить пароль', + withGitHub: 'Продолжить с GitHub', + withGoogle: 'Продолжить с Google', + rightTitle: 'Раскройте весь потенциал LLM', + rightDesc: 'Без труда создавайте визуально привлекательные, работоспособные и улучшаемые приложения ИИ.', + tos: 'Условия обслуживания', + pp: 'Политика конфиденциальности', + tosDesc: 'Регистрируясь, вы соглашаетесь с нашими', + goToInit: 'Если вы не инициализировали учетную запись, перейдите на страницу инициализации', + dontHave: 'Нет?', + invalidInvitationCode: 'Неверный пригласительный код', + accountAlreadyInited: 'Учетная запись уже инициализирована', + forgotPassword: 'Забыли пароль?', + resetLinkSent: 'Ссылка для сброса отправлена', + sendResetLink: 'Отправить ссылку для сброса', + backToSignIn: 'Вернуться к входу', + forgotPasswordDesc: 'Пожалуйста, введите свой адрес электронной почты, чтобы сбросить пароль. Мы отправим вам электронное письмо с инструкциями о том, как сбросить пароль.', + checkEmailForResetLink: 'Пожалуйста, проверьте свою электронную почту на наличие ссылки для сброса пароля. Если она не появится в течение нескольких минут, обязательно проверьте папку со спамом.', + passwordChanged: 'Войдите сейчас', + changePassword: 'Изменить пароль', + changePasswordTip: 'Пожалуйста, введите новый пароль для своей учетной записи', + invalidToken: 'Неверный или просроченный токен', + confirmPassword: 'Подтвердите пароль', + confirmPasswordPlaceholder: 'Подтвердите свой новый пароль', + passwordChangedTip: 'Ваш пароль был успешно изменен', + error: { + emailEmpty: 'Адрес электронной почты обязателен', + emailInValid: 'Пожалуйста, введите действительный адрес электронной почты', + nameEmpty: 'Имя обязательно', + passwordEmpty: 'Пароль обязателен', + passwordLengthInValid: 'Пароль должен содержать не менее 8 символов', + passwordInvalid: 'Пароль должен содержать буквы и цифры, а длина должна быть больше 8', + }, + license: { + tip: 'Перед запуском Dify Community Edition ознакомьтесь с лицензией GitHub', + link: 'Лицензия с открытым исходным кодом', + }, + join: 'Присоединиться', + joinTipStart: 'Приглашаем вас присоединиться к', + joinTipEnd: 'команде на Dify', + invalid: 'Ссылка истекла', + explore: 'Изучить Dify', + activatedTipStart: 'Вы присоединились к команде', + activatedTipEnd: '', + activated: 'Войдите сейчас', + adminInitPassword: 'Пароль инициализации администратора', + validate: 'Проверить', +} + +export default translation diff --git a/web/i18n/ru-RU/register.ts b/web/i18n/ru-RU/register.ts new file mode 100644 index 0000000000..928649474b --- /dev/null +++ b/web/i18n/ru-RU/register.ts @@ -0,0 +1,4 @@ +const translation = { +} + +export default translation diff --git a/web/i18n/ru-RU/run-log.ts b/web/i18n/ru-RU/run-log.ts new file mode 100644 index 0000000000..2099d6794f --- /dev/null +++ b/web/i18n/ru-RU/run-log.ts @@ -0,0 +1,29 @@ +const translation = { + input: 'ВВОД', + result: 'РЕЗУЛЬТАТ', + detail: 'ДЕТАЛИ', + tracing: 'ТРАССИРОВКА', + resultPanel: { + status: 'СТАТУС', + time: 'ПРОШЕДШЕЕ ВРЕМЯ', + tokens: 'ВСЕГО ТОКЕНОВ', + }, + meta: { + title: 'МЕТАДАННЫЕ', + status: 'Статус', + version: 'Версия', + executor: 'Исполнитель', + startTime: 'Время начала', + time: 'Прошедшее время', + tokens: 'Всего токенов', + steps: 'Шаги выполнения', + }, + resultEmpty: { + title: 'Этот запуск выводит только формат JSON,', + tipLeft: 'пожалуйста, перейдите на ', + link: 'панель деталей', + tipRight: ' чтобы просмотреть его.', + }, +} + +export default translation diff --git a/web/i18n/ru-RU/share-app.ts b/web/i18n/ru-RU/share-app.ts new file mode 100644 index 0000000000..f0166b26f1 --- /dev/null +++ b/web/i18n/ru-RU/share-app.ts @@ -0,0 +1,74 @@ +const translation = { + common: { + welcome: '', + appUnavailable: 'Приложение недоступно', + appUnknownError: 'Приложение недоступно', + }, + chat: { + newChat: 'Новый чат', + pinnedTitle: 'Закрепленные', + unpinnedTitle: 'Чаты', + newChatDefaultName: 'Новый разговор', + resetChat: 'Сбросить разговор', + poweredBy: 'Работает на', + prompt: 'Подсказка', + privatePromptConfigTitle: 'Настройки разговора', + publicPromptConfigTitle: 'Начальная подсказка', + configStatusDes: 'Перед началом вы можете изменить настройки разговора', + configDisabled: + 'Для этого сеанса использовались настройки предыдущего сеанса.', + startChat: 'Начать чат', + privacyPolicyLeft: + 'Пожалуйста, ознакомьтесь с ', + privacyPolicyMiddle: + 'политикой конфиденциальности', + privacyPolicyRight: + ', предоставленной разработчиком приложения.', + deleteConversation: { + title: 'Удалить разговор', + content: 'Вы уверены, что хотите удалить этот разговор?', + }, + tryToSolve: 'Попробуйте решить', + temporarySystemIssue: 'Извините, временная проблема с системой.', + }, + generation: { + tabs: { + create: 'Запустить один раз', + batch: 'Запустить пакетно', + saved: 'Сохраненные', + }, + savedNoData: { + title: 'Вы еще не сохранили ни одного результата!', + description: 'Начните генерировать контент, и вы найдете свои сохраненные результаты здесь.', + startCreateContent: 'Начать создавать контент', + }, + title: 'Завершение ИИ', + queryTitle: 'Содержимое запроса', + completionResult: 'Результат завершения', + queryPlaceholder: 'Напишите содержимое вашего запроса...', + run: 'Выполнить', + copy: 'Копировать', + resultTitle: 'Завершение ИИ', + noData: 'ИИ даст вам то, что вы хотите, здесь.', + csvUploadTitle: 'Перетащите сюда свой CSV-файл или ', + browse: 'обзор', + csvStructureTitle: 'CSV-файл должен соответствовать следующей структуре:', + downloadTemplate: 'Скачать шаблон здесь', + field: 'Поле', + batchFailed: { + info: '{{num}} неудачных выполнений', + retry: 'Повторить попытку', + outputPlaceholder: 'Нет выходного содержимого', + }, + errorMsg: { + empty: 'Пожалуйста, введите содержимое в загруженный файл.', + fileStructNotMatch: 'Загруженный CSV-файл не соответствует структуре.', + emptyLine: 'Строка {{rowIndex}} пуста', + invalidLine: 'Строка {{rowIndex}}: значение {{varName}} не может быть пустым', + moreThanMaxLengthLine: 'Строка {{rowIndex}}: значение {{varName}} не может превышать {{maxLength}} символов', + atLeastOne: 'Пожалуйста, введите хотя бы одну строку в загруженный файл.', + }, + }, +} + +export default translation diff --git a/web/i18n/ru-RU/tools.ts b/web/i18n/ru-RU/tools.ts new file mode 100644 index 0000000000..e0dfd571b2 --- /dev/null +++ b/web/i18n/ru-RU/tools.ts @@ -0,0 +1,153 @@ +const translation = { + title: 'Инструменты', + createCustomTool: 'Создать пользовательский инструмент', + customToolTip: 'Узнать больше о пользовательских инструментах Dify', + type: { + all: 'Все', + builtIn: 'Встроенные', + custom: 'Пользовательские', + workflow: 'Рабочий процесс', + }, + contribute: { + line1: 'Я заинтересован в', + line2: 'внесении инструментов в Dify.', + viewGuide: 'Посмотреть руководство', + }, + author: 'Автор', + auth: { + unauthorized: 'Авторизовать', + authorized: 'Авторизовано', + setup: 'Настроить авторизацию для использования', + setupModalTitle: 'Настроить авторизацию', + setupModalTitleDescription: 'После настройки учетных данных все участники рабочего пространства смогут использовать этот инструмент при оркестровке приложений.', + }, + includeToolNum: 'Включено {{num}} инструментов', + addTool: 'Добавить инструмент', + addToolModal: { + type: 'тип', + category: 'категория', + add: 'добавить', + added: 'добавлено', + manageInTools: 'Управлять в инструментах', + emptyTitle: 'Нет доступных инструментов рабочего процесса', + emptyTip: 'Перейдите в "Рабочий процесс -> Опубликовать как инструмент"', + }, + createTool: { + title: 'Создать пользовательский инструмент', + editAction: 'Настроить', + editTitle: 'Редактировать пользовательский инструмент', + name: 'Название', + toolNamePlaceHolder: 'Введите название инструмента', + nameForToolCall: 'Название вызова инструмента', + nameForToolCallPlaceHolder: 'Используется для машинного распознавания, например getCurrentWeather, list_pets', + nameForToolCallTip: 'Поддерживаются только цифры, буквы и подчеркивания.', + description: 'Описание', + descriptionPlaceholder: 'Краткое описание назначения инструмента, например, получить температуру для определенного местоположения.', + schema: 'Схема', + schemaPlaceHolder: 'Введите свою схему OpenAPI здесь', + viewSchemaSpec: 'Посмотреть спецификацию OpenAPI-Swagger', + importFromUrl: 'Импортировать из URL', + importFromUrlPlaceHolder: 'https://...', + urlError: 'Пожалуйста, введите действительный URL', + examples: 'Примеры', + exampleOptions: { + json: 'Погода (JSON)', + yaml: 'Зоомагазин (YAML)', + blankTemplate: 'Пустой шаблон', + }, + availableTools: { + title: 'Доступные инструменты', + name: 'Название', + description: 'Описание', + method: 'Метод', + path: 'Путь', + action: 'Действия', + test: 'Тест', + }, + authMethod: { + title: 'Метод авторизации', + type: 'Тип авторизации', + keyTooltip: 'Ключ заголовка HTTP, вы можете оставить его как "Authorization", если не знаете, что это такое, или установить его на пользовательское значение', + types: { + none: 'Нет', + api_key: 'Ключ API', + apiKeyPlaceholder: 'Название заголовка HTTP для ключа API', + apiValuePlaceholder: 'Введите ключ API', + }, + key: 'Ключ', + value: 'Значение', + }, + authHeaderPrefix: { + title: 'Тип авторизации', + types: { + basic: 'Базовый', + bearer: 'Bearer', + custom: 'Пользовательский', + }, + }, + privacyPolicy: 'Политика конфиденциальности', + privacyPolicyPlaceholder: 'Пожалуйста, введите политику конфиденциальности', + toolInput: { + title: 'Входные данные инструмента', + name: 'Название', + required: 'Обязательно', + method: 'Метод', + methodSetting: 'Настройка', + methodSettingTip: 'Пользователь заполняет конфигурацию инструмента', + methodParameter: 'Параметр', + methodParameterTip: 'LLM заполняет во время вывода', + label: 'Теги', + labelPlaceholder: 'Выберите теги (необязательно)', + description: 'Описание', + descriptionPlaceholder: 'Описание значения параметра', + }, + customDisclaimer: 'Пользовательский отказ от ответственности', + customDisclaimerPlaceholder: 'Пожалуйста, введите пользовательский отказ от ответственности', + confirmTitle: 'Подтвердить сохранение?', + confirmTip: 'Приложения, использующие этот инструмент, будут затронуты', + deleteToolConfirmTitle: 'Удалить этот инструмент?', + deleteToolConfirmContent: 'Удаление инструмента необратимо. Пользователи больше не смогут получить доступ к вашему инструменту.', + }, + test: { + title: 'Тест', + parametersValue: 'Параметры и значение', + parameters: 'Параметры', + value: 'Значение', + testResult: 'Результаты теста', + testResultPlaceholder: 'Результат теста будет отображаться здесь', + }, + thought: { + using: 'Использование', + used: 'Использовано', + requestTitle: 'Запрос к', + responseTitle: 'Ответ от', + }, + setBuiltInTools: { + info: 'Информация', + setting: 'Настройка', + toolDescription: 'Описание инструмента', + parameters: 'параметры', + string: 'строка', + number: 'число', + required: 'Обязательно', + infoAndSetting: 'Информация и настройки', + }, + noCustomTool: { + title: 'Нет пользовательских инструментов!', + content: 'Добавьте и управляйте своими пользовательскими инструментами здесь для создания приложений ИИ.', + createTool: 'Создать инструмент', + }, + noSearchRes: { + title: 'Извините, результаты не найдены!', + content: 'Мы не смогли найти никаких инструментов, соответствующих вашему поиску.', + reset: 'Сбросить поиск', + }, + builtInPromptTitle: 'Подсказка', + toolRemoved: 'Инструмент удален', + notAuthorized: 'Инструмент не авторизован', + howToGet: 'Как получить', + openInStudio: 'Открыть в Studio', + toolNameUsageTip: 'Название вызова инструмента для рассуждений агента и подсказок', +} + +export default translation diff --git a/web/i18n/ru-RU/workflow.ts b/web/i18n/ru-RU/workflow.ts new file mode 100644 index 0000000000..9d3ce1235c --- /dev/null +++ b/web/i18n/ru-RU/workflow.ts @@ -0,0 +1,557 @@ +const translation = { + common: { + undo: 'Отменить', + redo: 'Повторить', + editing: 'Редактирование', + autoSaved: 'Автосохранено', + unpublished: 'Не опубликовано', + published: 'Опубликовано', + publish: 'Опубликовать', + update: 'Обновить', + run: 'Запустить', + running: 'Выполняется', + inRunMode: 'В режиме выполнения', + inPreview: 'В режиме предпросмотра', + inPreviewMode: 'В режиме предпросмотра', + preview: 'Предпросмотр', + viewRunHistory: 'Посмотреть историю запусков', + runHistory: 'История запусков', + goBackToEdit: 'Вернуться к редактору', + conversationLog: 'Журнал разговоров', + features: 'Функции', + debugAndPreview: 'Предпросмотр', + restart: 'Перезапустить', + currentDraft: 'Текущий черновик', + currentDraftUnpublished: 'Текущий черновик не опубликован', + latestPublished: 'Последняя опубликованная версия', + publishedAt: 'Опубликовано', + restore: 'Восстановить', + runApp: 'Запустить приложение', + batchRunApp: 'Пакетный запуск приложения', + accessAPIReference: 'Доступ к справочнику API', + embedIntoSite: 'Встроить на сайт', + addTitle: 'Добавить заголовок...', + addDescription: 'Добавить описание...', + noVar: 'Нет переменной', + searchVar: 'Поиск переменной', + variableNamePlaceholder: 'Имя переменной', + setVarValuePlaceholder: 'Установить значение переменной', + needConnectTip: 'Этот шаг ни к чему не подключен', + maxTreeDepth: 'Максимальный предел {{depth}} узлов на ветку', + needEndNode: 'Необходимо добавить блок "Конец"', + needAnswerNode: 'Необходимо добавить блок "Ответ"', + workflowProcess: 'Процесс рабочего процесса', + notRunning: 'Еще не запущено', + previewPlaceholder: 'Введите текст в поле ниже, чтобы начать отладку чат-бота', + effectVarConfirm: { + title: 'Удалить переменную', + content: 'Переменная используется в других узлах. Вы все еще хотите удалить ее?', + }, + insertVarTip: 'Нажмите клавишу "/" чтобы быстро вставить', + processData: 'Обработка данных', + input: 'Вход', + output: 'Выход', + jinjaEditorPlaceholder: 'Введите "/" или "{" для вставки переменной', + viewOnly: 'Только просмотр', + showRunHistory: 'Показать историю запусков', + enableJinja: 'Включить поддержку шаблонов Jinja', + learnMore: 'Узнать больше', + copy: 'Копировать', + duplicate: 'Дублировать', + addBlock: 'Добавить блок', + pasteHere: 'Вставить сюда', + pointerMode: 'Режим указателя', + handMode: 'Режим руки', + model: 'Модель', + workflowAsTool: 'Рабочий процесс как инструмент', + configureRequired: 'Требуется настройка', + configure: 'Настроить', + manageInTools: 'Управление в инструментах', + workflowAsToolTip: 'После обновления рабочего процесса требуется перенастройка инструмента.', + viewDetailInTracingPanel: 'Посмотреть подробности', + syncingData: 'Синхронизация данных, всего несколько секунд.', + importDSL: 'Импортировать DSL', + importDSLTip: 'Текущий черновик будет перезаписан. Экспортируйте рабочий процесс в качестве резервной копии перед импортом.', + backupCurrentDraft: 'Резервное копирование текущего черновика', + chooseDSL: 'Выберите файл DSL(yml)', + overwriteAndImport: 'Перезаписать и импортировать', + importFailure: 'Ошибка импорта', + importSuccess: 'Импорт успешно завершен', + parallelTip: { + click: { + title: 'Щелчок', + desc: 'добавить', + }, + drag: { + title: 'Волочить', + desc: 'для подключения', + }, + limit: 'Параллелизм ограничен ветвями {{num}}.', + depthLimit: 'Ограничение на количество слоев параллельной вложенности {{num}}', + }, + parallelRun: 'Параллельный прогон', + disconnect: 'Разъединять', + jumpToNode: 'Перейти к этому узлу', + addParallelNode: 'Добавить параллельный узел', + }, + env: { + envPanelTitle: 'Переменные среды', + envDescription: 'Переменные среды могут использоваться для хранения конфиденциальной информации и учетных данных. Они доступны только для чтения и могут быть отделены от файла DSL во время экспорта.', + envPanelButton: 'Добавить переменную', + modal: { + title: 'Добавить переменную среды', + editTitle: 'Редактировать переменную среды', + type: 'Тип', + name: 'Имя', + namePlaceholder: 'Имя переменной среды', + value: 'Значение', + valuePlaceholder: 'Значение переменной среды', + secretTip: 'Используется для определения конфиденциальной информации или данных, с настройками DSL, настроенными для предотвращения утечки.', + }, + export: { + title: 'Экспортировать секретные переменные среды?', + checkbox: 'Экспортировать секретные значения', + ignore: 'Экспортировать DSL', + export: 'Экспортировать DSL с секретными значениями ', + }, + }, + chatVariable: { + panelTitle: 'Переменные разговора', + panelDescription: 'Переменные разговора используются для хранения интерактивной информации, которую LLM необходимо запомнить, включая историю разговоров, загруженные файлы, пользовательские настройки. Они доступны для чтения и записи. ', + docLink: 'Посетите нашу документацию, чтобы узнать больше.', + button: 'Добавить переменную', + modal: { + title: 'Добавить переменную разговора', + editTitle: 'Редактировать переменную разговора', + name: 'Имя', + namePlaceholder: 'Имя переменной', + type: 'Тип', + value: 'Значение по умолчанию', + valuePlaceholder: 'Значение по умолчанию, оставьте пустым, чтобы не устанавливать', + description: 'Описание', + descriptionPlaceholder: 'Опишите переменную', + editInJSON: 'Редактировать в JSON', + oneByOne: 'Добавлять по одному', + editInForm: 'Редактировать в форме', + arrayValue: 'Значение', + addArrayValue: 'Добавить значение', + objectKey: 'Ключ', + objectType: 'Тип', + objectValue: 'Значение по умолчанию', + }, + storedContent: 'Сохраненный контент', + updatedAt: 'Обновлено в ', + }, + changeHistory: { + title: 'История изменений', + placeholder: 'Вы еще ничего не изменили', + clearHistory: 'Очистить историю', + hint: 'Подсказка', + hintText: 'Ваши действия по редактированию отслеживаются в истории изменений, которая хранится на вашем устройстве в течение этого сеанса. Эта история будет очищена, когда вы покинете редактор.', + stepBackward_one: '{{count}} шаг назад', + stepBackward_other: '{{count}} шагов назад', + stepForward_one: '{{count}} шаг вперед', + stepForward_other: '{{count}} шагов вперед', + sessionStart: 'Начало сеанса', + currentState: 'Текущее состояние', + nodeTitleChange: 'Изменено название блока', + nodeDescriptionChange: 'Изменено описание блока', + nodeDragStop: 'Блок перемещен', + nodeChange: 'Блок изменен', + nodeConnect: 'Блок подключен', + nodePaste: 'Блок вставлен', + nodeDelete: 'Блок удален', + nodeAdd: 'Блок добавлен', + nodeResize: 'Размер блока изменен', + noteAdd: 'Заметка добавлена', + noteChange: 'Заметка изменена', + noteDelete: 'Заметка удалена', + edgeDelete: 'Блок отключен', + }, + errorMsg: { + fieldRequired: '{{field}} обязательно для заполнения', + authRequired: 'Требуется авторизация', + invalidJson: '{{field}} неверный JSON', + fields: { + variable: 'Имя переменной', + variableValue: 'Значение переменной', + code: 'Код', + model: 'Модель', + rerankModel: 'Модель переранжирования', + }, + invalidVariable: 'Неверная переменная', + }, + singleRun: { + testRun: 'Тестовый запуск ', + startRun: 'Начать запуск', + running: 'Выполняется', + testRunIteration: 'Итерация тестового запуска', + back: 'Назад', + iteration: 'Итерация', + }, + tabs: { + 'searchBlock': 'Поиск блока', + 'blocks': 'Блоки', + 'searchTool': 'Поиск инструмента', + 'tools': 'Инструменты', + 'allTool': 'Все', + 'builtInTool': 'Встроенные', + 'customTool': 'Пользовательские', + 'workflowTool': 'Рабочий процесс', + 'question-understand': 'Понимание вопроса', + 'logic': 'Логика', + 'transform': 'Преобразование', + 'utilities': 'Утилиты', + 'noResult': 'Ничего не найдено', + }, + blocks: { + 'start': 'Начало', + 'end': 'Конец', + 'answer': 'Ответ', + 'llm': 'LLM', + 'knowledge-retrieval': 'Поиск знаний', + 'question-classifier': 'Классификатор вопросов', + 'if-else': 'ЕСЛИ/ИНАЧЕ', + 'code': 'Код', + 'template-transform': 'Шаблон', + 'http-request': 'HTTP-запрос', + 'variable-assigner': 'Агрегатор переменных', + 'variable-aggregator': 'Агрегатор переменных', + 'assigner': 'Назначение переменной', + 'iteration-start': 'Начало итерации', + 'iteration': 'Итерация', + 'parameter-extractor': 'Извлечение параметров', + }, + blocksAbout: { + 'start': 'Определите начальные параметры для запуска рабочего процесса', + 'end': 'Определите конец и тип результата рабочего процесса', + 'answer': 'Определите содержимое ответа в чате', + 'llm': 'Вызов больших языковых моделей для ответа на вопросы или обработки естественного языка', + 'knowledge-retrieval': 'Позволяет запрашивать текстовый контент, связанный с вопросами пользователей, из базы знаний', + 'question-classifier': 'Определите условия классификации вопросов пользователей, LLM может определить, как будет развиваться разговор на основе описания классификации', + 'if-else': 'Позволяет разделить рабочий процесс на две ветки на основе условий if/else', + 'code': 'Выполните фрагмент кода Python или NodeJS для реализации пользовательской логики', + 'template-transform': 'Преобразование данных в строку с использованием синтаксиса шаблонов Jinja', + 'http-request': 'Разрешить отправку запросов на сервер по протоколу HTTP', + 'variable-assigner': 'Объединение переменных из нескольких ветвей в одну переменную для унифицированной настройки подчиненных узлов.', + 'assigner': 'Узел назначения переменной используется для назначения значений записываемым переменным (например, переменным разговора).', + 'variable-aggregator': 'Объединение переменных из нескольких ветвей в одну переменную для унифицированной настройки подчиненных узлов.', + 'iteration': 'Выполнение нескольких шагов над объектом списка до тех пор, пока не будут выведены все результаты.', + 'parameter-extractor': 'Используйте LLM для извлечения структурированных параметров из естественного языка для вызова инструментов или HTTP-запросов.', + }, + operator: { + zoomIn: 'Увеличить', + zoomOut: 'Уменьшить', + zoomTo50: 'Масштаб 50%', + zoomTo100: 'Масштаб 100%', + zoomToFit: 'По размеру', + }, + panel: { + userInputField: 'Поле ввода пользователя', + changeBlock: 'Изменить блок', + helpLink: 'Ссылка на справку', + about: 'О программе', + createdBy: 'Создано ', + nextStep: 'Следующий шаг', + addNextStep: 'Добавить следующий блок в этот рабочий процесс', + selectNextStep: 'Выбрать следующий блок', + runThisStep: 'Выполнить этот шаг', + checklist: 'Контрольный список', + checklistTip: 'Убедитесь, что все проблемы решены перед публикацией', + checklistResolved: 'Все проблемы решены', + organizeBlocks: 'Организовать блоки', + change: 'Изменить', + optional: '(необязательно)', + }, + nodes: { + common: { + outputVars: 'Выходные переменные', + insertVarTip: 'Вставить переменную', + memory: { + memory: 'Память', + memoryTip: 'Настройки памяти чата', + windowSize: 'Размер окна', + conversationRoleName: 'Имя роли разговора', + user: 'Префикс пользователя', + assistant: 'Префикс помощника', + }, + memories: { + title: 'Воспоминания', + tip: 'Память чата', + builtIn: 'Встроенные', + }, + }, + start: { + required: 'обязательно', + inputField: 'Поле ввода', + builtInVar: 'Встроенные переменные', + outputVars: { + query: 'Ввод пользователя', + memories: { + des: 'История разговоров', + type: 'тип сообщения', + content: 'содержимое сообщения', + }, + files: 'Список файлов', + }, + noVarTip: 'Установите входные данные, которые можно использовать в рабочем процессе', + }, + end: { + outputs: 'Выходы', + output: { + type: 'тип вывода', + variable: 'выходная переменная', + }, + type: { + 'none': 'Нет', + 'plain-text': 'Простой текст', + 'structured': 'Структурированный', + }, + }, + answer: { + answer: 'Ответ', + outputVars: 'Выходные переменные', + }, + llm: { + model: 'модель', + variables: 'переменные', + context: 'контекст', + contextTooltip: 'Вы можете импортировать знания как контекст', + notSetContextInPromptTip: 'Чтобы включить функцию контекста, пожалуйста, заполните переменную контекста в PROMPT.', + prompt: 'подсказка', + roleDescription: { + system: 'Дайте высокоуровневые инструкции для разговора', + user: 'Предоставьте инструкции, запросы или любой текстовый ввод для модели', + assistant: 'Ответы модели на основе сообщений пользователя', + }, + addMessage: 'Добавить сообщение', + vision: 'зрение', + files: 'Файлы', + resolution: { + name: 'Разрешение', + high: 'Высокое', + low: 'Низкое', + }, + outputVars: { + output: 'Создать контент', + usage: 'Информация об использовании модели', + }, + singleRun: { + variable: 'Переменная', + }, + sysQueryInUser: 'sys.query в сообщении пользователя обязателен', + }, + knowledgeRetrieval: { + queryVariable: 'Переменная запроса', + knowledge: 'Знания', + outputVars: { + output: 'Извлеченные сегментированные данные', + content: 'Сегментированный контент', + title: 'Сегментированный заголовок', + icon: 'Сегментированный значок', + url: 'Сегментированный URL', + metadata: 'Другие метаданные', + }, + }, + http: { + inputVars: 'Входные переменные', + api: 'API', + apiPlaceholder: 'Введите URL, введите "/" для вставки переменной', + notStartWithHttp: 'API должен начинаться с http:// или https://', + key: 'Ключ', + value: 'Значение', + bulkEdit: 'Массовое редактирование', + keyValueEdit: 'Редактирование ключа-значения', + headers: 'Заголовки', + params: 'Параметры', + body: 'Тело', + outputVars: { + body: 'Содержимое ответа', + statusCode: 'Код состояния ответа', + headers: 'Список заголовков ответа JSON', + files: 'Список файлов', + }, + authorization: { + 'authorization': 'Авторизация', + 'authorizationType': 'Тип авторизации', + 'no-auth': 'Нет', + 'api-key': 'API-ключ', + 'auth-type': 'Тип аутентификации', + 'basic': 'Базовая', + 'bearer': 'Bearer', + 'custom': 'Пользовательская', + 'api-key-title': 'API-ключ', + 'header': 'Заголовок', + }, + insertVarPlaceholder: 'введите "/" для вставки переменной', + timeout: { + title: 'Тайм-аут', + connectLabel: 'Тайм-аут подключения', + connectPlaceholder: 'Введите тайм-аут подключения в секундах', + readLabel: 'Тайм-аут чтения', + readPlaceholder: 'Введите тайм-аут чтения в секундах', + writeLabel: 'Тайм-аут записи', + writePlaceholder: 'Введите тайм-аут записи в секундах', + }, + }, + code: { + inputVars: 'Входные переменные', + outputVars: 'Выходные переменные', + advancedDependencies: 'Расширенные зависимости', + advancedDependenciesTip: 'Добавьте сюда некоторые предварительно загруженные зависимости, которые занимают больше времени для потребления или не являются встроенными по умолчанию', + searchDependencies: 'Поиск зависимостей', + }, + templateTransform: { + inputVars: 'Входные переменные', + code: 'Код', + codeSupportTip: 'Поддерживает только Jinja2', + outputVars: { + output: 'Преобразованный контент', + }, + }, + ifElse: { + if: 'Если', + else: 'Иначе', + elseDescription: 'Используется для определения логики, которая должна быть выполнена, когда условие if не выполняется.', + and: 'и', + or: 'или', + operator: 'Оператор', + notSetVariable: 'Пожалуйста, сначала установите переменную', + comparisonOperator: { + 'contains': 'содержит', + 'not contains': 'не содержит', + 'start with': 'начинается с', + 'end with': 'заканчивается на', + 'is': 'равно', + 'is not': 'не равно', + 'empty': 'пусто', + 'not empty': 'не пусто', + 'null': 'null', + 'not null': 'не null', + 'regex match': 'Совпадение с регулярным выражением', + }, + enterValue: 'Введите значение', + addCondition: 'Добавить условие', + conditionNotSetup: 'Условие НЕ настроено', + selectVariable: 'Выберите переменную...', + }, + variableAssigner: { + title: 'Назначить переменные', + outputType: 'Тип вывода', + varNotSet: 'Переменная не установлена', + noVarTip: 'Добавьте переменные, которые нужно назначить', + type: { + string: 'Строка', + number: 'Число', + object: 'Объект', + array: 'Массив', + }, + aggregationGroup: 'Группа агрегации', + aggregationGroupTip: 'Включение этой функции позволяет агрегатору переменных агрегировать несколько наборов переменных.', + addGroup: 'Добавить группу', + outputVars: { + varDescribe: 'Вывод {{groupName}}', + }, + setAssignVariable: 'Установить переменную назначения', + }, + assigner: { + 'assignedVariable': 'Назначенная переменная', + 'writeMode': 'Режим записи', + 'writeModeTip': 'Режим добавления: доступен только для переменных массива.', + 'over-write': 'Перезаписать', + 'append': 'Добавить', + 'plus': 'Плюс', + 'clear': 'Очистить', + 'setVariable': 'Установить переменную', + 'variable': 'Переменная', + }, + tool: { + toAuthorize: 'Авторизовать', + inputVars: 'Входные переменные', + outputVars: { + text: 'контент, сгенерированный инструментом', + files: { + title: 'файлы, сгенерированные инструментом', + type: 'Поддерживаемый тип. Сейчас поддерживаются только изображения', + transfer_method: 'Метод передачи. Значение - remote_url или local_file', + url: 'URL изображения', + upload_file_id: 'Идентификатор загруженного файла', + }, + json: 'json, сгенерированный инструментом', + }, + }, + questionClassifiers: { + model: 'модель', + inputVars: 'Входные переменные', + outputVars: { + className: 'Имя класса', + }, + class: 'Класс', + classNamePlaceholder: 'Введите имя вашего класса', + advancedSetting: 'Расширенные настройки', + topicName: 'Название темы', + topicPlaceholder: 'Введите название вашей темы', + addClass: 'Добавить класс', + instruction: 'Инструкция', + instructionTip: 'Введите дополнительные инструкции, чтобы помочь классификатору вопросов лучше понять, как классифицировать вопросы.', + instructionPlaceholder: 'Введите вашу инструкцию', + }, + parameterExtractor: { + inputVar: 'Входная переменная', + extractParameters: 'Извлечь параметры', + importFromTool: 'Импортировать из инструментов', + addExtractParameter: 'Добавить параметр для извлечения', + addExtractParameterContent: { + name: 'Имя', + namePlaceholder: 'Имя извлекаемого параметра', + type: 'Тип', + typePlaceholder: 'Тип извлекаемого параметра', + description: 'Описание', + descriptionPlaceholder: 'Описание извлекаемого параметра', + required: 'Обязательный', + requiredContent: 'Обязательный используется только в качестве ссылки для вывода модели, а не для обязательной проверки вывода параметра.', + }, + extractParametersNotSet: 'Параметры для извлечения не настроены', + instruction: 'Инструкция', + instructionTip: 'Введите дополнительные инструкции, чтобы помочь извлекателю параметров понять, как извлекать параметры.', + advancedSetting: 'Расширенные настройки', + reasoningMode: 'Режим рассуждения', + reasoningModeTip: 'Вы можете выбрать соответствующий режим рассуждения, основываясь на способности модели реагировать на инструкции для вызова функций или подсказки.', + isSuccess: 'Успешно. В случае успеха значение равно 1, в случае сбоя - 0.', + errorReason: 'Причина ошибки', + }, + iteration: { + deleteTitle: 'Удалить узел итерации?', + deleteDesc: 'Удаление узла итерации приведет к удалению всех дочерних узлов', + input: 'Вход', + output: 'Выходные переменные', + iteration_one: '{{count}} Итерация', + iteration_other: '{{count}} Итераций', + currentIteration: 'Текущая итерация', + }, + note: { + addNote: 'Добавить заметку', + editor: { + placeholder: 'Напишите свою заметку...', + small: 'Маленький', + medium: 'Средний', + large: 'Большой', + bold: 'Жирный', + italic: 'Курсив', + strikethrough: 'Зачеркнутый', + link: 'Ссылка', + openLink: 'Открыть', + unlink: 'Удалить ссылку', + enterUrl: 'Введите URL...', + invalidUrl: 'Неверный URL', + bulletList: 'Маркированный список', + showAuthor: 'Показать автора', + }, + }, + }, + tracing: { + stopBy: 'Остановлено {{user}}', + }, +} + +export default translation diff --git a/web/i18n/tr-TR/app-api.ts b/web/i18n/tr-TR/app-api.ts index be6466f001..9a64de546b 100644 --- a/web/i18n/tr-TR/app-api.ts +++ b/web/i18n/tr-TR/app-api.ts @@ -10,7 +10,7 @@ const translation = { pause: 'Duraklat', playing: 'Oynatılıyor', loading: 'Yükleniyor', - merMaind: { + merMaid: { rerender: 'Yeniden İşleme', }, never: 'Asla', diff --git a/web/i18n/tr-TR/app-debug.ts b/web/i18n/tr-TR/app-debug.ts index fbf51535fe..f08d221d45 100644 --- a/web/i18n/tr-TR/app-debug.ts +++ b/web/i18n/tr-TR/app-debug.ts @@ -301,7 +301,7 @@ const translation = { historyNoBeEmpty: 'Konuşma geçmişi prompt\'ta ayarlanmalıdır', queryNoBeEmpty: 'Sorgu prompt\'ta ayarlanmalıdır', }, - variableConig: { + variableConfig: { addModalTitle: 'Giriş Alanı Ekle', editModalTitle: 'Giriş Alanı Düzenle', description: 'Değişken ayarı {{varName}}', diff --git a/web/i18n/tr-TR/app-overview.ts b/web/i18n/tr-TR/app-overview.ts index 77a54dc4b3..721bac0000 100644 --- a/web/i18n/tr-TR/app-overview.ts +++ b/web/i18n/tr-TR/app-overview.ts @@ -48,6 +48,8 @@ const translation = { title: 'Workflow Adımları', show: 'Göster', hide: 'Gizle', + showDesc: 'WebApp\'te iş akışı ayrıntılarını gösterme veya gizleme', + subTitle: 'İş Akışı Detayları', }, chatColorTheme: 'Sohbet renk teması', chatColorThemeDesc: 'Sohbet botunun renk temasını ayarlayın', @@ -64,6 +66,12 @@ const translation = { customDisclaimerPlaceholder: 'Özel ifşa metnini girin', customDisclaimerTip: 'Özel ifşa metni istemci tarafında görüntülenecek ve uygulama hakkında ek bilgiler sağlayacak', }, + sso: { + title: 'WebApp SSO\'su', + tooltip: 'WebApp SSO\'yu etkinleştirmek için yöneticiyle iletişime geçin', + label: 'SSO Kimlik Doğrulaması', + description: 'Tüm kullanıcıların WebApp\'i kullanmadan önce SSO ile oturum açmaları gerekir', + }, }, embedded: { entry: 'Gömülü', @@ -119,7 +127,11 @@ const translation = { tokenPS: 'Token/s', totalMessages: { title: 'Toplam Mesajlar', - explanation: 'Günlük AI etkileşim sayısı; prompt mühendisliği/hata ayıklama hariç.', + explanation: 'Günlük AI etkileşimi sayısı.', + }, + totalConversations: { + title: 'Toplam Konuşmalar', + explanation: 'Günlük AI konuşmaları sayısı; prompt mühendisliği/hata ayıklama hariç.', }, activeUsers: { title: 'Aktif Kullanıcılar', diff --git a/web/i18n/tr-TR/app.ts b/web/i18n/tr-TR/app.ts index fb1ac36762..09cb680f50 100644 --- a/web/i18n/tr-TR/app.ts +++ b/web/i18n/tr-TR/app.ts @@ -122,6 +122,12 @@ const translation = { removeConfirmTitle: '{{key}} yapılandırmasını kaldır?', removeConfirmContent: 'Mevcut yapılandırma kullanımda, kaldırılması İzleme özelliğini kapatacaktır.', }, + view: 'Görünüm', + }, + answerIcon: { + descriptionInExplore: 'Keşfet\'te değiştirilecek 🤖 WebApp simgesinin kullanılıp kullanılmayacağı', + title: 'Değiştirmek 🤖 için WebApp simgesini kullanın', + description: 'Paylaşılan uygulamada değiştirmek 🤖 için WebApp simgesinin kullanılıp kullanılmayacağı', }, } diff --git a/web/i18n/tr-TR/common.ts b/web/i18n/tr-TR/common.ts index a194ffd769..a41925cd20 100644 --- a/web/i18n/tr-TR/common.ts +++ b/web/i18n/tr-TR/common.ts @@ -37,6 +37,7 @@ const translation = { params: 'Parametreler', duplicate: 'Çoğalt', rename: 'Yeniden Adlandır', + audioSourceUnavailable: 'AudioSource kullanılamıyor', }, errorMsg: { fieldRequired: '{{field}} gereklidir', @@ -132,7 +133,8 @@ const translation = { workspace: 'Çalışma Alanı', createWorkspace: 'Çalışma Alanı Oluştur', helpCenter: 'Yardım', - roadmapAndFeedback: 'Geri Bildirim', + communityFeedback: 'Geri Bildirim', + roadmap: 'Yol haritası', community: 'Topluluk', about: 'Hakkında', logout: 'Çıkış Yap', @@ -198,7 +200,7 @@ const translation = { invitationSent: 'Davet gönderildi', invitationSentTip: 'Davet gönderildi, Dify\'ye giriş yaparak takım verilerinize erişebilirler.', invitationLink: 'Davet Linki', - failedinvitationEmails: 'Aşağıdaki kullanıcılar başarıyla davet edilmedi', + failedInvitationEmails: 'Aşağıdaki kullanıcılar başarıyla davet edilmedi', ok: 'Tamam', removeFromTeam: 'Takımdan Kaldır', removeFromTeamTip: 'Takım erişimi kaldırılacak', @@ -206,7 +208,7 @@ const translation = { setMember: 'Normal üye olarak ayarla', setBuilder: 'Oluşturucu olarak ayarla', setEditor: 'Editör olarak ayarla', - disinvite: 'Davetiyeyi iptal et', + disInvite: 'Davetiyeyi iptal et', deleteMember: 'Üyeyi Sil', you: '(Siz)', }, diff --git a/web/i18n/tr-TR/dataset-creation.ts b/web/i18n/tr-TR/dataset-creation.ts index b35cbc26b2..b26608c39f 100644 --- a/web/i18n/tr-TR/dataset-creation.ts +++ b/web/i18n/tr-TR/dataset-creation.ts @@ -50,7 +50,7 @@ const translation = { input: 'Bilgi adı', placeholder: 'Lütfen girin', nameNotEmpty: 'Ad boş olamaz', - nameLengthInvaild: 'Ad 1 ile 40 karakter arasında olmalıdır', + nameLengthInvalid: 'Ad 1 ile 40 karakter arasında olmalıdır', cancelButton: 'İptal', confirmButton: 'Oluştur', failed: 'Oluşturma başarısız', @@ -109,8 +109,8 @@ const translation = { QATitle: 'Soru ve Yanıt formatında parçalama', QATip: 'Bu seçeneği etkinleştirmek daha fazla token tüketecektir', QALanguage: 'Kullanarak parçalara ayır', - emstimateCost: 'Tahmin', - emstimateSegment: 'Tahmini parçalar', + estimateCost: 'Tahmin', + estimateSegment: 'Tahmini parçalar', segmentCount: 'parçalar', calculating: 'Hesaplanıyor...', fileSource: 'Belgeleri ön işleme', @@ -135,8 +135,8 @@ const translation = { previewSwitchTipStart: 'Geçerli parça önizlemesi metin formatındadır, soru ve yanıt formatına geçiş ek tüketir', previewSwitchTipEnd: 'token', characters: 'karakterler', - indexSettedTip: 'Dizin yöntemini değiştirmek için, lütfen', - retrivalSettedTip: 'Dizin yöntemini değiştirmek için, lütfen', + indexSettingTip: 'Dizin yöntemini değiştirmek için, lütfen', + retrievalSettingTip: 'Dizin yöntemini değiştirmek için, lütfen', datasetSettingLink: 'Bilgi ayarlarına gidin.', }, stepThree: { diff --git a/web/i18n/tr-TR/dataset.ts b/web/i18n/tr-TR/dataset.ts index 31d483f504..5e55e071c5 100644 --- a/web/i18n/tr-TR/dataset.ts +++ b/web/i18n/tr-TR/dataset.ts @@ -53,6 +53,7 @@ const translation = { semantic_search: 'VEKTÖR', full_text_search: 'TAM METİN', hybrid_search: 'HİBRİT', + invertedIndex: 'TERS', }, mixtureHighQualityAndEconomicTip: 'Yüksek kaliteli ve ekonomik bilgi tabanlarının karışımı için Yeniden Sıralama modeli gereklidir.', inconsistentEmbeddingModelTip: 'Seçilen bilgi tabanlarının Yerleştirme modelleri tutarsızsa Yeniden Sıralama modeli gereklidir.', @@ -70,6 +71,7 @@ const translation = { nTo1RetrievalLegacy: 'Geri alım stratejisinin optimizasyonu ve yükseltilmesi nedeniyle, N-to-1 geri alımı Eylül ayında resmi olarak kullanım dışı kalacaktır. O zamana kadar normal şekilde kullanabilirsiniz.', nTo1RetrievalLegacyLink: 'Daha fazla bilgi edin', nTo1RetrievalLegacyLinkText: 'N-1 geri alma Eylül ayında resmi olarak kullanımdan kaldırılacaktır.', + defaultRetrievalTip: 'Varsayılan olarak çok alma kullanılır. Bilgi, birden fazla bilgi tabanından alınır ve ardından yeniden sıralanır.', } export default translation diff --git a/web/i18n/tr-TR/login.ts b/web/i18n/tr-TR/login.ts index 617b58be36..8f0a9eff89 100644 --- a/web/i18n/tr-TR/login.ts +++ b/web/i18n/tr-TR/login.ts @@ -32,7 +32,7 @@ const translation = { pp: 'Gizlilik Politikası', tosDesc: 'Kaydolarak, Hizmet Şartlarımızı kabul etmiş olursunuz', goToInit: 'Hesabı başlatmadıysanız, lütfen başlatma sayfasına gidin', - donthave: 'Sahip değil misiniz?', + dontHave: 'Sahip değil misiniz?', invalidInvitationCode: 'Geçersiz davet kodu', accountAlreadyInited: 'Hesap zaten başlatılmış', forgotPassword: 'Şifrenizi mi unuttunuz?', diff --git a/web/i18n/tr-TR/share-app.ts b/web/i18n/tr-TR/share-app.ts index 4fe58a8b2b..26c6f56fb4 100644 --- a/web/i18n/tr-TR/share-app.ts +++ b/web/i18n/tr-TR/share-app.ts @@ -2,7 +2,7 @@ const translation = { common: { welcome: '', appUnavailable: 'Uygulama kullanılamıyor', - appUnkonwError: 'Uygulama kullanılamıyor', + appUnknownError: 'Uygulama kullanılamıyor', }, chat: { newChat: 'Yeni sohbet', @@ -10,7 +10,7 @@ const translation = { unpinnedTitle: 'Sohbetler', newChatDefaultName: 'Yeni konuşma', resetChat: 'Konuşmayı sıfırla', - powerBy: 'Tarafından desteklenmektedir', + poweredBy: 'Tarafından desteklenmektedir', prompt: 'Prompt', privatePromptConfigTitle: 'Konuşma ayarları', publicPromptConfigTitle: 'Başlangıç Promptu', diff --git a/web/i18n/tr-TR/workflow.ts b/web/i18n/tr-TR/workflow.ts index dd5fe17c37..96313d6d6b 100644 --- a/web/i18n/tr-TR/workflow.ts +++ b/web/i18n/tr-TR/workflow.ts @@ -36,7 +36,7 @@ const translation = { searchVar: 'Değişkeni ara', variableNamePlaceholder: 'Değişken adı', setVarValuePlaceholder: 'Değişkeni ayarla', - needConnecttip: 'Bu adım hiçbir şeye bağlı değil', + needConnectTip: 'Bu adım hiçbir şeye bağlı değil', maxTreeDepth: 'Her dal için maksimum {{depth}} düğüm limiti', needEndNode: 'Son blok eklenmelidir', needAnswerNode: 'Yanıt bloğu eklenmelidir', @@ -77,6 +77,22 @@ const translation = { overwriteAndImport: 'Üzerine Yaz ve İçe Aktar', importFailure: 'İçe Aktarma Başarısız', importSuccess: 'İçe Aktarma Başarılı', + parallelTip: { + click: { + desc: 'Eklemek için', + title: 'Tık', + }, + drag: { + title: 'Sürükleme', + desc: 'Bağlanmak için', + }, + depthLimit: '{{num}} katmanlarının paralel iç içe geçme katmanı sınırı', + limit: 'Paralellik {{num}} dallarıyla sınırlıdır.', + }, + jumpToNode: 'Bu düğüme atla', + addParallelNode: 'Paralel Düğüm Ekle', + disconnect: 'Ayırmak', + parallelRun: 'Paralel Koşu', }, env: { envPanelTitle: 'Çevre Değişkenleri', @@ -186,6 +202,7 @@ const translation = { 'transform': 'Dönüştür', 'utilities': 'Yardımcı Araçlar', 'noResult': 'Eşleşen bulunamadı', + 'searchTool': 'Arama aracı', }, blocks: { 'start': 'Başlat', @@ -365,6 +382,7 @@ const translation = { 'custom': 'Özel', 'api-key-title': 'API Anahtarı', 'header': 'Başlık', + 'auth-type': 'Kimlik Doğrulama Türü', }, insertVarPlaceholder: 'değişkeni eklemek için \'/\' yazın', timeout: { @@ -411,13 +429,13 @@ const translation = { 'not empty': 'boş değil', 'null': 'null', 'not null': 'null değil', + 'regex match': 'normal ifade maçı', }, enterValue: 'Değer girin', addCondition: 'Koşul Ekle', conditionNotSetup: 'Koşul AYARLANMADI', selectVariable: 'Değişken seçin...', }, - variableAssigner: { title: 'Değişken ata', outputType: 'Çıktı Türü', diff --git a/web/i18n/uk-UA/app-api.ts b/web/i18n/uk-UA/app-api.ts index 62385e43e7..6f46e9a9b4 100644 --- a/web/i18n/uk-UA/app-api.ts +++ b/web/i18n/uk-UA/app-api.ts @@ -9,7 +9,7 @@ const translation = { play: 'Відтворити', pause: 'Пауза', playing: 'Відтворення', - merMaind: { + merMaid: { rerender: 'Повторити рендер', }, never: 'Ніколи', @@ -61,6 +61,23 @@ const translation = { pathParams: 'Параметри шляху', query: 'Запит', }, + completionMode: { + messageIDTip: 'Ідентифікатор повідомлення', + streaming: 'Потокове передавання повертається. Реалізація повернення потокового мовлення на основі SSE (Server-Sent Events).', + blocking: 'Тип блокування, очікування завершення виконання та повернення результатів. (Запити можуть бути перервані, якщо процес тривалий)', + title: 'API програми для завершення', + ratingTip: 'Подобається чи не подобається, null – це скасувати', + createCompletionApiTip: 'Створіть повідомлення про завершення, щоб підтримувати режим запитань і відповідей.', + parametersApi: 'Отримання інформації про параметри програми', + queryTips: 'Текстовий контент, що вводиться користувачем.', + createCompletionApi: 'Створити повідомлення про завершення', + messageFeedbackApi: 'Відгук у повідомленні (подобається)', + messageFeedbackApiTip: 'Оцінюйте отримані повідомлення від імені кінцевих користувачів з лайками або дизлайками. Ці дані відображаються на сторінці «Журнали та анотації» та використовуються для доопрацювання майбутньої моделі.', + info: 'Для створення високоякісного тексту, такого як статті, резюме та переклади, використовуйте API повідомлень про завершення з введенням користувачем. Генерація тексту залежить від параметрів моделі та шаблонів підказок, встановлених у Dify Prompt Engineering.', + inputsTips: '(Необов\'язково.) Надайте поля введення користувача у вигляді пар ключ-значення, що відповідають змінним у Prompt Eng. Key — це ім\'я змінної, Value — значення параметра. Якщо вибрано тип поля Вибір, надіслане значення має бути одним із попередньо встановлених варіантів.', + parametersApiTip: 'Отримання налаштованих вхідних параметрів, включаючи імена змінних, імена полів, типи та значення за замовчуванням. Зазвичай використовується для відображення цих полів у формі або заповнення значень за замовчуванням після завантаження клієнта.', + }, + loading: 'Завантаження', } export default translation diff --git a/web/i18n/uk-UA/app-debug.ts b/web/i18n/uk-UA/app-debug.ts index 7c0ba45b3c..1fc6981122 100644 --- a/web/i18n/uk-UA/app-debug.ts +++ b/web/i18n/uk-UA/app-debug.ts @@ -259,7 +259,7 @@ const translation = { historyNoBeEmpty: 'Історію розмови необхідно встановити у підказці', // Conversation history must be set in the prompt queryNoBeEmpty: 'Запит має бути встановлений у підказці', // Query must be set in the prompt }, - variableConig: { + variableConfig: { 'addModalTitle': 'Додати Поле Введення', 'editModalTitle': 'Редагувати Поле Введення', 'description': 'Налаштування для змінної {{varName}}', diff --git a/web/i18n/uk-UA/app-overview.ts b/web/i18n/uk-UA/app-overview.ts index 8bd1f0fb39..11c88c5699 100644 --- a/web/i18n/uk-UA/app-overview.ts +++ b/web/i18n/uk-UA/app-overview.ts @@ -48,6 +48,8 @@ const translation = { title: 'Кроки робочого процесу', show: 'Показати', hide: 'Приховати', + subTitle: 'Деталі робочого процесу', + showDesc: 'Відображення або приховування відомостей про робочий процес у веб-програмі', }, chatColorTheme: 'Тема кольору чату', chatColorThemeDesc: 'Встановіть тему кольору чат-бота', @@ -64,6 +66,12 @@ const translation = { customDisclaimerPlaceholder: 'Введіть відмову від відповідальності', customDisclaimerTip: 'Відображається на клієнтському боці, щоб визначити відповідальність за використання додатка', }, + sso: { + title: 'Єдиний вхід для WebApp', + description: 'Усі користувачі повинні увійти в систему за допомогою єдиного входу перед використанням WebApp', + tooltip: 'Зверніться до адміністратора, щоб увімкнути єдиний вхід WebApp', + label: 'Автентифікація за допомогою єдиного входу', + }, }, embedded: { entry: 'Вбудоване', @@ -119,7 +127,11 @@ const translation = { tokenPS: 'Токени/с', totalMessages: { title: 'Загальна кількість повідомлень', - explanation: 'Щоденна кількість взаємодій з штучним інтелектом; інженерія/налагодження запитів виключено.', + explanation: 'Кількість щоденних взаємодій з ШІ.', + }, + totalConversations: { + title: 'Загальна кількість розмов', + explanation: 'Кількість щоденних розмов з ШІ; інженерія/налагодження промптів виключено.', }, activeUsers: { title: 'Активні користувачі', diff --git a/web/i18n/uk-UA/app.ts b/web/i18n/uk-UA/app.ts index fbe9eea81e..c8fb4ca7d4 100644 --- a/web/i18n/uk-UA/app.ts +++ b/web/i18n/uk-UA/app.ts @@ -122,7 +122,17 @@ const translation = { removeConfirmTitle: 'Видалити налаштування {{key}}?', removeConfirmContent: 'Поточне налаштування використовується, його видалення вимкне функцію Відстеження.', }, + view: 'Вид', }, + answerIcon: { + title: 'Використовуйте піктограму WebApp для заміни 🤖', + description: 'Чи слід використовувати піктограму WebApp для заміни 🤖 у спільній програмі', + descriptionInExplore: 'Чи використовувати піктограму веб-програми для заміни 🤖 в Огляді', + }, + importFromDSLUrl: 'З URL', + importFromDSL: 'Імпорт з DSL', + importFromDSLUrlPlaceholder: 'Вставте посилання на DSL тут', + importFromDSLFile: 'З DSL-файлу', } export default translation diff --git a/web/i18n/uk-UA/billing.ts b/web/i18n/uk-UA/billing.ts index afc434e652..cebdb11521 100644 --- a/web/i18n/uk-UA/billing.ts +++ b/web/i18n/uk-UA/billing.ts @@ -58,6 +58,9 @@ const translation = { ragAPIRequest: 'RAG API запити', agentMode: 'Режим агента', workflow: 'Робочий процес', + bulkUpload: 'Масове завантаження документів', + llmLoadingBalancing: 'Балансування навантаження LLM', + llmLoadingBalancingTooltip: 'Додавайте кілька ключів API до моделей, ефективно обходячи обмеження швидкості API.', }, comingSoon: 'Скоро', member: 'Учасник', @@ -72,6 +75,8 @@ const translation = { }, ragAPIRequestTooltip: 'Відноситься до кількості викликів API, що викликають лише можливості обробки бази знань Dify.', receiptInfo: 'Лише власник команди та адміністратор команди можуть підписуватися та переглядати інформацію про виставлення рахунків', + annotationQuota: 'Квота анотацій', + documentsUploadQuota: 'Квота завантаження документів', }, plans: { sandbox: { diff --git a/web/i18n/uk-UA/common.ts b/web/i18n/uk-UA/common.ts index 33324ce0f2..cc70772be3 100644 --- a/web/i18n/uk-UA/common.ts +++ b/web/i18n/uk-UA/common.ts @@ -37,6 +37,7 @@ const translation = { params: 'Параметри', duplicate: 'дублікат', rename: 'Перейменувати', + audioSourceUnavailable: 'AudioSource недоступний', }, placeholder: { input: 'Будь ласка, введіть текст', @@ -128,7 +129,8 @@ const translation = { workspace: 'Робочий простір', createWorkspace: 'Створити робочий простір', helpCenter: 'Довідковий центр', - roadmapAndFeedback: 'відгуки', + communityFeedback: 'відгуки', + roadmap: 'Дорожня карта', community: 'Спільнота', about: 'Про нас', logout: 'Вийти', @@ -190,16 +192,21 @@ const translation = { invitationSent: 'Запрошення надіслано', invitationSentTip: 'Запрошення надіслано, і вони можуть увійти в Dify, щоб отримати доступ до даних вашої команди.', invitationLink: 'Посилання на запрошення', - failedinvitationEmails: 'Наступних користувачів не було успішно запрошено', + failedInvitationEmails: 'Наступних користувачів не було успішно запрошено', ok: 'ОК', removeFromTeam: 'Видалити з команди', removeFromTeamTip: 'Буде видалено доступ до команди', setAdmin: 'Призначити адміністратором', setMember: 'Встановити як звичайного члена', setEditor: 'Встановити як Редактор', - disinvite: 'Скасувати запрошення', + disInvite: 'Скасувати запрошення', deleteMember: 'Видалити учасника', you: '(Ви)', + builder: 'Будівник', + datasetOperatorTip: 'Тільки може управляти базою знань', + datasetOperator: 'Адміністратор знань', + setBuilder: 'Встановити як будівельник', + builderTip: 'Може створювати та редагувати власні програми', }, integrations: { connected: 'Підключено', @@ -344,8 +351,25 @@ const translation = { deprecated: 'Застарілий', confirmDelete: 'підтвердити видалення?', quotaTip: 'Залишилося доступних безкоштовних токенів', - loadPresets: 'Завантажити', // If need adjustment, provide more context on 'Load Presets' function + // If need adjustment, provide more context on 'Load Presets' function + loadPresets: 'Завантажити', parameters: 'ПАРАМЕТРИ', + apiKeyStatusNormal: 'Статус APIKey нормальний', + loadBalancing: 'Балансування навантаження', + editConfig: 'Редагувати конфігурацію', + loadBalancingHeadline: 'Балансування навантаження', + apiKey: 'API-КЛЮЧ', + defaultConfig: 'Конфігурація за замовчуванням', + providerManaged: 'Під управлінням провайдера', + loadBalancingDescription: 'Зменшіть тиск за допомогою кількох наборів облікових даних.', + modelHasBeenDeprecated: 'Ця модель вважається застарілою', + addConfig: 'Додати конфігурацію', + configLoadBalancing: 'Балансування навантаження конфігурації', + upgradeForLoadBalancing: 'Оновіть свій план, щоб увімкнути балансування навантаження.', + apiKeyRateLimit: 'Було досягнуто ліміту швидкості, доступного після {{seconds}}', + providerManagedDescription: 'Використовуйте єдиний набір облікових даних, наданий постачальником моделі.', + loadBalancingLeastKeyWarning: 'Щоб увімкнути балансування навантаження, має бути ввімкнено щонайменше 2 клавіші.', + loadBalancingInfo: 'За замовчуванням для балансування навантаження використовується стратегія кругової системи. Якщо спрацьовує обмеження швидкості, буде застосовано період перезарядки тривалістю 1 хвилина.', }, dataSource: { add: 'Додати джерело даних', @@ -369,6 +393,15 @@ const translation = { preview: 'ПЕРЕДПЕРЕГЛЯД', }, }, + website: { + with: 'З', + active: 'Активний', + inactive: 'Неактивні', + configuredCrawlers: 'Налаштовані обхідні роботи', + title: 'Веб-сторінка', + description: 'Імпортуйте вміст із веб-сайтів за допомогою веб-сканера.', + }, + configure: 'Настроїти', }, plugin: { serpapi: { @@ -537,6 +570,10 @@ const translation = { created: 'Тег створено успішно', failed: 'Не вдалося створити тег', }, + errorMsg: { + fieldRequired: '{{field}} є обов\'язковим', + urlError: 'URL-адреса повинна починатися з http:// або https://', + }, } export default translation diff --git a/web/i18n/uk-UA/dataset-creation.ts b/web/i18n/uk-UA/dataset-creation.ts index 6c0099a771..e4a38f41f4 100644 --- a/web/i18n/uk-UA/dataset-creation.ts +++ b/web/i18n/uk-UA/dataset-creation.ts @@ -45,11 +45,35 @@ const translation = { input: 'Назва Знань', placeholder: 'Введіть, будь ласка', nameNotEmpty: 'Ім’я не може бути порожнім', - nameLengthInvaild: 'Ім’я має бути від 1 до 40 символів', + nameLengthInvalid: 'Ім’я має бути від 1 до 40 символів', cancelButton: 'Скасувати', confirmButton: 'Створити', failed: 'Створення не вдалося', }, + website: { + totalPageScraped: 'Всього вискоблених сторінок:', + run: 'Бігти', + configure: 'Настроїти', + limit: 'Межа', + selectAll: 'Вибрати все', + unknownError: 'Невідома помилка', + maxDepth: 'Максимальна глибина', + crawlSubPage: 'Сканування підсторінок', + firecrawlDocLink: 'https://docs.dify.ai/guides/knowledge-base/sync-from-website', + preview: 'Попередній перегляд', + fireCrawlNotConfigured: 'Firecrawl не налаштовано', + includeOnlyPaths: 'Включати лише контури', + options: 'Параметри', + resetAll: 'Скинути все', + excludePaths: 'Виключення контурів', + firecrawlDoc: 'Документація Firecrawl', + exceptionErrorTitle: 'Виняток стався під час виконання завдання Firecrawl:', + firecrawlTitle: 'Видобування веб-вмісту за допомогою 🔥Firecrawl', + scrapTimeInfo: 'Викрадено {{total}} сторінок загалом протягом {{time}}s', + fireCrawlNotConfiguredDescription: 'Налаштуйте Firecrawl за допомогою ключа API, щоб використовувати його.', + extractOnlyMainContent: 'Витягуйте лише основний контент (без заголовків, навігаторів, нижніх колонтитулів тощо)', + maxDepthTooltip: 'Максимальна глибина для сканування щодо введеної URL-адреси. Глибина 0 просто зішкрібає сторінку введеного url, глибина 1 шкребе url і все після введеногоURL + один /, і так далі.', + }, }, stepTwo: { segmentation: 'Налаштування фрагментації', @@ -80,8 +104,8 @@ const translation = { QATitle: 'Сегментація у форматі "питання та відповідь"', QATip: 'Увімкнення цієї опції споживатиме більше токенів', QALanguage: 'Сегментація з використанням', - emstimateCost: 'Оцінка', - emstimateSegment: 'Орієнтовні фрагменти', + estimateCost: 'Оцінка', + estimateSegment: 'Орієнтовні фрагменти', segmentCount: 'фрагментів', calculating: 'Розраховується...', fileSource: 'Попередня обробка документа', @@ -104,9 +128,11 @@ const translation = { previewSwitchTipStart: 'Поточний попередній перегляд має текстовий формат, зміна способу подання на формат запитань та відповідей ', previewSwitchTipEnd: ' потребує додаткових токенів', characters: 'символів', - indexSettedTip: 'Щоб змінити метод індексування, будь ласка, перейдіть до ', - retrivalSettedTip: 'Щоб змінити метод індексування, будь ласка, перейдіть до ', + indexSettingTip: 'Щоб змінити метод індексування, будь ласка, перейдіть до ', + retrievalSettingTip: 'Щоб змінити метод індексування, будь ласка, перейдіть до ', datasetSettingLink: 'Налаштування знань.', + webpageUnit: 'Сторінок', + websiteSource: 'Веб-сайт попередньої обробки', }, stepThree: { creationTitle: '🎉 Знання створено', @@ -125,6 +151,11 @@ const translation = { modelButtonConfirm: 'Підтвердити', modelButtonCancel: 'Скасувати', }, + firecrawl: { + getApiKeyLinkText: 'Отримайте свій API-ключ від firecrawl.dev', + configFirecrawl: 'Налаштування 🔥Firecrawl', + apiKeyPlaceholder: 'Ключ API від firecrawl.dev', + }, } export default translation diff --git a/web/i18n/uk-UA/dataset-documents.ts b/web/i18n/uk-UA/dataset-documents.ts index 90b686ba08..0b20d534e7 100644 --- a/web/i18n/uk-UA/dataset-documents.ts +++ b/web/i18n/uk-UA/dataset-documents.ts @@ -13,6 +13,8 @@ const translation = { status: 'СТАТУС', action: 'ДІЯ', }, + name: 'Ім\'я', + rename: 'Перейменувати', }, action: { uploadFile: 'Завантажити новий файл', @@ -74,6 +76,7 @@ const translation = { error: 'Помилка імпорту', ok: 'ОК', }, + addUrl: 'Додати URL-адресу', }, metadata: { title: 'Метадані', diff --git a/web/i18n/uk-UA/dataset-settings.ts b/web/i18n/uk-UA/dataset-settings.ts index 4ea1e24f26..85e80902a7 100644 --- a/web/i18n/uk-UA/dataset-settings.ts +++ b/web/i18n/uk-UA/dataset-settings.ts @@ -27,6 +27,8 @@ const translation = { longDescription: ' про метод вибірки, ви можете змінити це будь-коли в налаштуваннях бази знань.', }, save: 'Зберегти', + me: '(Ви)', + permissionsInvitedMembers: 'Часткові члени команди', }, } diff --git a/web/i18n/uk-UA/dataset.ts b/web/i18n/uk-UA/dataset.ts index 3bf59ed33b..1bf6786976 100644 --- a/web/i18n/uk-UA/dataset.ts +++ b/web/i18n/uk-UA/dataset.ts @@ -1,6 +1,7 @@ const translation = { knowledge: 'Знання', - documentCount: ' док.', // Скорочення від 'документів' + // Скорочення від 'документів' + documentCount: ' док.', wordCount: ' тис. слів', appCount: ' пов\'язаних додатків', createDataset: 'Створити Знання', @@ -71,6 +72,7 @@ const translation = { nTo1RetrievalLegacy: 'N-до-1 пошук буде офіційно застарілим з вересня. Рекомендується використовувати найновіший багатошляховий пошук для отримання кращих результатів.', nTo1RetrievalLegacyLink: 'Дізнатися більше', nTo1RetrievalLegacyLinkText: 'N-до-1 пошук буде офіційно застарілим у вересні.', + defaultRetrievalTip: 'За замовчуванням використовується отримання кількома шляхами. Знання витягуються з кількох баз знань, а потім заново ранжуються.', } export default translation diff --git a/web/i18n/uk-UA/login.ts b/web/i18n/uk-UA/login.ts index 46de22bec2..7acc1920fc 100644 --- a/web/i18n/uk-UA/login.ts +++ b/web/i18n/uk-UA/login.ts @@ -31,7 +31,7 @@ const translation = { pp: 'Політика конфіденційності', tosDesc: 'Реєструючись, ви приймаєте наші', goToInit: 'Якщо ви ще не ініціалізували обліковий запис, перейдіть на сторінку ініціалізації', - donthave: 'Не маєте?', + dontHave: 'Не маєте?', invalidInvitationCode: 'Недійсний код запрошення', accountAlreadyInited: 'Обліковий запис уже ініціалізовано', forgotPassword: 'Забули пароль?', @@ -53,6 +53,7 @@ const translation = { nameEmpty: 'Ім\'я обов\'язкове', passwordEmpty: 'Пароль є обов’язковим', passwordInvalid: 'Пароль повинен містити літери та цифри, а довжина повинна бути більшою за 8', + passwordLengthInValid: 'Пароль повинен бути не менше 8 символів', }, license: { tip: 'Перед запуском Dify Community Edition ознайомтеся з ліцензією з відкритим кодом на GitHub', @@ -68,6 +69,7 @@ const translation = { activated: 'Увійти зараз', adminInitPassword: 'Пароль ініціалізації адміністратора', validate: 'Перевірити', + sso: 'Продовжуйте працювати з SSW', } export default translation diff --git a/web/i18n/uk-UA/share-app.ts b/web/i18n/uk-UA/share-app.ts index 9a121aaadc..3465a6e5b9 100644 --- a/web/i18n/uk-UA/share-app.ts +++ b/web/i18n/uk-UA/share-app.ts @@ -2,7 +2,7 @@ const translation = { common: { welcome: '', appUnavailable: 'Додаток недоступний', - appUnkonwError: 'Додаток недоступний', + appUnknownError: 'Додаток недоступний', }, chat: { newChat: 'Новий чат', @@ -10,7 +10,7 @@ const translation = { unpinnedTitle: 'Чати', newChatDefaultName: 'Нова розмова', resetChat: 'Очистити розмову', - powerBy: 'Забезпечується', + poweredBy: 'Забезпечується', prompt: 'Підказка', privatePromptConfigTitle: 'Налаштування розмови', publicPromptConfigTitle: 'Початкова підказка', @@ -27,7 +27,6 @@ const translation = { tryToSolve: 'Спробувати вирішити', temporarySystemIssue: 'Вибачте, тимчасова системна проблема.', }, - generation: { tabs: { create: 'Запустити один раз', @@ -65,7 +64,6 @@ const translation = { moreThanMaxLengthLine: 'Рядок {{rowIndex}}: значення {{varName}} не може містити більше {{maxLength}} символів', atLeastOne: 'Будь ласка, введіть принаймні один рядок у завантажений файл.', }, - }, } diff --git a/web/i18n/uk-UA/tools.ts b/web/i18n/uk-UA/tools.ts index 313332e3a4..309a450afc 100644 --- a/web/i18n/uk-UA/tools.ts +++ b/web/i18n/uk-UA/tools.ts @@ -5,6 +5,7 @@ const translation = { all: 'Усі', builtIn: 'Вбудовані', custom: 'Користувацькі', + workflow: 'Робочий процес', }, contribute: { line1: 'Мені цікаво зробити свій внесок', @@ -67,6 +68,7 @@ const translation = { bearer: 'Bearer', custom: 'Custom', }, + title: 'Тип аутентифікації', }, privacyPolicy: 'Політика конфіденційності', privacyPolicyPlaceholder: 'Введіть політику конфіденційності', @@ -74,8 +76,28 @@ const translation = { customDisclaimerPlaceholder: 'Введіть власні відомості', deleteToolConfirmTitle: 'Видалити цей інструмент?', deleteToolConfirmContent: 'Видалення інструменту є незворотнім. Користувачі більше не зможуть отримати доступ до вашого інструменту.', + toolInput: { + label: 'Мітки', + name: 'Ім\'я', + required: 'Необхідний', + method: 'Метод', + title: 'Введення інструменту', + methodSetting: 'Параметр', + description: 'Опис', + methodParameter: 'Параметр', + labelPlaceholder: 'Виберіть теги (необов\'язково)', + descriptionPlaceholder: 'Опис значення параметра', + methodSettingTip: 'Користувач заповнює конфігурацію інструменту', + methodParameterTip: 'LLM заповнюється під час логічного висновку', + }, + description: 'Опис', + nameForToolCall: 'Ім\'я виклику інструменту', + confirmTitle: 'Підтвердьте, щоб зберегти?', + nameForToolCallTip: 'Підтримує лише цифри, літери та підкреслення.', + confirmTip: 'Це вплине на програми, які використовують цей інструмент', + nameForToolCallPlaceHolder: 'Використовується для розпізнавання машин, таких як getCurrentWeather, list_pets', + descriptionPlaceholder: 'Короткий опис призначення інструменту, наприклад, отримання температури для конкретного місця.', }, - test: { title: 'Тест', parametersValue: 'Параметри та значення', @@ -114,6 +136,18 @@ const translation = { toolRemoved: 'Інструмент видалено', notAuthorized: 'Інструмент не авторизовано', howToGet: 'Як отримати', + addToolModal: { + category: 'категорія', + add: 'Додати', + added: 'Додано', + type: 'тип', + manageInTools: 'Керування в інструментах', + emptyTip: 'Перейдіть до розділу "Робочий процес -> Опублікувати як інструмент"', + emptyTitle: 'Немає доступного інструменту для роботи з робочими процесами', + }, + openInStudio: 'Відкрити в Студії', + customToolTip: 'Дізнайтеся більше про користувацькі інструменти Dify', + toolNameUsageTip: 'Ім\'я виклику інструменту для міркувань і підказок агента', } export default translation diff --git a/web/i18n/uk-UA/workflow.ts b/web/i18n/uk-UA/workflow.ts index 066a245770..03471348c8 100644 --- a/web/i18n/uk-UA/workflow.ts +++ b/web/i18n/uk-UA/workflow.ts @@ -36,7 +36,7 @@ const translation = { searchVar: 'Пошук змінної', variableNamePlaceholder: 'Назва змінної', setVarValuePlaceholder: 'Встановити значення змінної', - needConnecttip: 'Цей крок ні до чого не підключений', + needConnectTip: 'Цей крок ні до чого не підключений', maxTreeDepth: 'Максимальний ліміт {{depth}} вузлів на гілку', needEndNode: 'Потрібно додати кінцевий блок', needAnswerNode: 'Потрібно додати блок відповіді', @@ -69,6 +69,30 @@ const translation = { manageInTools: 'Керування в інструментах', workflowAsToolTip: 'Після оновлення робочого потоку необхідна переконфігурація інструменту.', viewDetailInTracingPanel: 'Переглянути деталі', + importSuccess: 'Успіх імпорту', + overwriteAndImport: 'Перезапис та імпорт', + importFailure: 'Помилка імпорту', + importDSL: 'Імпорт DSL', + syncingData: 'Синхронізація даних, всього за кілька секунд.', + chooseDSL: 'Виберіть файл DSL(yml)', + backupCurrentDraft: 'Резервна поточна чернетка', + importDSLTip: 'Поточна чернетка буде перезаписана. Експортуйте робочий процес як резервну копію перед імпортом.', + parallelTip: { + click: { + title: 'Натисніть', + desc: 'щоб додати', + }, + drag: { + title: 'Перетягувати', + desc: 'Щоб підключити', + }, + limit: 'Паралелізм обмежується {{num}} гілками.', + depthLimit: 'Обмеження рівня паралельного вкладеності шарів {{num}}', + }, + disconnect: 'Відключити', + parallelRun: 'Паралельний біг', + jumpToNode: 'Перейти до цього вузла', + addParallelNode: 'Додати паралельний вузол', }, env: { envPanelTitle: 'Змінні середовища', @@ -178,6 +202,7 @@ const translation = { 'transform': 'Трансформація', 'utilities': 'Утиліти', 'noResult': 'Нічого не знайдено', + 'searchTool': 'Інструмент пошуку', }, blocks: { 'start': 'Початок', @@ -403,10 +428,12 @@ const translation = { 'not empty': 'не порожній', 'null': 'є null', 'not null': 'не є null', + 'regex match': 'Регулярний вираз збігу', }, enterValue: 'Введіть значення', addCondition: 'Додати умову', conditionNotSetup: 'Умова НЕ налаштована', + selectVariable: 'Виберіть змінну...', }, variableAssigner: { title: 'Присвоєння змінних', @@ -502,6 +529,25 @@ const translation = { iteration_other: '{{count}} Ітерацій', currentIteration: 'Поточна ітерація', }, + note: { + editor: { + large: 'Великий', + bold: 'Жирний', + openLink: 'Відкривати', + small: 'Малий', + link: 'Посилання', + italic: 'Курсив', + placeholder: 'Напишіть свою замітку...', + strikethrough: 'Закреслені', + medium: 'Середнє', + showAuthor: 'Показати автора', + bulletList: 'Маркований список', + enterUrl: 'Введіть URL-адресу...', + unlink: 'Від\'єднати', + invalidUrl: 'Невірна URL-адреса', + }, + addNote: 'Додати примітку', + }, }, tracing: { stopBy: 'Зупинено користувачем {{user}}', diff --git a/web/i18n/vi-VN/app-api.ts b/web/i18n/vi-VN/app-api.ts index cb89b98008..2d4ee90e3c 100644 --- a/web/i18n/vi-VN/app-api.ts +++ b/web/i18n/vi-VN/app-api.ts @@ -9,7 +9,7 @@ const translation = { play: 'Chạy', pause: 'Tạm dừng', playing: 'Đang chạy', - merMaind: { + merMaid: { rerender: 'Vẽ lại', }, never: 'Không bao giờ', @@ -77,6 +77,7 @@ const translation = { pathParams: 'Tham số đường dẫn', query: 'Truy vấn', }, + loading: 'Tải', } export default translation diff --git a/web/i18n/vi-VN/app-debug.ts b/web/i18n/vi-VN/app-debug.ts index 906b39d10a..4e8a1962fe 100644 --- a/web/i18n/vi-VN/app-debug.ts +++ b/web/i18n/vi-VN/app-debug.ts @@ -259,7 +259,7 @@ const translation = { historyNoBeEmpty: 'Lịch sử cuộc trò chuyện phải được thiết lập trong lời nhắc', queryNoBeEmpty: 'Truy vấn phải được thiết lập trong lời nhắc', }, - variableConig: { + variableConfig: { 'addModalTitle': 'Thêm trường nhập', 'editModalTitle': 'Chỉnh sửa trường nhập', 'description': 'Cài đặt cho biến {{varName}}', diff --git a/web/i18n/vi-VN/app-log.ts b/web/i18n/vi-VN/app-log.ts index c4df7b512f..30a3988c12 100644 --- a/web/i18n/vi-VN/app-log.ts +++ b/web/i18n/vi-VN/app-log.ts @@ -89,7 +89,9 @@ const translation = { iterations: 'Số lần lặp', iteration: 'Lần lặp', finalProcessing: 'Xử lý cuối cùng', + agentMode: 'Chế độ đại lý', }, + agentLog: 'Nhật ký đại lý', } export default translation diff --git a/web/i18n/vi-VN/app-overview.ts b/web/i18n/vi-VN/app-overview.ts index 55a53d73a2..7cc7428906 100644 --- a/web/i18n/vi-VN/app-overview.ts +++ b/web/i18n/vi-VN/app-overview.ts @@ -48,6 +48,8 @@ const translation = { title: 'Các bước quy trình', show: 'Hiển thị', hide: 'Ẩn', + showDesc: 'Hiển thị hoặc ẩn chi tiết dòng công việc trong WebApp', + subTitle: 'Chi tiết quy trình làm việc', }, chatColorTheme: 'Giao diện màu trò chuyện', chatColorThemeDesc: 'Thiết lập giao diện màu của chatbot', @@ -64,6 +66,12 @@ const translation = { customDisclaimerPlaceholder: 'Nhập liên kết tuyên bố từ chối trách nhiệm', customDisclaimerTip: 'Liên kết này sẽ hiển thị ở phía người dùng, cung cấp thông tin về trách nhiệm của ứng dụng', }, + sso: { + title: 'SSO ứng dụng web', + description: 'Tất cả người dùng được yêu cầu đăng nhập bằng SSO trước khi sử dụng WebApp', + tooltip: 'Liên hệ với quản trị viên để bật SSO WebApp', + label: 'Xác thực SSO', + }, }, embedded: { entry: 'Nhúng', @@ -119,7 +127,11 @@ const translation = { tokenPS: 'Token/giây', totalMessages: { title: 'Tổng số tin nhắn', - explanation: 'Số lần tương tác AI hàng ngày; không tính việc tạo lại/lặp lại câu hỏi.', + explanation: 'Số lượng tương tác AI hàng ngày.', + }, + totalConversations: { + title: 'Tổng số cuộc hội thoại', + explanation: 'Số lượng cuộc hội thoại AI hàng ngày; không bao gồm kỹ thuật/gỡ lỗi prompt.', }, activeUsers: { title: 'Người dùng hoạt động', diff --git a/web/i18n/vi-VN/app.ts b/web/i18n/vi-VN/app.ts index 4052506f83..9e84341f63 100644 --- a/web/i18n/vi-VN/app.ts +++ b/web/i18n/vi-VN/app.ts @@ -122,7 +122,17 @@ const translation = { removeConfirmTitle: 'Xóa cấu hình {{key}}?', removeConfirmContent: 'Cấu hình hiện tại đang được sử dụng, việc xóa nó sẽ tắt tính năng Theo dõi.', }, + view: 'Cảnh', }, + answerIcon: { + description: 'Có nên sử dụng biểu tượng WebApp để thay thế 🤖 trong ứng dụng được chia sẻ hay không', + descriptionInExplore: 'Có nên sử dụng biểu tượng WebApp để thay thế 🤖 trong Khám phá hay không', + title: 'Sử dụng biểu tượng WebApp để thay thế 🤖', + }, + importFromDSLFile: 'Từ tệp DSL', + importFromDSL: 'Nhập từ DSL', + importFromDSLUrlPlaceholder: 'Dán liên kết DSL vào đây', + importFromDSLUrl: 'Từ URL', } export default translation diff --git a/web/i18n/vi-VN/billing.ts b/web/i18n/vi-VN/billing.ts index 71abd8a884..595481e3a4 100644 --- a/web/i18n/vi-VN/billing.ts +++ b/web/i18n/vi-VN/billing.ts @@ -60,6 +60,8 @@ const translation = { bulkUpload: 'Tải lên tài liệu hàng loạt', agentMode: 'Chế độ Đại lý', workflow: 'Quy trình làm việc', + llmLoadingBalancing: 'Cân bằng tải LLM', + llmLoadingBalancingTooltip: 'Thêm nhiều khóa API vào mô hình, vượt qua giới hạn tốc độ API một cách hiệu quả.', }, comingSoon: 'Sắp ra mắt', member: 'Thành viên', @@ -74,6 +76,7 @@ const translation = { }, ragAPIRequestTooltip: 'Đề cập đến số lượng cuộc gọi API triệu hồi chỉ khả năng xử lý cơ sở kiến thức của Dify.', receiptInfo: 'Chỉ chủ nhóm và quản trị viên nhóm có thể đăng ký và xem thông tin thanh toán', + annotationQuota: 'Hạn ngạch chú thích', }, plans: { sandbox: { diff --git a/web/i18n/vi-VN/common.ts b/web/i18n/vi-VN/common.ts index 19855d31f0..252fa7e1df 100644 --- a/web/i18n/vi-VN/common.ts +++ b/web/i18n/vi-VN/common.ts @@ -37,6 +37,7 @@ const translation = { params: 'Tham số', duplicate: 'Nhân bản', rename: 'Đổi tên', + audioSourceUnavailable: 'AudioSource không khả dụng', }, placeholder: { input: 'Vui lòng nhập', @@ -128,7 +129,8 @@ const translation = { workspace: 'Không gian làm việc', createWorkspace: 'Tạo Không gian làm việc', helpCenter: 'Trung tâm trợ giúp', - roadmapAndFeedback: 'Phản hồi', + communityFeedback: 'Phản hồi', + roadmap: 'Lộ trình', community: 'Cộng đồng', about: 'Về chúng tôi', logout: 'Đăng xuất', @@ -190,16 +192,21 @@ const translation = { invitationSent: 'Lời mời đã được gửi', invitationSentTip: 'Lời mời đã được gửi, và họ có thể đăng nhập vào Dify để truy cập vào dữ liệu nhóm của bạn.', invitationLink: 'Liên kết Lời mời', - failedinvitationEmails: 'Dưới đây là danh sách email không gửi được lời mời', + failedInvitationEmails: 'Dưới đây là danh sách email không gửi được lời mời', ok: 'OK', removeFromTeam: 'Xóa khỏi nhóm', removeFromTeamTip: 'Sẽ xóa quyền truy cập nhóm', setAdmin: 'Đặt làm quản trị viên', setMember: 'Đặt thành viên bình thường', setEditor: 'Đặt làm biên tập viên', - disinvite: 'Hủy lời mời', + disInvite: 'Hủy lời mời', deleteMember: 'Xóa thành viên', you: '(Bạn)', + datasetOperatorTip: 'Chỉ có thể quản lý cơ sở kiến thức', + builderTip: 'Có thể xây dựng và chỉnh sửa ứng dụng của riêng mình', + builder: 'Chủ thầu', + datasetOperator: 'Quản trị viên kiến thức', + setBuilder: 'Đặt làm trình tạo', }, integrations: { connected: 'Đã kết nối', @@ -346,6 +353,22 @@ const translation = { quotaTip: 'Số lượng mã thông báo miễn phí còn lại', loadPresets: 'Tải Cài đặt trước', parameters: 'THAM SỐ', + loadBalancingHeadline: 'Cân bằng tải', + loadBalancing: 'Cân bằng tải', + configLoadBalancing: 'Cấu hình cân bằng tải', + defaultConfig: 'Cấu hình mặc định', + modelHasBeenDeprecated: 'Mô hình này đã bị phản đối', + providerManagedDescription: 'Sử dụng bộ thông tin đăng nhập duy nhất do nhà cung cấp mô hình cung cấp.', + apiKeyStatusNormal: 'Trạng thái APIKey bình thường', + editConfig: 'Chỉnh sửa cấu hình', + loadBalancingInfo: 'Theo mặc định, cân bằng tải sử dụng chiến lược Vòng tròn. Nếu giới hạn tốc độ được kích hoạt, thời gian hồi chiêu 1 phút sẽ được áp dụng.', + addConfig: 'Thêm cấu hình', + loadBalancingDescription: 'Giảm áp lực với nhiều bộ thông tin xác thực.', + apiKey: 'KHÓA API', + providerManaged: 'Nhà cung cấp được quản lý', + apiKeyRateLimit: 'Đã đạt đến giới hạn tốc độ, có sẵn sau {{giây}} giây', + upgradeForLoadBalancing: 'Nâng cấp gói của bạn để bật Cân bằng tải.', + loadBalancingLeastKeyWarning: 'Để bật cân bằng tải, ít nhất 2 phím phải được bật.', }, dataSource: { add: 'Thêm nguồn dữ liệu', @@ -369,6 +392,15 @@ const translation = { preview: 'Xem trước', }, }, + website: { + title: 'Trang mạng', + inactive: 'Không hoạt động', + with: 'Với', + active: 'Hoạt động', + configuredCrawlers: 'Trình thu thập thông tin đã định cấu hình', + description: 'Nhập nội dung từ các trang web bằng trình thu thập dữ liệu web.', + }, + configure: 'Cấu hình', }, plugin: { serpapi: { @@ -537,6 +569,10 @@ const translation = { created: 'Thẻ được tạo thành công', failed: 'Tạo thẻ không thành công', }, + errorMsg: { + fieldRequired: '{{trường}} là bắt buộc', + urlError: 'URL phải bắt đầu bằng http:// hoặc https://', + }, } export default translation diff --git a/web/i18n/vi-VN/dataset-creation.ts b/web/i18n/vi-VN/dataset-creation.ts index 23b210d177..da69020287 100644 --- a/web/i18n/vi-VN/dataset-creation.ts +++ b/web/i18n/vi-VN/dataset-creation.ts @@ -45,11 +45,35 @@ const translation = { input: 'Tên Kiến thức', placeholder: 'Vui lòng nhập', nameNotEmpty: 'Tên không thể để trống', - nameLengthInvaild: 'Tên phải từ 1 đến 40 ký tự', + nameLengthInvalid: 'Tên phải từ 1 đến 40 ký tự', cancelButton: 'Hủy', confirmButton: 'Tạo', failed: 'Tạo thất bại', }, + website: { + fireCrawlNotConfigured: 'Firecrawl không được cấu hình', + limit: 'Giới hạn', + run: 'Chạy', + firecrawlDoc: 'Tài liệu Firecrawl', + fireCrawlNotConfiguredDescription: 'Định cấu hình Firecrawl bằng khóa API để sử dụng.', + configure: 'Cấu hình', + scrapTimeInfo: 'Tổng cộng {{tổng}} trang được thu thập trong vòng {{thời gian}}', + options: 'Tùy chọn', + unknownError: 'Lỗi không xác định', + extractOnlyMainContent: 'Chỉ trích xuất nội dung chính (không có đầu trang, điều hướng, chân trang, v.v.)', + exceptionErrorTitle: 'Một ngoại lệ xảy ra trong khi chạy tác vụ Firecrawl:', + firecrawlDocLink: 'https://docs.dify.ai/guides/knowledge-base/sync-from-website', + selectAll: 'Chọn tất cả', + firecrawlTitle: 'Trích xuất nội dung web bằng 🔥Firecrawl', + totalPageScraped: 'Tổng số trang được cạo:', + excludePaths: 'Loại trừ đường dẫn', + includeOnlyPaths: 'Chỉ bao gồm đường dẫn', + maxDepth: 'Độ sâu tối đa', + preview: 'Download', + resetAll: 'Đặt lại tất cả', + crawlSubPage: 'Thu thập dữ liệu các trang phụ', + maxDepthTooltip: 'Độ sâu tối đa cần thu thập dữ liệu so với URL đã nhập. Độ sâu 0 chỉ cần cạo trang của url đã nhập, độ sâu 1 cạo url và mọi thứ sau khi nhậpURL + một /, v.v.', + }, }, stepTwo: { segmentation: 'Cài đặt phân đoạn', @@ -80,8 +104,8 @@ const translation = { QATitle: 'Phân đoạn theo định dạng Câu hỏi & Trả lời', QATip: 'Bật tùy chọn này sẽ tiêu tốn thêm token', QALanguage: 'Phân đoạn bằng', - emstimateCost: 'Ước tính', - emstimateSegment: 'Số đoạn ước tính', + estimateCost: 'Ước tính', + estimateSegment: 'Số đoạn ước tính', segmentCount: 'đoạn', calculating: 'Đang tính toán...', fileSource: 'Tiền xử lý tài liệu', @@ -104,9 +128,11 @@ const translation = { previewSwitchTipStart: 'Xem trước đoạn hiện tại đang ở định dạng văn bản, chuyển sang xem trước dạng câu hỏi và trả lời sẽ', previewSwitchTipEnd: ' tiêu tốn thêm token', characters: 'ký tự', - indexSettedTip: 'Để thay đổi phương pháp chỉ mục, vui lòng đi tới ', - retrivalSettedTip: 'Để thay đổi phương pháp truy xuất, vui lòng đi tới ', + indexSettingTip: 'Để thay đổi phương pháp chỉ mục, vui lòng đi tới ', + retrievalSettingTip: 'Để thay đổi phương pháp truy xuất, vui lòng đi tới ', datasetSettingLink: 'cài đặt Kiến thức.', + websiteSource: 'Trang web tiền xử lý', + webpageUnit: 'Trang', }, stepThree: { creationTitle: '🎉 Kiến thức đã được tạo', @@ -125,6 +151,11 @@ const translation = { modelButtonConfirm: 'Xác nhận', modelButtonCancel: 'Hủy', }, + firecrawl: { + getApiKeyLinkText: 'Lấy khóa API của bạn từ firecrawl.dev', + configFirecrawl: 'Định cấu hình 🔥Firecrawl', + apiKeyPlaceholder: 'Khóa API từ firecrawl.dev', + }, } export default translation diff --git a/web/i18n/vi-VN/dataset-documents.ts b/web/i18n/vi-VN/dataset-documents.ts index 5df6e40718..16570dff6e 100644 --- a/web/i18n/vi-VN/dataset-documents.ts +++ b/web/i18n/vi-VN/dataset-documents.ts @@ -13,6 +13,8 @@ const translation = { status: 'TRẠNG THÁI', action: 'THAO TÁC', }, + rename: 'Rename', + name: 'Tên', }, action: { uploadFile: 'Tải lên tệp mới', @@ -74,6 +76,7 @@ const translation = { error: 'Lỗi nhập', ok: 'OK', }, + addUrl: 'Thêm URL', }, metadata: { title: 'Siêu dữ liệu', diff --git a/web/i18n/vi-VN/dataset-settings.ts b/web/i18n/vi-VN/dataset-settings.ts index e6feb78278..cc68bea7ae 100644 --- a/web/i18n/vi-VN/dataset-settings.ts +++ b/web/i18n/vi-VN/dataset-settings.ts @@ -27,6 +27,8 @@ const translation = { longDescription: ' về phương pháp truy xuất. Bạn có thể thay đổi điều này bất kỳ lúc nào trong cài đặt Kiến thức.', }, save: 'Lưu', + permissionsInvitedMembers: 'Thành viên một phần trong nhóm', + me: '(Bạn)', }, } diff --git a/web/i18n/vi-VN/dataset.ts b/web/i18n/vi-VN/dataset.ts index 81b4597800..a2b9f8d087 100644 --- a/web/i18n/vi-VN/dataset.ts +++ b/web/i18n/vi-VN/dataset.ts @@ -71,6 +71,7 @@ const translation = { nTo1RetrievalLegacy: 'Truy xuất N-đến-1 sẽ chính thức bị loại bỏ từ tháng 9. Khuyến nghị sử dụng truy xuất đa đường dẫn mới nhất để có kết quả tốt hơn.', nTo1RetrievalLegacyLink: 'Tìm hiểu thêm', nTo1RetrievalLegacyLinkText: 'Truy xuất N-đến-1 sẽ chính thức bị loại bỏ vào tháng 9.', + defaultRetrievalTip: 'Truy xuất nhiều đường dẫn được sử dụng theo mặc định. Kiến thức được lấy từ nhiều cơ sở kiến thức và sau đó được xếp hạng lại.', } export default translation diff --git a/web/i18n/vi-VN/login.ts b/web/i18n/vi-VN/login.ts index 8d291c7f33..0ee39ffe2c 100644 --- a/web/i18n/vi-VN/login.ts +++ b/web/i18n/vi-VN/login.ts @@ -31,7 +31,7 @@ const translation = { pp: 'Chính sách bảo mật', tosDesc: 'Bằng cách đăng ký, bạn đồng ý với', goToInit: 'Nếu bạn chưa khởi tạo tài khoản, vui lòng chuyển đến trang khởi tạo', - donthave: 'Chưa có tài khoản?', + dontHave: 'Chưa có tài khoản?', invalidInvitationCode: 'Mã mời không hợp lệ', accountAlreadyInited: 'Tài khoản đã được khởi tạo', forgotPassword: 'Quên mật khẩu?', @@ -53,6 +53,7 @@ const translation = { nameEmpty: 'Vui lòng nhập tên', passwordEmpty: 'Vui lòng nhập mật khẩu', passwordInvalid: 'Mật khẩu phải chứa cả chữ và số, và có độ dài ít nhất 8 ký tự', + passwordLengthInValid: 'Mật khẩu phải có ít nhất 8 ký tự', }, license: { tip: 'Trước khi bắt đầu sử dụng Phiên bản Cộng đồng của Dify, vui lòng đọc', @@ -68,6 +69,7 @@ const translation = { activated: 'Đăng nhập ngay', adminInitPassword: 'Mật khẩu khởi tạo quản trị viên', validate: 'Xác thực', + sso: 'Tiếp tục với SSO', } export default translation diff --git a/web/i18n/vi-VN/share-app.ts b/web/i18n/vi-VN/share-app.ts index d440ad55dc..7078ecc299 100644 --- a/web/i18n/vi-VN/share-app.ts +++ b/web/i18n/vi-VN/share-app.ts @@ -2,7 +2,7 @@ const translation = { common: { welcome: '', appUnavailable: 'Ứng dụng không khả dụng', - appUnkonwError: 'Ứng dụng gặp lỗi không xác định', + appUnknownError: 'Ứng dụng gặp lỗi không xác định', }, chat: { newChat: 'Cuộc trò chuyện mới', @@ -10,7 +10,7 @@ const translation = { unpinnedTitle: 'Trò chuyện', newChatDefaultName: 'Cuộc trò chuyện mới', resetChat: 'Đặt lại cuộc trò chuyện', - powerBy: 'Được cung cấp bởi', + poweredBy: 'Được cung cấp bởi', prompt: 'Lời nhắc', privatePromptConfigTitle: 'Cài đặt cuộc trò chuyện', publicPromptConfigTitle: 'Lời nhắc ban đầu', diff --git a/web/i18n/vi-VN/tools.ts b/web/i18n/vi-VN/tools.ts index 40e16a5fa9..b03a6ccc98 100644 --- a/web/i18n/vi-VN/tools.ts +++ b/web/i18n/vi-VN/tools.ts @@ -5,6 +5,7 @@ const translation = { all: 'Tất cả', builtIn: 'Tích hợp sẵn', custom: 'Tùy chỉnh', + workflow: 'Quy trình làm việc', }, contribute: { line1: 'Tôi quan tâm đến việc ', @@ -75,6 +76,27 @@ const translation = { customDisclaimerPlaceholder: 'Vui lòng nhập tuyên bố từ chối trách nhiệm tùy chỉnh', deleteToolConfirmTitle: 'Xóa công cụ này?', deleteToolConfirmContent: 'Xóa công cụ là không thể hoàn tác. Người dùng sẽ không thể truy cập lại công cụ của bạn.', + toolInput: { + label: 'Tags', + methodParameter: 'Thông số', + name: 'Tên', + descriptionPlaceholder: 'Mô tả ý nghĩa của tham số', + methodSetting: 'Khung cảnh', + title: 'Công cụ nhập liệu', + methodSettingTip: 'Người dùng điền vào cấu hình công cụ', + required: 'Bắt buộc', + method: 'Phương pháp', + methodParameterTip: 'LLM lấp đầy trong quá trình suy luận', + description: 'Sự miêu tả', + labelPlaceholder: 'Chọn thẻ (tùy chọn)', + }, + nameForToolCallTip: 'Chỉ hỗ trợ số, chữ cái và dấu gạch dưới.', + nameForToolCall: 'Công cụ gọi tên', + nameForToolCallPlaceHolder: 'Được sử dụng để nhận dạng máy, chẳng hạn như getCurrentWeather, list_pets', + descriptionPlaceholder: 'Mô tả ngắn gọn về mục đích của công cụ, ví dụ: lấy nhiệt độ cho một vị trí cụ thể.', + description: 'Sự miêu tả', + confirmTitle: 'Xác nhận để lưu ?', + confirmTip: 'Các ứng dụng sử dụng công cụ này sẽ bị ảnh hưởng', }, test: { title: 'Kiểm tra', @@ -114,6 +136,18 @@ const translation = { toolRemoved: 'Công cụ đã bị xóa', notAuthorized: 'Công cụ chưa được xác thực', howToGet: 'Cách nhận', + addToolModal: { + category: 'loại', + manageInTools: 'Quản lý trong Công cụ', + type: 'kiểu', + add: 'thêm', + added: 'Thêm', + emptyTip: 'Đi tới "Quy trình làm việc -> Xuất bản dưới dạng công cụ"', + emptyTitle: 'Không có sẵn công cụ quy trình làm việc', + }, + toolNameUsageTip: 'Tên cuộc gọi công cụ để lý luận và nhắc nhở tổng đài viên', + customToolTip: 'Tìm hiểu thêm về các công cụ tùy chỉnh Dify', + openInStudio: 'Mở trong Studio', } export default translation diff --git a/web/i18n/vi-VN/workflow.ts b/web/i18n/vi-VN/workflow.ts index 3cb12f23ac..5be19ab7fd 100644 --- a/web/i18n/vi-VN/workflow.ts +++ b/web/i18n/vi-VN/workflow.ts @@ -36,7 +36,7 @@ const translation = { searchVar: 'Tìm kiếm biến', variableNamePlaceholder: 'Tên biến', setVarValuePlaceholder: 'Đặt giá trị biến', - needConnecttip: 'Bước này không được kết nối với bất kỳ điều gì', + needConnectTip: 'Bước này không được kết nối với bất kỳ điều gì', maxTreeDepth: 'Giới hạn tối đa {{depth}} nút trên mỗi nhánh', needEndNode: 'Phải thêm khối Kết thúc', needAnswerNode: 'Phải thêm khối Trả lời', @@ -69,6 +69,30 @@ const translation = { manageInTools: 'Quản lý trong công cụ', workflowAsToolTip: 'Cần cấu hình lại công cụ sau khi cập nhật quy trình làm việc.', viewDetailInTracingPanel: 'Xem chi tiết', + importSuccess: 'Nhập thành công', + backupCurrentDraft: 'Sao lưu dự thảo hiện tại', + chooseDSL: 'Chọn tệp DSL(yml)', + importDSLTip: 'Dự thảo hiện tại sẽ bị ghi đè. Xuất quy trình làm việc dưới dạng bản sao lưu trước khi nhập.', + importFailure: 'Nhập không thành công', + overwriteAndImport: 'Ghi đè và nhập', + importDSL: 'Nhập DSL', + syncingData: 'Đồng bộ hóa dữ liệu, chỉ vài giây.', + parallelTip: { + click: { + title: 'Bấm', + desc: 'để thêm', + }, + drag: { + title: 'Kéo', + desc: 'Để kết nối', + }, + limit: 'Song song được giới hạn trong các nhánh {{num}}.', + depthLimit: 'Giới hạn lớp lồng song song của {{num}} layer', + }, + parallelRun: 'Chạy song song', + disconnect: 'Ngắt kết nối', + jumpToNode: 'Chuyển đến nút này', + addParallelNode: 'Thêm nút song song', }, env: { envPanelTitle: 'Biến Môi Trường', @@ -177,7 +201,8 @@ const translation = { 'logic': 'Logic', 'transform': 'Chuyển đổi', 'utilities': 'Tiện ích', - 'noResult': 'Không tìm thấy kết quả phù hợp', + 'noResult': 'Không tìm thấy kế;t quả phù hợp', + 'searchTool': 'Công cụ tìm kiếm', }, blocks: { 'start': 'Bắt đầu', @@ -403,10 +428,12 @@ const translation = { 'not empty': 'không trống', 'null': 'là null', 'not null': 'không là null', + 'regex match': 'Trận đấu Regex', }, enterValue: 'Nhập giá trị', addCondition: 'Thêm điều kiện', conditionNotSetup: 'Điều kiện chưa được thiết lập', + selectVariable: 'Chọn biến...', }, variableAssigner: { title: 'Gán biến', @@ -502,6 +529,25 @@ const translation = { iteration_other: '{{count}} Lặp', currentIteration: 'Lặp hiện tại', }, + note: { + editor: { + openLink: 'Mở', + italic: 'Nghiêng', + link: 'Liên kết', + medium: 'Đau vừa', + small: 'Nhỏ', + placeholder: 'Viết ghi chú của bạn...', + large: 'Lớn', + showAuthor: 'Hiển thị tác giả', + bulletList: 'Danh sách dấu đầu dòng', + bold: 'Dũng cảm', + unlink: 'Hủy liên kết', + invalidUrl: 'URL không hợp lệ', + strikethrough: 'Gạch ngang', + enterUrl: 'Nhập URL...', + }, + addNote: 'Thêm ghi chú', + }, }, tracing: { stopBy: 'Dừng bởi {{user}}', diff --git a/web/i18n/zh-Hans/app-api.ts b/web/i18n/zh-Hans/app-api.ts index f8f6ab7083..6b9048b66e 100644 --- a/web/i18n/zh-Hans/app-api.ts +++ b/web/i18n/zh-Hans/app-api.ts @@ -10,7 +10,7 @@ const translation = { pause: '暂停', playing: '播放中', loading: '加载中', - merMaind: { + merMaid: { rerender: '重新渲染', }, never: '从未', diff --git a/web/i18n/zh-Hans/app-debug.ts b/web/i18n/zh-Hans/app-debug.ts index febf80d786..62ef300f4d 100644 --- a/web/i18n/zh-Hans/app-debug.ts +++ b/web/i18n/zh-Hans/app-debug.ts @@ -298,7 +298,7 @@ const translation = { historyNoBeEmpty: '提示词中必须设置对话历史', queryNoBeEmpty: '提示词中必须设置查询内容', }, - variableConig: { + variableConfig: { 'addModalTitle': '添加变量', 'editModalTitle': '编辑变量', 'description': '设置变量 {{varName}}', diff --git a/web/i18n/zh-Hans/app.ts b/web/i18n/zh-Hans/app.ts index e12ed1b35d..ee316200fa 100644 --- a/web/i18n/zh-Hans/app.ts +++ b/web/i18n/zh-Hans/app.ts @@ -76,6 +76,11 @@ const translation = { emoji: '表情符号', image: '图片', }, + answerIcon: { + title: '使用 WebApp 图标替换 🤖', + description: '是否使用 WebApp 图标替换分享的应用界面中的 🤖', + descriptionInExplore: '是否使用 WebApp 图标替换 Explore 界面中的 🤖', + }, switch: '迁移为工作流编排', switchTipStart: '将为您创建一个使用工作流编排的新应用。新应用将', switchTip: '不能够', diff --git a/web/i18n/zh-Hans/common.ts b/web/i18n/zh-Hans/common.ts index 5333d18763..52ab7d6f02 100644 --- a/web/i18n/zh-Hans/common.ts +++ b/web/i18n/zh-Hans/common.ts @@ -37,6 +37,7 @@ const translation = { params: '参数设置', duplicate: '复制', rename: '重命名', + audioSourceUnavailable: '音源不可用', }, errorMsg: { fieldRequired: '{{field}} 为必填项', @@ -132,7 +133,8 @@ const translation = { workspace: '工作空间', createWorkspace: '创建工作空间', helpCenter: '帮助文档', - roadmapAndFeedback: '用户反馈', + communityFeedback: '用户反馈', + roadmap: '路线图', community: '社区', about: '关于', logout: '登出', @@ -196,16 +198,19 @@ const translation = { invitationSent: '邀请已发送', invitationSentTip: '邀请已发送,对方登录 Dify 后即可访问你的团队数据。', invitationLink: '邀请链接', - failedinvitationEmails: '邀请以下邮箱失败', + failedInvitationEmails: '邀请以下邮箱失败', ok: '好的', removeFromTeam: '移除团队', removeFromTeamTip: '将取消团队访问', setAdmin: '设为管理员', setMember: '设为普通成员', setEditor: '设为编辑', - disinvite: '取消邀请', + disInvite: '取消邀请', deleteMember: '删除成员', you: '(你)', + builderTip: '可以构建和编辑自己的应用程序', + setBuilder: 'Set as builder (设置为构建器)', + builder: '构建器', }, integrations: { connected: '登录方式', @@ -367,6 +372,7 @@ const translation = { loadBalancingLeastKeyWarning: '至少启用 2 个 Key 以使用负载均衡', loadBalancingInfo: '默认情况下,负载平衡使用 Round-robin 策略。如果触发速率限制,将应用 1 分钟的冷却时间', upgradeForLoadBalancing: '升级以解锁负载均衡功能', + apiKey: 'API 密钥', }, dataSource: { add: '添加数据源', diff --git a/web/i18n/zh-Hans/dataset-creation.ts b/web/i18n/zh-Hans/dataset-creation.ts index 257f409abd..47a15921f7 100644 --- a/web/i18n/zh-Hans/dataset-creation.ts +++ b/web/i18n/zh-Hans/dataset-creation.ts @@ -50,7 +50,7 @@ const translation = { input: '知识库名称', placeholder: '请输入知识库名称', nameNotEmpty: '名称不能为空', - nameLengthInvaild: '名称长度不能超过 40 个字符', + nameLengthInvalid: '名称长度不能超过 40 个字符', cancelButton: '取消', confirmButton: '创建', failed: '创建失败', @@ -109,8 +109,8 @@ const translation = { QATitle: '采用 Q&A 分段模式', QATip: '开启后将会消耗额外的 token', QALanguage: '分段使用', - emstimateCost: '执行嵌入预估消耗', - emstimateSegment: '预估分段数', + estimateCost: '执行嵌入预估消耗', + estimateSegment: '预估分段数', segmentCount: '段', calculating: '计算中...', fileSource: '预处理文档', @@ -135,8 +135,8 @@ const translation = { previewSwitchTipStart: '当前分段预览是文本模式,切换到 Q&A 模式将会', previewSwitchTipEnd: '消耗额外的 token', characters: '字符', - indexSettedTip: '要更改索引方法,请转到', - retrivalSettedTip: '要更改检索方法,请转到', + indexSettingTip: '要更改索引方法和 embedding 模型,请转到', + retrievalSettingTip: '要更改检索方法,请转到', datasetSettingLink: '知识库设置。', }, stepThree: { diff --git a/web/i18n/zh-Hans/dataset.ts b/web/i18n/zh-Hans/dataset.ts index f76be97818..013830af6f 100644 --- a/web/i18n/zh-Hans/dataset.ts +++ b/web/i18n/zh-Hans/dataset.ts @@ -55,6 +55,7 @@ const translation = { hybrid_search: '混合检索', invertedIndex: '倒排索引', }, + defaultRetrievalTip: '默认情况下使用多路召回。从多个知识库中检索知识,然后重新排序。', mixtureHighQualityAndEconomicTip: '混合使用高质量和经济型知识库需要配置 Rerank 模型。', inconsistentEmbeddingModelTip: '当所选知识库配置的 Embedding 模型不一致时,需要配置 Rerank 模型。', retrievalSettings: '召回设置', diff --git a/web/i18n/zh-Hans/login.ts b/web/i18n/zh-Hans/login.ts index 5ac9b9fcb4..f0a6ab76a3 100644 --- a/web/i18n/zh-Hans/login.ts +++ b/web/i18n/zh-Hans/login.ts @@ -31,7 +31,7 @@ const translation = { pp: '隐私政策', tosDesc: '使用即代表你并同意我们的', goToInit: '如果您还没有初始化账户,请前往初始化页面', - donthave: '还没有邀请码?', + dontHave: '还没有邀请码?', invalidInvitationCode: '无效的邀请码', accountAlreadyInited: '账户已经初始化', forgotPassword: '忘记密码?', @@ -53,6 +53,7 @@ const translation = { nameEmpty: '用户名不能为空', passwordEmpty: '密码不能为空', passwordInvalid: '密码必须包含字母和数字,且长度不小于8位', + passwordLengthInValid: '密码必须至少为 8 个字符', }, license: { tip: '启动 Dify 社区版之前, 请阅读 GitHub 上的', @@ -68,6 +69,7 @@ const translation = { activated: '现在登录', adminInitPassword: '管理员初始化密码', validate: '验证', + sso: '使用 SSO 继续', } export default translation diff --git a/web/i18n/zh-Hans/share-app.ts b/web/i18n/zh-Hans/share-app.ts index bb8e1574fd..968381bb37 100644 --- a/web/i18n/zh-Hans/share-app.ts +++ b/web/i18n/zh-Hans/share-app.ts @@ -2,7 +2,7 @@ const translation = { common: { welcome: '', appUnavailable: '应用不可用', - appUnkonwError: '应用不可用', + appUnknownError: '应用不可用', }, chat: { newChat: '新对话', @@ -10,7 +10,7 @@ const translation = { unpinnedTitle: '对话列表', newChatDefaultName: '新的对话', resetChat: '重置对话', - powerBy: 'Powered by', + poweredBy: 'Powered by', prompt: '提示词', privatePromptConfigTitle: '对话设置', publicPromptConfigTitle: '对话前提示词', @@ -32,7 +32,6 @@ const translation = { create: '运行一次', batch: '批量运行', saved: '已保存', - }, savedNoData: { title: '您还没有保存结果!', diff --git a/web/i18n/zh-Hans/workflow.ts b/web/i18n/zh-Hans/workflow.ts index f57eb40bb0..39311649f4 100644 --- a/web/i18n/zh-Hans/workflow.ts +++ b/web/i18n/zh-Hans/workflow.ts @@ -36,7 +36,7 @@ const translation = { variableNamePlaceholder: '变量名', searchVar: '搜索变量', setVarValuePlaceholder: '设置变量值', - needConnecttip: '此节点尚未连接到其他节点', + needConnectTip: '此节点尚未连接到其他节点', maxTreeDepth: '每个分支最大限制 {{depth}} 个节点', needEndNode: '必须添加结束节点', needAnswerNode: '必须添加直接回复节点', @@ -77,10 +77,26 @@ const translation = { overwriteAndImport: '覆盖并导入', importFailure: '导入失败', importSuccess: '导入成功', + parallelRun: '并行运行', + parallelTip: { + click: { + title: '点击', + desc: '添加节点', + }, + drag: { + title: '拖拽', + desc: '连接节点', + }, + limit: '并行分支限制为 {{num}} 个', + depthLimit: '并行嵌套层数限制 {{num}} 层', + }, + disconnect: '断开连接', + jumpToNode: '跳转到节点', + addParallelNode: '添加并行节点', }, env: { envPanelTitle: '环境变量', - envDescription: '环境变量是一种存储敏感信息的方法,如 API 密钥、数据库密码等。它们被存储在工作流程中,而不是代码中,以便在不同环墋中共享。', + envDescription: '环境变量是一种存储敏感信息的方法,如 API 密钥、数据库密码等。它们被存储在工作流程中,而不是代码中,以便在不同环境中共享。', envPanelButton: '添加环境变量', modal: { title: '添加环境变量', diff --git a/web/i18n/zh-Hant/app-api.ts b/web/i18n/zh-Hant/app-api.ts index 18ae6dcfff..63a9778378 100644 --- a/web/i18n/zh-Hant/app-api.ts +++ b/web/i18n/zh-Hant/app-api.ts @@ -10,7 +10,7 @@ const translation = { pause: '暫停', playing: '播放中', loading: '載入中', - merMaind: { + merMaid: { rerender: '重新渲染', }, never: '從未', diff --git a/web/i18n/zh-Hant/app-debug.ts b/web/i18n/zh-Hant/app-debug.ts index ca4dfbb0cf..ad29195cb5 100644 --- a/web/i18n/zh-Hant/app-debug.ts +++ b/web/i18n/zh-Hant/app-debug.ts @@ -244,7 +244,7 @@ const translation = { historyNoBeEmpty: '提示詞中必須設定對話歷史', queryNoBeEmpty: '提示詞中必須設定查詢內容', }, - variableConig: { + variableConfig: { 'addModalTitle': '新增變數', 'editModalTitle': '編輯變數', 'description': '設定變數 {{varName}}', diff --git a/web/i18n/zh-Hant/app-overview.ts b/web/i18n/zh-Hant/app-overview.ts index ecf2b5c3b1..7ad7fc86d9 100644 --- a/web/i18n/zh-Hant/app-overview.ts +++ b/web/i18n/zh-Hant/app-overview.ts @@ -48,6 +48,8 @@ const translation = { title: '工作流程步驟', show: '展示', hide: '隱藏', + subTitle: '工作流詳細資訊', + showDesc: '在 WebApp 中顯示或隱藏工作流詳細資訊', }, chatColorTheme: '聊天顏色主題', chatColorThemeDesc: '設定聊天機器人的顏色主題', @@ -64,6 +66,12 @@ const translation = { customDisclaimerPlaceholder: '請輸入免責聲明', customDisclaimerTip: '客製化的免責聲明文字將在客戶端顯示,提供有關應用程式的額外資訊。', }, + sso: { + description: '所有使用者在使用 WebApp 之前都需要使用 SSO 登錄', + title: 'WebApp SSO', + tooltip: '聯繫管理員以啟用 WebApp SSO', + label: 'SSO 身份驗證', + }, }, embedded: { entry: '嵌入', @@ -123,7 +131,11 @@ const translation = { }, activeUsers: { title: '活躍使用者數', - explanation: '與 AI 有效互動,即有一問一答以上的唯一使用者數。提示詞編排和除錯的會話不計入。', + explanation: '每日AI互動次數。', + }, + totalConversations: { + title: '總對話數', + explanation: '每日AI對話次數;不包括提示工程/調試。', }, tokenUsage: { title: '費用消耗', diff --git a/web/i18n/zh-Hant/app.ts b/web/i18n/zh-Hant/app.ts index ff162c5b61..5d52bb102b 100644 --- a/web/i18n/zh-Hant/app.ts +++ b/web/i18n/zh-Hant/app.ts @@ -123,6 +123,15 @@ const translation = { removeConfirmContent: '當前配置正在使用中,移除它將關閉追蹤功能。', }, }, + answerIcon: { + descriptionInExplore: '是否使用 WebApp 圖示在 Explore 中取代 🤖', + title: '使用 WebApp 圖示取代 🤖', + description: '是否在共享應用程式中使用 WebApp 圖示進行取代 🤖', + }, + importFromDSLUrl: '寄件者 URL', + importFromDSL: '從 DSL 導入', + importFromDSLFile: '從 DSL 檔', + importFromDSLUrlPlaceholder: '在此處粘貼 DSL 連結', } export default translation diff --git a/web/i18n/zh-Hant/billing.ts b/web/i18n/zh-Hant/billing.ts index dd6b69e334..f318b6fa66 100644 --- a/web/i18n/zh-Hant/billing.ts +++ b/web/i18n/zh-Hant/billing.ts @@ -60,6 +60,8 @@ const translation = { bulkUpload: '批次上傳文件', agentMode: '代理模式', workflow: '工作流', + llmLoadingBalancing: 'LLM 負載均衡', + llmLoadingBalancingTooltip: '向模型添加多個 API 金鑰,從而有效地繞過 API 速率限制。', }, comingSoon: '即將推出', member: '成員', @@ -74,6 +76,7 @@ const translation = { }, ragAPIRequestTooltip: '指單獨呼叫 Dify 知識庫資料處理能力的 API。', receiptInfo: '只有團隊所有者和團隊管理員才能訂閱和檢視賬單資訊', + annotationQuota: '註釋配額', }, plans: { sandbox: { diff --git a/web/i18n/zh-Hant/common.ts b/web/i18n/zh-Hant/common.ts index 02296907d9..c1f3ed2b2b 100644 --- a/web/i18n/zh-Hant/common.ts +++ b/web/i18n/zh-Hant/common.ts @@ -37,6 +37,7 @@ const translation = { params: '引數設定', duplicate: '複製', rename: '重新命名', + audioSourceUnavailable: '音訊來源不可用', }, placeholder: { input: '請輸入', @@ -128,7 +129,8 @@ const translation = { workspace: '工作空間', createWorkspace: '建立工作空間', helpCenter: '幫助文件', - roadmapAndFeedback: '使用者反饋', + communityFeedback: '使用者反饋', + roadmap: '路線圖', community: '社群', about: '關於', logout: '登出', @@ -190,16 +192,21 @@ const translation = { invitationSent: '邀請已傳送', invitationSentTip: '邀請已傳送,對方登入 Dify 後即可訪問你的團隊資料。', invitationLink: '邀請連結', - failedinvitationEmails: '邀請以下郵箱失敗', + failedInvitationEmails: '邀請以下郵箱失敗', ok: '好的', removeFromTeam: '移除團隊', removeFromTeamTip: '將取消團隊訪問', setAdmin: '設為管理員', setMember: '設為普通成員', setEditor: '設為編輯', - disinvite: '取消邀請', + disInvite: '取消邀請', deleteMember: '刪除成員', you: '(你)', + setBuilder: 'Set as builder (設置為建構器)', + datasetOperator: '知識管理員', + builder: '建築工人', + builderTip: '可以構建和編輯自己的應用程式', + datasetOperatorTip: '只能管理知識庫', }, integrations: { connected: '登入方式', @@ -346,6 +353,22 @@ const translation = { quotaTip: '剩餘免費額度', loadPresets: '載入預設', parameters: '引數', + loadBalancingHeadline: '負載均衡', + apiKeyStatusNormal: 'APIKey 狀態正常', + defaultConfig: '默認配置', + configLoadBalancing: '配置負載均衡', + loadBalancingDescription: '使用多組憑證減輕壓力。', + addConfig: '添加配置', + upgradeForLoadBalancing: '升級您的計劃以啟用Load Balancing。', + apiKey: 'API 金鑰', + loadBalancing: '負載均衡', + providerManagedDescription: '使用模型提供程式提供的單組憑證。', + modelHasBeenDeprecated: '此模型已棄用', + apiKeyRateLimit: '已達到速率限制,在 {{seconds}} 秒後可用', + providerManaged: '提供者管理', + editConfig: '編輯配置', + loadBalancingInfo: '默認情況下,負載均衡使用 Round-robin 策略。如果觸發了速率限制,將應用 1 分鐘的冷卻時間。', + loadBalancingLeastKeyWarning: '要啟用負載均衡,必須至少啟用 2 個金鑰。', }, dataSource: { add: '新增資料來源', @@ -369,6 +392,15 @@ const translation = { preview: '預覽', }, }, + website: { + active: '積極', + title: '網站', + with: '跟', + inactive: '無效', + configuredCrawlers: '配置的爬網程式', + description: '使用 Web 爬蟲從網站導入內容。', + }, + configure: '配置', }, plugin: { serpapi: { @@ -537,6 +569,10 @@ const translation = { created: '標籤建立成功', failed: '標籤建立失敗', }, + errorMsg: { + fieldRequired: '{{field}} 為必填項', + urlError: 'URL應以 http:// 或 https:// 開頭', + }, } export default translation diff --git a/web/i18n/zh-Hant/dataset-creation.ts b/web/i18n/zh-Hant/dataset-creation.ts index 849e1578da..fd810d41c1 100644 --- a/web/i18n/zh-Hant/dataset-creation.ts +++ b/web/i18n/zh-Hant/dataset-creation.ts @@ -45,11 +45,35 @@ const translation = { input: '知識庫名稱', placeholder: '請輸入知識庫名稱', nameNotEmpty: '名稱不能為空', - nameLengthInvaild: '名稱長度不能超過 40 個字元', + nameLengthInvalid: '名稱長度不能超過 40 個字元', cancelButton: '取消', confirmButton: '建立', failed: '建立失敗', }, + website: { + maxDepth: '最大深度', + selectAll: '全選', + exceptionErrorTitle: '運行 Firecrawl 作業時發生異常:', + run: '跑', + extractOnlyMainContent: '僅提取主要內容(無頁眉、導航、頁腳等)', + fireCrawlNotConfiguredDescription: '使用 API 金鑰配置 Firecrawl 以使用它。', + limit: '限制', + crawlSubPage: '抓取子頁面', + firecrawlDocLink: 'https://docs.dify.ai/guides/knowledge-base/sync-from-website', + preview: '預覽', + configure: '配置', + excludePaths: '排除路徑', + options: '選項', + firecrawlDoc: 'Firecrawl 文件', + totalPageScraped: '抓取的總頁數:', + firecrawlTitle: '使用 🔥Firecrawl 提取 Web 內容', + includeOnlyPaths: '僅包含路徑', + resetAll: '全部重置', + scrapTimeInfo: '在 {{time}} 秒內總共抓取了 {{total}} 個頁面', + unknownError: '未知錯誤', + fireCrawlNotConfigured: '未配置 Firecrawl', + maxDepthTooltip: '相對於輸入的 URL 的最大爬網深度。深度 0 只是抓取輸入的 url 的頁面,深度 1 抓取 url 以及 enteredURL + 1 / 之後的所有內容,依此類推。', + }, }, stepTwo: { segmentation: '分段設定', @@ -80,8 +104,8 @@ const translation = { QATitle: '採用 Q&A 分段模式', QATip: '開啟後將會消耗額外的 token', QALanguage: '分段使用', - emstimateCost: '執行嵌入預估消耗', - emstimateSegment: '預估分段數', + estimateCost: '執行嵌入預估消耗', + estimateSegment: '預估分段數', segmentCount: '段', calculating: '計算中...', fileSource: '預處理文件', @@ -104,9 +128,11 @@ const translation = { previewSwitchTipStart: '當前分段預覽是文字模式,切換到 Q&A 模式將會', previewSwitchTipEnd: '消耗額外的 token', characters: '字元', - indexSettedTip: '要更改索引方法,請轉到', - retrivalSettedTip: '要更改檢索方法,請轉到', + indexSettingTip: '要更改索引方法,請轉到', + retrievalSettingTip: '要更改檢索方法,請轉到', datasetSettingLink: '知識庫設定。', + websiteSource: '預處理網站', + webpageUnit: '頁面', }, stepThree: { creationTitle: '🎉 知識庫已建立', @@ -125,6 +151,11 @@ const translation = { modelButtonConfirm: '確認停止', modelButtonCancel: '取消', }, + firecrawl: { + configFirecrawl: '配置 🔥Firecrawl', + apiKeyPlaceholder: '來自 firecrawl.dev 的 API 金鑰', + getApiKeyLinkText: '從 firecrawl.dev 獲取 API 金鑰', + }, } export default translation diff --git a/web/i18n/zh-Hant/dataset-documents.ts b/web/i18n/zh-Hant/dataset-documents.ts index ccc0fcf764..b4e6b44181 100644 --- a/web/i18n/zh-Hant/dataset-documents.ts +++ b/web/i18n/zh-Hant/dataset-documents.ts @@ -13,6 +13,8 @@ const translation = { status: '狀態', action: '操作', }, + name: '名字', + rename: '重新命名', }, action: { uploadFile: '上傳新檔案', @@ -74,6 +76,7 @@ const translation = { error: '匯入出錯', ok: '確定', }, + addUrl: '添加 URL', }, metadata: { title: '元資料', diff --git a/web/i18n/zh-Hant/dataset-settings.ts b/web/i18n/zh-Hant/dataset-settings.ts index d18c3fdd9b..f34d1d4acc 100644 --- a/web/i18n/zh-Hant/dataset-settings.ts +++ b/web/i18n/zh-Hant/dataset-settings.ts @@ -27,6 +27,8 @@ const translation = { longDescription: '關於檢索方法,您可以隨時在知識庫設定中更改此設定。', }, save: '儲存', + permissionsInvitedMembers: '部分團隊成員', + me: '(您)', }, } diff --git a/web/i18n/zh-Hant/dataset.ts b/web/i18n/zh-Hant/dataset.ts index 1e011bc987..1888a28631 100644 --- a/web/i18n/zh-Hant/dataset.ts +++ b/web/i18n/zh-Hant/dataset.ts @@ -71,6 +71,7 @@ const translation = { nTo1RetrievalLegacy: 'N對1檢索將從9月起正式棄用。建議使用最新的多路徑檢索以獲得更好的結果。', nTo1RetrievalLegacyLink: '了解更多', nTo1RetrievalLegacyLinkText: 'N對1檢索將於9月正式棄用。', + defaultRetrievalTip: '默認情況下,使用多路徑檢索。從多個知識庫中檢索知識,然後重新排名。', } export default translation diff --git a/web/i18n/zh-Hant/login.ts b/web/i18n/zh-Hant/login.ts index cce869f38a..649f618158 100644 --- a/web/i18n/zh-Hant/login.ts +++ b/web/i18n/zh-Hant/login.ts @@ -31,7 +31,7 @@ const translation = { pp: '隱私政策', tosDesc: '使用即代表你並同意我們的', goToInit: '如果您還沒有初始化賬戶,請前往初始化頁面', - donthave: '還沒有邀請碼?', + dontHave: '還沒有邀請碼?', invalidInvitationCode: '無效的邀請碼', accountAlreadyInited: '賬戶已經初始化', forgotPassword: '忘記密碼?', @@ -53,6 +53,7 @@ const translation = { nameEmpty: '使用者名稱不能為空', passwordEmpty: '密碼不能為空', passwordInvalid: '密碼必須包含字母和數字,且長度不小於8位', + passwordLengthInValid: '密碼必須至少為8個字元', }, license: { tip: '啟動 Dify 社群版之前, 請閱讀 GitHub 上的', @@ -68,6 +69,7 @@ const translation = { activated: '現在登入', adminInitPassword: '管理員初始化密碼', validate: '驗證', + sso: '繼續使用 SSO', } export default translation diff --git a/web/i18n/zh-Hant/share-app.ts b/web/i18n/zh-Hant/share-app.ts index e91cbaf121..ea5f206985 100644 --- a/web/i18n/zh-Hant/share-app.ts +++ b/web/i18n/zh-Hant/share-app.ts @@ -2,7 +2,7 @@ const translation = { common: { welcome: '', appUnavailable: '應用不可用', - appUnkonwError: '應用不可用', + appUnknownError: '應用不可用', }, chat: { newChat: '新對話', @@ -10,7 +10,7 @@ const translation = { unpinnedTitle: '對話列表', newChatDefaultName: '新的對話', resetChat: '重置對話', - powerBy: 'Powered by', + poweredBy: 'Powered by', prompt: '提示詞', privatePromptConfigTitle: '對話設定', publicPromptConfigTitle: '對話前提示詞', @@ -32,7 +32,6 @@ const translation = { create: '執行一次', batch: '批次執行', saved: '已儲存', - }, savedNoData: { title: '您還沒有儲存結果!', diff --git a/web/i18n/zh-Hant/tools.ts b/web/i18n/zh-Hant/tools.ts index 58ba9f5c81..d45980c017 100644 --- a/web/i18n/zh-Hant/tools.ts +++ b/web/i18n/zh-Hant/tools.ts @@ -5,6 +5,7 @@ const translation = { all: '全部', builtIn: '內建', custom: '自定義', + workflow: '工作流', }, contribute: { line1: '我有興趣為 ', @@ -75,6 +76,27 @@ const translation = { customDisclaimerPlaceholder: '請輸入自定義免責聲明', deleteToolConfirmTitle: '刪除這個工具?', deleteToolConfirmContent: '刪除工具是不可逆的。用戶將無法再訪問您的工具。', + toolInput: { + labelPlaceholder: '選擇標籤(選擇標籤)', + label: '標籤', + required: '必填', + methodSettingTip: '用戶填寫工具配置', + name: '名字', + description: '描述', + methodParameterTip: '推理期間 LLM 填充', + method: '方法', + title: '工具輸入', + methodSetting: '設置', + methodParameter: '參數', + descriptionPlaceholder: '參數含義的描述', + }, + description: '描述', + nameForToolCall: '工具調用名稱', + confirmTitle: '確認儲存 ?', + descriptionPlaceholder: '工具用途的簡要描述,例如,獲取特定位置的溫度。', + nameForToolCallTip: '僅支援數位、字母和下劃線。', + confirmTip: '使用此工具的應用程式將受到影響', + nameForToolCallPlaceHolder: '用於機器識別,例如 getCurrentWeather、list_pets', }, test: { title: '測試', @@ -114,6 +136,18 @@ const translation = { toolRemoved: '工具已被移除', notAuthorized: '工具未授權', howToGet: '如何獲取', + addToolModal: { + add: '加', + type: '類型', + added: '添加', + manageInTools: '在工具中管理', + category: '類別', + emptyTitle: '沒有可用的工作流程工具', + emptyTip: '轉到“工作流 - >發佈為工具”', + }, + customToolTip: '瞭解有關 Dify 自訂工具的更多資訊', + toolNameUsageTip: '用於代理推理和提示的工具調用名稱', + openInStudio: '在 Studio 中打開', } export default translation diff --git a/web/i18n/zh-Hant/workflow.ts b/web/i18n/zh-Hant/workflow.ts index 3a456858fe..eef3ffaebd 100644 --- a/web/i18n/zh-Hant/workflow.ts +++ b/web/i18n/zh-Hant/workflow.ts @@ -36,7 +36,7 @@ const translation = { variableNamePlaceholder: '變量名', searchVar: '搜索變量', setVarValuePlaceholder: '設置變量值', - needConnecttip: '此節點尚未連接到其他節點', + needConnectTip: '此節點尚未連接到其他節點', maxTreeDepth: '每個分支最大限制 {{depth}} 個節點', needEndNode: '必須添加結束節點', needAnswerNode: '必須添加直接回覆節點', @@ -69,6 +69,30 @@ const translation = { manageInTools: '訪問工具頁', workflowAsToolTip: '工作流更新後需要重新配置工具參數', viewDetailInTracingPanel: '查看詳細信息', + importDSL: '導入 DSL', + backupCurrentDraft: 'Backup Current Draft', + overwriteAndImport: '覆蓋和導入', + importSuccess: '導入成功', + chooseDSL: '選擇 DSL(yml) 檔', + syncingData: '同步數據,只需幾秒鐘。', + importDSLTip: '當前草稿將被覆蓋。在導入之前將工作流匯出為備份。', + importFailure: '匯入失敗', + parallelTip: { + click: { + title: '點擊', + desc: '添加', + }, + drag: { + title: '拖动', + desc: '連接', + }, + limit: '並行度僅限於 {{num}} 個分支。', + depthLimit: '並行嵌套層限制為 {{num}} 個層', + }, + parallelRun: '並行運行', + disconnect: '斷開', + jumpToNode: '跳轉到此節點', + addParallelNode: '添加並行節點', }, env: { envPanelTitle: '環境變數', @@ -142,6 +166,7 @@ const translation = { noteAdd: '註釋已添加', noteChange: '註釋已更改', edgeDelete: '區塊已斷開連接', + noteDelete: '註釋已刪除', }, errorMsg: { fieldRequired: '{{field}} 不能為空', @@ -177,6 +202,7 @@ const translation = { 'transform': '轉換', 'utilities': '工具', 'noResult': '未找到匹配項', + 'searchTool': '搜索工具', }, blocks: { 'start': '開始', @@ -402,10 +428,12 @@ const translation = { 'not empty': '不為空', 'null': '空', 'not null': '不為空', + 'regex match': '正則表達式匹配', }, enterValue: '輸入值', addCondition: '添加條件', conditionNotSetup: '條件未設置', + selectVariable: '選擇變數...', }, variableAssigner: { title: '變量賦值', @@ -501,6 +529,25 @@ const translation = { iteration_other: '{{count}}個迭代', currentIteration: '當前迭代', }, + note: { + editor: { + link: '連結', + openLink: '打開', + medium: '中等', + small: '小', + invalidUrl: 'URL 無效', + italic: '斜體的', + bulletList: '項目符號清單', + large: '大', + unlink: '取消連結', + enterUrl: '輸入網址...', + bold: '大膽', + showAuthor: '顯示作者', + strikethrough: '刪除線', + placeholder: '寫下您的筆記...', + }, + addNote: '添加註釋', + }, }, tracing: { stopBy: '由{{user}}終止', diff --git a/web/models/app.ts b/web/models/app.ts index 82efcd5fa0..e550b82ab6 100644 --- a/web/models/app.ts +++ b/web/models/app.ts @@ -4,7 +4,7 @@ import type { App, AppSSO, AppTemplate, SiteConfig } from '@/types/app' /* export type App = { id: string name: string - decription: string + description: string mode: AppMode enable_site: boolean enable_api: boolean @@ -103,15 +103,15 @@ export type AppTokenCostsResponse = { export type UpdateAppModelConfigResponse = { result: string } -export type ApikeyItemResponse = { +export type ApiKeyItemResponse = { id: string token: string last_used_at: string created_at: string } -export type ApikeysListResponse = { - data: ApikeyItemResponse[] +export type ApiKeysListResponse = { + data: ApiKeyItemResponse[] } export type CreateApiKeyResponse = { diff --git a/web/models/datasets.ts b/web/models/datasets.ts index 0ae7831245..23d1fe6136 100644 --- a/web/models/datasets.ts +++ b/web/models/datasets.ts @@ -227,6 +227,8 @@ export type DocumentReq = { export type CreateDocumentReq = DocumentReq & { data_source: DataSource retrieval_model: RetrievalConfig + embedding_model: string + embedding_model_provider: string } export type IndexingEstimateParams = DocumentReq & Partial & { @@ -437,7 +439,7 @@ export type RelatedAppResponse = { total: number } -export type SegmentUpdator = { +export type SegmentUpdater = { content: string answer?: string keywords?: string[] diff --git a/web/models/debug.ts b/web/models/debug.ts index 2b2af80065..565798e598 100644 --- a/web/models/debug.ts +++ b/web/models/debug.ts @@ -215,7 +215,7 @@ export type LogSessionListResponse = { query: string // user's query question message: string // prompt send to LLM answer: string - creat_at: string + created_at: string }[] total: number page: number @@ -224,7 +224,7 @@ export type LogSessionListResponse = { // log session detail and debug export type LogSessionDetailResponse = { id: string - cnversation_id: string + conversation_id: string model_provider: string query: string inputs: Record[] diff --git a/web/models/explore.ts b/web/models/explore.ts index 78dd2e8675..ad60d99c6f 100644 --- a/web/models/explore.ts +++ b/web/models/explore.ts @@ -8,6 +8,7 @@ export type AppBasicInfo = { icon_url: string name: string description: string + use_icon_as_answer_icon: boolean } export type AppCategory = 'Writing' | 'Translate' | 'HR' | 'Programming' | 'Assistant' diff --git a/web/models/log.ts b/web/models/log.ts index 6f8ebb1a78..8da1c4cf4e 100644 --- a/web/models/log.ts +++ b/web/models/log.ts @@ -6,7 +6,7 @@ import type { } from '@/app/components/workflow/types' import type { Metadata } from '@/app/components/base/chat/chat/type' -// Log type contains key:string conversation_id:string created_at:string quesiton:string answer:string +// Log type contains key:string conversation_id:string created_at:string question:string answer:string export type Conversation = { id: string key: string diff --git a/web/models/share.ts b/web/models/share.ts index 127f3d0a51..3521365e82 100644 --- a/web/models/share.ts +++ b/web/models/share.ts @@ -25,6 +25,7 @@ export type SiteInfo = { privacy_policy?: string custom_disclaimer?: string show_workflow_steps?: boolean + use_icon_as_answer_icon?: boolean } export type AppMeta = { diff --git a/web/package.json b/web/package.json index e896f58886..374286f8f7 100644 --- a/web/package.json +++ b/web/package.json @@ -1,6 +1,6 @@ { "name": "dify-web", - "version": "0.7.2", + "version": "0.8.0", "private": true, "engines": { "node": ">=18.17.0" @@ -15,7 +15,8 @@ "prepare": "cd ../ && node -e \"if (process.env.NODE_ENV !== 'production'){process.exit(1)} \" || husky install ./web/.husky", "gen-icons": "node ./app/components/base/icons/script.js", "uglify-embed": "node ./bin/uglify-embed", - "check-i18n": "node ./i18n/script.js", + "check-i18n": "node ./i18n/check-i18n.js", + "auto-gen-i18n": "node ./i18n/auto-gen-i18n.js", "test": "jest", "test:watch": "jest --watch" }, @@ -86,6 +87,7 @@ "reactflow": "^11.11.3", "recordrtc": "^5.6.2", "rehype-katex": "^6.0.2", + "rehype-raw": "^7.0.0", "remark-breaks": "^3.0.2", "remark-gfm": "^3.0.1", "remark-math": "^5.1.1", @@ -126,6 +128,7 @@ "@types/sortablejs": "^1.15.1", "@types/uuid": "^9.0.8", "autoprefixer": "^10.4.14", + "bing-translate-api": "^4.0.2", "code-inspector-plugin": "^0.13.0", "cross-env": "^7.0.3", "eslint": "^8.36.0", @@ -134,6 +137,7 @@ "jest": "^29.7.0", "jest-environment-jsdom": "^29.7.0", "lint-staged": "^13.2.2", + "magicast": "^0.3.4", "postcss": "^8.4.31", "sass": "^1.61.0", "tailwindcss": "^3.4.4", diff --git a/web/public/embed.js b/web/public/embed.js index 14420f0c8c..8ed7a67dc8 100644 --- a/web/public/embed.js +++ b/web/public/embed.js @@ -73,7 +73,7 @@ box-shadow: rgba(150, 150, 150, 0.2) 0px 10px 30px 0px, rgba(150, 150, 150, 0.2) 0px 0px 0px 1px; bottom: 5rem; right: 1rem; width: 24rem; max-width: calc(100vw - 2rem); height: 40rem; max-height: calc(100vh - 6rem); border-radius: 0.75rem; display: flex; z-index: 2147483647; - overflow: hidden; left: unset; background-color: #F3F4F6; + overflow: hidden; left: unset; background-color: #F3F4F6;user-select: none; `; document.body.appendChild(iframe); @@ -255,6 +255,9 @@ if (!document.getElementById(buttonId)) { createButton(); } + + createIframe(); + document.getElementById(iframeId).style.display = 'none'; } // Add esc Exit keyboard event triggered diff --git a/web/public/embed.min.js b/web/public/embed.min.js index ec721a204d..0e023cb5d1 100644 --- a/web/public/embed.min.js +++ b/web/public/embed.min.js @@ -1 +1,31 @@ -!function(){const e="difyChatbotConfig",t="dify-chatbot-bubble-button",n="dify-chatbot-bubble-window",o=window[e],i={open:'\n \n ',close:'\n \n '};async function d(){if(!o||!o.token)return void console.error(`${e} is empty or token is not provided`);const d=new URLSearchParams(await async function(){const e=o?.inputs||{},t={};return await Promise.all(Object.entries(e).map((async([e,n])=>{t[e]=await async function(e){const t=(new TextEncoder).encode(e),n=new Response(new Blob([t]).stream().pipeThrough(new CompressionStream("gzip"))).arrayBuffer(),o=new Uint8Array(await n);return btoa(String.fromCharCode(...o))}(n)}))),t}()),s=`${o.baseUrl||`https://${o.isDev?"dev.":""}udify.app`}/chatbot/${o.token}?${d}`;function c(){const e=document.getElementById(n),o=document.getElementById(t);if(e&&o){const t=o.getBoundingClientRect(),n=window.innerHeight-t.bottom,i=window.innerWidth-t.right,d=t.left;e.style.bottom=`${n+t.height+5+e.clientHeight>window.innerHeight?n-e.clientHeight-5:n+t.height+5}px`,e.style.right=`${i+e.clientWidth>window.innerWidth?window.innerWidth-d-e.clientWidth:i}px`}}s.length>2048&&console.error("The URL is too long, please reduce the number of inputs to prevent the bot from failing to load"),document.getElementById(t)||function(){const e=document.createElement("div");Object.entries(o.containerProps||{}).forEach((([t,n])=>{"className"===t?e.classList.add(...n.split(" ")):"style"===t?"object"==typeof n?Object.assign(e.style,n):e.style.cssText=n:"function"==typeof n?e.addEventListener(t.replace(/^on/,"").toLowerCase(),n):e[t]=n})),e.id=t;const d=document.createElement("style");document.head.appendChild(d),d.sheet.insertRule(`\n #${e.id} {\n position: fixed;\n bottom: var(--${e.id}-bottom, 1rem);\n right: var(--${e.id}-right, 1rem);\n left: var(--${e.id}-left, unset);\n top: var(--${e.id}-top, unset);\n width: var(--${e.id}-width, 50px);\n height: var(--${e.id}-height, 50px);\n border-radius: var(--${e.id}-border-radius, 25px);\n background-color: var(--${e.id}-bg-color, #155EEF);\n box-shadow: var(--${e.id}-box-shadow, rgba(0, 0, 0, 0.2) 0px 4px 8px 0px);\n cursor: pointer;\n z-index: 2147483647;\n transition: all 0.2s ease-in-out 0s;\n }\n `),d.sheet.insertRule(`\n #${e.id}:hover {\n transform: var(--${e.id}-hover-transform, scale(1.1));\n }\n `);const l=document.createElement("div");l.style.cssText="display: flex; align-items: center; justify-content: center; width: 100%; height: 100%; z-index: 2147483647;",l.innerHTML=i.open,e.appendChild(l),document.body.appendChild(e),e.addEventListener("click",(function(){const e=document.getElementById(n);if(!e)return function(){const e=document.createElement("iframe");e.allow="fullscreen;microphone",e.title="dify chatbot bubble window",e.id=n,e.src=s,e.style.cssText="\n border: none; position: fixed; flex-direction: column; justify-content: space-between;\n box-shadow: rgba(150, 150, 150, 0.2) 0px 10px 30px 0px, rgba(150, 150, 150, 0.2) 0px 0px 0px 1px;\n bottom: 5rem; right: 1rem; width: 24rem; max-width: calc(100vw - 2rem); height: 40rem;\n max-height: calc(100vh - 6rem); border-radius: 0.75rem; display: flex; z-index: 2147483647;\n overflow: hidden; left: unset; background-color: #F3F4F6;\n ",document.body.appendChild(e)}(),c(),this.title="Exit (ESC)",l.innerHTML=i.close,void document.addEventListener("keydown",r);e.style.display="none"===e.style.display?"block":"none",l.innerHTML="none"===e.style.display?i.open:i.close,"none"===e.style.display?document.removeEventListener("keydown",r):document.addEventListener("keydown",r),c()})),o.draggable&&function(e,o){let d,r,s=!1;function c(t){s=!0,d=t.clientX-e.offsetLeft,r=t.clientY-e.offsetTop}function l(c){if(!s)return;e.style.transition="none",e.style.cursor="grabbing";const l=document.getElementById(n);l&&(l.style.display="none",e.querySelector("div").innerHTML=i.open);const a=c.clientX-d,h=window.innerHeight-c.clientY-r,p=e.getBoundingClientRect(),u=window.innerWidth-p.width,m=window.innerHeight-p.height;"x"!==o&&"both"!==o||e.style.setProperty(`--${t}-left`,`${Math.max(0,Math.min(a,u))}px`),"y"!==o&&"both"!==o||e.style.setProperty(`--${t}-bottom`,`${Math.max(0,Math.min(h,m))}px`)}function a(){s=!1,e.style.transition="",e.style.cursor="pointer"}e.addEventListener("mousedown",c),document.addEventListener("mousemove",l),document.addEventListener("mouseup",a)}(e,o.dragAxis||"both")}()}function r(e){if("Escape"===e.key){const e=document.getElementById(n),o=document.getElementById(t);e&&"none"!==e.style.display&&(e.style.display="none",o.querySelector("div").innerHTML=i.open)}}document.addEventListener("keydown",r),o?.dynamicScript?d():document.body.onload=d}(); +(()=>{let t="difyChatbotConfig",a="dify-chatbot-bubble-button",c="dify-chatbot-bubble-window",h=window[t],p={open:` + + `,close:` + + `};async function e(){if(h&&h.token){var e=new URLSearchParams(await(async()=>{var e=h?.inputs||{};let n={};return await Promise.all(Object.entries(e).map(async([e,t])=>{n[e]=(e=t,e=(new TextEncoder).encode(e),e=new Response(new Blob([e]).stream().pipeThrough(new CompressionStream("gzip"))).arrayBuffer(),e=new Uint8Array(await e),await btoa(String.fromCharCode(...e)))})),n})());let t=`${h.baseUrl||`https://${h.isDev?"dev.":""}udify.app`}/chatbot/${h.token}?`+e;function o(){var e=document.createElement("iframe");e.allow="fullscreen;microphone",e.title="dify chatbot bubble window",e.id=c,e.src=t,e.style.cssText=` + border: none; position: fixed; flex-direction: column; justify-content: space-between; + box-shadow: rgba(150, 150, 150, 0.2) 0px 10px 30px 0px, rgba(150, 150, 150, 0.2) 0px 0px 0px 1px; + bottom: 5rem; right: 1rem; width: 24rem; max-width: calc(100vw - 2rem); height: 40rem; + max-height: calc(100vh - 6rem); border-radius: 0.75rem; display: flex; z-index: 2147483647; + overflow: hidden; left: unset; background-color: #F3F4F6;user-select: none; + `,document.body.appendChild(e)}function i(){var e,t,n,o=document.getElementById(c),i=document.getElementById(a);o&&i&&(i=i.getBoundingClientRect(),e=window.innerHeight-i.bottom,t=window.innerWidth-i.right,n=i.left,o.style.bottom=`${e+i.height+5+o.clientHeight>window.innerHeight?e-o.clientHeight-5:e+i.height+5}px`,o.style.right=`${t+o.clientWidth>window.innerWidth?window.innerWidth-n-o.clientWidth:t}px`)}function n(){let n=document.createElement("div");Object.entries(h.containerProps||{}).forEach(([e,t])=>{"className"===e?n.classList.add(...t.split(" ")):"style"===e?"object"==typeof t?Object.assign(n.style,t):n.style.cssText=t:"function"==typeof t?n.addEventListener(e.replace(/^on/,"").toLowerCase(),t):n[e]=t}),n.id=a;var e=document.createElement("style");document.head.appendChild(e),e.sheet.insertRule(` + #${n.id} { + position: fixed; + bottom: var(--${n.id}-bottom, 1rem); + right: var(--${n.id}-right, 1rem); + left: var(--${n.id}-left, unset); + top: var(--${n.id}-top, unset); + width: var(--${n.id}-width, 50px); + height: var(--${n.id}-height, 50px); + border-radius: var(--${n.id}-border-radius, 25px); + background-color: var(--${n.id}-bg-color, #155EEF); + box-shadow: var(--${n.id}-box-shadow, rgba(0, 0, 0, 0.2) 0px 4px 8px 0px); + cursor: pointer; + z-index: 2147483647; + transition: all 0.2s ease-in-out 0s; + } + `),e.sheet.insertRule(` + #${n.id}:hover { + transform: var(--${n.id}-hover-transform, scale(1.1)); + } + `);let t=document.createElement("div");if(t.style.cssText="display: flex; align-items: center; justify-content: center; width: 100%; height: 100%; z-index: 2147483647;",t.innerHTML=p.open,n.appendChild(t),document.body.appendChild(n),n.addEventListener("click",function(){var e=document.getElementById(c);e?(e.style.display="none"===e.style.display?"block":"none",t.innerHTML="none"===e.style.display?p.open:p.close,"none"===e.style.display?document.removeEventListener("keydown",d):document.addEventListener("keydown",d),i()):(o(),i(),this.title="Exit (ESC)",t.innerHTML=p.close,document.addEventListener("keydown",d))}),h.draggable){var s=n;var l=h.dragAxis||"both";let i=!1,d,r;s.addEventListener("mousedown",function(e){i=!0,d=e.clientX-s.offsetLeft,r=e.clientY-s.offsetTop}),document.addEventListener("mousemove",function(e){var t,n,o;i&&(s.style.transition="none",s.style.cursor="grabbing",(t=document.getElementById(c))&&(t.style.display="none",s.querySelector("div").innerHTML=p.open),t=e.clientX-d,e=window.innerHeight-e.clientY-r,o=s.getBoundingClientRect(),n=window.innerWidth-o.width,o=window.innerHeight-o.height,"x"!==l&&"both"!==l||s.style.setProperty(`--${a}-left`,Math.max(0,Math.min(t,n))+"px"),"y"!==l&&"both"!==l||s.style.setProperty(`--${a}-bottom`,Math.max(0,Math.min(e,o))+"px"))}),document.addEventListener("mouseup",function(){i=!1,s.style.transition="",s.style.cursor="pointer"})}}2048('apps', { body: { name, icon_type, icon, icon_background, mode, description, model_config: config } }) } -export const updateAppInfo: Fetcher = ({ appID, name, icon_type, icon, icon_background, description }) => { - return put(`apps/${appID}`, { body: { name, icon_type, icon, icon_background, description } }) +export const updateAppInfo: Fetcher = ({ appID, name, icon_type, icon, icon_background, description, use_icon_as_answer_icon }) => { + return put(`apps/${appID}`, { body: { name, icon_type, icon, icon_background, description, use_icon_as_answer_icon } }) } export const copyApp: Fetcher = ({ appID, name, icon_type, icon, icon_background, mode, description }) => { @@ -110,8 +110,8 @@ export const fetchAppListNoMock: Fetcher(url, params) } -export const fetchApiKeysList: Fetcher }> = ({ url, params }) => { - return get(url, params) +export const fetchApiKeysList: Fetcher }> = ({ url, params }) => { + return get(url, params) } export const delApikey: Fetcher }> = ({ url, params }) => { diff --git a/web/service/base.ts b/web/service/base.ts index bda83f1c8e..83389d8be8 100644 --- a/web/service/base.ts +++ b/web/service/base.ts @@ -4,10 +4,12 @@ import type { AnnotationReply, MessageEnd, MessageReplace, ThoughtItem } from '@ import type { VisionFile } from '@/types/app' import type { IterationFinishedResponse, - IterationNextedResponse, + IterationNextResponse, IterationStartedResponse, NodeFinishedResponse, NodeStartedResponse, + ParallelBranchFinishedResponse, + ParallelBranchStartedResponse, TextChunkResponse, TextReplaceResponse, WorkflowFinishedResponse, @@ -57,8 +59,10 @@ export type IOnWorkflowFinished = (workflowFinished: WorkflowFinishedResponse) = export type IOnNodeStarted = (nodeStarted: NodeStartedResponse) => void export type IOnNodeFinished = (nodeFinished: NodeFinishedResponse) => void export type IOnIterationStarted = (workflowStarted: IterationStartedResponse) => void -export type IOnIterationNexted = (workflowStarted: IterationNextedResponse) => void +export type IOnIterationNext = (workflowStarted: IterationNextResponse) => void export type IOnIterationFinished = (workflowFinished: IterationFinishedResponse) => void +export type IOnParallelBranchStarted = (parallelBranchStarted: ParallelBranchStartedResponse) => void +export type IOnParallelBranchFinished = (parallelBranchFinished: ParallelBranchFinishedResponse) => void export type IOnTextChunk = (textChunk: TextChunkResponse) => void export type IOnTTSChunk = (messageId: string, audioStr: string, audioType?: string) => void export type IOnTTSEnd = (messageId: string, audioStr: string, audioType?: string) => void @@ -84,8 +88,10 @@ export type IOtherOptions = { onNodeStarted?: IOnNodeStarted onNodeFinished?: IOnNodeFinished onIterationStart?: IOnIterationStarted - onIterationNext?: IOnIterationNexted + onIterationNext?: IOnIterationNext onIterationFinish?: IOnIterationFinished + onParallelBranchStarted?: IOnParallelBranchStarted + onParallelBranchFinished?: IOnParallelBranchFinished onTextChunk?: IOnTextChunk onTTSChunk?: IOnTTSChunk onTTSEnd?: IOnTTSEnd @@ -137,8 +143,10 @@ const handleStream = ( onNodeStarted?: IOnNodeStarted, onNodeFinished?: IOnNodeFinished, onIterationStart?: IOnIterationStarted, - onIterationNext?: IOnIterationNexted, + onIterationNext?: IOnIterationNext, onIterationFinish?: IOnIterationFinished, + onParallelBranchStarted?: IOnParallelBranchStarted, + onParallelBranchFinished?: IOnParallelBranchFinished, onTextChunk?: IOnTextChunk, onTTSChunk?: IOnTTSChunk, onTTSEnd?: IOnTTSEnd, @@ -187,7 +195,7 @@ const handleStream = ( return } if (bufferObj.event === 'message' || bufferObj.event === 'agent_message') { - // can not use format here. Because message is splited. + // can not use format here. Because message is splitted. onData(unicodeToChar(bufferObj.answer), isFirstMessage, { conversationId: bufferObj.conversation_id, taskId: bufferObj.task_id, @@ -223,11 +231,17 @@ const handleStream = ( onIterationStart?.(bufferObj as IterationStartedResponse) } else if (bufferObj.event === 'iteration_next') { - onIterationNext?.(bufferObj as IterationNextedResponse) + onIterationNext?.(bufferObj as IterationNextResponse) } else if (bufferObj.event === 'iteration_completed') { onIterationFinish?.(bufferObj as IterationFinishedResponse) } + else if (bufferObj.event === 'parallel_branch_started') { + onParallelBranchStarted?.(bufferObj as ParallelBranchStartedResponse) + } + else if (bufferObj.event === 'parallel_branch_finished') { + onParallelBranchFinished?.(bufferObj as ParallelBranchFinishedResponse) + } else if (bufferObj.event === 'text_chunk') { onTextChunk?.(bufferObj as TextChunkResponse) } @@ -488,6 +502,8 @@ export const ssePost = ( onIterationStart, onIterationNext, onIterationFinish, + onParallelBranchStarted, + onParallelBranchFinished, onTextChunk, onTTSChunk, onTTSEnd, @@ -544,7 +560,7 @@ export const ssePost = ( return } onData?.(str, isFirstMessage, moreInfo) - }, onCompleted, onThought, onMessageEnd, onMessageReplace, onFile, onWorkflowStarted, onWorkflowFinished, onNodeStarted, onNodeFinished, onIterationStart, onIterationNext, onIterationFinish, onTextChunk, onTTSChunk, onTTSEnd, onTextReplace) + }, onCompleted, onThought, onMessageEnd, onMessageReplace, onFile, onWorkflowStarted, onWorkflowFinished, onNodeStarted, onNodeFinished, onIterationStart, onIterationNext, onIterationFinish, onParallelBranchStarted, onParallelBranchFinished, onTextChunk, onTTSChunk, onTTSEnd, onTextReplace) }).catch((e) => { if (e.toString() !== 'AbortError: The user aborted a request.' && !e.toString().errorMessage.includes('TypeError: Cannot assign to read only property')) Toast.notify({ type: 'error', message: e }) diff --git a/web/service/datasets.ts b/web/service/datasets.ts index c861a73c37..4ca269a7d6 100644 --- a/web/service/datasets.ts +++ b/web/service/datasets.ts @@ -18,14 +18,14 @@ import type { ProcessRuleResponse, RelatedAppResponse, SegmentDetailModel, - SegmentUpdator, + SegmentUpdater, SegmentsQuery, SegmentsResponse, createDocumentResponse, } from '@/models/datasets' import type { CommonResponse, DataSourceNotionWorkspace } from '@/models/common' import type { - ApikeysListResponse, + ApiKeysListResponse, CreateApiKeyResponse, } from '@/models/app' import type { RetrievalConfig } from '@/types/app' @@ -184,11 +184,11 @@ export const disableSegment: Fetcher(`/datasets/${datasetId}/segments/${segmentId}/disable`) } -export const updateSegment: Fetcher<{ data: SegmentDetailModel; doc_form: string }, { datasetId: string; documentId: string; segmentId: string; body: SegmentUpdator }> = ({ datasetId, documentId, segmentId, body }) => { +export const updateSegment: Fetcher<{ data: SegmentDetailModel; doc_form: string }, { datasetId: string; documentId: string; segmentId: string; body: SegmentUpdater }> = ({ datasetId, documentId, segmentId, body }) => { return patch<{ data: SegmentDetailModel; doc_form: string }>(`/datasets/${datasetId}/documents/${documentId}/segments/${segmentId}`, { body }) } -export const addSegment: Fetcher<{ data: SegmentDetailModel; doc_form: string }, { datasetId: string; documentId: string; body: SegmentUpdator }> = ({ datasetId, documentId, body }) => { +export const addSegment: Fetcher<{ data: SegmentDetailModel; doc_form: string }, { datasetId: string; documentId: string; body: SegmentUpdater }> = ({ datasetId, documentId, body }) => { return post<{ data: SegmentDetailModel; doc_form: string }>(`/datasets/${datasetId}/documents/${documentId}/segment`, { body }) } @@ -221,8 +221,8 @@ export const fetchNotionPagePreview: Fetcher<{ content: string }, { workspaceID: return get<{ content: string }>(`notion/workspaces/${workspaceID}/pages/${pageID}/${pageType}/preview`) } -export const fetchApiKeysList: Fetcher }> = ({ url, params }) => { - return get(url, params) +export const fetchApiKeysList: Fetcher }> = ({ url, params }) => { + return get(url, params) } export const delApikey: Fetcher }> = ({ url, params }) => { diff --git a/web/service/debug.ts b/web/service/debug.ts index 8e90fe565f..38068cad6e 100644 --- a/web/service/debug.ts +++ b/web/service/debug.ts @@ -96,7 +96,7 @@ export const fetchPromptTemplate = ({ }) } -export const fetchTextGenerationMessge = ({ +export const fetchTextGenerationMessage = ({ appId, messageId, }: { appId: string; messageId: string }) => { diff --git a/web/service/share.ts b/web/service/share.ts index f5a695f6c3..0e46e30d01 100644 --- a/web/service/share.ts +++ b/web/service/share.ts @@ -1,9 +1,9 @@ -import type { IOnCompleted, IOnData, IOnError, IOnFile, IOnIterationFinished, IOnIterationNexted, IOnIterationStarted, IOnMessageEnd, IOnMessageReplace, IOnNodeFinished, IOnNodeStarted, IOnTTSChunk, IOnTTSEnd, IOnTextChunk, IOnTextReplace, IOnThought, IOnWorkflowFinished, IOnWorkflowStarted } from './base' +import type { IOnCompleted, IOnData, IOnError, IOnFile, IOnIterationFinished, IOnIterationNext, IOnIterationStarted, IOnMessageEnd, IOnMessageReplace, IOnNodeFinished, IOnNodeStarted, IOnTTSChunk, IOnTTSEnd, IOnTextChunk, IOnTextReplace, IOnThought, IOnWorkflowFinished, IOnWorkflowStarted } from './base' import { del as consoleDel, get as consoleGet, patch as consolePatch, post as consolePost, delPublic as del, getPublic as get, patchPublic as patch, postPublic as post, ssePost, } from './base' -import type { Feedbacktype } from '@/app/components/base/chat/chat/type' +import type { FeedbackType } from '@/app/components/base/chat/chat/type' import type { AppConversationData, AppData, @@ -86,7 +86,7 @@ export const sendWorkflowMessage = async ( onNodeFinished: IOnNodeFinished onWorkflowFinished: IOnWorkflowFinished onIterationStart: IOnIterationStarted - onIterationNext: IOnIterationNexted + onIterationNext: IOnIterationNext onIterationFinish: IOnIterationFinished onTextChunk: IOnTextChunk onTextReplace: IOnTextReplace @@ -180,7 +180,7 @@ export const fetchAppMeta = async (isInstalledApp: boolean, installedAppId = '') return (getAction('get', isInstalledApp))(getUrl('meta', isInstalledApp, installedAppId)) as Promise } -export const updateFeedback = async ({ url, body }: { url: string; body: Feedbacktype }, isInstalledApp: boolean, installedAppId = '') => { +export const updateFeedback = async ({ url, body }: { url: string; body: FeedbackType }, isInstalledApp: boolean, installedAppId = '') => { return (getAction('post', isInstalledApp))(getUrl(url, isInstalledApp, installedAppId), { body }) } diff --git a/web/service/workflow.ts b/web/service/workflow.ts index 93ab0006d4..431beef96c 100644 --- a/web/service/workflow.ts +++ b/web/service/workflow.ts @@ -26,7 +26,7 @@ export const fetchWorkflowRunHistory: Fetcher(url) } -export const fetcChatRunHistory: Fetcher = (url) => { +export const fetchChatRunHistory: Fetcher = (url) => { return get(url) } diff --git a/web/themes/dark.css b/web/themes/dark.css index b94124aad2..8d77329b5a 100644 --- a/web/themes/dark.css +++ b/web/themes/dark.css @@ -147,13 +147,13 @@ html[data-theme="dark"] { --color-components-main-nav-nav-user-border: #FFFFFF0D; - --color-components-silder-knob: #F4F4F5; - --color-components-silder-knob-hover: #FEFEFE; - --color-components-silder-knob-disabled: #FFFFFF33; - --color-components-silder-range: #296DFF; - --color-components-silder-track: #FFFFFF33; - --color-components-silder-knob-border-hover: #1018284D; - --color-components-silder-knob-border: #10182833; + --color-components-slider-knob: #F4F4F5; + --color-components-slider-knob-hover: #FEFEFE; + --color-components-slider-knob-disabled: #FFFFFF33; + --color-components-slider-range: #296DFF; + --color-components-slider-track: #FFFFFF33; + --color-components-slider-knob-border-hover: #1018284D; + --color-components-slider-knob-border: #10182833; --color-components-segmented-control-item-active-bg: #FFFFFF14; --color-components-segmented-control-item-active-border: #C8CEDA14; @@ -268,7 +268,7 @@ html[data-theme="dark"] { --color-background-body: #1D1D20; --color-background-default-subtle: #222225; - --color-background-neurtral-subtle: #1D1D20; + --color-background-neutral-subtle: #1D1D20; --color-background-sidenav-bg: #27272AEB; --color-background-default: #222225; --color-background-soft: #18181B40; @@ -316,6 +316,7 @@ html[data-theme="dark"] { --color-workflow-block-border: #FFFFFF14; --color-workflow-block-parma-bg: #FFFFFF0D; --color-workflow-block-bg: #27272B; + --color-workflow-block-border-highlight: #C8CEDA33; --color-workflow-canvas-workflow-dot-color: #8585AD26; --color-workflow-canvas-workflow-bg: #1D1D20; @@ -324,8 +325,8 @@ html[data-theme="dark"] { --color-workflow-link-line-normal: #676F83; --color-workflow-link-line-handle: #296DFF; - --color-workflow-minmap-bg: #27272B; - --color-workflow-minmap-block: #C8CEDA14; + --color-workflow-minimap-bg: #27272B; + --color-workflow-minimap-block: #C8CEDA14; --color-workflow-display-success-bg: #17B26A33; --color-workflow-display-success-border-1: #17B26AE5; @@ -371,8 +372,8 @@ html[data-theme="dark"] { --color-divider-deep: #C8CEDA33; --color-divider-burn: #18181BF2; --color-divider-intense: #C8CEDA66; - --color-divider-soild: #3A3A40; - --color-divider-soild-alt: #747481; + --color-divider-solid: #3A3A40; + --color-divider-solid-alt: #747481; --color-state-base-hover: #C8CEDA14; --color-state-base-active: #C8CEDA33; @@ -383,24 +384,24 @@ html[data-theme="dark"] { --color-state-accent-hover: #155AEF24; --color-state-accent-active: #155AEF24; --color-state-accent-hover-alt: #155AEF40; - --color-state-accent-soild: #5289FF; + --color-state-accent-solid: #5289FF; --color-state-accent-active-alt: #155AEF33; --color-state-destructive-hover: #F0443824; --color-state-destructive-hover-alt: #F0443840; --color-state-destructive-active: #F044384D; - --color-state-destructive-soild: #F97066; + --color-state-destructive-solid: #F97066; --color-state-destructive-border: #F97066; --color-state-success-hover: #17B26A24; --color-state-success-hover-alt: #17B26A40; --color-state-success-active: #17B26A4D; - --color-state-success-soild: #47CD89; + --color-state-success-solid: #47CD89; --color-state-warning-hover: #F7900924; --color-state-warning-hover-alt: #F7900940; --color-state-warning-active: #F790094D; - --color-state-warning-soild: #F79009; + --color-state-warning-solid: #F79009; --color-effects-highlight: #C8CEDA14; --color-effects-highlight-lightmode-off: #C8CEDA14; diff --git a/web/themes/light.css b/web/themes/light.css index 80a0fa36f5..89303c250e 100644 --- a/web/themes/light.css +++ b/web/themes/light.css @@ -147,13 +147,13 @@ html[data-theme="light"] { --color-components-main-nav-nav-user-border: #FFFFFF; - --color-components-silder-knob: #FFFFFF; - --color-components-silder-knob-hover: #FFFFFF; - --color-components-silder-knob-disabled: #FFFFFFF2; - --color-components-silder-range: #296DFF; - --color-components-silder-track: #E9EBF0; - --color-components-silder-knob-border-hover: #10182833; - --color-components-silder-knob-border: #10182824; + --color-components-slider-knob: #FFFFFF; + --color-components-slider-knob-hover: #FFFFFF; + --color-components-slider-knob-disabled: #FFFFFFF2; + --color-components-slider-range: #296DFF; + --color-components-slider-track: #E9EBF0; + --color-components-slider-knob-border-hover: #10182833; + --color-components-slider-knob-border: #10182824; --color-components-segmented-control-item-active-bg: #FFFFFF; --color-components-segmented-control-item-active-border: #FFFFFF; @@ -268,7 +268,7 @@ html[data-theme="light"] { --color-background-body: #F2F4F7; --color-background-default-subtle: #FCFCFD; - --color-background-neurtral-subtle: #F9FAFB; + --color-background-neutral-subtle: #F9FAFB; --color-background-sidenav-bg: #FFFFFFCC; --color-background-default: #FFFFFF; --color-background-soft: #F9FAFB; @@ -316,6 +316,7 @@ html[data-theme="light"] { --color-workflow-block-border: #FFFFFF; --color-workflow-block-parma-bg: #F2F4F7; --color-workflow-block-bg: #FCFCFD; + --color-workflow-block-border-highlight: #155AEF24; --color-workflow-canvas-workflow-dot-color: #8585AD26; --color-workflow-canvas-workflow-bg: #F2F4F7; @@ -324,8 +325,8 @@ html[data-theme="light"] { --color-workflow-link-line-normal: #D0D5DC; --color-workflow-link-line-handle: #296DFF; - --color-workflow-minmap-bg: #E9EBF0; - --color-workflow-minmap-block: #C8CEDA4D; + --color-workflow-minimap-bg: #E9EBF0; + --color-workflow-minimap-block: #C8CEDA4D; --color-workflow-display-success-bg: #ECFDF3; --color-workflow-display-success-border-1: #17B26ACC; @@ -371,8 +372,8 @@ html[data-theme="light"] { --color-divider-deep: #10182824; --color-divider-burn: #1018280A; --color-divider-intense: #1018284D; - --color-divider-soild: #D0D5DC; - --color-divider-soild-alt: #98A2B2; + --color-divider-solid: #D0D5DC; + --color-divider-solid-alt: #98A2B2; --color-state-base-hover: #C8CEDA33; --color-state-base-active: #C8CEDA66; @@ -383,24 +384,24 @@ html[data-theme="light"] { --color-state-accent-hover: #EFF4FF; --color-state-accent-active: #155AEF14; --color-state-accent-hover-alt: #D1E0FF; - --color-state-accent-soild: #296DFF; + --color-state-accent-solid: #296DFF; --color-state-accent-active-alt: #155AEF24; --color-state-destructive-hover: #FEF3F2; --color-state-destructive-hover-alt: #FEE4E2; --color-state-destructive-active: #FECDCA; - --color-state-destructive-soild: #F04438; + --color-state-destructive-solid: #F04438; --color-state-destructive-border: #FDA29B; --color-state-success-hover: #ECFDF3; --color-state-success-hover-alt: #DCFAE6; --color-state-success-active: #ABEFC6; - --color-state-success-soild: #17B26A; + --color-state-success-solid: #17B26A; --color-state-warning-hover: #FFFAEB; --color-state-warning-hover-alt: #FEF0C7; --color-state-warning-active: #FEDF89; - --color-state-warning-soild: #F79009; + --color-state-warning-solid: #F79009; --color-effects-highlight: #FFFFFF; --color-effects-highlight-lightmode-off: #FFFFFF00; diff --git a/web/themes/tailwind-theme-var-define.ts b/web/themes/tailwind-theme-var-define.ts index caeb01b5fa..643c96d1a1 100644 --- a/web/themes/tailwind-theme-var-define.ts +++ b/web/themes/tailwind-theme-var-define.ts @@ -147,13 +147,13 @@ const vars = { 'components-main-nav-nav-user-border': 'var(--color-components-main-nav-nav-user-border)', - 'components-silder-knob': 'var(--color-components-silder-knob)', - 'components-silder-knob-hover': 'var(--color-components-silder-knob-hover)', - 'components-silder-knob-disabled': 'var(--color-components-silder-knob-disabled)', - 'components-silder-range': 'var(--color-components-silder-range)', - 'components-silder-track': 'var(--color-components-silder-track)', - 'components-silder-knob-border-hover': 'var(--color-components-silder-knob-border-hover)', - 'components-silder-knob-border': 'var(--color-components-silder-knob-border)', + 'components-slider-knob': 'var(--color-components-slider-knob)', + 'components-slider-knob-hover': 'var(--color-components-slider-knob-hover)', + 'components-slider-knob-disabled': 'var(--color-components-slider-knob-disabled)', + 'components-slider-range': 'var(--color-components-slider-range)', + 'components-slider-track': 'var(--color-components-slider-track)', + 'components-slider-knob-border-hover': 'var(--color-components-slider-knob-border-hover)', + 'components-slider-knob-border': 'var(--color-components-slider-knob-border)', 'components-segmented-control-item-active-bg': 'var(--color-components-segmented-control-item-active-bg)', 'components-segmented-control-item-active-border': 'var(--color-components-segmented-control-item-active-border)', @@ -268,7 +268,7 @@ const vars = { 'background-body': 'var(--color-background-body)', 'background-default-subtle': 'var(--color-background-default-subtle)', - 'background-neurtral-subtle': 'var(--color-background-neurtral-subtle)', + 'background-neutral-subtle': 'var(--color-background-neutral-subtle)', 'background-sidenav-bg': 'var(--color-background-sidenav-bg)', 'background-default': 'var(--color-background-default)', 'background-soft': 'var(--color-background-soft)', @@ -316,6 +316,7 @@ const vars = { 'workflow-block-border': 'var(--color-workflow-block-border)', 'workflow-block-parma-bg': 'var(--color-workflow-block-parma-bg)', 'workflow-block-bg': 'var(--color-workflow-block-bg)', + 'workflow-block-border-highlight': 'var(--color-workflow-block-border-highlight)', 'workflow-canvas-workflow-dot-color': 'var(--color-workflow-canvas-workflow-dot-color)', 'workflow-canvas-workflow-bg': 'var(--color-workflow-canvas-workflow-bg)', @@ -324,8 +325,8 @@ const vars = { 'workflow-link-line-normal': 'var(--color-workflow-link-line-normal)', 'workflow-link-line-handle': 'var(--color-workflow-link-line-handle)', - 'workflow-minmap-bg': 'var(--color-workflow-minmap-bg)', - 'workflow-minmap-block': 'var(--color-workflow-minmap-block)', + 'workflow-minimap-bg': 'var(--color-workflow-minimap-bg)', + 'workflow-minimap-block': 'var(--color-workflow-minimap-block)', 'workflow-display-success-bg': 'var(--color-workflow-display-success-bg)', 'workflow-display-success-border-1': 'var(--color-workflow-display-success-border-1)', @@ -371,8 +372,8 @@ const vars = { 'divider-deep': 'var(--color-divider-deep)', 'divider-burn': 'var(--color-divider-burn)', 'divider-intense': 'var(--color-divider-intense)', - 'divider-soild': 'var(--color-divider-soild)', - 'divider-soild-alt': 'var(--color-divider-soild-alt)', + 'divider-solid': 'var(--color-divider-solid)', + 'divider-solid-alt': 'var(--color-divider-solid-alt)', 'state-base-hover': 'var(--color-state-base-hover)', 'state-base-active': 'var(--color-state-base-active)', @@ -383,24 +384,24 @@ const vars = { 'state-accent-hover': 'var(--color-state-accent-hover)', 'state-accent-active': 'var(--color-state-accent-active)', 'state-accent-hover-alt': 'var(--color-state-accent-hover-alt)', - 'state-accent-soild': 'var(--color-state-accent-soild)', + 'state-accent-solid': 'var(--color-state-accent-solid)', 'state-accent-active-alt': 'var(--color-state-accent-active-alt)', 'state-destructive-hover': 'var(--color-state-destructive-hover)', 'state-destructive-hover-alt': 'var(--color-state-destructive-hover-alt)', 'state-destructive-active': 'var(--color-state-destructive-active)', - 'state-destructive-soild': 'var(--color-state-destructive-soild)', + 'state-destructive-solid': 'var(--color-state-destructive-solid)', 'state-destructive-border': 'var(--color-state-destructive-border)', 'state-success-hover': 'var(--color-state-success-hover)', 'state-success-hover-alt': 'var(--color-state-success-hover-alt)', 'state-success-active': 'var(--color-state-success-active)', - 'state-success-soild': 'var(--color-state-success-soild)', + 'state-success-solid': 'var(--color-state-success-solid)', 'state-warning-hover': 'var(--color-state-warning-hover)', 'state-warning-hover-alt': 'var(--color-state-warning-hover-alt)', 'state-warning-active': 'var(--color-state-warning-active)', - 'state-warning-soild': 'var(--color-state-warning-soild)', + 'state-warning-solid': 'var(--color-state-warning-solid)', 'effects-highlight': 'var(--color-effects-highlight)', 'effects-highlight-lightmode-off': 'var(--color-effects-highlight-lightmode-off)', diff --git a/web/types/app.ts b/web/types/app.ts index fb8a407dd2..cb05bc3878 100644 --- a/web/types/app.ts +++ b/web/types/app.ts @@ -297,6 +297,7 @@ export type SiteConfig = { icon_url: string | null show_workflow_steps: boolean + use_icon_as_answer_icon: boolean } export type AppIconType = 'image' | 'emoji' @@ -323,6 +324,8 @@ export type App = { icon_background: string | null /** Icon URL, only available when icon_type is 'image' */ icon_url: string | null + /** Whether to use app icon as answer icon */ + use_icon_as_answer_icon: boolean /** Mode */ mode: AppMode diff --git a/web/types/workflow.ts b/web/types/workflow.ts index f7991bc4e0..dbf2b3e587 100644 --- a/web/types/workflow.ts +++ b/web/types/workflow.ts @@ -26,9 +26,14 @@ export type NodeTracing = { currency: string iteration_id?: string iteration_index?: number + parallel_id?: string + parallel_start_node_id?: string + parent_parallel_id?: string + parent_parallel_start_node_id?: string } metadata: { iterator_length: number + iterator_index: number } created_at: number created_by: { @@ -40,6 +45,10 @@ export type NodeTracing = { extras?: any expand?: boolean // for UI details?: NodeTracing[][] // iteration detail + parallel_id?: string + parallel_start_node_id?: string + parent_parallel_id?: string + parent_parallel_start_node_id?: string } export type FetchWorkflowDraftResponse = { @@ -109,6 +118,7 @@ export type NodeStartedResponse = { data: { id: string node_id: string + iteration_id?: string node_type: string index: number predecessor_node_id?: string @@ -125,6 +135,7 @@ export type NodeFinishedResponse = { data: { id: string node_id: string + iteration_id?: string node_type: string index: number predecessor_node_id?: string @@ -138,6 +149,10 @@ export type NodeFinishedResponse = { total_tokens: number total_price: number currency: string + parallel_id?: string + parallel_start_node_id?: string + iteration_index?: number + iteration_id?: string } created_at: number } @@ -152,13 +167,15 @@ export type IterationStartedResponse = { node_id: string metadata: { iterator_length: number + iteration_id: string + iteration_index: number } created_at: number extras?: any } } -export type IterationNextedResponse = { +export type IterationNextResponse = { task_id: string workflow_run_id: string event: string @@ -169,6 +186,9 @@ export type IterationNextedResponse = { output: any extras?: any created_at: number + execution_metadata: { + parallel_id?: string + } } } @@ -184,6 +204,39 @@ export type IterationFinishedResponse = { status: string created_at: number error: string + execution_metadata: { + parallel_id?: string + } + } +} + +export type ParallelBranchStartedResponse = { + task_id: string + workflow_run_id: string + event: string + data: { + parallel_id: string + parallel_start_node_id: string + parent_parallel_id: string + parent_parallel_start_node_id: string + iteration_id?: string + created_at: number + } +} + +export type ParallelBranchFinishedResponse = { + task_id: string + workflow_run_id: string + event: string + data: { + parallel_id: string + parallel_start_node_id: string + parent_parallel_id: string + parent_parallel_start_node_id: string + iteration_id?: string + status: string + created_at: number + error: string } } diff --git a/web/utils/var.ts b/web/utils/var.ts index 436ebfc70b..236c9debac 100644 --- a/web/utils/var.ts +++ b/web/utils/var.ts @@ -1,4 +1,4 @@ -import { MAX_VAR_KEY_LENGHT, VAR_ITEM_TEMPLATE, VAR_ITEM_TEMPLATE_IN_WORKFLOW, getMaxVarNameLength } from '@/config' +import { MAX_VAR_KEY_LENGTH, VAR_ITEM_TEMPLATE, VAR_ITEM_TEMPLATE_IN_WORKFLOW, getMaxVarNameLength } from '@/config' import { CONTEXT_PLACEHOLDER_TEXT, HISTORY_PLACEHOLDER_TEXT, PRE_PROMPT_PLACEHOLDER_TEXT, QUERY_PLACEHOLDER_TEXT } from '@/app/components/base/prompt-editor/constants' import { InputVarType } from '@/app/components/workflow/types' @@ -47,7 +47,7 @@ export const checkKey = (key: string, canBeEmpty?: boolean) => { if (canBeEmpty && key === '') return true - if (key.length > MAX_VAR_KEY_LENGHT) + if (key.length > MAX_VAR_KEY_LENGTH) return 'tooLong' if (otherAllowedRegex.test(key)) { @@ -86,7 +86,7 @@ export const getVars = (value: string) => { return ![CONTEXT_PLACEHOLDER_TEXT, HISTORY_PLACEHOLDER_TEXT, QUERY_PLACEHOLDER_TEXT, PRE_PROMPT_PLACEHOLDER_TEXT].includes(item) }).map((item) => { return item.replace('{{', '').replace('}}', '') - }).filter(key => key.length <= MAX_VAR_KEY_LENGHT) || [] + }).filter(key => key.length <= MAX_VAR_KEY_LENGTH) || [] const keyObj: Record = {} // remove duplicate keys const res: string[] = [] diff --git a/web/yarn.lock b/web/yarn.lock index d50aa33f3e..3c020c9664 100644 --- a/web/yarn.lock +++ b/web/yarn.lock @@ -260,6 +260,13 @@ resolved "https://registry.npmjs.org/@babel/parser/-/parser-7.24.4.tgz" integrity sha512-zTvEBcghmeBma9QIGunWevvBAp4/Qu9Bdq+2k0Ot4fVMD6v3dsC9WOcRSKk7tRRyBM/53yKMJko9xOatGQAwSg== +"@babel/parser@^7.25.4": + version "7.25.6" + resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.25.6.tgz#85660c5ef388cbbf6e3d2a694ee97a38f18afe2f" + integrity sha512-trGdfBdbD0l1ZPmcJ83eNxB9rbEax4ALFTF7fN386TMYbeCQbyme5cOEXQhbGXKebwGaB/J52w1mrklMcbgy6Q== + dependencies: + "@babel/types" "^7.25.6" + "@babel/plugin-syntax-async-generators@^7.8.4": version "7.8.4" resolved "https://registry.npmjs.org/@babel/plugin-syntax-async-generators/-/plugin-syntax-async-generators-7.8.4.tgz#a983fb1aeb2ec3f6ed042a210f640e90e786fe0d" @@ -406,6 +413,15 @@ "@babel/helper-validator-identifier" "^7.24.7" to-fast-properties "^2.0.0" +"@babel/types@^7.25.4", "@babel/types@^7.25.6": + version "7.25.6" + resolved "https://registry.yarnpkg.com/@babel/types/-/types-7.25.6.tgz#893942ddb858f32ae7a004ec9d3a76b3463ef8e6" + integrity sha512-/l42B1qxpG6RdfYf343Uw1vmDjeNhneUXtzhojE7pDgfpEypmRhI6j1kr17XCVv4Cgl9HdAiQY2x0GwKm7rWCw== + dependencies: + "@babel/helper-string-parser" "^7.24.8" + "@babel/helper-validator-identifier" "^7.24.7" + to-fast-properties "^2.0.0" + "@bcoe/v8-coverage@^0.2.3": version "0.2.3" resolved "https://registry.npmjs.org/@bcoe/v8-coverage/-/v8-coverage-0.2.3.tgz#75a2e8b51cb758a7553d6804a5932d7aace75c39" @@ -1454,6 +1470,11 @@ resolved "https://registry.npmjs.org/@sinclair/typebox/-/typebox-0.27.8.tgz#6667fac16c436b5434a387a34dedb013198f6e6e" integrity sha512-+Fj43pSMwJs4KRrH/938Uf+uAELIgVBmQzg/q1YG10djyfA3TnrU8N8XzqCh/okZdszqBQTZf96idMfE5lnwTA== +"@sindresorhus/is@^4.0.0": + version "4.6.0" + resolved "https://registry.yarnpkg.com/@sindresorhus/is/-/is-4.6.0.tgz#3c7c9c46e678feefe7a2e5bb609d3dbd665ffb3f" + integrity sha512-t09vSN3MdfsyCHoFcTRCH/iUtG7OJ0CsjzB8cjAmKc/va/kIgeDI/TxsigdncE/4be734m0cvIYwNaV4i2XqAw== + "@sinonjs/commons@^3.0.0": version "3.0.1" resolved "https://registry.npmjs.org/@sinonjs/commons/-/commons-3.0.1.tgz#1029357e44ca901a615585f6d27738dbc89084cd" @@ -1481,6 +1502,13 @@ "@swc/counter" "^0.1.3" tslib "^2.4.0" +"@szmarczak/http-timer@^4.0.5": + version "4.0.6" + resolved "https://registry.yarnpkg.com/@szmarczak/http-timer/-/http-timer-4.0.6.tgz#b4a914bb62e7c272d4e5989fe4440f812ab1d807" + integrity sha512-4BAffykYOgO+5nzBWYwE3W90sBgLJoUPRWWcL8wlyiM8IB8ipJz3UMJ9KXQd1RKQXpKp8Tutn80HZtWsu2u76w== + dependencies: + defer-to-connect "^2.0.0" + "@tailwindcss/line-clamp@^0.4.4": version "0.4.4" resolved "https://registry.npmjs.org/@tailwindcss/line-clamp/-/line-clamp-0.4.4.tgz" @@ -1601,6 +1629,16 @@ dependencies: "@babel/types" "^7.20.7" +"@types/cacheable-request@^6.0.1": + version "6.0.3" + resolved "https://registry.yarnpkg.com/@types/cacheable-request/-/cacheable-request-6.0.3.tgz#a430b3260466ca7b5ca5bfd735693b36e7a9d183" + integrity sha512-IQ3EbTzGxIigb1I3qPZc1rWJnH0BmSKv5QYTalEwweFvyBDLSAe24zP0le/hyi7ecGfZVlIVAg4BZqb8WBwKqw== + dependencies: + "@types/http-cache-semantics" "*" + "@types/keyv" "^3.1.4" + "@types/node" "*" + "@types/responselike" "^1.0.0" + "@types/crypto-js@^4.1.1": version "4.1.1" resolved "https://registry.npmjs.org/@types/crypto-js/-/crypto-js-4.1.1.tgz" @@ -1859,6 +1897,18 @@ dependencies: "@types/unist" "*" +"@types/hast@^3.0.0": + version "3.0.4" + resolved "https://registry.yarnpkg.com/@types/hast/-/hast-3.0.4.tgz#1d6b39993b82cea6ad783945b0508c25903e15aa" + integrity sha512-WPs+bbQw5aCj+x6laNGWLH3wviHtoCv/P3+otBhbOhJgG8qtpdAMlTCxLtsTWA7LH1Oh/bFCHsBn0TPS5m30EQ== + dependencies: + "@types/unist" "*" + +"@types/http-cache-semantics@*": + version "4.0.4" + resolved "https://registry.yarnpkg.com/@types/http-cache-semantics/-/http-cache-semantics-4.0.4.tgz#b979ebad3919799c979b17c72621c0bc0a31c6c4" + integrity sha512-1m0bIFVc7eJWyve9S0RnuRgcQqF/Xd5QsUZAZeQFr1Q3/p9JWoQQEqmVy+DPTNpGXwhgIetAoYF8JSc33q29QA== + "@types/istanbul-lib-coverage@*", "@types/istanbul-lib-coverage@^2.0.0", "@types/istanbul-lib-coverage@^2.0.1": version "2.0.6" resolved "https://registry.npmjs.org/@types/istanbul-lib-coverage/-/istanbul-lib-coverage-2.0.6.tgz#7739c232a1fee9b4d3ce8985f314c0c6d33549d7" @@ -1925,6 +1975,13 @@ resolved "https://registry.npmjs.org/@types/katex/-/katex-0.16.0.tgz" integrity sha512-hz+S3nV6Mym5xPbT9fnO8dDhBFQguMYpY0Ipxv06JMi1ORgnEM4M1ymWDUhUNer3ElLmT583opRo4RzxKmh9jw== +"@types/keyv@^3.1.4": + version "3.1.4" + resolved "https://registry.yarnpkg.com/@types/keyv/-/keyv-3.1.4.tgz#3ccdb1c6751b0c7e52300bcdacd5bcbf8faa75b6" + integrity sha512-BQ5aZNSCpj7D6K2ksrRCTmKRLEpnPvWDiLPfoGyhZ++8YtiK9d/3DBKPJgry359X/P1PfruyYwvnvwFjuEiEIg== + dependencies: + "@types/node" "*" + "@types/lodash-es@^4.17.7": version "4.17.7" resolved "https://registry.npmjs.org/@types/lodash-es/-/lodash-es-4.17.7.tgz" @@ -1944,6 +2001,13 @@ dependencies: "@types/unist" "*" +"@types/mdast@^4.0.0": + version "4.0.4" + resolved "https://registry.yarnpkg.com/@types/mdast/-/mdast-4.0.4.tgz#7ccf72edd2f1aa7dd3437e180c64373585804dd6" + integrity sha512-kGaNbPh1k7AFzgpud/gMdvIm5xuECykRR+JnWKQno9TAXVa6WIVCGTPvYGekIDL4uwCZQSYbUxNBSb1aUo79oA== + dependencies: + "@types/unist" "*" + "@types/mdx@^2.0.0": version "2.0.5" resolved "https://registry.npmjs.org/@types/mdx/-/mdx-2.0.5.tgz" @@ -2035,6 +2099,13 @@ resolved "https://registry.npmjs.org/@types/recordrtc/-/recordrtc-5.6.11.tgz" integrity sha512-X4XD5nltz0cjmyzsPNegQReOPF+C5ARTfSPAPhqnKV7SsfRta/M4FBJ5AtSInCaEveL71FLLSVQE9mg8Uuo++w== +"@types/responselike@^1.0.0": + version "1.0.3" + resolved "https://registry.yarnpkg.com/@types/responselike/-/responselike-1.0.3.tgz#cc29706f0a397cfe6df89debfe4bf5cea159db50" + integrity sha512-H/+L+UkTV33uf49PH5pCAUBVPNj2nDBXTN+qS1dOwyyg24l3CcicicCA7ca+HMvJBZcFgl5r8e+RR6elsb4Lyw== + dependencies: + "@types/node" "*" + "@types/semver@^7.3.12": version "7.5.0" resolved "https://registry.npmjs.org/@types/semver/-/semver-7.5.0.tgz" @@ -2060,6 +2131,11 @@ resolved "https://registry.npmjs.org/@types/unist/-/unist-2.0.6.tgz" integrity sha512-PBjIUxZHOuj0R15/xuwJYjFi+KZdNFrehocChv4g5hu6aFroHue8m0lBP0POdK2nKzbw0cgV1mws8+V/JAcEkQ== +"@types/unist@^3.0.0": + version "3.0.2" + resolved "https://registry.yarnpkg.com/@types/unist/-/unist-3.0.2.tgz#6dd61e43ef60b34086287f83683a5c1b2dc53d20" + integrity sha512-dqId9J8K/vGi5Zr7oo212BGii5m3q5Hxlkwy3WpYuKPklmBEvsbMYYyLxAQpSffdLl/gdW0XUpKWFvYmyoWCoQ== + "@types/uuid@^9.0.8": version "9.0.8" resolved "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.8.tgz" @@ -2161,6 +2237,11 @@ "@typescript-eslint/types" "5.59.9" eslint-visitor-keys "^3.3.0" +"@ungap/structured-clone@^1.0.0": + version "1.2.0" + resolved "https://registry.yarnpkg.com/@ungap/structured-clone/-/structured-clone-1.2.0.tgz#756641adb587851b5ccb3e095daf27ae581c8406" + integrity sha512-zuVdFrMJiuCDQUMCzQaD6KL28MjnqqN8XnAqiEq9PNm/hCPTSGfrXCOfwj1ow4LFb/tNymJPwsNbVePc1xFqrQ== + "@vue/compiler-core@3.4.25": version "3.4.25" resolved "https://registry.npmjs.org/@vue/compiler-core/-/compiler-core-3.4.25.tgz" @@ -2578,6 +2659,13 @@ binary-extensions@^2.0.0: resolved "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.2.0.tgz" integrity sha512-jDctJ/IVQbZoJykoeHbhXpOlNBqGNcwXJKJog42E5HDPUwQTSdjCHdihjj0DlnheQ7blbT6dHOafNAiS8ooQKA== +bing-translate-api@^4.0.2: + version "4.0.2" + resolved "https://registry.yarnpkg.com/bing-translate-api/-/bing-translate-api-4.0.2.tgz#52807a128e883bf074b4174c5e674ffca60685e7" + integrity sha512-JJ8XUehnxzOhHU91oy86xEtp8OOMjVEjCZJX042fKxoO19NNvxJ5omeCcxQNFoPbDqVpBJwqiGVquL0oPdQm1Q== + dependencies: + got "^11.8.6" + boolbase@^1.0.0: version "1.0.0" resolved "https://registry.npmjs.org/boolbase/-/boolbase-1.0.0.tgz" @@ -2670,6 +2758,24 @@ busboy@1.6.0: dependencies: streamsearch "^1.1.0" +cacheable-lookup@^5.0.3: + version "5.0.4" + resolved "https://registry.yarnpkg.com/cacheable-lookup/-/cacheable-lookup-5.0.4.tgz#5a6b865b2c44357be3d5ebc2a467b032719a7005" + integrity sha512-2/kNscPhpcxrOigMZzbiWF7dz8ilhb/nIHU3EyZiXWXpeq/au8qJ8VhdftMkty3n7Gj6HIGalQG8oiBNB3AJgA== + +cacheable-request@^7.0.2: + version "7.0.4" + resolved "https://registry.yarnpkg.com/cacheable-request/-/cacheable-request-7.0.4.tgz#7a33ebf08613178b403635be7b899d3e69bbe817" + integrity sha512-v+p6ongsrp0yTGbJXjgxPow2+DL93DASP4kXCDKb8/bwRtt9OEF3whggkkDkGNzgcWy2XaF4a8nZglC7uElscg== + dependencies: + clone-response "^1.0.2" + get-stream "^5.1.0" + http-cache-semantics "^4.0.0" + keyv "^4.0.0" + lowercase-keys "^2.0.0" + normalize-url "^6.0.1" + responselike "^2.0.0" + call-bind@^1.0.0, call-bind@^1.0.2, call-bind@^1.0.4, call-bind@^1.0.5: version "1.0.5" resolved "https://registry.npmjs.org/call-bind/-/call-bind-1.0.5.tgz" @@ -2893,6 +2999,13 @@ cliui@^8.0.1: strip-ansi "^6.0.1" wrap-ansi "^7.0.0" +clone-response@^1.0.2: + version "1.0.3" + resolved "https://registry.yarnpkg.com/clone-response/-/clone-response-1.0.3.tgz#af2032aa47816399cf5f0a1d0db902f517abb8c3" + integrity sha512-ROoL94jJH2dUVML2Y/5PEDNaSHgeOdSDicUyS7izcF63G6sTc/FTjLub4b8Il9S8S0beOfYt0TaA5qvFK+w0wA== + dependencies: + mimic-response "^1.0.0" + clsx@2.0.0: version "2.0.0" resolved "https://registry.npmjs.org/clsx/-/clsx-2.0.0.tgz" @@ -3464,6 +3577,13 @@ decode-named-character-reference@^1.0.0: dependencies: character-entities "^2.0.0" +decompress-response@^6.0.0: + version "6.0.0" + resolved "https://registry.yarnpkg.com/decompress-response/-/decompress-response-6.0.0.tgz#ca387612ddb7e104bd16d85aab00d5ecf09c66fc" + integrity sha512-aW35yZM6Bb/4oJlZncMH2LCoZtJXTRxES17vE3hoRiowU2kWHaJKFkSBDnDR+cm9J+9QhXmREyIfv0pji9ejCQ== + dependencies: + mimic-response "^3.1.0" + dedent@^1.0.0: version "1.5.3" resolved "https://registry.npmjs.org/dedent/-/dedent-1.5.3.tgz#99aee19eb9bae55a67327717b6e848d0bf777e5a" @@ -3521,6 +3641,11 @@ default-browser@^4.0.0: execa "^7.1.1" titleize "^3.0.0" +defer-to-connect@^2.0.0: + version "2.0.1" + resolved "https://registry.yarnpkg.com/defer-to-connect/-/defer-to-connect-2.0.1.tgz#8016bdb4143e4632b77a3449c6236277de520587" + integrity sha512-4tvttepXG1VaYGrRibk5EwJd1t4udunSOVMdLSAL6mId1ix438oPwPZMALY41FCijukO1L0twNcGsdzS7dHgDg== + define-data-property@^1.0.1, define-data-property@^1.1.1: version "1.1.1" resolved "https://registry.npmjs.org/define-data-property/-/define-data-property-1.1.1.tgz" @@ -3571,6 +3696,13 @@ detect-newline@^3.0.0: resolved "https://registry.npmjs.org/detect-newline/-/detect-newline-3.1.0.tgz#576f5dfc63ae1a192ff192d8ad3af6308991b651" integrity sha512-TLz+x/vEXm/Y7P7wn1EJFNLxYpUD4TgMosxY6fAVJUnJMbupHBOncxyWUG9OpTaH9EBD7uFI5LfEgmMOc54DsA== +devlop@^1.0.0: + version "1.1.0" + resolved "https://registry.yarnpkg.com/devlop/-/devlop-1.1.0.tgz#4db7c2ca4dc6e0e834c30be70c94bbc976dc7018" + integrity sha512-RWmIqhcFf1lRYBvNmr7qTNuyCt/7/ns2jbpp1+PalgE/rDQcBT0fioSMUpJ93irlUhC5hrg4cYqe6U+0ImW0rA== + dependencies: + dequal "^2.0.0" + didyoumean@^1.2.2: version "1.2.2" resolved "https://registry.npmjs.org/didyoumean/-/didyoumean-1.2.2.tgz" @@ -3720,6 +3852,13 @@ emoji-regex@^9.2.2: resolved "https://registry.npmjs.org/emoji-regex/-/emoji-regex-9.2.2.tgz" integrity sha512-L18DaJsXSUk2+42pv8mLs5jJT2hqFkFE4j21wOmgbUqsZ2hL72NsUU785g9RXgo3s0ZNgVl42TiHp3ZtOv/Vyg== +end-of-stream@^1.1.0: + version "1.4.4" + resolved "https://registry.yarnpkg.com/end-of-stream/-/end-of-stream-1.4.4.tgz#5ae64a5f45057baf3626ec14da0ca5e4b2431eb0" + integrity sha512-+uw1inIHVPQoaVuHzRyXd21icM+cnt4CzD5rW+NC1wjOUSTOs+Te7FOv7AhN7vS9x/oIyhLP5PR1H+phQAHu5Q== + dependencies: + once "^1.4.0" + enhanced-resolve@^5.12.0: version "5.16.1" resolved "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.16.1.tgz" @@ -4557,6 +4696,13 @@ get-package-type@^0.1.0: resolved "https://registry.npmjs.org/get-package-type/-/get-package-type-0.1.0.tgz#8de2d803cff44df3bc6c456e6668b36c3926e11a" integrity sha512-pjzuKtY64GYfWizNAJ0fr9VqttZkNiK2iS430LtIHzjBEr6bX8Am2zm4sW4Ro5wjWW5cAlRL1qAMTcXbjNAO2Q== +get-stream@^5.1.0: + version "5.2.0" + resolved "https://registry.yarnpkg.com/get-stream/-/get-stream-5.2.0.tgz#4966a1795ee5ace65e706c4b7beb71257d6e22d3" + integrity sha512-nBF+F1rAZVCu/p7rjzgA+Yb4lfYXrpl7a6VmJrU8wF9I1CKvP/QwPNZHnOlwbTkY6dvtFIzFMSyQXbLoTQPRpA== + dependencies: + pump "^3.0.0" + get-stream@^6.0.0, get-stream@^6.0.1: version "6.0.1" resolved "https://registry.npmjs.org/get-stream/-/get-stream-6.0.1.tgz" @@ -4675,6 +4821,23 @@ gopd@^1.0.1: dependencies: get-intrinsic "^1.1.3" +got@^11.8.6: + version "11.8.6" + resolved "https://registry.yarnpkg.com/got/-/got-11.8.6.tgz#276e827ead8772eddbcfc97170590b841823233a" + integrity sha512-6tfZ91bOr7bOXnK7PRDCGBLa1H4U080YHNaAQ2KsMGlLEzRbk44nsZF2E1IeRc3vtJHPVbKCYgdFbaGO2ljd8g== + dependencies: + "@sindresorhus/is" "^4.0.0" + "@szmarczak/http-timer" "^4.0.5" + "@types/cacheable-request" "^6.0.1" + "@types/responselike" "^1.0.0" + cacheable-lookup "^5.0.3" + cacheable-request "^7.0.2" + decompress-response "^6.0.0" + http2-wrapper "^1.0.0-beta.5.2" + lowercase-keys "^2.0.0" + p-cancelable "^2.0.0" + responselike "^2.0.0" + graceful-fs@^4.2.11, graceful-fs@^4.2.4, graceful-fs@^4.2.9: version "4.2.11" resolved "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.11.tgz" @@ -4780,6 +4943,20 @@ hast-util-from-parse5@^7.0.0: vfile-location "^4.0.0" web-namespaces "^2.0.0" +hast-util-from-parse5@^8.0.0: + version "8.0.1" + resolved "https://registry.yarnpkg.com/hast-util-from-parse5/-/hast-util-from-parse5-8.0.1.tgz#654a5676a41211e14ee80d1b1758c399a0327651" + integrity sha512-Er/Iixbc7IEa7r/XLtuG52zoqn/b3Xng/w6aZQ0xGVxzhw5xUFxcRqdPzP6yFi/4HBYRaifaI5fQ1RH8n0ZeOQ== + dependencies: + "@types/hast" "^3.0.0" + "@types/unist" "^3.0.0" + devlop "^1.0.0" + hastscript "^8.0.0" + property-information "^6.0.0" + vfile "^6.0.0" + vfile-location "^5.0.0" + web-namespaces "^2.0.0" + hast-util-is-element@^2.0.0: version "2.1.3" resolved "https://registry.npmjs.org/hast-util-is-element/-/hast-util-is-element-2.1.3.tgz" @@ -4800,6 +4977,32 @@ hast-util-parse-selector@^3.0.0: dependencies: "@types/hast" "^2.0.0" +hast-util-parse-selector@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/hast-util-parse-selector/-/hast-util-parse-selector-4.0.0.tgz#352879fa86e25616036037dd8931fb5f34cb4a27" + integrity sha512-wkQCkSYoOGCRKERFWcxMVMOcYE2K1AaNLU8DXS9arxnLOUEWbOXKXiJUNzEpqZ3JOKpnha3jkFrumEjVliDe7A== + dependencies: + "@types/hast" "^3.0.0" + +hast-util-raw@^9.0.0: + version "9.0.4" + resolved "https://registry.yarnpkg.com/hast-util-raw/-/hast-util-raw-9.0.4.tgz#2da03e37c46eb1a6f1391f02f9b84ae65818f7ed" + integrity sha512-LHE65TD2YiNsHD3YuXcKPHXPLuYh/gjp12mOfU8jxSrm1f/yJpsb0F/KKljS6U9LJoP0Ux+tCe8iJ2AsPzTdgA== + dependencies: + "@types/hast" "^3.0.0" + "@types/unist" "^3.0.0" + "@ungap/structured-clone" "^1.0.0" + hast-util-from-parse5 "^8.0.0" + hast-util-to-parse5 "^8.0.0" + html-void-elements "^3.0.0" + mdast-util-to-hast "^13.0.0" + parse5 "^7.0.0" + unist-util-position "^5.0.0" + unist-util-visit "^5.0.0" + vfile "^6.0.0" + web-namespaces "^2.0.0" + zwitch "^2.0.0" + hast-util-to-estree@^2.0.0: version "2.3.3" resolved "https://registry.npmjs.org/hast-util-to-estree/-/hast-util-to-estree-2.3.3.tgz" @@ -4821,6 +5024,19 @@ hast-util-to-estree@^2.0.0: unist-util-position "^4.0.0" zwitch "^2.0.0" +hast-util-to-parse5@^8.0.0: + version "8.0.0" + resolved "https://registry.yarnpkg.com/hast-util-to-parse5/-/hast-util-to-parse5-8.0.0.tgz#477cd42d278d4f036bc2ea58586130f6f39ee6ed" + integrity sha512-3KKrV5ZVI8if87DVSi1vDeByYrkGzg4mEfeu4alwgmmIeARiBLKCZS2uw5Gb6nU9x9Yufyj3iudm6i7nl52PFw== + dependencies: + "@types/hast" "^3.0.0" + comma-separated-tokens "^2.0.0" + devlop "^1.0.0" + property-information "^6.0.0" + space-separated-tokens "^2.0.0" + web-namespaces "^2.0.0" + zwitch "^2.0.0" + hast-util-to-text@^3.1.0: version "3.1.2" resolved "https://registry.npmjs.org/hast-util-to-text/-/hast-util-to-text-3.1.2.tgz" @@ -4858,6 +5074,17 @@ hastscript@^7.0.0: property-information "^6.0.0" space-separated-tokens "^2.0.0" +hastscript@^8.0.0: + version "8.0.0" + resolved "https://registry.yarnpkg.com/hastscript/-/hastscript-8.0.0.tgz#4ef795ec8dee867101b9f23cc830d4baf4fd781a" + integrity sha512-dMOtzCEd3ABUeSIISmrETiKuyydk1w0pa+gE/uormcTpSYuaNJPbX1NU3JLyscSLjwAQM8bWMhhIlnCqnRvDTw== + dependencies: + "@types/hast" "^3.0.0" + comma-separated-tokens "^2.0.0" + hast-util-parse-selector "^4.0.0" + property-information "^6.0.0" + space-separated-tokens "^2.0.0" + heap@^0.2.6: version "0.2.7" resolved "https://registry.npmjs.org/heap/-/heap-0.2.7.tgz" @@ -4899,6 +5126,11 @@ html-parse-stringify@^3.0.1: dependencies: void-elements "3.1.0" +html-void-elements@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/html-void-elements/-/html-void-elements-3.0.0.tgz#fc9dbd84af9e747249034d4d62602def6517f1d7" + integrity sha512-bEqo66MRXsUGxWHV5IP0PUiAWwoEjba4VCzg0LjFJBpchPaTfyfCKTG6bc5F8ucKec3q5y6qOdGyYTSBEvhCrg== + htmlparser2@^8.0.1: version "8.0.2" resolved "https://registry.npmjs.org/htmlparser2/-/htmlparser2-8.0.2.tgz" @@ -4909,6 +5141,11 @@ htmlparser2@^8.0.1: domutils "^3.0.1" entities "^4.4.0" +http-cache-semantics@^4.0.0: + version "4.1.1" + resolved "https://registry.yarnpkg.com/http-cache-semantics/-/http-cache-semantics-4.1.1.tgz#abe02fcb2985460bf0323be664436ec3476a6d5a" + integrity sha512-er295DKPVsV82j5kw1Gjt+ADA/XYHsajl82cGNQG2eyoPkvgUhX+nDIyelzhIWbbsXP39EHcI6l5tYs2FYqYXQ== + http-proxy-agent@^5.0.0: version "5.0.0" resolved "https://registry.npmjs.org/http-proxy-agent/-/http-proxy-agent-5.0.0.tgz#5129800203520d434f142bc78ff3c170800f2b43" @@ -4918,6 +5155,14 @@ http-proxy-agent@^5.0.0: agent-base "6" debug "4" +http2-wrapper@^1.0.0-beta.5.2: + version "1.0.3" + resolved "https://registry.yarnpkg.com/http2-wrapper/-/http2-wrapper-1.0.3.tgz#b8f55e0c1f25d4ebd08b3b0c2c079f9590800b3d" + integrity sha512-V+23sDMr12Wnz7iTcDeJr3O6AIxlnvT/bmaAAAP/Xda35C90p9599p0F1eHR/N1KILWSoWVAiOMFjBBXaXSMxg== + dependencies: + quick-lru "^5.1.1" + resolve-alpn "^1.0.0" + https-proxy-agent@^5.0.1: version "5.0.1" resolved "https://registry.npmjs.org/https-proxy-agent/-/https-proxy-agent-5.0.1.tgz#c59ef224a04fe8b754f3db0063a25ea30d0005d6" @@ -5905,6 +6150,11 @@ jsesc@~0.5.0: resolved "https://registry.npmjs.org/jsesc/-/jsesc-0.5.0.tgz" integrity sha512-uZz5UnB7u4T9LvwmFqXii7pZSouaRPorGs5who1Ip7VO0wxanFvBL7GkM6dTHlgX+jhBApRetaWpnDabOeTcnA== +json-buffer@3.0.1: + version "3.0.1" + resolved "https://registry.yarnpkg.com/json-buffer/-/json-buffer-3.0.1.tgz#9338802a30d3b6605fbe0613e094008ca8c05a13" + integrity sha512-4bV5BfR2mqfQTJm+V5tPPdf+ZpuhiIvTuAB5g8kcrXOZpTT/QwwVRWBywX1ozr6lEuPdbHxwaJlm9G6mI2sfSQ== + json-parse-even-better-errors@^2.3.0: version "2.3.1" resolved "https://registry.npmjs.org/json-parse-even-better-errors/-/json-parse-even-better-errors-2.3.1.tgz" @@ -5957,6 +6207,13 @@ katex@^0.16.0, katex@^0.16.10: dependencies: commander "^8.3.0" +keyv@^4.0.0: + version "4.5.4" + resolved "https://registry.yarnpkg.com/keyv/-/keyv-4.5.4.tgz#a879a99e29452f942439f2a405e3af8b31d4de93" + integrity sha512-oxVHkHR/EJf2CNXnWxRLW6mg7JyCCUcG0DtEGmL2ctUo1PNTin1PUil+r/+4r5MpVgC/fn1kjsx7mjSujKqIpw== + dependencies: + json-buffer "3.0.1" + khroma@^2.0.0: version "2.0.0" resolved "https://registry.npmjs.org/khroma/-/khroma-2.0.0.tgz" @@ -6128,6 +6385,11 @@ loose-envify@^1.1.0, loose-envify@^1.4.0: dependencies: js-tokens "^3.0.0 || ^4.0.0" +lowercase-keys@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/lowercase-keys/-/lowercase-keys-2.0.0.tgz#2603e78b7b4b0006cbca2fbcc8a3202558ac9479" + integrity sha512-tqNXrS78oMOE73NMxK4EMLQsQowWf8jKooH9g7xPavRT706R6bkQJ6DY2Te7QukaZsulxa30wQ7bk0pm4XiHmA== + lowlight@^1.17.0: version "1.20.0" resolved "https://registry.npmjs.org/lowlight/-/lowlight-1.20.0.tgz" @@ -6160,6 +6422,15 @@ lz-string@^1.5.0: resolved "https://registry.npmjs.org/lz-string/-/lz-string-1.5.0.tgz#c1ab50f77887b712621201ba9fd4e3a6ed099941" integrity sha512-h5bgJWpxJNswbU7qCrV0tIKQCaS3blPDrqKWx+QxzuzL1zGUzij9XCWLrSLsJPu5t+eWA/ycetzYAO5IOMcWAQ== +magicast@^0.3.4: + version "0.3.5" + resolved "https://registry.yarnpkg.com/magicast/-/magicast-0.3.5.tgz#8301c3c7d66704a0771eb1bad74274f0ec036739" + integrity sha512-L0WhttDl+2BOsybvEOLK7fW3UA0OQ0IQ2d6Zl2x/a6vVRs3bAY0ECOSHHeL5jD+SbOpOCUEi0y1DgHEn9Qn1AQ== + dependencies: + "@babel/parser" "^7.25.4" + "@babel/types" "^7.25.4" + source-map-js "^1.2.0" + make-dir@^4.0.0: version "4.0.0" resolved "https://registry.npmjs.org/make-dir/-/make-dir-4.0.0.tgz#c3c2307a771277cd9638305f915c29ae741b614e" @@ -6385,6 +6656,21 @@ mdast-util-to-hast@^12.1.0: unist-util-position "^4.0.0" unist-util-visit "^4.0.0" +mdast-util-to-hast@^13.0.0: + version "13.2.0" + resolved "https://registry.yarnpkg.com/mdast-util-to-hast/-/mdast-util-to-hast-13.2.0.tgz#5ca58e5b921cc0a3ded1bc02eed79a4fe4fe41f4" + integrity sha512-QGYKEuUsYT9ykKBCMOEDLsU5JRObWQusAolFMeko/tYPufNkRffBAQjIE+99jbA87xv6FgmjLtwjh9wBWajwAA== + dependencies: + "@types/hast" "^3.0.0" + "@types/mdast" "^4.0.0" + "@ungap/structured-clone" "^1.0.0" + devlop "^1.0.0" + micromark-util-sanitize-uri "^2.0.0" + trim-lines "^3.0.0" + unist-util-position "^5.0.0" + unist-util-visit "^5.0.0" + vfile "^6.0.0" + mdast-util-to-markdown@^1.0.0, mdast-util-to-markdown@^1.3.0: version "1.5.0" resolved "https://registry.npmjs.org/mdast-util-to-markdown/-/mdast-util-to-markdown-1.5.0.tgz" @@ -6701,6 +6987,14 @@ micromark-util-character@^1.0.0: micromark-util-symbol "^1.0.0" micromark-util-types "^1.0.0" +micromark-util-character@^2.0.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/micromark-util-character/-/micromark-util-character-2.1.0.tgz#31320ace16b4644316f6bf057531689c71e2aee1" + integrity sha512-KvOVV+X1yLBfs9dCBSopq/+G1PcgT3lAK07mC4BzXi5E7ahzMAF8oIupDDJ6mievI6F+lAATkbQQlQixJfT3aQ== + dependencies: + micromark-util-symbol "^2.0.0" + micromark-util-types "^2.0.0" + micromark-util-chunked@^1.0.0: version "1.1.0" resolved "https://registry.npmjs.org/micromark-util-chunked/-/micromark-util-chunked-1.1.0.tgz" @@ -6747,6 +7041,11 @@ micromark-util-encode@^1.0.0: resolved "https://registry.npmjs.org/micromark-util-encode/-/micromark-util-encode-1.1.0.tgz" integrity sha512-EuEzTWSTAj9PA5GOAs992GzNh2dGQO52UvAbtSOMvXTxv3Criqb6IOzJUBCmEqrrXSblJIJBbFFv6zPxpreiJw== +micromark-util-encode@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/micromark-util-encode/-/micromark-util-encode-2.0.0.tgz#0921ac7953dc3f1fd281e3d1932decfdb9382ab1" + integrity sha512-pS+ROfCXAGLWCOc8egcBvT0kf27GoWMqtdarNfDcjb6YLuV5cM3ioG45Ys2qOVqeqSbjaKg72vU+Wby3eddPsA== + micromark-util-events-to-acorn@^1.0.0: version "1.2.3" resolved "https://registry.npmjs.org/micromark-util-events-to-acorn/-/micromark-util-events-to-acorn-1.2.3.tgz" @@ -6789,6 +7088,15 @@ micromark-util-sanitize-uri@^1.0.0, micromark-util-sanitize-uri@^1.1.0: micromark-util-encode "^1.0.0" micromark-util-symbol "^1.0.0" +micromark-util-sanitize-uri@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/micromark-util-sanitize-uri/-/micromark-util-sanitize-uri-2.0.0.tgz#ec8fbf0258e9e6d8f13d9e4770f9be64342673de" + integrity sha512-WhYv5UEcZrbAtlsnPuChHUAsu/iBPOVaEVsntLBIdpibO0ddy8OzavZz3iL2xVvBZOpolujSliP65Kq0/7KIYw== + dependencies: + micromark-util-character "^2.0.0" + micromark-util-encode "^2.0.0" + micromark-util-symbol "^2.0.0" + micromark-util-subtokenize@^1.0.0: version "1.1.0" resolved "https://registry.npmjs.org/micromark-util-subtokenize/-/micromark-util-subtokenize-1.1.0.tgz" @@ -6804,11 +7112,21 @@ micromark-util-symbol@^1.0.0: resolved "https://registry.npmjs.org/micromark-util-symbol/-/micromark-util-symbol-1.1.0.tgz" integrity sha512-uEjpEYY6KMs1g7QfJ2eX1SQEV+ZT4rUD3UcF6l57acZvLNK7PBZL+ty82Z1qhK1/yXIY4bdx04FKMgR0g4IAag== +micromark-util-symbol@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/micromark-util-symbol/-/micromark-util-symbol-2.0.0.tgz#12225c8f95edf8b17254e47080ce0862d5db8044" + integrity sha512-8JZt9ElZ5kyTnO94muPxIGS8oyElRJaiJO8EzV6ZSyGQ1Is8xwl4Q45qU5UOg+bGH4AikWziz0iN4sFLWs8PGw== + micromark-util-types@^1.0.0, micromark-util-types@^1.0.1: version "1.1.0" resolved "https://registry.npmjs.org/micromark-util-types/-/micromark-util-types-1.1.0.tgz" integrity sha512-ukRBgie8TIAcacscVHSiddHjO4k/q3pnedmzMQ4iwDcK0FtFCohKOlFbaOL/mPgfnPsL3C1ZyxJa4sbWrBl3jg== +micromark-util-types@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/micromark-util-types/-/micromark-util-types-2.0.0.tgz#63b4b7ffeb35d3ecf50d1ca20e68fc7caa36d95e" + integrity sha512-oNh6S2WMHWRZrmutsRmDDfkzKtxF+bc2VxLC9dvtrDIRFln627VsFP6fLMgTryGDljgLPjkrzQSDcPrjPyDJ5w== + micromark@^3.0.0: version "3.2.0" resolved "https://registry.npmjs.org/micromark/-/micromark-3.2.0.tgz" @@ -6870,6 +7188,16 @@ mimic-fn@^4.0.0: resolved "https://registry.npmjs.org/mimic-fn/-/mimic-fn-4.0.0.tgz" integrity sha512-vqiC06CuhBTUdZH+RYl8sFrL096vA45Ok5ISO6sE/Mr1jRbGH4Csnhi8f3wKVl7x8mO4Au7Ir9D3Oyv1VYMFJw== +mimic-response@^1.0.0: + version "1.0.1" + resolved "https://registry.yarnpkg.com/mimic-response/-/mimic-response-1.0.1.tgz#4923538878eef42063cb8a3e3b0798781487ab1b" + integrity sha512-j5EctnkH7amfV/q5Hgmoal1g2QHFJRraOtmx0JpIqkxhBhI/lJSl1nMpQ45hVarwNETOoWEimndZ4QK0RHxuxQ== + +mimic-response@^3.1.0: + version "3.1.0" + resolved "https://registry.yarnpkg.com/mimic-response/-/mimic-response-3.1.0.tgz#2d1d59af9c1b129815accc2c46a022a5ce1fa3c9" + integrity sha512-z0yWI+4FDrrweS8Zmt4Ej5HdJmky15+L2e6Wgn3+iK5fWzb6T3fhNFq2+MeTRb064c6Wr4N/wv0DzQTjNzHNGQ== + min-indent@^1.0.0: version "1.0.1" resolved "https://registry.npmjs.org/min-indent/-/min-indent-1.0.1.tgz" @@ -7010,6 +7338,11 @@ normalize-range@^0.1.2: resolved "https://registry.npmjs.org/normalize-range/-/normalize-range-0.1.2.tgz" integrity sha512-bdok/XvKII3nUpklnV6P2hxtMNrCboOjAcyBuQnWEhO665FwrSNRxU+AqpsyvO6LgGYPspN+lu5CLtw4jPRKNA== +normalize-url@^6.0.1: + version "6.1.0" + resolved "https://registry.yarnpkg.com/normalize-url/-/normalize-url-6.1.0.tgz#40d0885b535deffe3f3147bec877d05fe4c5668a" + integrity sha512-DlL+XwOy3NxAQ8xuC0okPgK46iuVNAK01YN7RueYBqqFeGsBjV9XmCAzAdgt+667bCl5kPh9EqKKDwnaPG1I7A== + normalize-wheel@^1.0.1: version "1.0.1" resolved "https://registry.yarnpkg.com/normalize-wheel/-/normalize-wheel-1.0.1.tgz#aec886affdb045070d856447df62ecf86146ec45" @@ -7129,9 +7462,9 @@ object.values@^1.1.6, object.values@^1.1.7: define-properties "^1.2.0" es-abstract "^1.22.1" -once@^1.3.0: +once@^1.3.0, once@^1.3.1, once@^1.4.0: version "1.4.0" - resolved "https://registry.npmjs.org/once/-/once-1.4.0.tgz" + resolved "https://registry.yarnpkg.com/once/-/once-1.4.0.tgz#583b1aa775961d4b113ac17d9c50baef9dd76bd1" integrity sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w== dependencies: wrappy "1" @@ -7172,6 +7505,11 @@ optionator@^0.9.1: type-check "^0.4.0" word-wrap "^1.2.3" +p-cancelable@^2.0.0: + version "2.1.1" + resolved "https://registry.yarnpkg.com/p-cancelable/-/p-cancelable-2.1.1.tgz#aab7fbd416582fa32a3db49859c122487c5ed2cf" + integrity sha512-BZOr3nRQHOntUjTrH8+Lh54smKHoHyur8We1V8DSMVrl5A2malOOwuJRnKRDjSnkoeBh4at6BwEnb5I7Jl31wg== + p-limit@^2.2.0: version "2.3.0" resolved "https://registry.npmjs.org/p-limit/-/p-limit-2.3.0.tgz" @@ -7498,6 +7836,14 @@ psl@^1.1.33: resolved "https://registry.npmjs.org/psl/-/psl-1.9.0.tgz#d0df2a137f00794565fcaf3b2c00cd09f8d5a5a7" integrity sha512-E/ZsdU4HLs/68gYzgGTkMicWTLPdAftJLfJFlLUAAKZGkStNU72sZjT66SnMDVOfOWY/YAoiD7Jxa9iHvngcag== +pump@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/pump/-/pump-3.0.0.tgz#b4a2116815bde2f4e1ea602354e8c75565107a64" + integrity sha512-LwZy+p3SFs1Pytd/jYct4wpv49HiYCqd9Rlc5ZVdk0V+8Yzv6jR5Blk3TRmPL1ft69TxP0IMZGJ+WPFU2BFhww== + dependencies: + end-of-stream "^1.1.0" + once "^1.3.1" + punycode@^2.1.0: version "2.3.0" resolved "https://registry.npmjs.org/punycode/-/punycode-2.3.0.tgz" @@ -7535,6 +7881,11 @@ queue-microtask@^1.2.2: resolved "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz" integrity sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A== +quick-lru@^5.1.1: + version "5.1.1" + resolved "https://registry.yarnpkg.com/quick-lru/-/quick-lru-5.1.1.tgz#366493e6b3e42a3a6885e2e99d18f80fb7a8c932" + integrity sha512-WuyALRjWPDGtt/wzJiadO5AXY+8hZ80hVpe6MyivgraREW751X3SbhRvG3eLKOYN+8VEvqLcf3wdnt44Z4S4SA== + rc-input@~1.3.5: version "1.3.6" resolved "https://registry.npmjs.org/rc-input/-/rc-input-1.3.6.tgz" @@ -7867,6 +8218,15 @@ rehype-katex@^6.0.2: katex "^0.16.0" unist-util-visit "^4.0.0" +rehype-raw@^7.0.0: + version "7.0.0" + resolved "https://registry.yarnpkg.com/rehype-raw/-/rehype-raw-7.0.0.tgz#59d7348fd5dbef3807bbaa1d443efd2dd85ecee4" + integrity sha512-/aE8hCfKlQeA8LmyeyQvQF3eBiLRGNlfBJEvWH7ivp9sBqs7TNqBL5X3v157rM4IFETqDnIOO+z5M/biZbo9Ww== + dependencies: + "@types/hast" "^3.0.0" + hast-util-raw "^9.0.0" + vfile "^6.0.0" + remark-breaks@^3.0.2: version "3.0.3" resolved "https://registry.npmjs.org/remark-breaks/-/remark-breaks-3.0.3.tgz" @@ -7938,6 +8298,11 @@ resize-observer-polyfill@^1.5.1: resolved "https://registry.npmjs.org/resize-observer-polyfill/-/resize-observer-polyfill-1.5.1.tgz" integrity sha512-LwZrotdHOo12nQuZlHEmtuXdqGoOD0OhaxopaNFxWzInpEgaLWoVuAMbTzixuosCx2nEG58ngzW3vxdWoxIgdg== +resolve-alpn@^1.0.0: + version "1.2.1" + resolved "https://registry.yarnpkg.com/resolve-alpn/-/resolve-alpn-1.2.1.tgz#b7adbdac3546aaaec20b45e7d8265927072726f9" + integrity sha512-0a1F4l73/ZFZOakJnQ3FvkJ2+gSTQWz/r2KE5OdDY0TxPm5h4GkqkWWfM47T7HsbnOtcJVEF4epCVy6u7Q3K+g== + resolve-cwd@^3.0.0: version "3.0.0" resolved "https://registry.npmjs.org/resolve-cwd/-/resolve-cwd-3.0.0.tgz#0f0075f1bb2544766cf73ba6a6e2adfebcb13f2d" @@ -7983,6 +8348,13 @@ resolve@^2.0.0-next.4: path-parse "^1.0.7" supports-preserve-symlinks-flag "^1.0.0" +responselike@^2.0.0: + version "2.0.1" + resolved "https://registry.yarnpkg.com/responselike/-/responselike-2.0.1.tgz#9a0bc8fdc252f3fb1cca68b016591059ba1422bc" + integrity sha512-4gl03wn3hj1HP3yzgdI7d3lCkF95F21Pz4BPGvKHinyQzALR5CapwC8yIi0Rh58DEMQ/SguC03wFj2k0M/mHhw== + dependencies: + lowercase-keys "^2.0.0" + restore-cursor@^3.1.0: version "3.1.0" resolved "https://registry.npmjs.org/restore-cursor/-/restore-cursor-3.1.0.tgz" @@ -8382,7 +8754,7 @@ string-length@^4.0.1: string-width@4.2.3, string-width@^4.1.0, string-width@^4.2.0, string-width@^4.2.3, string-width@^5.0.0, string-width@^5.0.1, string-width@^5.1.2: version "4.2.3" - resolved "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010" + resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010" integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g== dependencies: emoji-regex "^8.0.0" @@ -8439,14 +8811,7 @@ stringify-entities@^4.0.0: character-entities-html4 "^2.0.0" character-entities-legacy "^3.0.0" -"strip-ansi-cjs@npm:strip-ansi@^6.0.1": - version "6.0.1" - resolved "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz" - integrity sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A== - dependencies: - ansi-regex "^5.0.1" - -strip-ansi@^6.0.0, strip-ansi@^6.0.1: +"strip-ansi-cjs@npm:strip-ansi@^6.0.1", strip-ansi@^6.0.0, strip-ansi@^6.0.1: version "6.0.1" resolved "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz" integrity sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A== @@ -8760,9 +9125,9 @@ tslib@^1.8.1, tslib@^1.9.3: integrity sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg== tslib@^2.0.1: - version "2.6.3" - resolved "https://registry.yarnpkg.com/tslib/-/tslib-2.6.3.tgz#0438f810ad7a9edcde7a241c3d80db693c8cbfe0" - integrity sha512-xNvxJEOUiWPGhUuUdQgAJPKOOJfGnIyKySOc09XkKsgdUV/3E2zvwZYdejjmRgPCgcym1juLH3226yA7sEFJKQ== + version "2.7.0" + resolved "https://registry.yarnpkg.com/tslib/-/tslib-2.7.0.tgz#d9b40c5c40ab59e8738f297df3087bf1a2690c01" + integrity sha512-gLXCKdN1/j47AiHiOkJN69hJmcbGTHI0ImLmbYLHykhgeN0jVGola9yVjFgzCUklsZQMW55o+dW7IXv3RCXDzA== tslib@^2.1.0, tslib@^2.4.0, tslib@^2.4.1, tslib@^2.5.0: version "2.5.3" @@ -8900,6 +9265,13 @@ unist-util-is@^5.0.0: dependencies: "@types/unist" "^2.0.0" +unist-util-is@^6.0.0: + version "6.0.0" + resolved "https://registry.yarnpkg.com/unist-util-is/-/unist-util-is-6.0.0.tgz#b775956486aff107a9ded971d996c173374be424" + integrity sha512-2qCTHimwdxLfz+YzdGfkqNlH0tLi9xjTnHddPmJwtIG9MGsdbutfTc4P+haPD7l7Cjxf/WZj+we5qfVPvvxfYw== + dependencies: + "@types/unist" "^3.0.0" + unist-util-position-from-estree@^1.0.0, unist-util-position-from-estree@^1.1.0: version "1.1.2" resolved "https://registry.npmjs.org/unist-util-position-from-estree/-/unist-util-position-from-estree-1.1.2.tgz" @@ -8914,6 +9286,13 @@ unist-util-position@^4.0.0: dependencies: "@types/unist" "^2.0.0" +unist-util-position@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/unist-util-position/-/unist-util-position-5.0.0.tgz#678f20ab5ca1207a97d7ea8a388373c9cf896be4" + integrity sha512-fucsC7HjXvkB5R3kTCO7kUjRdrS0BJt3M/FPxmHMBOm8JQi2BsHAHFsy27E0EolP8rp0NzXsJ+jNPyDWvOJZPA== + dependencies: + "@types/unist" "^3.0.0" + unist-util-remove-position@^4.0.0: version "4.0.2" resolved "https://registry.npmjs.org/unist-util-remove-position/-/unist-util-remove-position-4.0.2.tgz" @@ -8936,6 +9315,13 @@ unist-util-stringify-position@^3.0.0: dependencies: "@types/unist" "^2.0.0" +unist-util-stringify-position@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/unist-util-stringify-position/-/unist-util-stringify-position-4.0.0.tgz#449c6e21a880e0855bf5aabadeb3a740314abac2" + integrity sha512-0ASV06AAoKCDkS2+xw5RXJywruurpbC4JZSm7nr7MOt1ojAzvyyaO+UxZf18j8FCF6kmzCZKcAgN/yu2gm2XgQ== + dependencies: + "@types/unist" "^3.0.0" + unist-util-visit-parents@^5.0.0, unist-util-visit-parents@^5.1.1: version "5.1.3" resolved "https://registry.npmjs.org/unist-util-visit-parents/-/unist-util-visit-parents-5.1.3.tgz" @@ -8944,6 +9330,14 @@ unist-util-visit-parents@^5.0.0, unist-util-visit-parents@^5.1.1: "@types/unist" "^2.0.0" unist-util-is "^5.0.0" +unist-util-visit-parents@^6.0.0: + version "6.0.1" + resolved "https://registry.yarnpkg.com/unist-util-visit-parents/-/unist-util-visit-parents-6.0.1.tgz#4d5f85755c3b8f0dc69e21eca5d6d82d22162815" + integrity sha512-L/PqWzfTP9lzzEa6CKs0k2nARxTdZduw3zyh8d2NVBnsyvHjSX4TWse388YrrQKbvI8w20fGjGlhgT96WwKykw== + dependencies: + "@types/unist" "^3.0.0" + unist-util-is "^6.0.0" + unist-util-visit@^4.0.0: version "4.1.2" resolved "https://registry.npmjs.org/unist-util-visit/-/unist-util-visit-4.1.2.tgz" @@ -8953,6 +9347,15 @@ unist-util-visit@^4.0.0: unist-util-is "^5.0.0" unist-util-visit-parents "^5.1.1" +unist-util-visit@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/unist-util-visit/-/unist-util-visit-5.0.0.tgz#a7de1f31f72ffd3519ea71814cccf5fd6a9217d6" + integrity sha512-MR04uvD+07cwl/yhVuVWAtw+3GOR/knlL55Nd/wAdblk27GCVt3lqpTivy/tkJcZoNPzTwS1Y+KMojlLDhoTzg== + dependencies: + "@types/unist" "^3.0.0" + unist-util-is "^6.0.0" + unist-util-visit-parents "^6.0.0" + universalify@^0.2.0: version "0.2.0" resolved "https://registry.npmjs.org/universalify/-/universalify-0.2.0.tgz#6451760566fa857534745ab1dde952d1b1761be0" @@ -9059,6 +9462,14 @@ vfile-location@^4.0.0: "@types/unist" "^2.0.0" vfile "^5.0.0" +vfile-location@^5.0.0: + version "5.0.3" + resolved "https://registry.yarnpkg.com/vfile-location/-/vfile-location-5.0.3.tgz#cb9eacd20f2b6426d19451e0eafa3d0a846225c3" + integrity sha512-5yXvWDEgqeiYiBe1lbxYF7UMAIm/IcopxMHrMQDq3nvKcjPKIhZklUKL+AE7J7uApI4kwe2snsK+eI6UTj9EHg== + dependencies: + "@types/unist" "^3.0.0" + vfile "^6.0.0" + vfile-message@^3.0.0: version "3.1.4" resolved "https://registry.npmjs.org/vfile-message/-/vfile-message-3.1.4.tgz" @@ -9067,6 +9478,14 @@ vfile-message@^3.0.0: "@types/unist" "^2.0.0" unist-util-stringify-position "^3.0.0" +vfile-message@^4.0.0: + version "4.0.2" + resolved "https://registry.yarnpkg.com/vfile-message/-/vfile-message-4.0.2.tgz#c883c9f677c72c166362fd635f21fc165a7d1181" + integrity sha512-jRDZ1IMLttGj41KcZvlrYAaI3CfqpLpfpf+Mfig13viT6NKvRzWZ+lXz0Y5D60w6uJIBAOGq9mSHf0gktF0duw== + dependencies: + "@types/unist" "^3.0.0" + unist-util-stringify-position "^4.0.0" + vfile@^5.0.0: version "5.3.7" resolved "https://registry.npmjs.org/vfile/-/vfile-5.3.7.tgz" @@ -9077,6 +9496,15 @@ vfile@^5.0.0: unist-util-stringify-position "^3.0.0" vfile-message "^3.0.0" +vfile@^6.0.0: + version "6.0.2" + resolved "https://registry.yarnpkg.com/vfile/-/vfile-6.0.2.tgz#ef49548ea3d270097a67011921411130ceae7deb" + integrity sha512-zND7NlS8rJYb/sPqkb13ZvbbUoExdbi4w3SfRrMq6R3FvnLQmmfpajJNITuuYm6AZ5uao9vy4BAos3EXBPf2rg== + dependencies: + "@types/unist" "^3.0.0" + unist-util-stringify-position "^4.0.0" + vfile-message "^4.0.0" + vite-code-inspector-plugin@0.13.0: version "0.13.0" resolved "https://registry.npmjs.org/vite-code-inspector-plugin/-/vite-code-inspector-plugin-0.13.0.tgz" @@ -9220,7 +9648,8 @@ word-wrap@^1.2.3: resolved "https://registry.npmjs.org/word-wrap/-/word-wrap-1.2.5.tgz" integrity sha512-BN22B5eaMMI9UMtjrGd5g5eCYPpCPDUy0FJXbYsaT5zYxjFOckS53SQDE3pWkVoWpHXVb3BrYcEN4Twa55B5cA== -"wrap-ansi-cjs@npm:wrap-ansi@^7.0.0": +"wrap-ansi-cjs@npm:wrap-ansi@^7.0.0", wrap-ansi@^7.0.0: + name wrap-ansi-cjs version "7.0.0" resolved "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-7.0.0.tgz" integrity sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q== @@ -9238,15 +9667,6 @@ wrap-ansi@^6.2.0: string-width "^4.1.0" strip-ansi "^6.0.0" -wrap-ansi@^7.0.0: - version "7.0.0" - resolved "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-7.0.0.tgz" - integrity sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q== - dependencies: - ansi-styles "^4.0.0" - string-width "^4.1.0" - strip-ansi "^6.0.0" - wrap-ansi@^8.1.0: version "8.1.0" resolved "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-8.1.0.tgz" @@ -9373,4 +9793,4 @@ zustand@^4.4.1, zustand@^4.5.2: zwitch@^2.0.0: version "2.0.4" resolved "https://registry.npmjs.org/zwitch/-/zwitch-2.0.4.tgz" - integrity sha512-bXE4cR/kVZhKZX/RjPEflHaKVhUVl85noU3v6b8apfQEc1x4A+zBxjZ4lN8LqGd6WZ3dl98pY4o717VFmoPp+A== \ No newline at end of file + integrity sha512-bXE4cR/kVZhKZX/RjPEflHaKVhUVl85noU3v6b8apfQEc1x4A+zBxjZ4lN8LqGd6WZ3dl98pY4o717VFmoPp+A==