diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8f57cd545e..da2928d189 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -81,7 +81,7 @@ Dify requires the following dependencies to build, make sure they're installed o Dify is composed of a backend and a frontend. Navigate to the backend directory by `cd api/`, then follow the [Backend README](api/README.md) to install it. In a separate terminal, navigate to the frontend directory by `cd web/`, then follow the [Frontend README](web/README.md) to install. -Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/self-host-faq) for a list of common issues and steps to troubleshoot. +Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/install-faq) for a list of common issues and steps to troubleshoot. ### 5. Visit dify in your browser diff --git a/CONTRIBUTING_VI.md b/CONTRIBUTING_VI.md index 80e68a046e..a77239ff38 100644 --- a/CONTRIBUTING_VI.md +++ b/CONTRIBUTING_VI.md @@ -79,7 +79,7 @@ Dify yêu cầu các phụ thuộc sau để build, hãy đảm bảo chúng đ Dify bao gồm một backend và một frontend. Đi đến thư mục backend bằng lệnh `cd api/`, sau đó làm theo hướng dẫn trong [README của Backend](api/README.md) để cài đặt. Trong một terminal khác, đi đến thư mục frontend bằng lệnh `cd web/`, sau đó làm theo hướng dẫn trong [README của Frontend](web/README.md) để cài đặt. -Kiểm tra [FAQ về cài đặt](https://docs.dify.ai/learn-more/faq/self-host-faq) để xem danh sách các vấn đề thường gặp và các bước khắc phục. +Kiểm tra [FAQ về cài đặt](https://docs.dify.ai/learn-more/faq/install-faq) để xem danh sách các vấn đề thường gặp và các bước khắc phục. ### 5. Truy cập Dify trong trình duyệt của bạn diff --git a/api/.env.example b/api/.env.example index d13f4a13ad..d6661932db 100644 --- a/api/.env.example +++ b/api/.env.example @@ -120,7 +120,8 @@ SUPABASE_URL=your-server-url WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* -# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash + +# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm VECTOR_STORE=weaviate # Weaviate configuration @@ -263,6 +264,11 @@ VIKINGDB_SCHEMA=http VIKINGDB_CONNECTION_TIMEOUT=30 VIKINGDB_SOCKET_TIMEOUT=30 +# Lindorm configuration +LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070 +LINDORM_USERNAME=admin +LINDORM_PASSWORD=admin + # OceanBase Vector configuration OCEANBASE_VECTOR_HOST=127.0.0.1 OCEANBASE_VECTOR_PORT=2881 @@ -271,6 +277,7 @@ OCEANBASE_VECTOR_PASSWORD= OCEANBASE_VECTOR_DATABASE=test OCEANBASE_MEMORY_LIMIT=6G + # Upload configuration UPLOAD_FILE_SIZE_LIMIT=15 UPLOAD_FILE_BATCH_LIMIT=5 @@ -320,6 +327,9 @@ SSRF_DEFAULT_MAX_RETRIES=3 BATCH_UPLOAD_LIMIT=10 KEYWORD_DATA_SOURCE_TYPE=database +# Workflow file upload limit +WORKFLOW_FILE_UPLOAD_LIMIT=10 + # CODE EXECUTION CONFIGURATION CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194 CODE_EXECUTION_API_KEY=dify-sandbox diff --git a/api/Dockerfile b/api/Dockerfile index c71317f797..ff34603004 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -55,12 +55,7 @@ RUN apt-get update \ && echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \ && apt-get update \ # For Security - && apt-get install -y --no-install-recommends expat=2.6.3-2 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-6 libsqlite3-0=3.46.1-1 \ - && if [ "$(dpkg --print-architecture)" = "amd64" ]; then \ - apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1+b1; \ - else \ - apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1; \ - fi \ + && apt-get install -y --no-install-recommends expat=2.6.3-2 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-6 libsqlite3-0=3.46.1-1 zlib1g=1:1.3.dfsg+really1.3.1-1+b1 \ # install a chinese font to support the use of tools like matplotlib && apt-get install -y fonts-noto-cjk \ && apt-get autoremove -y \ diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index f88019fbb6..4fedf6cd49 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -269,6 +269,11 @@ class FileUploadConfig(BaseSettings): default=20, ) + WORKFLOW_FILE_UPLOAD_LIMIT: PositiveInt = Field( + description="Maximum number of files allowed in a workflow upload operation", + default=10, + ) + class HttpConfig(BaseSettings): """ diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 4be761747d..57cc805ebf 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -20,6 +20,7 @@ from configs.middleware.vdb.baidu_vector_config import BaiduVectorDBConfig from configs.middleware.vdb.chroma_config import ChromaConfig from configs.middleware.vdb.couchbase_config import CouchbaseConfig from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig +from configs.middleware.vdb.lindorm_config import LindormConfig from configs.middleware.vdb.milvus_config import MilvusConfig from configs.middleware.vdb.myscale_config import MyScaleConfig from configs.middleware.vdb.oceanbase_config import OceanBaseVectorConfig @@ -259,6 +260,7 @@ class MiddlewareConfig( VikingDBConfig, UpstashConfig, TidbOnQdrantConfig, + LindormConfig, OceanBaseVectorConfig, BaiduVectorDBConfig, ): diff --git a/api/configs/middleware/vdb/lindorm_config.py b/api/configs/middleware/vdb/lindorm_config.py new file mode 100644 index 0000000000..0f6c652806 --- /dev/null +++ b/api/configs/middleware/vdb/lindorm_config.py @@ -0,0 +1,23 @@ +from typing import Optional + +from pydantic import Field +from pydantic_settings import BaseSettings + + +class LindormConfig(BaseSettings): + """ + Lindorm configs + """ + + LINDORM_URL: Optional[str] = Field( + description="Lindorm url", + default=None, + ) + LINDORM_USERNAME: Optional[str] = Field( + description="Lindorm user", + default=None, + ) + LINDORM_PASSWORD: Optional[str] = Field( + description="Lindorm password", + default=None, + ) diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py new file mode 100644 index 0000000000..79869916ed --- /dev/null +++ b/api/controllers/common/fields.py @@ -0,0 +1,24 @@ +from flask_restful import fields + +parameters__system_parameters = { + "image_file_size_limit": fields.Integer, + "video_file_size_limit": fields.Integer, + "audio_file_size_limit": fields.Integer, + "file_size_limit": fields.Integer, + "workflow_file_upload_limit": fields.Integer, +} + +parameters_fields = { + "opening_statement": fields.String, + "suggested_questions": fields.Raw, + "suggested_questions_after_answer": fields.Raw, + "speech_to_text": fields.Raw, + "text_to_speech": fields.Raw, + "retriever_resource": fields.Raw, + "annotation_reply": fields.Raw, + "more_like_this": fields.Raw, + "user_input_form": fields.Raw, + "sensitive_word_avoidance": fields.Raw, + "file_upload": fields.Raw, + "system_parameters": fields.Nested(parameters__system_parameters), +} diff --git a/api/controllers/common/helpers.py b/api/controllers/common/helpers.py index ed24b265ef..2bae203712 100644 --- a/api/controllers/common/helpers.py +++ b/api/controllers/common/helpers.py @@ -2,11 +2,15 @@ import mimetypes import os import re import urllib.parse +from collections.abc import Mapping +from typing import Any from uuid import uuid4 import httpx from pydantic import BaseModel +from configs import dify_config + class FileInfo(BaseModel): filename: str @@ -56,3 +60,38 @@ def guess_file_info_from_response(response: httpx.Response): mimetype=mimetype, size=int(response.headers.get("Content-Length", -1)), ) + + +def get_parameters_from_feature_dict(*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]): + return { + "opening_statement": features_dict.get("opening_statement"), + "suggested_questions": features_dict.get("suggested_questions", []), + "suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}), + "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), + "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), + "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), + "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), + "more_like_this": features_dict.get("more_like_this", {"enabled": False}), + "user_input_form": user_input_form, + "sensitive_word_avoidance": features_dict.get( + "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} + ), + "file_upload": features_dict.get( + "file_upload", + { + "image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"], + } + }, + ), + "system_parameters": { + "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, + "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, + "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, + "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT, + "workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT, + }, + } diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 07ef0ce3e5..82163a32ee 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -456,7 +456,7 @@ class DatasetIndexingEstimateApi(Resource): ) except LLMBadRequestError: raise ProviderNotInitializeError( - "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." + "No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider." ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -620,6 +620,7 @@ class DatasetRetrievalSettingApi(Resource): case ( VectorType.MILVUS | VectorType.RELYT + | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT @@ -640,6 +641,7 @@ class DatasetRetrievalSettingApi(Resource): | VectorType.ELASTICSEARCH | VectorType.PGVECTOR | VectorType.TIDB_ON_QDRANT + | VectorType.LINDORM | VectorType.COUCHBASE ): return { @@ -682,6 +684,7 @@ class DatasetRetrievalSettingMockApi(Resource): | VectorType.ELASTICSEARCH | VectorType.COUCHBASE | VectorType.PGVECTOR + | VectorType.LINDORM ): return { "retrieval_method": [ diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 7c7580e3c6..fee52248a6 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,6 +1,7 @@ -from flask_restful import fields, marshal_with +from flask_restful import marshal_with -from configs import dify_config +from controllers.common import fields +from controllers.common import helpers as controller_helpers from controllers.console import api from controllers.console.app.error import AppUnavailableError from controllers.console.explore.wraps import InstalledAppResource @@ -11,43 +12,14 @@ from services.app_service import AppService class AppParameterApi(InstalledAppResource): """Resource for app variables.""" - variable_fields = { - "key": fields.String, - "name": fields.String, - "description": fields.String, - "type": fields.String, - "default": fields.String, - "max_length": fields.Integer, - "options": fields.List(fields.String), - } - - system_parameters_fields = { - "image_file_size_limit": fields.Integer, - "video_file_size_limit": fields.Integer, - "audio_file_size_limit": fields.Integer, - "file_size_limit": fields.Integer, - } - - parameters_fields = { - "opening_statement": fields.String, - "suggested_questions": fields.Raw, - "suggested_questions_after_answer": fields.Raw, - "speech_to_text": fields.Raw, - "text_to_speech": fields.Raw, - "retriever_resource": fields.Raw, - "annotation_reply": fields.Raw, - "more_like_this": fields.Raw, - "user_input_form": fields.Raw, - "sensitive_word_avoidance": fields.Raw, - "file_upload": fields.Raw, - "system_parameters": fields.Nested(system_parameters_fields), - } - - @marshal_with(parameters_fields) + @marshal_with(fields.parameters_fields) def get(self, installed_app: InstalledApp): """Retrieve app parameters.""" app_model = installed_app.app + if app_model is None: + raise AppUnavailableError() + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: @@ -57,43 +29,16 @@ class AppParameterApi(InstalledAppResource): user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config + if app_model_config is None: + raise AppUnavailableError() + features_dict = app_model_config.to_dict() user_input_form = features_dict.get("user_input_form", []) - return { - "opening_statement": features_dict.get("opening_statement"), - "suggested_questions": features_dict.get("suggested_questions", []), - "suggested_questions_after_answer": features_dict.get( - "suggested_questions_after_answer", {"enabled": False} - ), - "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), - "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), - "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), - "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), - "more_like_this": features_dict.get("more_like_this", {"enabled": False}), - "user_input_form": user_input_form, - "sensitive_word_avoidance": features_dict.get( - "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} - ), - "file_upload": features_dict.get( - "file_upload", - { - "image": { - "enabled": False, - "number_limits": 3, - "detail": "high", - "transfer_methods": ["remote_url", "local_file"], - } - }, - ), - "system_parameters": { - "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, - "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, - "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, - "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT, - }, - } + return controller_helpers.get_parameters_from_feature_dict( + features_dict=features_dict, user_input_form=user_input_form + ) class ExploreAppMetaApi(InstalledAppResource): diff --git a/api/controllers/console/files/__init__.py b/api/controllers/console/files/__init__.py index 69ee7eaabd..6c7bd8acfd 100644 --- a/api/controllers/console/files/__init__.py +++ b/api/controllers/console/files/__init__.py @@ -37,6 +37,7 @@ class FileApi(Resource): "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, + "workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT, }, 200 @setup_required diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 910b991de1..ad8dee09f9 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -61,6 +61,19 @@ class ToolBuiltinProviderListToolsApi(Resource): ) +class ToolBuiltinProviderInfoApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider): + user = current_user + + user_id = user.id + tenant_id = user.current_tenant_id + + return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(user_id, tenant_id, provider)) + + class ToolBuiltinProviderDeleteApi(Resource): @setup_required @login_required @@ -604,6 +617,7 @@ api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") # builtin tool provider api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin//tools") +api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin//info") api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin//delete") api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin//update") api.add_resource( diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 9a4cdc26cd..88b13faa52 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,6 +1,7 @@ -from flask_restful import Resource, fields, marshal_with +from flask_restful import Resource, marshal_with -from configs import dify_config +from controllers.common import fields +from controllers.common import helpers as controller_helpers from controllers.service_api import api from controllers.service_api.app.error import AppUnavailableError from controllers.service_api.wraps import validate_app_token @@ -11,40 +12,8 @@ from services.app_service import AppService class AppParameterApi(Resource): """Resource for app variables.""" - variable_fields = { - "key": fields.String, - "name": fields.String, - "description": fields.String, - "type": fields.String, - "default": fields.String, - "max_length": fields.Integer, - "options": fields.List(fields.String), - } - - system_parameters_fields = { - "image_file_size_limit": fields.Integer, - "video_file_size_limit": fields.Integer, - "audio_file_size_limit": fields.Integer, - "file_size_limit": fields.Integer, - } - - parameters_fields = { - "opening_statement": fields.String, - "suggested_questions": fields.Raw, - "suggested_questions_after_answer": fields.Raw, - "speech_to_text": fields.Raw, - "text_to_speech": fields.Raw, - "retriever_resource": fields.Raw, - "annotation_reply": fields.Raw, - "more_like_this": fields.Raw, - "user_input_form": fields.Raw, - "sensitive_word_avoidance": fields.Raw, - "file_upload": fields.Raw, - "system_parameters": fields.Nested(system_parameters_fields), - } - @validate_app_token - @marshal_with(parameters_fields) + @marshal_with(fields.parameters_fields) def get(self, app_model: App): """Retrieve app parameters.""" if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: @@ -56,43 +25,16 @@ class AppParameterApi(Resource): user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config + if app_model_config is None: + raise AppUnavailableError() + features_dict = app_model_config.to_dict() user_input_form = features_dict.get("user_input_form", []) - return { - "opening_statement": features_dict.get("opening_statement"), - "suggested_questions": features_dict.get("suggested_questions", []), - "suggested_questions_after_answer": features_dict.get( - "suggested_questions_after_answer", {"enabled": False} - ), - "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), - "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), - "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), - "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), - "more_like_this": features_dict.get("more_like_this", {"enabled": False}), - "user_input_form": user_input_form, - "sensitive_word_avoidance": features_dict.get( - "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} - ), - "file_upload": features_dict.get( - "file_upload", - { - "image": { - "enabled": False, - "number_limits": 3, - "detail": "high", - "transfer_methods": ["remote_url", "local_file"], - } - }, - ), - "system_parameters": { - "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, - "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, - "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, - "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT, - }, - } + return controller_helpers.get_parameters_from_feature_dict( + features_dict=features_dict, user_input_form=user_input_form + ) class AppMetaApi(Resource): diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 974d2cff94..cc8255ccf4 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,6 +1,7 @@ -from flask_restful import fields, marshal_with +from flask_restful import marshal_with -from configs import dify_config +from controllers.common import fields +from controllers.common import helpers as controller_helpers from controllers.web import api from controllers.web.error import AppUnavailableError from controllers.web.wraps import WebApiResource @@ -11,39 +12,7 @@ from services.app_service import AppService class AppParameterApi(WebApiResource): """Resource for app variables.""" - variable_fields = { - "key": fields.String, - "name": fields.String, - "description": fields.String, - "type": fields.String, - "default": fields.String, - "max_length": fields.Integer, - "options": fields.List(fields.String), - } - - system_parameters_fields = { - "image_file_size_limit": fields.Integer, - "video_file_size_limit": fields.Integer, - "audio_file_size_limit": fields.Integer, - "file_size_limit": fields.Integer, - } - - parameters_fields = { - "opening_statement": fields.String, - "suggested_questions": fields.Raw, - "suggested_questions_after_answer": fields.Raw, - "speech_to_text": fields.Raw, - "text_to_speech": fields.Raw, - "retriever_resource": fields.Raw, - "annotation_reply": fields.Raw, - "more_like_this": fields.Raw, - "user_input_form": fields.Raw, - "sensitive_word_avoidance": fields.Raw, - "file_upload": fields.Raw, - "system_parameters": fields.Nested(system_parameters_fields), - } - - @marshal_with(parameters_fields) + @marshal_with(fields.parameters_fields) def get(self, app_model: App, end_user): """Retrieve app parameters.""" if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: @@ -55,43 +24,16 @@ class AppParameterApi(WebApiResource): user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config + if app_model_config is None: + raise AppUnavailableError() + features_dict = app_model_config.to_dict() user_input_form = features_dict.get("user_input_form", []) - return { - "opening_statement": features_dict.get("opening_statement"), - "suggested_questions": features_dict.get("suggested_questions", []), - "suggested_questions_after_answer": features_dict.get( - "suggested_questions_after_answer", {"enabled": False} - ), - "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), - "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), - "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), - "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), - "more_like_this": features_dict.get("more_like_this", {"enabled": False}), - "user_input_form": user_input_form, - "sensitive_word_avoidance": features_dict.get( - "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} - ), - "file_upload": features_dict.get( - "file_upload", - { - "image": { - "enabled": False, - "number_limits": 3, - "detail": "high", - "transfer_methods": ["remote_url", "local_file"], - } - }, - ), - "system_parameters": { - "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, - "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, - "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, - "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT, - }, - } + return controller_helpers.get_parameters_from_feature_dict( + features_dict=features_dict, user_input_form=user_input_form + ) class AppMeta(WebApiResource): diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index cb529340af..0b8a586d0c 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -1,6 +1,5 @@ import urllib.parse -from flask_login import current_user from flask_restful import marshal_with, reqparse from controllers.common import helpers @@ -27,7 +26,7 @@ class RemoteFileInfoApi(WebApiResource): class RemoteFileUploadApi(WebApiResource): @marshal_with(file_fields_with_signed_url) - def post(self): + def post(self, app_model, end_user): # Add app_model and end_user parameters parser = reqparse.RequestParser() parser.add_argument("url", type=str, required=True, help="URL is required") args = parser.parse_args() @@ -51,7 +50,7 @@ class RemoteFileUploadApi(WebApiResource): filename=file_info.filename, content=content, mimetype=file_info.mimetype, - user=current_user, + user=end_user, # Use end_user instead of current_user source_url=url, ) except Exception as e: diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 42beec2535..d0f75d0b75 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -1,8 +1,7 @@ from collections.abc import Mapping from typing import Any -from core.file.models import FileExtraConfig -from models import FileUploadConfig +from core.file import FileExtraConfig class FileUploadConfigManager: @@ -43,6 +42,6 @@ class FileUploadConfigManager: if not config.get("file_upload"): config["file_upload"] = {} else: - FileUploadConfig.model_validate(config["file_upload"]) + FileExtraConfig.model_validate(config["file_upload"]) return config, ["file_upload"] diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index e4cb3f8527..1fc7ffe2c7 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -20,6 +20,7 @@ from core.app.entities.queue_entities import ( QueueIterationStartEvent, QueueMessageReplaceEvent, QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -314,7 +315,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc if response: yield response - elif isinstance(event, QueueNodeFailedEvent): + elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent): workflow_node_execution = self._handle_workflow_node_execution_failed(event) response = self._workflow_node_finish_to_stream_response( diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 993f8e904d..97615f0472 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -23,7 +23,10 @@ class BaseAppGenerator: user_inputs = user_inputs or {} # Filter input variables from form configuration, handle required fields, default values, and option values variables = app_config.variables - user_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables} + user_inputs = { + var.variable: self._validate_inputs(value=user_inputs.get(var.variable), variable_entity=var) + for var in variables + } user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()} # Convert files in inputs to File entity_dictionary = {item.variable: item for item in app_config.variables} @@ -75,50 +78,66 @@ class BaseAppGenerator: return user_inputs - def _validate_input(self, *, inputs: Mapping[str, Any], var: "VariableEntity"): - user_input_value = inputs.get(var.variable) - if not user_input_value: - if var.required: - raise ValueError(f"{var.variable} is required in input form") - else: - return None + def _validate_inputs( + self, + *, + variable_entity: "VariableEntity", + value: Any, + ): + if value is None: + if variable_entity.required: + raise ValueError(f"{variable_entity.variable} is required in input form") + return value - if var.type in { + if variable_entity.type in { VariableEntityType.TEXT_INPUT, VariableEntityType.SELECT, VariableEntityType.PARAGRAPH, - } and not isinstance(user_input_value, str): - raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string") - if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str): + } and not isinstance(value, str): + raise ValueError( + f"(type '{variable_entity.type}') {variable_entity.variable} in input form must be a string" + ) + + if variable_entity.type == VariableEntityType.NUMBER and isinstance(value, str): # may raise ValueError if user_input_value is not a valid number try: - if "." in user_input_value: - return float(user_input_value) + if "." in value: + return float(value) else: - return int(user_input_value) + return int(value) except ValueError: - raise ValueError(f"{var.variable} in input form must be a valid number") - if var.type == VariableEntityType.SELECT: - options = var.options - if user_input_value not in options: - raise ValueError(f"{var.variable} in input form must be one of the following: {options}") - elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}: - if var.max_length and len(user_input_value) > var.max_length: - raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters") - elif var.type == VariableEntityType.FILE: - if not isinstance(user_input_value, dict) and not isinstance(user_input_value, File): - raise ValueError(f"{var.variable} in input form must be a file") - elif var.type == VariableEntityType.FILE_LIST: - if not ( - isinstance(user_input_value, list) - and ( - all(isinstance(item, dict) for item in user_input_value) - or all(isinstance(item, File) for item in user_input_value) - ) - ): - raise ValueError(f"{var.variable} in input form must be a list of files") + raise ValueError(f"{variable_entity.variable} in input form must be a valid number") - return user_input_value + match variable_entity.type: + case VariableEntityType.SELECT: + if value not in variable_entity.options: + raise ValueError( + f"{variable_entity.variable} in input form must be one of the following: " + f"{variable_entity.options}" + ) + case VariableEntityType.TEXT_INPUT | VariableEntityType.PARAGRAPH: + if variable_entity.max_length and len(value) > variable_entity.max_length: + raise ValueError( + f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} " + "characters" + ) + case VariableEntityType.FILE: + if not isinstance(value, dict) and not isinstance(value, File): + raise ValueError(f"{variable_entity.variable} in input form must be a file") + case VariableEntityType.FILE_LIST: + # if number of files exceeds the limit, raise ValueError + if not ( + isinstance(value, list) + and (all(isinstance(item, dict) for item in value) or all(isinstance(item, File) for item in value)) + ): + raise ValueError(f"{variable_entity.variable} in input form must be a list of files") + + if variable_entity.max_length and len(value) > variable_entity.max_length: + raise ValueError( + f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} files" + ) + + return value def _sanitize_value(self, value: Any) -> Any: if isinstance(value, str): diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 419a5da806..d119d94a61 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -16,6 +16,7 @@ from core.app.entities.queue_entities import ( QueueIterationNextEvent, QueueIterationStartEvent, QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -275,7 +276,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa if response: yield response - elif isinstance(event, QueueNodeFailedEvent): + elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent): workflow_node_execution = self._handle_workflow_node_execution_failed(event) response = self._workflow_node_finish_to_stream_response( diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index ca23bbdd47..9a01e8a253 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -9,6 +9,7 @@ from core.app.entities.queue_entities import ( QueueIterationNextEvent, QueueIterationStartEvent, QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -30,6 +31,7 @@ from core.workflow.graph_engine.entities.event import ( IterationRunNextEvent, IterationRunStartedEvent, IterationRunSucceededEvent, + NodeInIterationFailedEvent, NodeRunFailedEvent, NodeRunRetrieverResourceEvent, NodeRunStartedEvent, @@ -193,6 +195,7 @@ class WorkflowBasedAppRunner(AppRunner): node_run_index=event.route_node_state.index, predecessor_node_id=event.predecessor_node_id, in_iteration_id=event.in_iteration_id, + parallel_mode_run_id=event.parallel_mode_run_id, ) ) elif isinstance(event, NodeRunSucceededEvent): @@ -246,9 +249,40 @@ class WorkflowBasedAppRunner(AppRunner): error=event.route_node_state.node_run_result.error if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error else "Unknown error", + execution_metadata=event.route_node_state.node_run_result.metadata + if event.route_node_state.node_run_result + else {}, in_iteration_id=event.in_iteration_id, ) ) + elif isinstance(event, NodeInIterationFailedEvent): + self._publish_event( + QueueNodeInIterationFailedEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.route_node_state.start_at, + inputs=event.route_node_state.node_run_result.inputs + if event.route_node_state.node_run_result + else {}, + process_data=event.route_node_state.node_run_result.process_data + if event.route_node_state.node_run_result + else {}, + outputs=event.route_node_state.node_run_result.outputs + if event.route_node_state.node_run_result + else {}, + execution_metadata=event.route_node_state.node_run_result.metadata + if event.route_node_state.node_run_result + else {}, + in_iteration_id=event.in_iteration_id, + error=event.error, + ) + ) elif isinstance(event, NodeRunStreamChunkEvent): self._publish_event( QueueTextChunkEvent( @@ -326,6 +360,7 @@ class WorkflowBasedAppRunner(AppRunner): index=event.index, node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, output=event.pre_iteration_output, + parallel_mode_run_id=event.parallel_mode_run_id, ) ) elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)): diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index bc43baf8a5..f1542ec5d8 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -107,7 +107,8 @@ class QueueIterationNextEvent(AppQueueEvent): """parent parallel id if node is in parallel""" parent_parallel_start_node_id: Optional[str] = None """parent parallel start node id if node is in parallel""" - + parallel_mode_run_id: Optional[str] = None + """iteratoin run in parallel mode run id""" node_run_index: int output: Optional[Any] = None # output for the current iteration @@ -273,6 +274,8 @@ class QueueNodeStartedEvent(AppQueueEvent): in_iteration_id: Optional[str] = None """iteration id if node is in iteration""" start_at: datetime + parallel_mode_run_id: Optional[str] = None + """iteratoin run in parallel mode run id""" class QueueNodeSucceededEvent(AppQueueEvent): @@ -306,6 +309,37 @@ class QueueNodeSucceededEvent(AppQueueEvent): error: Optional[str] = None +class QueueNodeInIterationFailedEvent(AppQueueEvent): + """ + QueueNodeInIterationFailedEvent entity + """ + + event: QueueEvent = QueueEvent.NODE_FAILED + + node_execution_id: str + node_id: str + node_type: NodeType + node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime + + inputs: Optional[dict[str, Any]] = None + process_data: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None + + error: str + + class QueueNodeFailedEvent(AppQueueEvent): """ QueueNodeFailedEvent entity @@ -332,6 +366,7 @@ class QueueNodeFailedEvent(AppQueueEvent): inputs: Optional[dict[str, Any]] = None process_data: Optional[dict[str, Any]] = None outputs: Optional[dict[str, Any]] = None + execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None error: str diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 4b5f4716ed..7e9aad54be 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -244,6 +244,7 @@ class NodeStartStreamResponse(StreamResponse): parent_parallel_id: Optional[str] = None parent_parallel_start_node_id: Optional[str] = None iteration_id: Optional[str] = None + parallel_run_id: Optional[str] = None event: StreamEvent = StreamEvent.NODE_STARTED workflow_run_id: str @@ -432,6 +433,7 @@ class IterationNodeNextStreamResponse(StreamResponse): extras: dict = {} parallel_id: Optional[str] = None parallel_start_node_id: Optional[str] = None + parallel_mode_run_id: Optional[str] = None event: StreamEvent = StreamEvent.ITERATION_NEXT workflow_run_id: str diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 2abee5bef5..b89edf9079 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -12,6 +12,7 @@ from core.app.entities.queue_entities import ( QueueIterationNextEvent, QueueIterationStartEvent, QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -35,6 +36,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.tools.tool_manager import ToolManager +from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.enums import SystemVariableKey from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolNodeData @@ -251,6 +253,12 @@ class WorkflowCycleManage: workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value workflow_node_execution.created_by_role = workflow_run.created_by_role workflow_node_execution.created_by = workflow_run.created_by + workflow_node_execution.execution_metadata = json.dumps( + { + NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, + NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, + } + ) workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) session.add(workflow_node_execution) @@ -305,7 +313,9 @@ class WorkflowCycleManage: return workflow_node_execution - def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution: + def _handle_workflow_node_execution_failed( + self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent + ) -> WorkflowNodeExecution: """ Workflow node execution failed :param event: queue node failed event @@ -318,16 +328,19 @@ class WorkflowCycleManage: outputs = WorkflowEntry.handle_special_values(event.outputs) finished_at = datetime.now(timezone.utc).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() - + execution_metadata = ( + json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None + ) db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( { WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value, WorkflowNodeExecution.error: event.error, WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, - WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None, + WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None, WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None, WorkflowNodeExecution.finished_at: finished_at, WorkflowNodeExecution.elapsed_time: elapsed_time, + WorkflowNodeExecution.execution_metadata: execution_metadata, } ) @@ -342,6 +355,7 @@ class WorkflowCycleManage: workflow_node_execution.outputs = json.dumps(outputs) if outputs else None workflow_node_execution.finished_at = finished_at workflow_node_execution.elapsed_time = elapsed_time + workflow_node_execution.execution_metadata = execution_metadata self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id) @@ -448,6 +462,7 @@ class WorkflowCycleManage: parent_parallel_id=event.parent_parallel_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id, iteration_id=event.in_iteration_id, + parallel_run_id=event.parallel_mode_run_id, ), ) @@ -464,7 +479,7 @@ class WorkflowCycleManage: def _workflow_node_finish_to_stream_response( self, - event: QueueNodeSucceededEvent | QueueNodeFailedEvent, + event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, ) -> Optional[NodeFinishStreamResponse]: @@ -608,6 +623,7 @@ class WorkflowCycleManage: extras={}, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, + parallel_mode_run_id=event.parallel_mode_run_id, ), ) @@ -633,7 +649,9 @@ class WorkflowCycleManage: created_at=int(time.time()), extras={}, inputs=event.inputs or {}, - status=WorkflowNodeExecutionStatus.SUCCEEDED, + status=WorkflowNodeExecutionStatus.SUCCEEDED + if event.error is None + else WorkflowNodeExecutionStatus.FAILED, error=None, elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(), total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, diff --git a/api/core/entities/parameter_entities.py b/api/core/entities/parameter_entities.py index 74d052ad11..990e84867e 100644 --- a/api/core/entities/parameter_entities.py +++ b/api/core/entities/parameter_entities.py @@ -12,7 +12,8 @@ class CommonParameterType(Enum): SYSTEM_FILES = "system-files" BOOLEAN = "boolean" APP_SELECTOR = "app-selector" - MODEL_CONFIG = "model-config" + TOOL_SELECTOR = "tool-selector" + MODEL_SELECTOR = "model-selector" class AppSelectorScope(Enum): @@ -22,7 +23,7 @@ class AppSelectorScope(Enum): COMPLETION = "completion" -class ModelConfigScope(Enum): +class ModelSelectorScope(Enum): LLM = "llm" TEXT_EMBEDDING = "text-embedding" RERANK = "rerank" @@ -30,3 +31,10 @@ class ModelConfigScope(Enum): SPEECH2TEXT = "speech2text" MODERATION = "moderation" VISION = "vision" + + +class ToolSelectorScope(Enum): + ALL = "all" + CUSTOM = "custom" + BUILTIN = "builtin" + WORKFLOW = "workflow" diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 596a841e74..b5de05d2fb 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -3,7 +3,12 @@ from typing import Optional, Union from pydantic import BaseModel, ConfigDict, Field -from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope +from core.entities.parameter_entities import ( + AppSelectorScope, + CommonParameterType, + ModelSelectorScope, + ToolSelectorScope, +) from core.model_runtime.entities.model_entities import ModelType from core.tools.entities.common_entities import I18nObject @@ -140,7 +145,8 @@ class BasicProviderConfig(BaseModel): SELECT = CommonParameterType.SELECT.value BOOLEAN = CommonParameterType.BOOLEAN.value APP_SELECTOR = CommonParameterType.APP_SELECTOR.value - MODEL_CONFIG = CommonParameterType.MODEL_CONFIG.value + MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value + TOOL_SELECTOR = CommonParameterType.TOOL_SELECTOR.value @classmethod def value_of(cls, value: str) -> "ProviderConfig.Type": @@ -168,7 +174,7 @@ class ProviderConfig(BasicProviderConfig): value: str = Field(..., description="The value of the option") label: I18nObject = Field(..., description="The label of the option") - scope: AppSelectorScope | ModelConfigScope | None = None + scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None required: bool = False default: Optional[Union[int, str]] = None options: Optional[list[Option]] = None diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index fb9fe8f210..e2a94073cf 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -598,7 +598,7 @@ class IndexingRunner: rules = DatasetProcessRule.AUTOMATIC_RULES else: rules = json.loads(processing_rule.rules) if processing_rule.rules else {} - document_text = CleanProcessor.clean(text, rules) + document_text = CleanProcessor.clean(text, {"rules": rules}) return document_text diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index bdbafc8ded..a044a948aa 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -10,8 +10,15 @@ from core.model_runtime.entities.model_entities import ( PriceInfo, PriceType, ) -from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity from core.plugin.manager.model import PluginModelManager @@ -31,7 +38,7 @@ class AIModel(BaseModel): model_config = ConfigDict(protected_namespaces=()) @property - def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + def _invoke_error_mapping(self) -> dict[type[Exception], list[type[Exception]]]: """ Map model invoke error to unified error The key is the error type thrown to the caller @@ -40,9 +47,17 @@ class AIModel(BaseModel): :return: Invoke error mapping """ - raise NotImplementedError + return { + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError], + PluginDaemonInnerError: [PluginDaemonInnerError], + ValueError: [ValueError], + } - def _transform_invoke_error(self, error: Exception) -> InvokeError: + def _transform_invoke_error(self, error: Exception) -> Exception: """ Transform invoke error to unified error @@ -52,13 +67,15 @@ class AIModel(BaseModel): for invoke_error, model_errors in self._invoke_error_mapping.items(): if isinstance(error, tuple(model_errors)): if invoke_error == InvokeAuthorizationError: - return invoke_error( + return InvokeAuthorizationError( description=( f"[{self.provider_name}] Incorrect model credentials provided, please check and try again." ) ) - - return invoke_error(description=f"[{self.provider_name}] {invoke_error.description}, {str(error)}") + elif isinstance(invoke_error, InvokeError): + return invoke_error(description=f"[{self.provider_name}] {invoke_error.description}, {str(error)}") + else: + return error return InvokeError(description=f"[{self.provider_name}] Error: {str(error)}") diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-5-haiku-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-5-haiku-v1.yaml new file mode 100644 index 0000000000..35fc8d0d11 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-5-haiku-v1.yaml @@ -0,0 +1,61 @@ +model: anthropic.claude-3-5-haiku-20241022-v1:0 +label: + en_US: Claude 3.5 Haiku +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 8192 + min: 1 + max: 8192 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + # docs: https://docs.anthropic.com/claude/docs/system-prompts + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format +pricing: + input: '0.001' + output: '0.005' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-5-haiku-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-5-haiku-v1.yaml new file mode 100644 index 0000000000..a9b66b1925 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-5-haiku-v1.yaml @@ -0,0 +1,61 @@ +model: us.anthropic.claude-3-5-haiku-20241022-v1:0 +label: + en_US: Claude 3.5 Haiku(US.Cross Region Inference) +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + # docs: https://docs.anthropic.com/claude/docs/system-prompts + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format +pricing: + input: '0.001' + output: '0.005' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.py b/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.py index ca67594ce4..14aa811905 100644 --- a/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.py +++ b/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.py @@ -1,6 +1,7 @@ import logging -from core.model_runtime.entities.model_entities import ModelType +import requests + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.model_provider import ModelProvider @@ -16,8 +17,18 @@ class GiteeAIProvider(ModelProvider): :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. """ try: - model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials(model="Qwen2-7B-Instruct", credentials=credentials) + api_key = credentials.get("api_key") + if not api_key: + raise CredentialsValidateFailedError("Credentials validation failed: api_key not given") + + # send a get request to validate the credentials + headers = {"Authorization": f"Bearer {api_key}"} + response = requests.get("https://ai.gitee.com/api/base/account/me", headers=headers, timeout=(10, 300)) + + if response.status_code != 200: + raise CredentialsValidateFailedError( + f"Credentials validation failed with status code {response.status_code}" + ) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: diff --git a/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.png new file mode 100644 index 0000000000..dfe8e78049 Binary files /dev/null and b/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.png differ diff --git a/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.svg new file mode 100644 index 0000000000..bb23bffcf1 --- /dev/null +++ b/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.svg @@ -0,0 +1,15 @@ + + + + + + + + + + + + + + + diff --git a/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.png b/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.png new file mode 100644 index 0000000000..b154821db9 Binary files /dev/null and b/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.png differ diff --git a/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.svg new file mode 100644 index 0000000000..c5c608cd7c --- /dev/null +++ b/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/api/core/model_runtime/model_providers/gpustack/gpustack.py b/api/core/model_runtime/model_providers/gpustack/gpustack.py new file mode 100644 index 0000000000..321100167e --- /dev/null +++ b/api/core/model_runtime/model_providers/gpustack/gpustack.py @@ -0,0 +1,10 @@ +import logging + +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class GPUStackProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + pass diff --git a/api/core/model_runtime/model_providers/gpustack/gpustack.yaml b/api/core/model_runtime/model_providers/gpustack/gpustack.yaml new file mode 100644 index 0000000000..ee4a3c159a --- /dev/null +++ b/api/core/model_runtime/model_providers/gpustack/gpustack.yaml @@ -0,0 +1,120 @@ +provider: gpustack +label: + en_US: GPUStack +icon_small: + en_US: icon_s_en.png +icon_large: + en_US: icon_l_en.png +supported_model_types: + - llm + - text-embedding + - rerank +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: endpoint_url + label: + zh_Hans: 服务器地址 + en_US: Server URL + type: text-input + required: true + placeholder: + zh_Hans: 输入 GPUStack 的服务器地址,如 http://192.168.1.100 + en_US: Enter the GPUStack server URL, e.g. http://192.168.1.100 + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 输入您的 API Key + en_US: Enter your API Key + - variable: mode + show_on: + - variable: __model_type + value: llm + label: + en_US: Completion mode + type: select + required: false + default: chat + placeholder: + zh_Hans: 选择补全类型 + en_US: Select completion type + options: + - value: completion + label: + en_US: Completion + zh_Hans: 补全 + - value: chat + label: + en_US: Chat + zh_Hans: 对话 + - variable: context_size + label: + zh_Hans: 模型上下文长度 + en_US: Model context size + required: true + type: text-input + default: "8192" + placeholder: + zh_Hans: 输入您的模型上下文长度 + en_US: Enter your Model context size + - variable: max_tokens_to_sample + label: + zh_Hans: 最大 token 上限 + en_US: Upper bound for max tokens + show_on: + - variable: __model_type + value: llm + default: "8192" + type: text-input + - variable: function_calling_type + show_on: + - variable: __model_type + value: llm + label: + en_US: Function calling + type: select + required: false + default: no_call + options: + - value: function_call + label: + en_US: Function Call + zh_Hans: Function Call + - value: tool_call + label: + en_US: Tool Call + zh_Hans: Tool Call + - value: no_call + label: + en_US: Not Support + zh_Hans: 不支持 + - variable: vision_support + show_on: + - variable: __model_type + value: llm + label: + zh_Hans: Vision 支持 + en_US: Vision Support + type: select + required: false + default: no_support + options: + - value: support + label: + en_US: Support + zh_Hans: 支持 + - value: no_support + label: + en_US: Not Support + zh_Hans: 不支持 diff --git a/api/core/model_runtime/model_providers/gpustack/llm/__init__.py b/api/core/model_runtime/model_providers/gpustack/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/gpustack/llm/llm.py b/api/core/model_runtime/model_providers/gpustack/llm/llm.py new file mode 100644 index 0000000000..ce6780b6a7 --- /dev/null +++ b/api/core/model_runtime/model_providers/gpustack/llm/llm.py @@ -0,0 +1,45 @@ +from collections.abc import Generator + +from yarl import URL + +from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.message_entities import ( + PromptMessage, + PromptMessageTool, +) +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import ( + OAIAPICompatLargeLanguageModel, +) + + +class GPUStackLanguageModel(OAIAPICompatLargeLanguageModel): + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: + return super()._invoke( + model, + credentials, + prompt_messages, + model_parameters, + tools, + stop, + stream, + user, + ) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials) + super().validate_credentials(model, credentials) + + @staticmethod + def _add_custom_parameters(credentials: dict) -> None: + credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai") + credentials["mode"] = "chat" diff --git a/api/core/model_runtime/model_providers/gpustack/rerank/__init__.py b/api/core/model_runtime/model_providers/gpustack/rerank/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/gpustack/rerank/rerank.py b/api/core/model_runtime/model_providers/gpustack/rerank/rerank.py new file mode 100644 index 0000000000..5ea7532564 --- /dev/null +++ b/api/core/model_runtime/model_providers/gpustack/rerank/rerank.py @@ -0,0 +1,146 @@ +from json import dumps +from typing import Optional + +import httpx +from requests import post +from yarl import URL + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, +) +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + + +class GPUStackRerankModel(RerankModel): + """ + Model class for GPUStack rerank model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n documents to return + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + endpoint_url = credentials["endpoint_url"] + headers = { + "Authorization": f"Bearer {credentials.get('api_key')}", + "Content-Type": "application/json", + } + + data = {"model": model, "query": query, "documents": docs, "top_n": top_n} + + try: + response = post( + str(URL(endpoint_url) / "v1" / "rerank"), + headers=headers, + data=dumps(data), + timeout=10, + ) + response.raise_for_status() + results = response.json() + + rerank_documents = [] + for result in results["results"]: + index = result["index"] + if "document" in result: + text = result["document"]["text"] + else: + text = docs[index] + + rerank_document = RerankDocument( + index=index, + text=text, + score=result["relevance_score"], + ) + + if score_threshold is None or result["relevance_score"] >= score_threshold: + rerank_documents.append(rerank_document) + + return RerankResult(model=model, docs=rerank_documents) + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + """ + return { + InvokeConnectionError: [httpx.ConnectError], + InvokeServerUnavailableError: [httpx.RemoteProtocolError], + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.RERANK, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, + ) + + return entity diff --git a/api/core/model_runtime/model_providers/gpustack/text_embedding/__init__.py b/api/core/model_runtime/model_providers/gpustack/text_embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/gpustack/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/gpustack/text_embedding/text_embedding.py new file mode 100644 index 0000000000..eb324491a2 --- /dev/null +++ b/api/core/model_runtime/model_providers/gpustack/text_embedding/text_embedding.py @@ -0,0 +1,35 @@ +from typing import Optional + +from yarl import URL + +from core.entities.embedding_type import EmbeddingInputType +from core.model_runtime.entities.text_embedding_entities import ( + TextEmbeddingResult, +) +from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import ( + OAICompatEmbeddingModel, +) + + +class GPUStackTextEmbeddingModel(OAICompatEmbeddingModel): + """ + Model class for GPUStack text embedding model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: + return super()._invoke(model, credentials, texts, user, input_type) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials) + super().validate_credentials(model, credentials) + + @staticmethod + def _add_custom_parameters(credentials: dict) -> None: + credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai") diff --git a/api/core/model_runtime/model_providers/x/__init__.py b/api/core/model_runtime/model_providers/x/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg b/api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg new file mode 100644 index 0000000000..f8b745cb13 --- /dev/null +++ b/api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/x/llm/__init__.py b/api/core/model_runtime/model_providers/x/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml b/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml new file mode 100644 index 0000000000..7c305735b9 --- /dev/null +++ b/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml @@ -0,0 +1,63 @@ +model: grok-beta +label: + en_US: Grok beta +model_type: llm +features: + - multi-tool-call +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 2.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: 0 + max: 2.0 + precision: 1 + required: false + help: + en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim." + zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/x/llm/llm.py b/api/core/model_runtime/model_providers/x/llm/llm.py new file mode 100644 index 0000000000..3f5325a857 --- /dev/null +++ b/api/core/model_runtime/model_providers/x/llm/llm.py @@ -0,0 +1,37 @@ +from collections.abc import Generator +from typing import Optional, Union + +from yarl import URL + +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult +from core.model_runtime.entities.message_entities import ( + PromptMessage, + PromptMessageTool, +) +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel + + +class XAILargeLanguageModel(OAIAPICompatLargeLanguageModel): + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + self._add_custom_parameters(credentials) + return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials) + super().validate_credentials(model, credentials) + + @staticmethod + def _add_custom_parameters(credentials) -> None: + credentials["endpoint_url"] = str(URL(credentials["endpoint_url"])) or "https://api.x.ai/v1" + credentials["mode"] = LLMMode.CHAT.value + credentials["function_calling_type"] = "tool_call" diff --git a/api/core/model_runtime/model_providers/x/x.py b/api/core/model_runtime/model_providers/x/x.py new file mode 100644 index 0000000000..e3f2b8eeba --- /dev/null +++ b/api/core/model_runtime/model_providers/x/x.py @@ -0,0 +1,25 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class XAIProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.LLM) + model_instance.validate_credentials(model="grok-beta", credentials=credentials) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise ex diff --git a/api/core/model_runtime/model_providers/x/x.yaml b/api/core/model_runtime/model_providers/x/x.yaml new file mode 100644 index 0000000000..90d1cbfe7e --- /dev/null +++ b/api/core/model_runtime/model_providers/x/x.yaml @@ -0,0 +1,38 @@ +provider: x +label: + en_US: xAI +description: + en_US: xAI is a company working on building artificial intelligence to accelerate human scientific discovery. We are guided by our mission to advance our collective understanding of the universe. +icon_small: + en_US: x-ai-logo.svg +icon_large: + en_US: x-ai-logo.svg +help: + title: + en_US: Get your token from xAI + zh_Hans: 从 xAI 获取 token + url: + en_US: https://x.ai/api +supported_model_types: + - llm +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: endpoint_url + label: + en_US: API Base + type: text-input + required: false + default: https://api.x.ai/v1 + placeholder: + zh_Hans: 在此输入您的 API Base + en_US: Enter your API Base diff --git a/api/core/plugin/manager/base.py b/api/core/plugin/manager/base.py index b25282cde2..6654c05e2d 100644 --- a/api/core/plugin/manager/base.py +++ b/api/core/plugin/manager/base.py @@ -53,7 +53,7 @@ class BasePluginManager: ) except requests.exceptions.ConnectionError as e: logger.exception(f"Request to Plugin Daemon Service failed: {e}") - raise ValueError("Request to Plugin Daemon Service failed") + raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed") return response @@ -157,8 +157,17 @@ class BasePluginManager: Make a stream request to the plugin daemon inner API and yield the response as a model. """ for line in self._stream_request(method, path, params, headers, data, files): - line_data = json.loads(line) - rep = PluginDaemonBasicResponse[type](**line_data) + line_data = None + try: + line_data = json.loads(line) + rep = PluginDaemonBasicResponse[type](**line_data) + except Exception as e: + # TODO modify this when line_data has code and message + if line_data and "error" in line_data: + raise ValueError(line_data["error"]) + else: + raise ValueError(line) + if rep.code != 0: if rep.code == -500: try: diff --git a/api/core/plugin/manager/model.py b/api/core/plugin/manager/model.py index 2081dcc298..4188148812 100644 --- a/api/core/plugin/manager/model.py +++ b/api/core/plugin/manager/model.py @@ -437,6 +437,7 @@ class PluginModelManager(BasePluginManager): voices = [] for voice in resp.voices: voices.append({"name": voice.name, "value": voice.value}) + return voices return [] diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 57af05861c..277513f3f2 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -103,7 +103,7 @@ class RetrievalService: if exceptions: exception_message = ";\n".join(exceptions) - raise Exception(exception_message) + raise ValueError(exception_message) if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: data_post_processor = DataPostProcessor( diff --git a/api/core/rag/datasource/vdb/lindorm/__init__.py b/api/core/rag/datasource/vdb/lindorm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py new file mode 100644 index 0000000000..abd8261a69 --- /dev/null +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -0,0 +1,498 @@ +import copy +import json +import logging +from collections.abc import Iterable +from typing import Any, Optional + +from opensearchpy import OpenSearch +from opensearchpy.helpers import bulk +from pydantic import BaseModel, model_validator +from tenacity import retry, stop_after_attempt, wait_fixed + +from configs import dify_config +from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings +from core.rag.models.document import Document +from extensions.ext_redis import redis_client +from models.dataset import Dataset + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logging.getLogger("lindorm").setLevel(logging.WARN) + + +class LindormVectorStoreConfig(BaseModel): + hosts: str + username: Optional[str] = None + password: Optional[str] = None + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict) -> dict: + if not values["hosts"]: + raise ValueError("config URL is required") + if not values["username"]: + raise ValueError("config USERNAME is required") + if not values["password"]: + raise ValueError("config PASSWORD is required") + return values + + def to_opensearch_params(self) -> dict[str, Any]: + params = { + "hosts": self.hosts, + } + if self.username and self.password: + params["http_auth"] = (self.username, self.password) + return params + + +class LindormVectorStore(BaseVector): + def __init__(self, collection_name: str, config: LindormVectorStoreConfig, **kwargs): + super().__init__(collection_name.lower()) + self._client_config = config + self._client = OpenSearch(**config.to_opensearch_params()) + self.kwargs = kwargs + + def get_type(self) -> str: + return VectorType.LINDORM + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + self.create_collection(len(embeddings[0]), **kwargs) + self.add_texts(texts, embeddings) + + def refresh(self): + self._client.indices.refresh(index=self._collection_name) + + def __filter_existed_ids( + self, + texts: list[str], + metadatas: list[dict], + ids: list[str], + bulk_size: int = 1024, + ) -> tuple[Iterable[str], Optional[list[dict]], Optional[list[str]]]: + @retry(stop=stop_after_attempt(3), wait=wait_fixed(60)) + def __fetch_existing_ids(batch_ids: list[str]) -> set[str]: + try: + existing_docs = self._client.mget(index=self._collection_name, body={"ids": batch_ids}, _source=False) + return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]} + except Exception as e: + logger.error(f"Error fetching batch {batch_ids}: {e}") + return set() + + @retry(stop=stop_after_attempt(3), wait=wait_fixed(60)) + def __fetch_existing_routing_ids(batch_ids: list[str], route_ids: list[str]) -> set[str]: + try: + existing_docs = self._client.mget( + body={ + "docs": [ + {"_index": self._collection_name, "_id": id, "routing": routing} + for id, routing in zip(batch_ids, route_ids) + ] + }, + _source=False, + ) + return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]} + except Exception as e: + logger.error(f"Error fetching batch {batch_ids}: {e}") + return set() + + if ids is None: + return texts, metadatas, ids + + if len(texts) != len(ids): + raise RuntimeError(f"texts {len(texts)} != {ids}") + + filtered_texts = [] + filtered_metadatas = [] + filtered_ids = [] + + def batch(iterable, n): + length = len(iterable) + for idx in range(0, length, n): + yield iterable[idx : min(idx + n, length)] + + for ids_batch, texts_batch, metadatas_batch in zip( + batch(ids, bulk_size), + batch(texts, bulk_size), + batch(metadatas, bulk_size) if metadatas is not None else batch([None] * len(ids), bulk_size), + ): + existing_ids_set = __fetch_existing_ids(ids_batch) + for text, metadata, doc_id in zip(texts_batch, metadatas_batch, ids_batch): + if doc_id not in existing_ids_set: + filtered_texts.append(text) + filtered_ids.append(doc_id) + if metadatas is not None: + filtered_metadatas.append(metadata) + + return filtered_texts, metadatas if metadatas is None else filtered_metadatas, filtered_ids + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + actions = [] + uuids = self._get_uuids(documents) + for i in range(len(documents)): + action = { + "_op_type": "index", + "_index": self._collection_name.lower(), + "_id": uuids[i], + "_source": { + Field.CONTENT_KEY.value: documents[i].page_content, + Field.VECTOR.value: embeddings[i], # Make sure you pass an array here + Field.METADATA_KEY.value: documents[i].metadata, + }, + } + actions.append(action) + bulk(self._client, actions) + self.refresh() + + def get_ids_by_metadata_field(self, key: str, value: str): + query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}} + response = self._client.search(index=self._collection_name, body=query) + if response["hits"]["hits"]: + return [hit["_id"] for hit in response["hits"]["hits"]] + else: + return None + + def delete_by_metadata_field(self, key: str, value: str): + query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}} + results = self._client.search(index=self._collection_name, body=query_str) + ids = [hit["_id"] for hit in results["hits"]["hits"]] + if ids: + self.delete_by_ids(ids) + + def delete_by_ids(self, ids: list[str]) -> None: + for id in ids: + if self._client.exists(index=self._collection_name, id=id): + self._client.delete(index=self._collection_name, id=id) + else: + logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.") + + def delete(self) -> None: + try: + if self._client.indices.exists(index=self._collection_name): + self._client.indices.delete(index=self._collection_name, params={"timeout": 60}) + logger.info("Delete index success") + else: + logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.") + except Exception as e: + logger.error(f"Error occurred while deleting the index: {e}") + raise e + + def text_exists(self, id: str) -> bool: + try: + self._client.get(index=self._collection_name, id=id) + return True + except: + return False + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + # Make sure query_vector is a list + if not isinstance(query_vector, list): + raise ValueError("query_vector should be a list of floats") + + # Check whether query_vector is a floating-point number list + if not all(isinstance(x, float) for x in query_vector): + raise ValueError("All elements in query_vector should be floats") + + top_k = kwargs.get("top_k", 10) + query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs) + try: + response = self._client.search(index=self._collection_name, body=query) + except Exception as e: + logger.error(f"Error executing search: {e}") + raise + + docs_and_scores = [] + for hit in response["hits"]["hits"]: + docs_and_scores.append( + ( + Document( + page_content=hit["_source"][Field.CONTENT_KEY.value], + vector=hit["_source"][Field.VECTOR.value], + metadata=hit["_source"][Field.METADATA_KEY.value], + ), + hit["_score"], + ) + ) + docs = [] + for doc, score in docs_and_scores: + score_threshold = kwargs.get("score_threshold", 0.0) or 0.0 + if score > score_threshold: + doc.metadata["score"] = score + docs.append(doc) + + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + must = kwargs.get("must") + must_not = kwargs.get("must_not") + should = kwargs.get("should") + minimum_should_match = kwargs.get("minimum_should_match", 0) + top_k = kwargs.get("top_k", 10) + filters = kwargs.get("filter") + routing = kwargs.get("routing") + full_text_query = default_text_search_query( + query_text=query, + k=top_k, + text_field=Field.CONTENT_KEY.value, + must=must, + must_not=must_not, + should=should, + minimum_should_match=minimum_should_match, + filters=filters, + routing=routing, + ) + response = self._client.search(index=self._collection_name, body=full_text_query) + docs = [] + for hit in response["hits"]["hits"]: + docs.append( + Document( + page_content=hit["_source"][Field.CONTENT_KEY.value], + vector=hit["_source"][Field.VECTOR.value], + metadata=hit["_source"][Field.METADATA_KEY.value], + ) + ) + + return docs + + def create_collection(self, dimension: int, **kwargs): + lock_name = f"vector_indexing_lock_{self._collection_name}" + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" + if redis_client.get(collection_exist_cache_key): + logger.info(f"Collection {self._collection_name} already exists.") + return + if self._client.indices.exists(index=self._collection_name): + logger.info("{self._collection_name.lower()} already exists.") + return + if len(self.kwargs) == 0 and len(kwargs) != 0: + self.kwargs = copy.deepcopy(kwargs) + vector_field = kwargs.pop("vector_field", Field.VECTOR.value) + shards = kwargs.pop("shards", 2) + + engine = kwargs.pop("engine", "lvector") + method_name = kwargs.pop("method_name", "hnsw") + data_type = kwargs.pop("data_type", "float") + space_type = kwargs.pop("space_type", "cosinesimil") + + hnsw_m = kwargs.pop("hnsw_m", 24) + hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500) + ivfpq_m = kwargs.pop("ivfpq_m", dimension) + nlist = kwargs.pop("nlist", 1000) + centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", True if nlist >= 5000 else False) + centroids_hnsw_m = kwargs.pop("centroids_hnsw_m", 24) + centroids_hnsw_ef_construct = kwargs.pop("centroids_hnsw_ef_construct", 500) + centroids_hnsw_ef_search = kwargs.pop("centroids_hnsw_ef_search", 100) + mapping = default_text_mapping( + dimension, + method_name, + shards=shards, + engine=engine, + data_type=data_type, + space_type=space_type, + vector_field=vector_field, + hnsw_m=hnsw_m, + hnsw_ef_construction=hnsw_ef_construction, + nlist=nlist, + ivfpq_m=ivfpq_m, + centroids_use_hnsw=centroids_use_hnsw, + centroids_hnsw_m=centroids_hnsw_m, + centroids_hnsw_ef_construct=centroids_hnsw_ef_construct, + centroids_hnsw_ef_search=centroids_hnsw_ef_search, + **kwargs, + ) + self._client.indices.create(index=self._collection_name.lower(), body=mapping) + redis_client.set(collection_exist_cache_key, 1, ex=3600) + # logger.info(f"create index success: {self._collection_name}") + + +def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dict: + routing_field = kwargs.get("routing_field") + excludes_from_source = kwargs.get("excludes_from_source") + analyzer = kwargs.get("analyzer", "ik_max_word") + text_field = kwargs.get("text_field", Field.CONTENT_KEY.value) + engine = kwargs["engine"] + shard = kwargs["shards"] + space_type = kwargs["space_type"] + data_type = kwargs["data_type"] + vector_field = kwargs.get("vector_field", Field.VECTOR.value) + + if method_name == "ivfpq": + ivfpq_m = kwargs["ivfpq_m"] + nlist = kwargs["nlist"] + centroids_use_hnsw = True if nlist > 10000 else False + centroids_hnsw_m = 24 + centroids_hnsw_ef_construct = 500 + centroids_hnsw_ef_search = 100 + parameters = { + "m": ivfpq_m, + "nlist": nlist, + "centroids_use_hnsw": centroids_use_hnsw, + "centroids_hnsw_m": centroids_hnsw_m, + "centroids_hnsw_ef_construct": centroids_hnsw_ef_construct, + "centroids_hnsw_ef_search": centroids_hnsw_ef_search, + } + elif method_name == "hnsw": + neighbor = kwargs["hnsw_m"] + ef_construction = kwargs["hnsw_ef_construction"] + parameters = {"m": neighbor, "ef_construction": ef_construction} + elif method_name == "flat": + parameters = {} + else: + raise RuntimeError(f"unexpected method_name: {method_name}") + + mapping = { + "settings": {"index": {"number_of_shards": shard, "knn": True}}, + "mappings": { + "properties": { + vector_field: { + "type": "knn_vector", + "dimension": dimension, + "data_type": data_type, + "method": { + "engine": engine, + "name": method_name, + "space_type": space_type, + "parameters": parameters, + }, + }, + text_field: {"type": "text", "analyzer": analyzer}, + } + }, + } + + if excludes_from_source: + mapping["mappings"]["_source"] = {"excludes": excludes_from_source} # e.g. {"excludes": ["vector_field"]} + + if method_name == "ivfpq" and routing_field is not None: + mapping["settings"]["index"]["knn_routing"] = True + mapping["settings"]["index"]["knn.offline.construction"] = True + + if method_name == "flat" and routing_field is not None: + mapping["settings"]["index"]["knn_routing"] = True + + return mapping + + +def default_text_search_query( + query_text: str, + k: int = 4, + text_field: str = Field.CONTENT_KEY.value, + must: Optional[list[dict]] = None, + must_not: Optional[list[dict]] = None, + should: Optional[list[dict]] = None, + minimum_should_match: int = 0, + filters: Optional[list[dict]] = None, + routing: Optional[str] = None, + **kwargs, +) -> dict: + if routing is not None: + routing_field = kwargs.get("routing_field", "routing_field") + query_clause = { + "bool": { + "must": [{"match": {text_field: query_text}}, {"term": {f"metadata.{routing_field}.keyword": routing}}] + } + } + else: + query_clause = {"match": {text_field: query_text}} + # build the simplest search_query when only query_text is specified + if not must and not must_not and not should and not filters: + search_query = {"size": k, "query": query_clause} + return search_query + + # build complex search_query when either of must/must_not/should/filter is specified + if must: + if not isinstance(must, list): + raise RuntimeError(f"unexpected [must] clause with {type(filters)}") + if query_clause not in must: + must.append(query_clause) + else: + must = [query_clause] + + boolean_query = {"must": must} + + if must_not: + if not isinstance(must_not, list): + raise RuntimeError(f"unexpected [must_not] clause with {type(filters)}") + boolean_query["must_not"] = must_not + + if should: + if not isinstance(should, list): + raise RuntimeError(f"unexpected [should] clause with {type(filters)}") + boolean_query["should"] = should + if minimum_should_match != 0: + boolean_query["minimum_should_match"] = minimum_should_match + + if filters: + if not isinstance(filters, list): + raise RuntimeError(f"unexpected [filter] clause with {type(filters)}") + boolean_query["filter"] = filters + + search_query = {"size": k, "query": {"bool": boolean_query}} + return search_query + + +def default_vector_search_query( + query_vector: list[float], + k: int = 4, + min_score: str = "0.0", + ef_search: Optional[str] = None, # only for hnsw + nprobe: Optional[str] = None, # "2000" + reorder_factor: Optional[str] = None, # "20" + client_refactor: Optional[str] = None, # "true" + vector_field: str = Field.VECTOR.value, + filters: Optional[list[dict]] = None, + filter_type: Optional[str] = None, + **kwargs, +) -> dict: + if filters is not None: + filter_type = "post_filter" if filter_type is None else filter_type + if not isinstance(filter, list): + raise RuntimeError(f"unexpected filter with {type(filters)}") + final_ext = {"lvector": {}} + if min_score != "0.0": + final_ext["lvector"]["min_score"] = min_score + if ef_search: + final_ext["lvector"]["ef_search"] = ef_search + if nprobe: + final_ext["lvector"]["nprobe"] = nprobe + if reorder_factor: + final_ext["lvector"]["reorder_factor"] = reorder_factor + if client_refactor: + final_ext["lvector"]["client_refactor"] = client_refactor + + search_query = { + "size": k, + "_source": True, # force return '_source' + "query": {"knn": {vector_field: {"vector": query_vector, "k": k}}}, + } + + if filters is not None: + # when using filter, transform filter from List[Dict] to Dict as valid format + filters = {"bool": {"must": filters}} if len(filters) > 1 else filters[0] + search_query["query"]["knn"][vector_field]["filter"] = filters # filter should be Dict + if filter_type: + final_ext["lvector"]["filter_type"] = filter_type + + if final_ext != {"lvector": {}}: + search_query["ext"] = final_ext + return search_query + + +class LindormVectorStoreFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.LINDORM, collection_name)) + lindorm_config = LindormVectorStoreConfig( + hosts=dify_config.LINDORM_URL, + username=dify_config.LINDORM_USERNAME, + password=dify_config.LINDORM_PASSWORD, + ) + return LindormVectorStore(collection_name, lindorm_config) diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index c8cb007ae8..6d2e04fc02 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -134,6 +134,10 @@ class Vector: from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory return TidbOnQdrantVectorFactory + case VectorType.LINDORM: + from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory + + return LindormVectorStoreFactory case VectorType.OCEANBASE: from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index e3b37ece88..8e53e3ae84 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -16,6 +16,7 @@ class VectorType(str, Enum): TENCENT = "tencent" ORACLE = "oracle" ELASTICSEARCH = "elasticsearch" + LINDORM = "lindorm" COUCHBASE = "couchbase" BAIDU = "baidu" VIKINGDB = "vikingdb" diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index ae3c25125c..d4434ea28f 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -14,6 +14,7 @@ import requests from docx import Document as DocxDocument from configs import dify_config +from core.helper import ssrf_proxy from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from extensions.ext_database import db @@ -86,7 +87,7 @@ class WordExtractor(BaseExtractor): image_count += 1 if rel.is_external: url = rel.reltype - response = requests.get(url, stream=True) + response = ssrf_proxy.get(url, stream=True) if response.status_code == 200: image_ext = mimetypes.guess_extension(response.headers["Content-Type"]) file_uuid = str(uuid.uuid4()) diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 94a0c783e1..26c9caaa43 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -5,7 +5,12 @@ from typing import Any, Optional, Union from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator -from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope +from core.entities.parameter_entities import ( + AppSelectorScope, + CommonParameterType, + ModelSelectorScope, + ToolSelectorScope, +) from core.entities.provider_entities import ProviderConfig from core.tools.entities.common_entities import I18nObject @@ -209,6 +214,9 @@ class ToolParameter(BaseModel): SECRET_INPUT = CommonParameterType.SECRET_INPUT.value FILE = CommonParameterType.FILE.value FILES = CommonParameterType.FILES.value + APP_SELECTOR = CommonParameterType.APP_SELECTOR.value + TOOL_SELECTOR = CommonParameterType.TOOL_SELECTOR.value + MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value # deprecated, should not use. SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value @@ -258,11 +266,26 @@ class ToolParameter(BaseModel): return float(value) else: return int(value) + case ToolParameter.ToolParameterType.SYSTEM_FILES | ToolParameter.ToolParameterType.FILES: + if not isinstance(value, list): + return [value] + return value + case ToolParameter.ToolParameterType.FILE: + if isinstance(value, list): + if len(value) != 1: + raise ValueError( + "This parameter only accepts one file but got multiple files while invoking." + ) + else: + return value[0] + return value case ( - ToolParameter.ToolParameterType.SYSTEM_FILES - | ToolParameter.ToolParameterType.FILE - | ToolParameter.ToolParameterType.FILES + ToolParameter.ToolParameterType.TOOL_SELECTOR + | ToolParameter.ToolParameterType.MODEL_SELECTOR + | ToolParameter.ToolParameterType.APP_SELECTOR ): + if not isinstance(value, dict): + raise ValueError("The selector must be a dictionary.") return value case _: return str(value) @@ -280,7 +303,7 @@ class ToolParameter(BaseModel): human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user") placeholder: Optional[I18nObject] = Field(default=None, description="The placeholder presented to the user") type: ToolParameterType = Field(..., description="The type of the parameter") - scope: AppSelectorScope | ModelConfigScope | None = None + scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm") llm_description: Optional[str] = None required: Optional[bool] = False diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 0131bb342b..7e10cddc71 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -23,6 +23,7 @@ class NodeRunMetadataKey(str, Enum): PARALLEL_START_NODE_ID = "parallel_start_node_id" PARENT_PARALLEL_ID = "parent_parallel_id" PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id" + PARALLEL_MODE_RUN_ID = "parallel_mode_run_id" class NodeRunResult(BaseModel): diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 86d89e0a32..bacea191dd 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -59,6 +59,7 @@ class BaseNodeEvent(GraphEngineEvent): class NodeRunStartedEvent(BaseNodeEvent): predecessor_node_id: Optional[str] = None + parallel_mode_run_id: Optional[str] = None """predecessor node id""" @@ -81,6 +82,10 @@ class NodeRunFailedEvent(BaseNodeEvent): error: str = Field(..., description="error") +class NodeInIterationFailedEvent(BaseNodeEvent): + error: str = Field(..., description="error") + + ########################################### # Parallel Branch Events ########################################### @@ -129,6 +134,8 @@ class BaseIterationEvent(GraphEngineEvent): """parent parallel id if node is in parallel""" parent_parallel_start_node_id: Optional[str] = None """parent parallel start node id if node is in parallel""" + parallel_mode_run_id: Optional[str] = None + """iteratoin run in parallel mode run id""" class IterationRunStartedEvent(BaseIterationEvent): diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 8f58af00ef..f07ad4de11 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -4,6 +4,7 @@ import time import uuid from collections.abc import Generator, Mapping from concurrent.futures import ThreadPoolExecutor, wait +from copy import copy, deepcopy from typing import Any, Optional from flask import Flask, current_app @@ -724,6 +725,16 @@ class GraphEngine: """ return time.perf_counter() - start_at > max_execution_time + def create_copy(self): + """ + create a graph engine copy + :return: with a new variable pool instance of graph engine + """ + new_instance = copy(self) + new_instance.graph_runtime_state = copy(self.graph_runtime_state) + new_instance.graph_runtime_state.variable_pool = deepcopy(self.graph_runtime_state.variable_pool) + return new_instance + class GraphRunFailedError(Exception): def __init__(self, error: str): diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 9d7d9027c3..de70af58dd 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -12,6 +12,12 @@ from core.workflow.nodes.code.entities import CodeNodeData from core.workflow.nodes.enums import NodeType from models.workflow import WorkflowNodeExecutionStatus +from .exc import ( + CodeNodeError, + DepthLimitError, + OutputValidationError, +) + class CodeNode(BaseNode[CodeNodeData]): _node_data_cls = CodeNodeData @@ -60,7 +66,7 @@ class CodeNode(BaseNode[CodeNodeData]): # Transform result result = self._transform_result(result, self.node_data.outputs) - except (CodeExecutionError, ValueError) as e: + except (CodeExecutionError, CodeNodeError) as e: return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) @@ -76,10 +82,10 @@ class CodeNode(BaseNode[CodeNodeData]): if value is None: return None else: - raise ValueError(f"Output variable `{variable}` must be a string") + raise OutputValidationError(f"Output variable `{variable}` must be a string") if len(value) > dify_config.CODE_MAX_STRING_LENGTH: - raise ValueError( + raise OutputValidationError( f"The length of output variable `{variable}` must be" f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters" ) @@ -97,10 +103,10 @@ class CodeNode(BaseNode[CodeNodeData]): if value is None: return None else: - raise ValueError(f"Output variable `{variable}` must be a number") + raise OutputValidationError(f"Output variable `{variable}` must be a number") if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER: - raise ValueError( + raise OutputValidationError( f"Output variable `{variable}` is out of range," f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}." ) @@ -108,7 +114,7 @@ class CodeNode(BaseNode[CodeNodeData]): if isinstance(value, float): # raise error if precision is too high if len(str(value).split(".")[1]) > dify_config.CODE_MAX_PRECISION: - raise ValueError( + raise OutputValidationError( f"Output variable `{variable}` has too high precision," f" it must be less than {dify_config.CODE_MAX_PRECISION} digits." ) @@ -125,7 +131,7 @@ class CodeNode(BaseNode[CodeNodeData]): :return: """ if depth > dify_config.CODE_MAX_DEPTH: - raise ValueError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.") + raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.") transformed_result = {} if output_schema is None: @@ -177,14 +183,14 @@ class CodeNode(BaseNode[CodeNodeData]): depth=depth + 1, ) else: - raise ValueError( + raise OutputValidationError( f"Output {prefix}.{output_name} is not a valid array." f" make sure all elements are of the same type." ) elif output_value is None: pass else: - raise ValueError(f"Output {prefix}.{output_name} is not a valid type.") + raise OutputValidationError(f"Output {prefix}.{output_name} is not a valid type.") return result @@ -192,7 +198,7 @@ class CodeNode(BaseNode[CodeNodeData]): for output_name, output_config in output_schema.items(): dot = "." if prefix else "" if output_name not in result: - raise ValueError(f"Output {prefix}{dot}{output_name} is missing.") + raise OutputValidationError(f"Output {prefix}{dot}{output_name} is missing.") if output_config.type == "object": # check if output is object @@ -200,7 +206,7 @@ class CodeNode(BaseNode[CodeNodeData]): if isinstance(result.get(output_name), type(None)): transformed_result[output_name] = None else: - raise ValueError( + raise OutputValidationError( f"Output {prefix}{dot}{output_name} is not an object," f" got {type(result.get(output_name))} instead." ) @@ -228,13 +234,13 @@ class CodeNode(BaseNode[CodeNodeData]): if isinstance(result[output_name], type(None)): transformed_result[output_name] = None else: - raise ValueError( + raise OutputValidationError( f"Output {prefix}{dot}{output_name} is not an array," f" got {type(result.get(output_name))} instead." ) else: if len(result[output_name]) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH: - raise ValueError( + raise OutputValidationError( f"The length of output variable `{prefix}{dot}{output_name}` must be" f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements." ) @@ -249,13 +255,13 @@ class CodeNode(BaseNode[CodeNodeData]): if isinstance(result[output_name], type(None)): transformed_result[output_name] = None else: - raise ValueError( + raise OutputValidationError( f"Output {prefix}{dot}{output_name} is not an array," f" got {type(result.get(output_name))} instead." ) else: if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH: - raise ValueError( + raise OutputValidationError( f"The length of output variable `{prefix}{dot}{output_name}` must be" f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements." ) @@ -270,13 +276,13 @@ class CodeNode(BaseNode[CodeNodeData]): if isinstance(result[output_name], type(None)): transformed_result[output_name] = None else: - raise ValueError( + raise OutputValidationError( f"Output {prefix}{dot}{output_name} is not an array," f" got {type(result.get(output_name))} instead." ) else: if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH: - raise ValueError( + raise OutputValidationError( f"The length of output variable `{prefix}{dot}{output_name}` must be" f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements." ) @@ -286,7 +292,7 @@ class CodeNode(BaseNode[CodeNodeData]): if value is None: pass else: - raise ValueError( + raise OutputValidationError( f"Output {prefix}{dot}{output_name}[{i}] is not an object," f" got {type(value)} instead at index {i}." ) @@ -303,13 +309,13 @@ class CodeNode(BaseNode[CodeNodeData]): for i, value in enumerate(result[output_name]) ] else: - raise ValueError(f"Output type {output_config.type} is not supported.") + raise OutputValidationError(f"Output type {output_config.type} is not supported.") parameters_validated[output_name] = True # check if all output parameters are validated if len(parameters_validated) != len(result): - raise ValueError("Not all output parameters are validated.") + raise CodeNodeError("Not all output parameters are validated.") return transformed_result diff --git a/api/core/workflow/nodes/code/exc.py b/api/core/workflow/nodes/code/exc.py new file mode 100644 index 0000000000..d6334fd554 --- /dev/null +++ b/api/core/workflow/nodes/code/exc.py @@ -0,0 +1,16 @@ +class CodeNodeError(ValueError): + """Base class for code node errors.""" + + pass + + +class OutputValidationError(CodeNodeError): + """Raised when there is an output validation error.""" + + pass + + +class DepthLimitError(CodeNodeError): + """Raised when the depth limit is reached.""" + + pass diff --git a/api/core/workflow/nodes/document_extractor/exc.py b/api/core/workflow/nodes/document_extractor/exc.py index c9d4bb8ef6..5caf00ebc5 100644 --- a/api/core/workflow/nodes/document_extractor/exc.py +++ b/api/core/workflow/nodes/document_extractor/exc.py @@ -1,4 +1,4 @@ -class DocumentExtractorError(Exception): +class DocumentExtractorError(ValueError): """Base exception for errors related to the DocumentExtractorNode.""" diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index c2f51ad1e5..c90017d5e1 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -6,12 +6,14 @@ import docx import pandas as pd import pypdfium2 import yaml +from unstructured.partition.api import partition_via_api from unstructured.partition.email import partition_email from unstructured.partition.epub import partition_epub from unstructured.partition.msg import partition_msg from unstructured.partition.ppt import partition_ppt from unstructured.partition.pptx import partition_pptx +from configs import dify_config from core.file import File, FileTransferMethod, file_manager from core.helper import ssrf_proxy from core.variables import ArrayFileSegment @@ -196,10 +198,8 @@ def _download_file_content(file: File) -> bytes: response = ssrf_proxy.get(file.remote_url) response.raise_for_status() return response.content - elif file.transfer_method == FileTransferMethod.LOCAL_FILE: - return file_manager.download(file) else: - raise ValueError(f"Unsupported transfer method: {file.transfer_method}") + return file_manager.download(file) except Exception as e: raise FileDownloadError(f"Error downloading file: {str(e)}") from e @@ -263,7 +263,14 @@ def _extract_text_from_ppt(file_content: bytes) -> str: def _extract_text_from_pptx(file_content: bytes) -> str: try: with io.BytesIO(file_content) as file: - elements = partition_pptx(file=file) + if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY: + elements = partition_via_api( + file=file, + api_url=dify_config.UNSTRUCTURED_API_URL, + api_key=dify_config.UNSTRUCTURED_API_KEY, + ) + else: + elements = partition_pptx(file=file) return "\n".join([getattr(element, "text", "") for element in elements]) except Exception as e: raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e diff --git a/api/core/workflow/nodes/http_request/exc.py b/api/core/workflow/nodes/http_request/exc.py new file mode 100644 index 0000000000..7a5ab7dbc1 --- /dev/null +++ b/api/core/workflow/nodes/http_request/exc.py @@ -0,0 +1,18 @@ +class HttpRequestNodeError(ValueError): + """Custom error for HTTP request node.""" + + +class AuthorizationConfigError(HttpRequestNodeError): + """Raised when authorization config is missing or invalid.""" + + +class FileFetchError(HttpRequestNodeError): + """Raised when a file cannot be fetched.""" + + +class InvalidHttpMethodError(HttpRequestNodeError): + """Raised when an invalid HTTP method is used.""" + + +class ResponseSizeError(HttpRequestNodeError): + """Raised when the response size exceeds the allowed threshold.""" diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 6872478299..d90dfcc766 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -18,6 +18,12 @@ from .entities import ( HttpRequestNodeTimeout, Response, ) +from .exc import ( + AuthorizationConfigError, + FileFetchError, + InvalidHttpMethodError, + ResponseSizeError, +) BODY_TYPE_TO_CONTENT_TYPE = { "json": "application/json", @@ -51,7 +57,7 @@ class Executor: # If authorization API key is present, convert the API key using the variable pool if node_data.authorization.type == "api-key": if node_data.authorization.config is None: - raise ValueError("authorization config is required") + raise AuthorizationConfigError("authorization config is required") node_data.authorization.config.api_key = variable_pool.convert_template( node_data.authorization.config.api_key ).text @@ -82,8 +88,10 @@ class Executor: self.url = self.variable_pool.convert_template(self.node_data.url).text def _init_params(self): - params = self.variable_pool.convert_template(self.node_data.params).text - self.params = _plain_text_to_dict(params) + params = _plain_text_to_dict(self.node_data.params) + for key in params: + params[key] = self.variable_pool.convert_template(params[key]).text + self.params = params def _init_headers(self): headers = self.variable_pool.convert_template(self.node_data.headers).text @@ -116,7 +124,7 @@ class Executor: file_selector = data[0].file file_variable = self.variable_pool.get_file(file_selector) if file_variable is None: - raise ValueError(f"cannot fetch file with selector {file_selector}") + raise FileFetchError(f"cannot fetch file with selector {file_selector}") file = file_variable.value self.content = file_manager.download(file) case "x-www-form-urlencoded": @@ -155,12 +163,12 @@ class Executor: headers = deepcopy(self.headers) or {} if self.auth.type == "api-key": if self.auth.config is None: - raise ValueError("self.authorization config is required") + raise AuthorizationConfigError("self.authorization config is required") if authorization.config is None: - raise ValueError("authorization config is required") + raise AuthorizationConfigError("authorization config is required") if self.auth.config.api_key is None: - raise ValueError("api_key is required") + raise AuthorizationConfigError("api_key is required") if not authorization.config.header: authorization.config.header = "Authorization" @@ -183,7 +191,7 @@ class Executor: else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE ) if executor_response.size > threshold_size: - raise ValueError( + raise ResponseSizeError( f'{"File" if executor_response.is_file else "Text"} size is too large,' f' max size is {threshold_size / 1024 / 1024:.2f} MB,' f' but current size is {executor_response.readable_size}.' @@ -196,7 +204,7 @@ class Executor: do http request depending on api bundle """ if self.method not in {"get", "head", "post", "put", "delete", "patch"}: - raise ValueError(f"Invalid http method {self.method}") + raise InvalidHttpMethodError(f"Invalid http method {self.method}") request_args = { "url": self.url, diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index a037bee665..61c661e587 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -20,6 +20,7 @@ from .entities import ( HttpRequestNodeTimeout, Response, ) +from .exc import HttpRequestNodeError HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout( connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, @@ -77,7 +78,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): "request": http_executor.to_log(), }, ) - except Exception as e: + except HttpRequestNodeError as e: logger.warning(f"http request node {self.node_id} failed to run: {e}") return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, diff --git a/api/core/workflow/nodes/iteration/entities.py b/api/core/workflow/nodes/iteration/entities.py index 4afc870e50..ebcb6f82fb 100644 --- a/api/core/workflow/nodes/iteration/entities.py +++ b/api/core/workflow/nodes/iteration/entities.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Any, Optional from pydantic import Field @@ -5,6 +6,12 @@ from pydantic import Field from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData +class ErrorHandleMode(str, Enum): + TERMINATED = "terminated" + CONTINUE_ON_ERROR = "continue-on-error" + REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output" + + class IterationNodeData(BaseIterationNodeData): """ Iteration Node Data. @@ -13,6 +20,9 @@ class IterationNodeData(BaseIterationNodeData): parent_loop_id: Optional[str] = None # redundant field, not used currently iterator_selector: list[str] # variable selector output_selector: list[str] # output selector + is_parallel: bool = False # open the parallel mode or not + parallel_nums: int = 10 # the numbers of parallel + error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED # how to handle the error class IterationStartNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index af79da9215..d121b0530a 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -1,12 +1,20 @@ import logging +import uuid from collections.abc import Generator, Mapping, Sequence +from concurrent.futures import Future, wait from datetime import datetime, timezone -from typing import Any, cast +from queue import Empty, Queue +from typing import TYPE_CHECKING, Any, Optional, cast + +from flask import Flask, current_app from configs import dify_config from core.model_runtime.utils.encoders import jsonable_encoder -from core.variables import IntegerSegment -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult +from core.workflow.entities.node_entities import ( + NodeRunMetadataKey, + NodeRunResult, +) +from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import ( BaseGraphEvent, BaseNodeEvent, @@ -17,6 +25,9 @@ from core.workflow.graph_engine.entities.event import ( IterationRunNextEvent, IterationRunStartedEvent, IterationRunSucceededEvent, + NodeInIterationFailedEvent, + NodeRunFailedEvent, + NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) @@ -24,9 +35,11 @@ from core.workflow.graph_engine.entities.graph import Graph from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent -from core.workflow.nodes.iteration.entities import IterationNodeData +from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from models.workflow import WorkflowNodeExecutionStatus +if TYPE_CHECKING: + from core.workflow.graph_engine.graph_engine import GraphEngine logger = logging.getLogger(__name__) @@ -38,6 +51,17 @@ class IterationNode(BaseNode[IterationNodeData]): _node_data_cls = IterationNodeData _node_type = NodeType.ITERATION + @classmethod + def get_default_config(cls, filters: Optional[dict] = None) -> dict: + return { + "type": "iteration", + "config": { + "is_parallel": False, + "parallel_nums": 10, + "error_handle_mode": ErrorHandleMode.TERMINATED.value, + }, + } + def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: """ Run the node. @@ -83,7 +107,7 @@ class IterationNode(BaseNode[IterationNodeData]): variable_pool.add([self.node_id, "item"], iterator_list_value[0]) # init graph engine - from core.workflow.graph_engine.graph_engine import GraphEngine + from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool graph_engine = GraphEngine( tenant_id=self.tenant_id, @@ -123,108 +147,64 @@ class IterationNode(BaseNode[IterationNodeData]): index=0, pre_iteration_output=None, ) - outputs: list[Any] = [] try: - for _ in range(len(iterator_list_value)): - # run workflow - rst = graph_engine.run() - for event in rst: - if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: - event.in_iteration_id = self.node_id - - if ( - isinstance(event, BaseNodeEvent) - and event.node_type == NodeType.ITERATION_START - and not isinstance(event, NodeRunStreamChunkEvent) - ): + if self.node_data.is_parallel: + futures: list[Future] = [] + q = Queue() + thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100) + for index, item in enumerate(iterator_list_value): + future: Future = thread_pool.submit( + self._run_single_iter_parallel, + current_app._get_current_object(), + q, + iterator_list_value, + inputs, + outputs, + start_at, + graph_engine, + iteration_graph, + index, + item, + ) + future.add_done_callback(thread_pool.task_done_callback) + futures.append(future) + succeeded_count = 0 + while True: + try: + event = q.get(timeout=1) + if event is None: + break + if isinstance(event, IterationRunNextEvent): + succeeded_count += 1 + if succeeded_count == len(futures): + q.put(None) + yield event + if isinstance(event, RunCompletedEvent): + q.put(None) + for f in futures: + if not f.done(): + f.cancel() + yield event + if isinstance(event, IterationRunFailedEvent): + q.put(None) + yield event + except Empty: continue - if isinstance(event, NodeRunSucceededEvent): - if event.route_node_state.node_run_result: - metadata = event.route_node_state.node_run_result.metadata - if not metadata: - metadata = {} - - if NodeRunMetadataKey.ITERATION_ID not in metadata: - metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id - index_variable = variable_pool.get([self.node_id, "index"]) - if not isinstance(index_variable, IntegerSegment): - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=f"Invalid index variable type: {type(index_variable)}", - ) - ) - return - metadata[NodeRunMetadataKey.ITERATION_INDEX] = index_variable.value - event.route_node_state.node_run_result.metadata = metadata - - yield event - elif isinstance(event, BaseGraphEvent): - if isinstance(event, GraphRunFailedEvent): - # iteration run failed - yield IterationRunFailedEvent( - iteration_id=self.id, - iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, - start_at=start_at, - inputs=inputs, - outputs={"output": jsonable_encoder(outputs)}, - steps=len(iterator_list_value), - metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, - error=event.error, - ) - - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=event.error, - ) - ) - return - else: - event = cast(InNodeEvent, event) - yield event - - # append to iteration output variable list - current_iteration_output_variable = variable_pool.get(self.node_data.output_selector) - if current_iteration_output_variable is None: - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=f"Iteration output variable {self.node_data.output_selector} not found", - ) + # wait all threads + wait(futures) + else: + for _ in range(len(iterator_list_value)): + yield from self._run_single_iter( + iterator_list_value, + variable_pool, + inputs, + outputs, + start_at, + graph_engine, + iteration_graph, ) - return - current_iteration_output = current_iteration_output_variable.to_object() - outputs.append(current_iteration_output) - - # remove all nodes outputs from variable pool - for node_id in iteration_graph.node_ids: - variable_pool.remove([node_id]) - - # move to next iteration - current_index_variable = variable_pool.get([self.node_id, "index"]) - if not isinstance(current_index_variable, IntegerSegment): - raise ValueError(f"iteration {self.node_id} current index not found") - - next_index = current_index_variable.value + 1 - variable_pool.add([self.node_id, "index"], next_index) - - if next_index < len(iterator_list_value): - variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) - - yield IterationRunNextEvent( - iteration_id=self.id, - iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, - index=next_index, - pre_iteration_output=jsonable_encoder(current_iteration_output), - ) - yield IterationRunSucceededEvent( iteration_id=self.id, iteration_node_id=self.node_id, @@ -330,3 +310,231 @@ class IterationNode(BaseNode[IterationNodeData]): } return variable_mapping + + def _handle_event_metadata( + self, event: BaseNodeEvent, iter_run_index: str, parallel_mode_run_id: str + ) -> NodeRunStartedEvent | BaseNodeEvent: + """ + add iteration metadata to event. + """ + if not isinstance(event, BaseNodeEvent): + return event + if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent): + event.parallel_mode_run_id = parallel_mode_run_id + return event + if event.route_node_state.node_run_result: + metadata = event.route_node_state.node_run_result.metadata + if not metadata: + metadata = {} + + if NodeRunMetadataKey.ITERATION_ID not in metadata: + metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id + if self.node_data.is_parallel: + metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id + else: + metadata[NodeRunMetadataKey.ITERATION_INDEX] = iter_run_index + event.route_node_state.node_run_result.metadata = metadata + return event + + def _run_single_iter( + self, + iterator_list_value: list[str], + variable_pool: VariablePool, + inputs: dict[str, list], + outputs: list, + start_at: datetime, + graph_engine: "GraphEngine", + iteration_graph: Graph, + parallel_mode_run_id: Optional[str] = None, + ) -> Generator[NodeEvent | InNodeEvent, None, None]: + """ + run single iteration + """ + try: + rst = graph_engine.run() + # get current iteration index + current_index = variable_pool.get([self.node_id, "index"]).value + next_index = int(current_index) + 1 + + if current_index is None: + raise ValueError(f"iteration {self.node_id} current index not found") + for event in rst: + if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: + event.in_iteration_id = self.node_id + + if ( + isinstance(event, BaseNodeEvent) + and event.node_type == NodeType.ITERATION_START + and not isinstance(event, NodeRunStreamChunkEvent) + ): + continue + + if isinstance(event, NodeRunSucceededEvent): + yield self._handle_event_metadata(event, current_index, parallel_mode_run_id) + elif isinstance(event, BaseGraphEvent): + if isinstance(event, GraphRunFailedEvent): + # iteration run failed + if self.node_data.is_parallel: + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + parallel_mode_run_id=parallel_mode_run_id, + start_at=start_at, + inputs=inputs, + outputs={"output": jsonable_encoder(outputs)}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=event.error, + ) + else: + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": jsonable_encoder(outputs)}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=event.error, + ) + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=event.error, + ) + ) + return + else: + event = cast(InNodeEvent, event) + metadata_event = self._handle_event_metadata(event, current_index, parallel_mode_run_id) + if isinstance(event, NodeRunFailedEvent): + if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR: + yield NodeInIterationFailedEvent( + **metadata_event.model_dump(), + ) + outputs.insert(current_index, None) + variable_pool.add([self.node_id, "index"], next_index) + if next_index < len(iterator_list_value): + variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) + yield IterationRunNextEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + index=next_index, + parallel_mode_run_id=parallel_mode_run_id, + pre_iteration_output=None, + ) + return + elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: + yield NodeInIterationFailedEvent( + **metadata_event.model_dump(), + ) + variable_pool.add([self.node_id, "index"], next_index) + + if next_index < len(iterator_list_value): + variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) + yield IterationRunNextEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + index=next_index, + parallel_mode_run_id=parallel_mode_run_id, + pre_iteration_output=None, + ) + return + elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED: + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": None}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=event.error, + ) + yield metadata_event + + current_iteration_output = variable_pool.get(self.node_data.output_selector).value + outputs.insert(current_index, current_iteration_output) + # remove all nodes outputs from variable pool + for node_id in iteration_graph.node_ids: + variable_pool.remove([node_id]) + + # move to next iteration + variable_pool.add([self.node_id, "index"], next_index) + + if next_index < len(iterator_list_value): + variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) + yield IterationRunNextEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + index=next_index, + parallel_mode_run_id=parallel_mode_run_id, + pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None, + ) + + except Exception as e: + logger.exception(f"Iteration run failed:{str(e)}") + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": None}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=str(e), + ) + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + ) + ) + + def _run_single_iter_parallel( + self, + flask_app: Flask, + q: Queue, + iterator_list_value: list[str], + inputs: dict[str, list], + outputs: list, + start_at: datetime, + graph_engine: "GraphEngine", + iteration_graph: Graph, + index: int, + item: Any, + ) -> Generator[NodeEvent | InNodeEvent, None, None]: + """ + run single iteration in parallel mode + """ + with flask_app.app_context(): + parallel_mode_run_id = uuid.uuid4().hex + graph_engine_copy = graph_engine.create_copy() + variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool + variable_pool_copy.add([self.node_id, "index"], index) + variable_pool_copy.add([self.node_id, "item"], item) + for event in self._run_single_iter( + iterator_list_value=iterator_list_value, + variable_pool=variable_pool_copy, + inputs=inputs, + outputs=outputs, + start_at=start_at, + graph_engine=graph_engine_copy, + iteration_graph=iteration_graph, + parallel_mode_run_id=parallel_mode_run_id, + ): + q.put(event) diff --git a/api/core/workflow/nodes/list_operator/exc.py b/api/core/workflow/nodes/list_operator/exc.py new file mode 100644 index 0000000000..f88aa0be29 --- /dev/null +++ b/api/core/workflow/nodes/list_operator/exc.py @@ -0,0 +1,16 @@ +class ListOperatorError(ValueError): + """Base class for all ListOperator errors.""" + + pass + + +class InvalidFilterValueError(ListOperatorError): + pass + + +class InvalidKeyError(ListOperatorError): + pass + + +class InvalidConditionError(ListOperatorError): + pass diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index d7e4c64313..49e7ca85fd 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Sequence -from typing import Literal +from typing import Literal, Union from core.file import File from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment @@ -9,6 +9,7 @@ from core.workflow.nodes.enums import NodeType from models.workflow import WorkflowNodeExecutionStatus from .entities import ListOperatorNodeData +from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError class ListOperatorNode(BaseNode[ListOperatorNodeData]): @@ -26,7 +27,17 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs ) - if variable.value and not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment): + if not variable.value: + inputs = {"variable": []} + process_data = {"variable": []} + outputs = {"result": [], "first_record": None, "last_record": None} + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs=outputs, + ) + if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment): error_message = ( f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment " "or ArrayStringSegment" @@ -36,70 +47,98 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): ) if isinstance(variable, ArrayFileSegment): + inputs = {"variable": [item.to_dict() for item in variable.value]} process_data["variable"] = [item.to_dict() for item in variable.value] else: + inputs = {"variable": variable.value} process_data["variable"] = variable.value - # Filter - if self.node_data.filter_by.enabled: - for condition in self.node_data.filter_by.conditions: - if isinstance(variable, ArrayStringSegment): - if not isinstance(condition.value, str): - raise ValueError(f"Invalid filter value: {condition.value}") - value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text - filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) - elif isinstance(variable, ArrayNumberSegment): - if not isinstance(condition.value, str): - raise ValueError(f"Invalid filter value: {condition.value}") - value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text - filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value)) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) - elif isinstance(variable, ArrayFileSegment): - if isinstance(condition.value, str): - value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text - else: - value = condition.value - filter_func = _get_file_filter_func( - key=condition.key, - condition=condition.comparison_operator, - value=value, - ) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) + try: + # Filter + if self.node_data.filter_by.enabled: + variable = self._apply_filter(variable) - # Order - if self.node_data.order_by.enabled: + # Order + if self.node_data.order_by.enabled: + variable = self._apply_order(variable) + + # Slice + if self.node_data.limit.enabled: + variable = self._apply_slice(variable) + + outputs = { + "result": variable.value, + "first_record": variable.value[0] if variable.value else None, + "last_record": variable.value[-1] if variable.value else None, + } + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs=outputs, + ) + except ListOperatorError as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + inputs=inputs, + process_data=process_data, + outputs=outputs, + ) + + def _apply_filter( + self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] + ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: + for condition in self.node_data.filter_by.conditions: if isinstance(variable, ArrayStringSegment): - result = _order_string(order=self.node_data.order_by.value, array=variable.value) + if not isinstance(condition.value, str): + raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") + value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text + filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value) + result = list(filter(filter_func, variable.value)) variable = variable.model_copy(update={"value": result}) elif isinstance(variable, ArrayNumberSegment): - result = _order_number(order=self.node_data.order_by.value, array=variable.value) + if not isinstance(condition.value, str): + raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") + value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text + filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value)) + result = list(filter(filter_func, variable.value)) variable = variable.model_copy(update={"value": result}) elif isinstance(variable, ArrayFileSegment): - result = _order_file( - order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value + if isinstance(condition.value, str): + value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text + else: + value = condition.value + filter_func = _get_file_filter_func( + key=condition.key, + condition=condition.comparison_operator, + value=value, ) + result = list(filter(filter_func, variable.value)) variable = variable.model_copy(update={"value": result}) + return variable - # Slice - if self.node_data.limit.enabled: - result = variable.value[: self.node_data.limit.size] + def _apply_order( + self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] + ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: + if isinstance(variable, ArrayStringSegment): + result = _order_string(order=self.node_data.order_by.value, array=variable.value) variable = variable.model_copy(update={"value": result}) + elif isinstance(variable, ArrayNumberSegment): + result = _order_number(order=self.node_data.order_by.value, array=variable.value) + variable = variable.model_copy(update={"value": result}) + elif isinstance(variable, ArrayFileSegment): + result = _order_file( + order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value + ) + variable = variable.model_copy(update={"value": result}) + return variable - outputs = { - "result": variable.value, - "first_record": variable.value[0] if variable.value else None, - "last_record": variable.value[-1] if variable.value else None, - } - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs=outputs, - ) + def _apply_slice( + self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] + ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: + result = variable.value[: self.node_data.limit.size] + return variable.model_copy(update={"value": result}) def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]: @@ -107,7 +146,7 @@ def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]: case "size": return lambda x: x.size case _: - raise ValueError(f"Invalid key: {key}") + raise InvalidKeyError(f"Invalid key: {key}") def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]: @@ -118,14 +157,14 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]: return lambda x: x.type case "extension": return lambda x: x.extension or "" - case "mimetype": + case "mime_type": return lambda x: x.mime_type or "" case "transfer_method": return lambda x: x.transfer_method case "url": return lambda x: x.remote_url or "" case _: - raise ValueError(f"Invalid key: {key}") + raise InvalidKeyError(f"Invalid key: {key}") def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]: @@ -151,7 +190,7 @@ def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bo case "not empty": return lambda x: x != "" case _: - raise ValueError(f"Invalid condition: {condition}") + raise InvalidConditionError(f"Invalid condition: {condition}") def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]: @@ -161,7 +200,7 @@ def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callab case "not in": return lambda x: not _in(value)(x) case _: - raise ValueError(f"Invalid condition: {condition}") + raise InvalidConditionError(f"Invalid condition: {condition}") def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]: @@ -179,7 +218,7 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[ case "≥": return _ge(value) case _: - raise ValueError(f"Invalid condition: {condition}") + raise InvalidConditionError(f"Invalid condition: {condition}") def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]: @@ -193,7 +232,7 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str extract_func = _get_file_extract_number_func(key=key) return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x)) else: - raise ValueError(f"Invalid key: {key}") + raise InvalidKeyError(f"Invalid key: {key}") def _contains(value: str): @@ -256,4 +295,4 @@ def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Seq extract_func = _get_file_extract_number_func(key=order_by) return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc") else: - raise ValueError(f"Invalid order key: {order_by}") + raise InvalidKeyError(f"Invalid order key: {order_by}") diff --git a/api/core/workflow/nodes/llm/exc.py b/api/core/workflow/nodes/llm/exc.py new file mode 100644 index 0000000000..f858be2515 --- /dev/null +++ b/api/core/workflow/nodes/llm/exc.py @@ -0,0 +1,26 @@ +class LLMNodeError(ValueError): + """Base class for LLM Node errors.""" + + +class VariableNotFoundError(LLMNodeError): + """Raised when a required variable is not found.""" + + +class InvalidContextStructureError(LLMNodeError): + """Raised when the context structure is invalid.""" + + +class InvalidVariableTypeError(LLMNodeError): + """Raised when the variable type is invalid.""" + + +class ModelNotExistError(LLMNodeError): + """Raised when the specified model does not exist.""" + + +class LLMModeRequiredError(LLMNodeError): + """Raised when LLM mode is required but not provided.""" + + +class NoPromptFoundError(LLMNodeError): + """Raised when no prompt is found in the LLM configuration.""" diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index b4728e6abf..47b0e25d9c 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -56,6 +56,15 @@ from .entities import ( LLMNodeData, ModelConfig, ) +from .exc import ( + InvalidContextStructureError, + InvalidVariableTypeError, + LLMModeRequiredError, + LLMNodeError, + ModelNotExistError, + NoPromptFoundError, + VariableNotFoundError, +) if TYPE_CHECKING: from core.file.models import File @@ -103,7 +112,7 @@ class LLMNode(BaseNode[LLMNodeData]): yield event if context: - node_inputs["#context#"] = context # type: ignore + node_inputs["#context#"] = context # fetch model config model_instance, model_config = self._fetch_model_config(self.node_data.model) @@ -115,7 +124,7 @@ class LLMNode(BaseNode[LLMNodeData]): if self.node_data.memory: query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) if not query: - raise ValueError("Query not found") + raise VariableNotFoundError("Query not found") query = query.text else: query = None @@ -161,7 +170,7 @@ class LLMNode(BaseNode[LLMNodeData]): usage = event.usage finish_reason = event.finish_reason break - except Exception as e: + except LLMNodeError as e: yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -275,7 +284,7 @@ class LLMNode(BaseNode[LLMNodeData]): variable_name = variable_selector.variable variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if variable is None: - raise ValueError(f"Variable {variable_selector.variable} not found") + raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") def parse_dict(input_dict: Mapping[str, Any]) -> str: """ @@ -325,7 +334,7 @@ class LLMNode(BaseNode[LLMNodeData]): for variable_selector in variable_selectors: variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if variable is None: - raise ValueError(f"Variable {variable_selector.variable} not found") + raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") if isinstance(variable, NoneSegment): inputs[variable_selector.variable] = "" inputs[variable_selector.variable] = variable.to_object() @@ -338,7 +347,7 @@ class LLMNode(BaseNode[LLMNodeData]): for variable_selector in query_variable_selectors: variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if variable is None: - raise ValueError(f"Variable {variable_selector.variable} not found") + raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") if isinstance(variable, NoneSegment): continue inputs[variable_selector.variable] = variable.to_object() @@ -355,7 +364,7 @@ class LLMNode(BaseNode[LLMNodeData]): return variable.value elif isinstance(variable, NoneSegment | ArrayAnySegment): return [] - raise ValueError(f"Invalid variable type: {type(variable)}") + raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}") def _fetch_context(self, node_data: LLMNodeData): if not node_data.context.enabled: @@ -376,7 +385,7 @@ class LLMNode(BaseNode[LLMNodeData]): context_str += item + "\n" else: if "content" not in item: - raise ValueError(f"Invalid context structure: {item}") + raise InvalidContextStructureError(f"Invalid context structure: {item}") context_str += item["content"] + "\n" @@ -441,7 +450,7 @@ class LLMNode(BaseNode[LLMNodeData]): ) if provider_model is None: - raise ValueError(f"Model {model_name} not exist.") + raise ModelNotExistError(f"Model {model_name} not exist.") if provider_model.status == ModelStatus.NO_CONFIGURE: raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") @@ -460,12 +469,12 @@ class LLMNode(BaseNode[LLMNodeData]): # get model mode model_mode = node_data_model.mode if not model_mode: - raise ValueError("LLM mode is required.") + raise LLMModeRequiredError("LLM mode is required.") model_schema = model_type_instance.get_model_schema(model_name, model_credentials) if not model_schema: - raise ValueError(f"Model {model_name} not exist.") + raise ModelNotExistError(f"Model {model_name} not exist.") return model_instance, ModelConfigWithCredentialsEntity( provider=provider_name, @@ -564,7 +573,7 @@ class LLMNode(BaseNode[LLMNodeData]): filtered_prompt_messages.append(prompt_message) if not filtered_prompt_messages: - raise ValueError( + raise NoPromptFoundError( "No prompt found in the LLM configuration. " "Please ensure a prompt is properly configured before proceeding." ) @@ -636,7 +645,7 @@ class LLMNode(BaseNode[LLMNodeData]): variable_template_parser = VariableTemplateParser(template=prompt_template.text) variable_selectors = variable_template_parser.extract_variable_selectors() else: - raise ValueError(f"Invalid prompt template type: {type(prompt_template)}") + raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}") variable_mapping = {} for variable_selector in variable_selectors: diff --git a/api/core/workflow/nodes/parameter_extractor/exc.py b/api/core/workflow/nodes/parameter_extractor/exc.py new file mode 100644 index 0000000000..6511aba185 --- /dev/null +++ b/api/core/workflow/nodes/parameter_extractor/exc.py @@ -0,0 +1,50 @@ +class ParameterExtractorNodeError(ValueError): + """Base error for ParameterExtractorNode.""" + + +class InvalidModelTypeError(ParameterExtractorNodeError): + """Raised when the model is not a Large Language Model.""" + + +class ModelSchemaNotFoundError(ParameterExtractorNodeError): + """Raised when the model schema is not found.""" + + +class InvalidInvokeResultError(ParameterExtractorNodeError): + """Raised when the invoke result is invalid.""" + + +class InvalidTextContentTypeError(ParameterExtractorNodeError): + """Raised when the text content type is invalid.""" + + +class InvalidNumberOfParametersError(ParameterExtractorNodeError): + """Raised when the number of parameters is invalid.""" + + +class RequiredParameterMissingError(ParameterExtractorNodeError): + """Raised when a required parameter is missing.""" + + +class InvalidSelectValueError(ParameterExtractorNodeError): + """Raised when a select value is invalid.""" + + +class InvalidNumberValueError(ParameterExtractorNodeError): + """Raised when a number value is invalid.""" + + +class InvalidBoolValueError(ParameterExtractorNodeError): + """Raised when a bool value is invalid.""" + + +class InvalidStringValueError(ParameterExtractorNodeError): + """Raised when a string value is invalid.""" + + +class InvalidArrayValueError(ParameterExtractorNodeError): + """Raised when an array value is invalid.""" + + +class InvalidModelModeError(ParameterExtractorNodeError): + """Raised when the model mode is invalid.""" diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 49546e9356..b64bde8ac5 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -32,6 +32,21 @@ from extensions.ext_database import db from models.workflow import WorkflowNodeExecutionStatus from .entities import ParameterExtractorNodeData +from .exc import ( + InvalidArrayValueError, + InvalidBoolValueError, + InvalidInvokeResultError, + InvalidModelModeError, + InvalidModelTypeError, + InvalidNumberOfParametersError, + InvalidNumberValueError, + InvalidSelectValueError, + InvalidStringValueError, + InvalidTextContentTypeError, + ModelSchemaNotFoundError, + ParameterExtractorNodeError, + RequiredParameterMissingError, +) from .prompts import ( CHAT_EXAMPLE, CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE, @@ -85,7 +100,7 @@ class ParameterExtractorNode(LLMNode): model_instance, model_config = self._fetch_model_config(node_data.model) if not isinstance(model_instance.model_type_instance, LargeLanguageModel): - raise ValueError("Model is not a Large Language Model") + raise InvalidModelTypeError("Model is not a Large Language Model") llm_model = model_instance.model_type_instance model_schema = llm_model.get_model_schema( @@ -93,7 +108,7 @@ class ParameterExtractorNode(LLMNode): credentials=model_config.credentials, ) if not model_schema: - raise ValueError("Model schema not found") + raise ModelSchemaNotFoundError("Model schema not found") # fetch memory memory = self._fetch_memory( @@ -155,7 +170,7 @@ class ParameterExtractorNode(LLMNode): process_data["usage"] = jsonable_encoder(usage) process_data["tool_call"] = jsonable_encoder(tool_call) process_data["llm_text"] = text - except Exception as e: + except ParameterExtractorNodeError as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=inputs, @@ -177,7 +192,7 @@ class ParameterExtractorNode(LLMNode): try: result = self._validate_result(data=node_data, result=result or {}) - except Exception as e: + except ParameterExtractorNodeError as e: error = str(e) # transform result into standard format @@ -217,11 +232,11 @@ class ParameterExtractorNode(LLMNode): # handle invoke result if not isinstance(invoke_result, LLMResult): - raise ValueError(f"Invalid invoke result: {invoke_result}") + raise InvalidInvokeResultError(f"Invalid invoke result: {invoke_result}") text = invoke_result.message.content if not isinstance(text, str): - raise ValueError(f"Invalid text content type: {type(text)}. Expected str.") + raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.") usage = invoke_result.usage tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None @@ -344,7 +359,7 @@ class ParameterExtractorNode(LLMNode): files=files, ) else: - raise ValueError(f"Invalid model mode: {model_mode}") + raise InvalidModelModeError(f"Invalid model mode: {model_mode}") def _generate_prompt_engineering_completion_prompt( self, @@ -449,36 +464,36 @@ class ParameterExtractorNode(LLMNode): Validate result. """ if len(data.parameters) != len(result): - raise ValueError("Invalid number of parameters") + raise InvalidNumberOfParametersError("Invalid number of parameters") for parameter in data.parameters: if parameter.required and parameter.name not in result: - raise ValueError(f"Parameter {parameter.name} is required") + raise RequiredParameterMissingError(f"Parameter {parameter.name} is required") if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options: - raise ValueError(f"Invalid `select` value for parameter {parameter.name}") + raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}") if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float): - raise ValueError(f"Invalid `number` value for parameter {parameter.name}") + raise InvalidNumberValueError(f"Invalid `number` value for parameter {parameter.name}") if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool): - raise ValueError(f"Invalid `bool` value for parameter {parameter.name}") + raise InvalidBoolValueError(f"Invalid `bool` value for parameter {parameter.name}") if parameter.type == "string" and not isinstance(result.get(parameter.name), str): - raise ValueError(f"Invalid `string` value for parameter {parameter.name}") + raise InvalidStringValueError(f"Invalid `string` value for parameter {parameter.name}") if parameter.type.startswith("array"): parameters = result.get(parameter.name) if not isinstance(parameters, list): - raise ValueError(f"Invalid `array` value for parameter {parameter.name}") + raise InvalidArrayValueError(f"Invalid `array` value for parameter {parameter.name}") nested_type = parameter.type[6:-1] for item in parameters: if nested_type == "number" and not isinstance(item, int | float): - raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}") + raise InvalidArrayValueError(f"Invalid `array[number]` value for parameter {parameter.name}") if nested_type == "string" and not isinstance(item, str): - raise ValueError(f"Invalid `array[string]` value for parameter {parameter.name}") + raise InvalidArrayValueError(f"Invalid `array[string]` value for parameter {parameter.name}") if nested_type == "object" and not isinstance(item, dict): - raise ValueError(f"Invalid `array[object]` value for parameter {parameter.name}") + raise InvalidArrayValueError(f"Invalid `array[object]` value for parameter {parameter.name}") return result def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict: @@ -634,7 +649,7 @@ class ParameterExtractorNode(LLMNode): user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) return [system_prompt_messages, user_prompt_message] else: - raise ValueError(f"Model mode {model_mode} not support.") + raise InvalidModelModeError(f"Model mode {model_mode} not support.") def _get_prompt_engineering_prompt_template( self, @@ -669,7 +684,7 @@ class ParameterExtractorNode(LLMNode): .replace("}γγγ", "") ) else: - raise ValueError(f"Model mode {model_mode} not support.") + raise InvalidModelModeError(f"Model mode {model_mode} not support.") def _calculate_rest_token( self, @@ -683,12 +698,12 @@ class ParameterExtractorNode(LLMNode): model_instance, model_config = self._fetch_model_config(node_data.model) if not isinstance(model_instance.model_type_instance, LargeLanguageModel): - raise ValueError("Model is not a Large Language Model") + raise InvalidModelTypeError("Model is not a Large Language Model") llm_model = model_instance.model_type_instance model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials) if not model_schema: - raise ValueError("Model schema not found") + raise ModelSchemaNotFoundError("Model schema not found") if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}: prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000) diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index d0c8c7e84f..0191102b90 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -91,6 +91,8 @@ def build_segment(value: Any, /) -> Segment: return ArrayObjectSegment(value=value) case SegmentType.FILE: return ArrayFileSegment(value=value) + case SegmentType.NONE: + return ArrayAnySegment(value=value) case _: raise ValueError(f"not supported value {value}") raise ValueError(f"not supported value {value}") diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index 1cddc24b2c..afaacc0568 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -8,6 +8,7 @@ upload_config_fields = { "image_file_size_limit": fields.Integer, "video_file_size_limit": fields.Integer, "audio_file_size_limit": fields.Integer, + "workflow_file_upload_limit": fields.Integer, } file_fields = { diff --git a/api/models/__init__.py b/api/models/__init__.py index 1d8bae6cfa..cd6c7674da 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -6,7 +6,6 @@ from .model import ( AppMode, Conversation, EndUser, - FileUploadConfig, InstalledApp, Message, MessageAnnotation, @@ -50,6 +49,5 @@ __all__ = [ "Tenant", "Conversation", "MessageAnnotation", - "FileUploadConfig", "ToolFile", ] diff --git a/api/models/model.py b/api/models/model.py index 8a619d3f30..0b3793cc62 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -121,7 +121,7 @@ class App(Base): return site @property - def app_model_config(self) -> Optional["AppModelConfig"]: + def app_model_config(self): if self.app_model_config_id: return db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first() @@ -1320,7 +1320,7 @@ class Site(Base): privacy_policy = db.Column(db.String(255)) show_workflow_steps = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") + _custom_disclaimer: Mapped[str] = mapped_column("custom_disclaimer", sa.TEXT, default="") customize_domain = db.Column(db.String(255)) customize_token_strategy = db.Column(db.String(255), nullable=False) prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) @@ -1331,6 +1331,16 @@ class Site(Base): updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) code = db.Column(db.String(255)) + @property + def custom_disclaimer(self): + return self._custom_disclaimer + + @custom_disclaimer.setter + def custom_disclaimer(self, value: str): + if len(value) > 512: + raise ValueError("Custom disclaimer cannot exceed 512 characters.") + self._custom_disclaimer = value + @staticmethod def generate_code(n): while True: diff --git a/api/models/workflow.py b/api/models/workflow.py index bc4434ae5a..1ad94d163b 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,6 +1,6 @@ import json from collections.abc import Mapping, Sequence -from datetime import datetime +from datetime import datetime, timezone from enum import Enum from typing import TYPE_CHECKING, Any, Optional, Union @@ -111,7 +111,9 @@ class Workflow(Base): db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) updated_by: Mapped[Optional[str]] = mapped_column(StringUUID) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False) + updated_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, default=datetime.now(tz=timezone.utc), server_onupdate=func.current_timestamp() + ) _environment_variables: Mapped[str] = mapped_column( "environment_variables", db.Text, nullable=False, server_default="{}" ) diff --git a/api/poetry.lock b/api/poetry.lock index d7af124794..a697c02c45 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -2532,6 +2532,19 @@ files = [ {file = "filetype-1.2.0.tar.gz", hash = "sha256:66b56cd6474bf41d8c54660347d37afcc3f7d1970648de365c102ef77548aadb"}, ] +[[package]] +name = "fire" +version = "0.7.0" +description = "A library for automatically generating command line interfaces." +optional = false +python-versions = "*" +files = [ + {file = "fire-0.7.0.tar.gz", hash = "sha256:961550f07936eaf65ad1dc8360f2b2bf8408fad46abbfa4d2a3794f8d2a95cdf"}, +] + +[package.dependencies] +termcolor = "*" + [[package]] name = "flasgger" version = "0.9.7.1" @@ -2697,6 +2710,19 @@ files = [ {file = "flatbuffers-24.3.25.tar.gz", hash = "sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4"}, ] +[[package]] +name = "fontmeta" +version = "1.6.1" +description = "An Utility to get ttf/otf font metadata" +optional = false +python-versions = "*" +files = [ + {file = "fontmeta-1.6.1.tar.gz", hash = "sha256:837e5bc4da879394b41bda1428a8a480eb7c4e993799a93cfb582bab771a9c24"}, +] + +[package.dependencies] +fonttools = "*" + [[package]] name = "fonttools" version = "4.54.1" @@ -5279,6 +5305,22 @@ files = [ {file = "monotonic-1.6.tar.gz", hash = "sha256:3a55207bcfed53ddd5c5bae174524062935efed17792e9de2ad0205ce9ad63f7"}, ] +[[package]] +name = "mplfonts" +version = "0.0.8" +description = "Fonts manager for matplotlib" +optional = false +python-versions = ">=3.7" +files = [ + {file = "mplfonts-0.0.8-py3-none-any.whl", hash = "sha256:b2182e5b0baa216cf016dec19942740e5b48956415708ad2d465e03952112ec1"}, + {file = "mplfonts-0.0.8.tar.gz", hash = "sha256:0abcb2fc0605645e1e7561c6923014d856f11676899b33b4d89757843f5e7c22"}, +] + +[package.dependencies] +fire = ">=0.4.0" +fontmeta = ">=1.6.1" +matplotlib = ">=3.4" + [[package]] name = "mpmath" version = "1.3.0" @@ -9300,6 +9342,20 @@ files = [ [package.dependencies] tencentcloud-sdk-python-common = "3.0.1257" +[[package]] +name = "termcolor" +version = "2.5.0" +description = "ANSI color formatting for output in terminal" +optional = false +python-versions = ">=3.9" +files = [ + {file = "termcolor-2.5.0-py3-none-any.whl", hash = "sha256:37b17b5fc1e604945c2642c872a3764b5d547a48009871aea3edd3afa180afb8"}, + {file = "termcolor-2.5.0.tar.gz", hash = "sha256:998d8d27da6d48442e8e1f016119076b690d962507531df4890fcd2db2ef8a6f"}, +] + +[package.extras] +tests = ["pytest", "pytest-cov"] + [[package]] name = "threadpoolctl" version = "3.5.0" @@ -10046,13 +10102,13 @@ files = [ [[package]] name = "vanna" -version = "0.7.3" +version = "0.7.5" description = "Generate SQL queries from natural language" optional = false python-versions = ">=3.9" files = [ - {file = "vanna-0.7.3-py3-none-any.whl", hash = "sha256:82ba39e5d6c503d1c8cca60835ed401d20ec3a3da98d487f529901dcb30061d6"}, - {file = "vanna-0.7.3.tar.gz", hash = "sha256:4590dd94d2fe180b4efc7a83c867b73144ef58794018910dc226857cfb703077"}, + {file = "vanna-0.7.5-py3-none-any.whl", hash = "sha256:07458c7befa49de517a8760c2d80a13147278b484c515d49a906acc88edcb835"}, + {file = "vanna-0.7.5.tar.gz", hash = "sha256:2fdffc58832898e4fc8e93c45b173424db59a22773b22ca348640161d391eacf"}, ] [package.dependencies] @@ -10073,7 +10129,7 @@ sqlparse = "*" tabulate = "*" [package.extras] -all = ["PyMySQL", "anthropic", "azure-common", "azure-identity", "azure-search-documents", "chromadb", "db-dtypes", "duckdb", "fastembed", "google-cloud-aiplatform", "google-cloud-bigquery", "google-generativeai", "httpx", "marqo", "mistralai (>=1.0.0)", "ollama", "openai", "opensearch-dsl", "opensearch-py", "pinecone-client", "psycopg2-binary", "pymilvus[model]", "qdrant-client", "qianfan", "snowflake-connector-python", "transformers", "weaviate-client", "zhipuai"] +all = ["PyMySQL", "anthropic", "azure-common", "azure-identity", "azure-search-documents", "boto", "boto3", "botocore", "chromadb", "db-dtypes", "duckdb", "faiss-cpu", "fastembed", "google-cloud-aiplatform", "google-cloud-bigquery", "google-generativeai", "httpx", "langchain_core", "langchain_postgres", "marqo", "mistralai (>=1.0.0)", "ollama", "openai", "opensearch-dsl", "opensearch-py", "pinecone-client", "psycopg2-binary", "pymilvus[model]", "qdrant-client", "qianfan", "snowflake-connector-python", "transformers", "weaviate-client", "xinference-client", "zhipuai"] anthropic = ["anthropic"] azuresearch = ["azure-common", "azure-identity", "azure-search-documents", "fastembed"] bedrock = ["boto3", "botocore"] @@ -10081,6 +10137,8 @@ bigquery = ["google-cloud-bigquery"] chromadb = ["chromadb"] clickhouse = ["clickhouse_connect"] duckdb = ["duckdb"] +faiss-cpu = ["faiss-cpu"] +faiss-gpu = ["faiss-gpu"] gemini = ["google-generativeai"] google = ["google-cloud-aiplatform", "google-generativeai"] hf = ["transformers"] @@ -10091,6 +10149,7 @@ mysql = ["PyMySQL"] ollama = ["httpx", "ollama"] openai = ["openai"] opensearch = ["opensearch-dsl", "opensearch-py"] +pgvector = ["langchain-postgres (>=0.0.12)"] pinecone = ["fastembed", "pinecone-client"] postgres = ["db-dtypes", "psycopg2-binary"] qdrant = ["fastembed", "qdrant-client"] @@ -10099,6 +10158,7 @@ snowflake = ["snowflake-connector-python"] test = ["tox"] vllm = ["vllm"] weaviate = ["weaviate-client"] +xinference-client = ["xinference-client"] zhipuai = ["zhipuai"] [[package]] @@ -10940,4 +11000,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "ef927b98c33d704d680e08db0e5c7d9a4e05454c66fcd6a5f656a65eb08e886b" +content-hash = "e4794898403da4ad7b51f248a6c07632a949114c1b569406d3aa6a94c62510a5" diff --git a/api/pyproject.toml b/api/pyproject.toml index ef3dc14131..e42ca1b5a4 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -206,13 +206,14 @@ cloudscraper = "1.2.71" duckduckgo-search = "~6.3.0" jsonpath-ng = "1.6.1" matplotlib = "~3.8.2" +mplfonts = "~0.0.8" newspaper3k = "0.2.8" nltk = "3.9.1" numexpr = "~2.9.0" pydub = "~0.25.1" qrcode = "~7.4.2" twilio = "~9.0.4" -vanna = { version = "0.7.3", extras = ["postgres", "mysql", "clickhouse", "duckdb"] } +vanna = { version = "0.7.5", extras = ["postgres", "mysql", "clickhouse", "duckdb", "oracle"] } wikipedia = "1.4.0" yfinance = "~0.2.40" diff --git a/api/schedule/clean_embedding_cache_task.py b/api/schedule/clean_embedding_cache_task.py index 67d0706828..9efe120b7a 100644 --- a/api/schedule/clean_embedding_cache_task.py +++ b/api/schedule/clean_embedding_cache_task.py @@ -14,7 +14,7 @@ from models.dataset import Embedding @app.celery.task(queue="dataset") def clean_embedding_cache_task(): click.echo(click.style("Start clean embedding cache.", fg="green")) - clean_days = int(dify_config.CLEAN_DAY_SETTING) + clean_days = int(dify_config.PLAN_SANDBOX_CLEAN_DAY_SETTING) start_at = time.perf_counter() thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days) while True: diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index ac05cbc4f5..50da547fd8 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -986,9 +986,6 @@ class DocumentService: raise NotFound("Document not found") if document.display_status != "available": raise ValueError("Document is not available") - # update document name - if document_data.get("name"): - document.name = document_data["name"] # save process rule if document_data.get("process_rule"): process_rule = document_data["process_rule"] @@ -1065,6 +1062,10 @@ class DocumentService: document.data_source_type = document_data["data_source"]["type"] document.data_source_info = json.dumps(data_source_info) document.name = file_name + + # update document name + if document_data.get("name"): + document.name = document_data["name"] # update document to be waiting document.indexing_status = "waiting" document.completed_at = None diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 35cc59b22e..3acae74db6 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -62,6 +62,37 @@ class BuiltinToolManageService: return result + @staticmethod + def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str): + """ + get builtin tool provider info + """ + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + tool_provider_configurations = ProviderConfigEncrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], + provider_type=provider_controller.provider_type.value, + provider_identity=provider_controller.entity.identity.name, + ) + # check if user has added the provider + builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id) + + credentials = {} + if builtin_provider is not None: + # get credentials + credentials = builtin_provider.credentials + credentials = tool_provider_configurations.decrypt(credentials) + + entity = ToolTransformService.builtin_provider_to_user_provider( + provider_controller=provider_controller, + db_provider=builtin_provider, + decrypt_credentials=True, + ) + + entity.original_credentials = {} + + return entity + @staticmethod def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str): """ @@ -255,6 +286,7 @@ class BuiltinToolManageService: @staticmethod def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None: try: + full_provider_name = provider_name provider_id_entity = ToolProviderID(provider_name) provider_name = provider_id_entity.provider_name if provider_id_entity.organization != "langgenius": @@ -264,7 +296,8 @@ class BuiltinToolManageService: db.session.query(BuiltinToolProvider) .filter( BuiltinToolProvider.tenant_id == tenant_id, - (BuiltinToolProvider.provider == provider_name) | (BuiltinToolProvider.provider == provider_name), + (BuiltinToolProvider.provider == provider_name) + | (BuiltinToolProvider.provider == full_provider_name), ) .first() ) diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index fa4a2eb36c..95cca83b44 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -96,5 +96,13 @@ VESSL_AI_MODEL_NAME= VESSL_AI_API_KEY= VESSL_AI_ENDPOINT_URL= +# GPUStack Credentials +GPUSTACK_SERVER_URL= +GPUSTACK_API_KEY= + # Gitee AI Credentials GITEE_AI_API_KEY= + +# xAI Credentials +XAI_API_KEY= +XAI_API_BASE= diff --git a/api/tests/integration_tests/model_runtime/gpustack/__init__.py b/api/tests/integration_tests/model_runtime/gpustack/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/gpustack/test_embedding.py b/api/tests/integration_tests/model_runtime/gpustack/test_embedding.py new file mode 100644 index 0000000000..f56ad0dadc --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gpustack/test_embedding.py @@ -0,0 +1,49 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.gpustack.text_embedding.text_embedding import ( + GPUStackTextEmbeddingModel, +) + + +def test_validate_credentials(): + model = GPUStackTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="bge-m3", + credentials={ + "endpoint_url": "invalid_url", + "api_key": "invalid_api_key", + }, + ) + + model.validate_credentials( + model="bge-m3", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + }, + ) + + +def test_invoke_model(): + model = GPUStackTextEmbeddingModel() + + result = model.invoke( + model="bge-m3", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + "context_size": 8192, + }, + texts=["hello", "world"], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens == 7 diff --git a/api/tests/integration_tests/model_runtime/gpustack/test_llm.py b/api/tests/integration_tests/model_runtime/gpustack/test_llm.py new file mode 100644 index 0000000000..326b7b16f0 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gpustack/test_llm.py @@ -0,0 +1,162 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.gpustack.llm.llm import GPUStackLanguageModel + + +def test_validate_credentials_for_chat_model(): + model = GPUStackLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="llama-3.2-1b-instruct", + credentials={ + "endpoint_url": "invalid_url", + "api_key": "invalid_api_key", + "mode": "chat", + }, + ) + + model.validate_credentials( + model="llama-3.2-1b-instruct", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + "mode": "chat", + }, + ) + + +def test_invoke_completion_model(): + model = GPUStackLanguageModel() + + response = model.invoke( + model="llama-3.2-1b-instruct", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + "mode": "completion", + }, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, + stop=[], + user="abc-123", + stream=False, + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + assert response.usage.total_tokens > 0 + + +def test_invoke_chat_model(): + model = GPUStackLanguageModel() + + response = model.invoke( + model="llama-3.2-1b-instruct", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + "mode": "chat", + }, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, + stop=[], + user="abc-123", + stream=False, + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + assert response.usage.total_tokens > 0 + + +def test_invoke_stream_chat_model(): + model = GPUStackLanguageModel() + + response = model.invoke( + model="llama-3.2-1b-instruct", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + "mode": "chat", + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, + stop=["you"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_get_num_tokens(): + model = GPUStackLanguageModel() + + num_tokens = model.get_num_tokens( + model="????", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + tools=[ + PromptMessageTool( + name="get_current_weather", + description="Get the current weather in a given location", + parameters={ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ) + ], + ) + + assert isinstance(num_tokens, int) + assert num_tokens == 80 + + num_tokens = model.get_num_tokens( + model="????", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + "mode": "chat", + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + ) + + assert isinstance(num_tokens, int) + assert num_tokens == 10 diff --git a/api/tests/integration_tests/model_runtime/gpustack/test_rerank.py b/api/tests/integration_tests/model_runtime/gpustack/test_rerank.py new file mode 100644 index 0000000000..f5c2d2d21c --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gpustack/test_rerank.py @@ -0,0 +1,107 @@ +import os + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.gpustack.rerank.rerank import ( + GPUStackRerankModel, +) + + +def test_validate_credentials_for_rerank_model(): + model = GPUStackRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="bge-reranker-v2-m3", + credentials={ + "endpoint_url": "invalid_url", + "api_key": "invalid_api_key", + }, + ) + + model.validate_credentials( + model="bge-reranker-v2-m3", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + }, + ) + + +def test_invoke_rerank_model(): + model = GPUStackRerankModel() + + response = model.invoke( + model="bge-reranker-v2-m3", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + }, + query="Organic skincare products for sensitive skin", + docs=[ + "Eco-friendly kitchenware for modern homes", + "Biodegradable cleaning supplies for eco-conscious consumers", + "Organic cotton baby clothes for sensitive skin", + "Natural organic skincare range for sensitive skin", + "Tech gadgets for smart homes: 2024 edition", + "Sustainable gardening tools and compost solutions", + "Sensitive skin-friendly facial cleansers and toners", + "Organic food wraps and storage solutions", + "Yoga mats made from recycled materials", + ], + top_n=3, + score_threshold=-0.75, + user="abc-123", + ) + + assert isinstance(response, RerankResult) + assert len(response.docs) == 3 + + +def test__invoke(): + model = GPUStackRerankModel() + + # Test case 1: Empty docs + result = model._invoke( + model="bge-reranker-v2-m3", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + }, + query="Organic skincare products for sensitive skin", + docs=[], + top_n=3, + score_threshold=0.75, + user="abc-123", + ) + assert isinstance(result, RerankResult) + assert len(result.docs) == 0 + + # Test case 2: Expected docs + result = model._invoke( + model="bge-reranker-v2-m3", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + }, + query="Organic skincare products for sensitive skin", + docs=[ + "Eco-friendly kitchenware for modern homes", + "Biodegradable cleaning supplies for eco-conscious consumers", + "Organic cotton baby clothes for sensitive skin", + "Natural organic skincare range for sensitive skin", + "Tech gadgets for smart homes: 2024 edition", + "Sustainable gardening tools and compost solutions", + "Sensitive skin-friendly facial cleansers and toners", + "Organic food wraps and storage solutions", + "Yoga mats made from recycled materials", + ], + top_n=3, + score_threshold=-0.75, + user="abc-123", + ) + assert isinstance(result, RerankResult) + assert len(result.docs) == 3 + assert all(isinstance(doc, RerankDocument) for doc in result.docs) diff --git a/api/tests/integration_tests/model_runtime/x/__init__.py b/api/tests/integration_tests/model_runtime/x/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/x/test_llm.py b/api/tests/integration_tests/model_runtime/x/test_llm.py new file mode 100644 index 0000000000..647a2f6480 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/x/test_llm.py @@ -0,0 +1,204 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.x.llm.llm import XAILargeLanguageModel + +"""FOR MOCK FIXTURES, DO NOT REMOVE""" +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +def test_predefined_models(): + model = XAILargeLanguageModel() + model_schemas = model.predefined_models() + + assert len(model_schemas) >= 1 + assert isinstance(model_schemas[0], AIModelEntity) + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_validate_credentials_for_chat_model(setup_openai_mock): + model = XAILargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + # model name to gpt-3.5-turbo because of mocking + model.validate_credentials( + model="gpt-3.5-turbo", + credentials={"api_key": "invalid_key", "endpoint_url": os.environ.get("XAI_API_BASE"), "mode": "chat"}, + ) + + model.validate_credentials( + model="grok-beta", + credentials={ + "api_key": os.environ.get("XAI_API_KEY"), + "endpoint_url": os.environ.get("XAI_API_BASE"), + "mode": "chat", + }, + ) + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_chat_model(setup_openai_mock): + model = XAILargeLanguageModel() + + result = model.invoke( + model="grok-beta", + credentials={ + "api_key": os.environ.get("XAI_API_KEY"), + "endpoint_url": os.environ.get("XAI_API_BASE"), + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={ + "temperature": 0.0, + "top_p": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, + }, + stop=["How"], + stream=False, + user="foo", + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_chat_model_with_tools(setup_openai_mock): + model = XAILargeLanguageModel() + + result = model.invoke( + model="grok-beta", + credentials={ + "api_key": os.environ.get("XAI_API_KEY"), + "endpoint_url": os.environ.get("XAI_API_BASE"), + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage( + content="what's the weather today in London?", + ), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + tools=[ + PromptMessageTool( + name="get_weather", + description="Determine weather in my location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ), + PromptMessageTool( + name="get_stock_price", + description="Get the current stock price", + parameters={ + "type": "object", + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), + ], + stream=False, + user="foo", + ) + + assert isinstance(result, LLMResult) + assert isinstance(result.message, AssistantPromptMessage) + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_stream_chat_model(setup_openai_mock): + model = XAILargeLanguageModel() + + result = model.invoke( + model="grok-beta", + credentials={ + "api_key": os.environ.get("XAI_API_KEY"), + "endpoint_url": os.environ.get("XAI_API_BASE"), + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + stream=True, + user="foo", + ) + + assert isinstance(result, Generator) + + for chunk in result: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + if chunk.delta.finish_reason is not None: + assert chunk.delta.usage is not None + assert chunk.delta.usage.completion_tokens > 0 + + +def test_get_num_tokens(): + model = XAILargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="grok-beta", + credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + ) + + assert num_tokens == 10 + + num_tokens = model.get_num_tokens( + model="grok-beta", + credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + tools=[ + PromptMessageTool( + name="get_weather", + description="Determine weather in my location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ), + ], + ) + + assert num_tokens == 77 diff --git a/api/tests/integration_tests/vdb/lindorm/__init__.py b/api/tests/integration_tests/vdb/lindorm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/vdb/lindorm/test_lindorm.py b/api/tests/integration_tests/vdb/lindorm/test_lindorm.py new file mode 100644 index 0000000000..f8f43ba6ef --- /dev/null +++ b/api/tests/integration_tests/vdb/lindorm/test_lindorm.py @@ -0,0 +1,35 @@ +import environs + +from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStore, LindormVectorStoreConfig +from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis + +env = environs.Env() + + +class Config: + SEARCH_ENDPOINT = env.str("SEARCH_ENDPOINT", "http://ld-*************-proxy-search-pub.lindorm.aliyuncs.com:30070") + SEARCH_USERNAME = env.str("SEARCH_USERNAME", "ADMIN") + SEARCH_PWD = env.str("SEARCH_PWD", "PWD") + + +class TestLindormVectorStore(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = LindormVectorStore( + collection_name=self.collection_name, + config=LindormVectorStoreConfig( + hosts=Config.SEARCH_ENDPOINT, + username=Config.SEARCH_USERNAME, + password=Config.SEARCH_PWD, + ), + ) + + def get_ids_by_metadata_field(self): + ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id) + assert ids is not None + assert len(ids) == 1 + assert ids[0] == self.example_doc_id + + +def test_lindorm_vector(setup_mock_redis): + TestLindormVectorStore().run_all_tests() diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py new file mode 100644 index 0000000000..a6bf43ab0c --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py @@ -0,0 +1,52 @@ +import pytest + +from core.app.app_config.entities import VariableEntity, VariableEntityType +from core.app.apps.base_app_generator import BaseAppGenerator + + +def test_validate_inputs_with_zero(): + base_app_generator = BaseAppGenerator() + + var = VariableEntity( + variable="test_var", + label="test_var", + type=VariableEntityType.NUMBER, + required=True, + ) + + # Test with input 0 + result = base_app_generator._validate_inputs( + variable_entity=var, + value=0, + ) + + assert result == 0 + + # Test with input "0" (string) + result = base_app_generator._validate_inputs( + variable_entity=var, + value="0", + ) + + assert result == 0 + + +def test_validate_input_with_none_for_required_variable(): + base_app_generator = BaseAppGenerator() + + for var_type in VariableEntityType: + var = VariableEntity( + variable="test_var", + label="test_var", + type=var_type, + required=True, + ) + + # Test with input None + with pytest.raises(ValueError) as exc_info: + base_app_generator._validate_inputs( + variable_entity=var, + value=None, + ) + + assert str(exc_info.value) == "test_var is required in input form" diff --git a/api/tests/unit_tests/core/app/segments/test_factory.py b/api/tests/unit_tests/core/app/segments/test_factory.py index 72d277fad4..882a87239b 100644 --- a/api/tests/unit_tests/core/app/segments/test_factory.py +++ b/api/tests/unit_tests/core/app/segments/test_factory.py @@ -13,6 +13,7 @@ from core.variables import ( StringVariable, ) from core.variables.exc import VariableError +from core.variables.segments import ArrayAnySegment from factories import variable_factory @@ -156,3 +157,9 @@ def test_variable_cannot_large_than_200_kb(): "value": "a" * 1024 * 201, } ) + + +def test_array_none_variable(): + var = variable_factory.build_segment([None, None, None, None]) + assert isinstance(var, ArrayAnySegment) + assert var.value == [None, None, None, None] diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py new file mode 100644 index 0000000000..12c469a81a --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -0,0 +1,198 @@ +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.http_request import ( + BodyData, + HttpRequestNodeAuthorization, + HttpRequestNodeBody, + HttpRequestNodeData, +) +from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout +from core.workflow.nodes.http_request.executor import Executor + + +def test_executor_with_json_body_and_number_variable(): + # Prepare the variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + variable_pool.add(["pre_node_id", "number"], 42) + + # Prepare the node data + node_data = HttpRequestNodeData( + title="Test JSON Body with Number Variable", + method="post", + url="https://api.example.com/data", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="Content-Type: application/json", + params="", + body=HttpRequestNodeBody( + type="json", + data=[ + BodyData( + key="", + type="text", + value='{"number": {{#pre_node_id.number#}}}', + ) + ], + ), + ) + + # Initialize the Executor + executor = Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), + variable_pool=variable_pool, + ) + + # Check the executor's data + assert executor.method == "post" + assert executor.url == "https://api.example.com/data" + assert executor.headers == {"Content-Type": "application/json"} + assert executor.params == {} + assert executor.json == {"number": 42} + assert executor.data is None + assert executor.files is None + assert executor.content is None + + # Check the raw request (to_log method) + raw_request = executor.to_log() + assert "POST /data HTTP/1.1" in raw_request + assert "Host: api.example.com" in raw_request + assert "Content-Type: application/json" in raw_request + assert '{"number": 42}' in raw_request + + +def test_executor_with_json_body_and_object_variable(): + # Prepare the variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) + + # Prepare the node data + node_data = HttpRequestNodeData( + title="Test JSON Body with Object Variable", + method="post", + url="https://api.example.com/data", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="Content-Type: application/json", + params="", + body=HttpRequestNodeBody( + type="json", + data=[ + BodyData( + key="", + type="text", + value="{{#pre_node_id.object#}}", + ) + ], + ), + ) + + # Initialize the Executor + executor = Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), + variable_pool=variable_pool, + ) + + # Check the executor's data + assert executor.method == "post" + assert executor.url == "https://api.example.com/data" + assert executor.headers == {"Content-Type": "application/json"} + assert executor.params == {} + assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"} + assert executor.data is None + assert executor.files is None + assert executor.content is None + + # Check the raw request (to_log method) + raw_request = executor.to_log() + assert "POST /data HTTP/1.1" in raw_request + assert "Host: api.example.com" in raw_request + assert "Content-Type: application/json" in raw_request + assert '"name": "John Doe"' in raw_request + assert '"age": 30' in raw_request + assert '"email": "john@example.com"' in raw_request + + +def test_executor_with_json_body_and_nested_object_variable(): + # Prepare the variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) + + # Prepare the node data + node_data = HttpRequestNodeData( + title="Test JSON Body with Nested Object Variable", + method="post", + url="https://api.example.com/data", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="Content-Type: application/json", + params="", + body=HttpRequestNodeBody( + type="json", + data=[ + BodyData( + key="", + type="text", + value='{"object": {{#pre_node_id.object#}}}', + ) + ], + ), + ) + + # Initialize the Executor + executor = Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), + variable_pool=variable_pool, + ) + + # Check the executor's data + assert executor.method == "post" + assert executor.url == "https://api.example.com/data" + assert executor.headers == {"Content-Type": "application/json"} + assert executor.params == {} + assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}} + assert executor.data is None + assert executor.files is None + assert executor.content is None + + # Check the raw request (to_log method) + raw_request = executor.to_log() + assert "POST /data HTTP/1.1" in raw_request + assert "Host: api.example.com" in raw_request + assert "Content-Type: application/json" in raw_request + assert '"object": {' in raw_request + assert '"name": "John Doe"' in raw_request + assert '"age": 30' in raw_request + assert '"email": "john@example.com"' in raw_request + + +def test_extract_selectors_from_template_with_newline(): + variable_pool = VariablePool() + variable_pool.add(("node_id", "custom_query"), "line1\nline2") + node_data = HttpRequestNodeData( + title="Test JSON Body with Nested Object Variable", + method="post", + url="https://api.example.com/data", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="Content-Type: application/json", + params="test: {{#node_id.custom_query#}}", + body=HttpRequestNodeBody( + type="none", + data=[], + ), + ) + + executor = Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), + variable_pool=variable_pool, + ) + + assert executor.params == {"test": "line1\nline2"} diff --git a/api/tests/unit_tests/core/workflow/nodes/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py similarity index 52% rename from api/tests/unit_tests/core/workflow/nodes/test_http_request_node.py rename to api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 720037d05f..741a3a1894 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -1,5 +1,3 @@ -import json - import httpx from core.app.entities.app_invoke_entities import InvokeFrom @@ -16,8 +14,7 @@ from core.workflow.nodes.http_request import ( HttpRequestNodeBody, HttpRequestNodeData, ) -from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout -from core.workflow.nodes.http_request.executor import Executor, _plain_text_to_dict +from core.workflow.nodes.http_request.executor import _plain_text_to_dict from models.enums import UserFrom from models.workflow import WorkflowNodeExecutionStatus, WorkflowType @@ -203,167 +200,3 @@ def test_http_request_node_form_with_file(monkeypatch): assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs is not None assert result.outputs["body"] == "" - - -def test_executor_with_json_body_and_number_variable(): - # Prepare the variable pool - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - variable_pool.add(["pre_node_id", "number"], 42) - - # Prepare the node data - node_data = HttpRequestNodeData( - title="Test JSON Body with Number Variable", - method="post", - url="https://api.example.com/data", - authorization=HttpRequestNodeAuthorization(type="no-auth"), - headers="Content-Type: application/json", - params="", - body=HttpRequestNodeBody( - type="json", - data=[ - BodyData( - key="", - type="text", - value='{"number": {{#pre_node_id.number#}}}', - ) - ], - ), - ) - - # Initialize the Executor - executor = Executor( - node_data=node_data, - timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), - variable_pool=variable_pool, - ) - - # Check the executor's data - assert executor.method == "post" - assert executor.url == "https://api.example.com/data" - assert executor.headers == {"Content-Type": "application/json"} - assert executor.params == {} - assert executor.json == {"number": 42} - assert executor.data is None - assert executor.files is None - assert executor.content is None - - # Check the raw request (to_log method) - raw_request = executor.to_log() - assert "POST /data HTTP/1.1" in raw_request - assert "Host: api.example.com" in raw_request - assert "Content-Type: application/json" in raw_request - assert '{"number": 42}' in raw_request - - -def test_executor_with_json_body_and_object_variable(): - # Prepare the variable pool - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) - - # Prepare the node data - node_data = HttpRequestNodeData( - title="Test JSON Body with Object Variable", - method="post", - url="https://api.example.com/data", - authorization=HttpRequestNodeAuthorization(type="no-auth"), - headers="Content-Type: application/json", - params="", - body=HttpRequestNodeBody( - type="json", - data=[ - BodyData( - key="", - type="text", - value="{{#pre_node_id.object#}}", - ) - ], - ), - ) - - # Initialize the Executor - executor = Executor( - node_data=node_data, - timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), - variable_pool=variable_pool, - ) - - # Check the executor's data - assert executor.method == "post" - assert executor.url == "https://api.example.com/data" - assert executor.headers == {"Content-Type": "application/json"} - assert executor.params == {} - assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"} - assert executor.data is None - assert executor.files is None - assert executor.content is None - - # Check the raw request (to_log method) - raw_request = executor.to_log() - assert "POST /data HTTP/1.1" in raw_request - assert "Host: api.example.com" in raw_request - assert "Content-Type: application/json" in raw_request - assert '"name": "John Doe"' in raw_request - assert '"age": 30' in raw_request - assert '"email": "john@example.com"' in raw_request - - -def test_executor_with_json_body_and_nested_object_variable(): - # Prepare the variable pool - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) - - # Prepare the node data - node_data = HttpRequestNodeData( - title="Test JSON Body with Nested Object Variable", - method="post", - url="https://api.example.com/data", - authorization=HttpRequestNodeAuthorization(type="no-auth"), - headers="Content-Type: application/json", - params="", - body=HttpRequestNodeBody( - type="json", - data=[ - BodyData( - key="", - type="text", - value='{"object": {{#pre_node_id.object#}}}', - ) - ], - ), - ) - - # Initialize the Executor - executor = Executor( - node_data=node_data, - timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), - variable_pool=variable_pool, - ) - - # Check the executor's data - assert executor.method == "post" - assert executor.url == "https://api.example.com/data" - assert executor.headers == {"Content-Type": "application/json"} - assert executor.params == {} - assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}} - assert executor.data is None - assert executor.files is None - assert executor.content is None - - # Check the raw request (to_log method) - raw_request = executor.to_log() - assert "POST /data HTTP/1.1" in raw_request - assert "Host: api.example.com" in raw_request - assert "Content-Type: application/json" in raw_request - assert '"object": {' in raw_request - assert '"name": "John Doe"' in raw_request - assert '"age": 30' in raw_request - assert '"email": "john@example.com"' in raw_request diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py index d755faee8a..29bd4d6c6c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py @@ -10,6 +10,7 @@ from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.event import RunCompletedEvent +from core.workflow.nodes.iteration.entities import ErrorHandleMode from core.workflow.nodes.iteration.iteration_node import IterationNode from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from models.enums import UserFrom @@ -185,8 +186,6 @@ def test_run(): outputs={"output": "dify 123"}, ) - # print("") - with patch.object(TemplateTransformNode, "_run", new=tt_generator): # execute node result = iteration_node._run() @@ -404,18 +403,458 @@ def test_run_parallel(): outputs={"output": "dify 123"}, ) - # print("") - with patch.object(TemplateTransformNode, "_run", new=tt_generator): # execute node result = iteration_node._run() count = 0 for item in result: - # print(type(item), item) count += 1 if isinstance(item, RunCompletedEvent): assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} assert count == 32 + + +def test_iteration_run_in_parallel_mode(): + graph_config = { + "edges": [ + { + "id": "start-source-pe-target", + "source": "start", + "target": "pe", + }, + { + "id": "iteration-1-source-answer-3-target", + "source": "iteration-1", + "target": "answer-3", + }, + { + "id": "iteration-start-source-tt-target", + "source": "iteration-start", + "target": "tt", + }, + { + "id": "iteration-start-source-tt-2-target", + "source": "iteration-start", + "target": "tt-2", + }, + { + "id": "tt-source-if-else-target", + "source": "tt", + "target": "if-else", + }, + { + "id": "tt-2-source-if-else-target", + "source": "tt-2", + "target": "if-else", + }, + { + "id": "if-else-true-answer-2-target", + "source": "if-else", + "sourceHandle": "true", + "target": "answer-2", + }, + { + "id": "if-else-false-answer-4-target", + "source": "if-else", + "sourceHandle": "false", + "target": "answer-4", + }, + { + "id": "pe-source-iteration-1-target", + "source": "pe", + "target": "iteration-1", + }, + ], + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "iteration", + "type": "iteration", + }, + "id": "iteration-1", + }, + { + "data": { + "answer": "{{#tt.output#}}", + "iteration_id": "iteration-1", + "title": "answer 2", + "type": "answer", + }, + "id": "answer-2", + }, + { + "data": { + "iteration_id": "iteration-1", + "title": "iteration-start", + "type": "iteration-start", + }, + "id": "iteration-start", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }} 123", + "title": "template transform", + "type": "template-transform", + "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], + }, + "id": "tt", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }} 321", + "title": "template transform", + "type": "template-transform", + "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], + }, + "id": "tt-2", + }, + { + "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, + "id": "answer-3", + }, + { + "data": { + "conditions": [ + { + "comparison_operator": "is", + "id": "1721916275284", + "value": "hi", + "variable_selector": ["sys", "query"], + } + ], + "iteration_id": "iteration-1", + "logical_operator": "and", + "title": "if", + "type": "if-else", + }, + "id": "if-else", + }, + { + "data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"}, + "id": "answer-4", + }, + { + "data": { + "instruction": "test1", + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "parameters": [ + {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} + ], + "query": ["sys", "query"], + "reasoning_mode": "prompt", + "title": "pe", + "type": "parameter-extractor", + }, + "id": "pe", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.CHAT, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "dify", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "1", + }, + user_inputs={}, + environment_variables=[], + ) + pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) + + parallel_iteration_node = IterationNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "迭代", + "type": "iteration", + "is_parallel": True, + }, + "id": "iteration-1", + }, + ) + sequential_iteration_node = IterationNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "迭代", + "type": "iteration", + "is_parallel": True, + }, + "id": "iteration-1", + }, + ) + + def tt_generator(self): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"iterator_selector": "dify"}, + outputs={"output": "dify 123"}, + ) + + with patch.object(TemplateTransformNode, "_run", new=tt_generator): + # execute node + parallel_result = parallel_iteration_node._run() + sequential_result = sequential_iteration_node._run() + assert parallel_iteration_node.node_data.parallel_nums == 10 + assert parallel_iteration_node.node_data.error_handle_mode == ErrorHandleMode.TERMINATED + count = 0 + parallel_arr = [] + sequential_arr = [] + for item in parallel_result: + count += 1 + parallel_arr.append(item) + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + assert count == 32 + + for item in sequential_result: + sequential_arr.append(item) + count += 1 + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + assert count == 64 + + +def test_iteration_run_error_handle(): + graph_config = { + "edges": [ + { + "id": "start-source-pe-target", + "source": "start", + "target": "pe", + }, + { + "id": "iteration-1-source-answer-3-target", + "source": "iteration-1", + "target": "answer-3", + }, + { + "id": "tt-source-if-else-target", + "source": "iteration-start", + "target": "if-else", + }, + { + "id": "if-else-true-answer-2-target", + "source": "if-else", + "sourceHandle": "true", + "target": "tt", + }, + { + "id": "if-else-false-answer-4-target", + "source": "if-else", + "sourceHandle": "false", + "target": "tt2", + }, + { + "id": "pe-source-iteration-1-target", + "source": "pe", + "target": "iteration-1", + }, + ], + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt2", "output"], + "output_type": "array[string]", + "start_node_id": "if-else", + "title": "iteration", + "type": "iteration", + }, + "id": "iteration-1", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1.split(arg2) }}", + "title": "template transform", + "type": "template-transform", + "variables": [ + {"value_selector": ["iteration-1", "item"], "variable": "arg1"}, + {"value_selector": ["iteration-1", "index"], "variable": "arg2"}, + ], + }, + "id": "tt", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }}", + "title": "template transform", + "type": "template-transform", + "variables": [ + {"value_selector": ["iteration-1", "item"], "variable": "arg1"}, + ], + }, + "id": "tt2", + }, + { + "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, + "id": "answer-3", + }, + { + "data": { + "iteration_id": "iteration-1", + "title": "iteration-start", + "type": "iteration-start", + }, + "id": "iteration-start", + }, + { + "data": { + "conditions": [ + { + "comparison_operator": "is", + "id": "1721916275284", + "value": "1", + "variable_selector": ["iteration-1", "item"], + } + ], + "iteration_id": "iteration-1", + "logical_operator": "and", + "title": "if", + "type": "if-else", + }, + "id": "if-else", + }, + { + "data": { + "instruction": "test1", + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "parameters": [ + {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} + ], + "query": ["sys", "query"], + "reasoning_mode": "prompt", + "title": "pe", + "type": "parameter-extractor", + }, + "id": "pe", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.CHAT, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "dify", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "1", + }, + user_inputs={}, + environment_variables=[], + ) + pool.add(["pe", "list_output"], ["1", "1"]) + iteration_node = IterationNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "iteration", + "type": "iteration", + "is_parallel": True, + "error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR, + }, + "id": "iteration-1", + }, + ) + # execute continue on error node + result = iteration_node._run() + result_arr = [] + count = 0 + for item in result: + result_arr.append(item) + count += 1 + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": [None, None]} + + assert count == 14 + # execute remove abnormal output + iteration_node.node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT + result = iteration_node._run() + count = 0 + for item in result: + count += 1 + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": []} + assert count == 14 diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index 53e3c93fcc..0f5c8bf51b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -2,11 +2,11 @@ from unittest.mock import MagicMock import pytest -from core.file import File -from core.file.models import FileTransferMethod, FileType +from core.file import File, FileTransferMethod, FileType from core.variables import ArrayFileSegment from core.workflow.nodes.list_operator.entities import FilterBy, FilterCondition, Limit, ListOperatorNodeData, OrderBy -from core.workflow.nodes.list_operator.node import ListOperatorNode +from core.workflow.nodes.list_operator.exc import InvalidKeyError +from core.workflow.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func from models.workflow import WorkflowNodeExecutionStatus @@ -109,3 +109,46 @@ def test_filter_files_by_type(list_operator_node): assert expected_file["tenant_id"] == result_file.tenant_id assert expected_file["transfer_method"] == result_file.transfer_method assert expected_file["related_id"] == result_file.related_id + + +def test_get_file_extract_string_func(): + # Create a File object + file = File( + tenant_id="test_tenant", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + filename="test_file.txt", + extension=".txt", + mime_type="text/plain", + remote_url="https://example.com/test_file.txt", + related_id="test_related_id", + ) + + # Test each case + assert _get_file_extract_string_func(key="name")(file) == "test_file.txt" + assert _get_file_extract_string_func(key="type")(file) == "document" + assert _get_file_extract_string_func(key="extension")(file) == ".txt" + assert _get_file_extract_string_func(key="mime_type")(file) == "text/plain" + assert _get_file_extract_string_func(key="transfer_method")(file) == "local_file" + assert _get_file_extract_string_func(key="url")(file) == "https://example.com/test_file.txt" + + # Test with empty values + empty_file = File( + tenant_id="test_tenant", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + filename=None, + extension=None, + mime_type=None, + remote_url=None, + related_id="test_related_id", + ) + + assert _get_file_extract_string_func(key="name")(empty_file) == "" + assert _get_file_extract_string_func(key="extension")(empty_file) == "" + assert _get_file_extract_string_func(key="mime_type")(empty_file) == "" + assert _get_file_extract_string_func(key="url")(empty_file) == "" + + # Test invalid key + with pytest.raises(InvalidKeyError): + _get_file_extract_string_func(key="invalid_key") diff --git a/dev/reformat b/dev/reformat index ad83e897d9..94a7f3e6fe 100755 --- a/dev/reformat +++ b/dev/reformat @@ -9,10 +9,10 @@ if ! command -v ruff &> /dev/null || ! command -v dotenv-linter &> /dev/null; th fi # run ruff linter -ruff check --fix ./api +poetry run -C api ruff check --fix ./api # run ruff formatter -ruff format ./api +poetry run -C api ruff format ./api # run dotenv-linter linter -dotenv-linter ./api/.env.example ./web/.env.example +poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example diff --git a/docker/.env.example b/docker/.env.example index 34b2136302..aa5e102bd0 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -222,7 +222,6 @@ REDIS_PORT=6379 REDIS_USERNAME= REDIS_PASSWORD=difyai123456 REDIS_USE_SSL=false -REDIS_DB=0 # Whether to use Redis Sentinel mode. # If set to true, the application will automatically discover and connect to the master node through Sentinel. @@ -531,6 +530,12 @@ VIKINGDB_SCHEMA=http VIKINGDB_CONNECTION_TIMEOUT=30 VIKINGDB_SOCKET_TIMEOUT=30 + +# Lindorm configuration, only available when VECTOR_STORE is `lindorm` +LINDORM_URL=http://ld-***************-proxy-search-pub.lindorm.aliyuncs.com:30070 +LINDORM_USERNAME=username +LINDORM_PASSWORD=password + # OceanBase Vector configuration, only available when VECTOR_STORE is `oceanbase` OCEANBASE_VECTOR_HOST=oceanbase-vector OCEANBASE_VECTOR_PORT=2881 @@ -645,7 +650,6 @@ MAIL_DEFAULT_SEND_FROM= # API-Key for the Resend email provider, used when MAIL_TYPE is `resend`. RESEND_API_KEY=your-resend-api-key -RESEND_API_URL=https://api.resend.com # SMTP server configuration, used when MAIL_TYPE is `smtp` SMTP_SERVER= @@ -686,6 +690,7 @@ WORKFLOW_MAX_EXECUTION_STEPS=500 WORKFLOW_MAX_EXECUTION_TIME=1200 WORKFLOW_CALL_MAX_DEPTH=5 MAX_VARIABLE_SIZE=204800 +WORKFLOW_FILE_UPLOAD_LIMIT=10 # HTTP request node in workflow configuration HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 112e9a2702..a26838af10 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -1,4 +1,5 @@ x-shared-env: &shared-api-worker-env + WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10} LOG_LEVEL: ${LOG_LEVEL:-INFO} LOG_FILE: ${LOG_FILE:-} LOG_FILE_MAX_SIZE: ${LOG_FILE_MAX_SIZE:-20} @@ -167,6 +168,9 @@ x-shared-env: &shared-api-worker-env ELASTICSEARCH_PORT: ${ELASTICSEARCH_PORT:-9200} ELASTICSEARCH_USERNAME: ${ELASTICSEARCH_USERNAME:-elastic} ELASTICSEARCH_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} + LINDORM_URL: ${LINDORM_URL:-http://lindorm:30070} + LINDORM_USERNAME: ${LINDORM_USERNAME:-lindorm} + LINDORM_PASSWORD: ${LINDORM_USERNAME:-lindorm } KIBANA_PORT: ${KIBANA_PORT:-5601} # AnalyticDB configuration ANALYTICDB_KEY_ID: ${ANALYTICDB_KEY_ID:-} diff --git a/web/app/account/avatar.tsx b/web/app/account/avatar.tsx index 2b9aeba5da..544e43ab27 100644 --- a/web/app/account/avatar.tsx +++ b/web/app/account/avatar.tsx @@ -23,8 +23,9 @@ export default function AppSelector() { params: {}, }) - if (localStorage?.getItem('console_token')) - localStorage.removeItem('console_token') + localStorage.removeItem('setup_status') + localStorage.removeItem('console_token') + localStorage.removeItem('refresh_token') router.push('/signin') } diff --git a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx index 05339c7216..0a20f4b376 100644 --- a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx +++ b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx @@ -33,6 +33,10 @@ import { LoveMessage } from '@/app/components/base/icons/src/vender/features' // type import type { AutomaticRes } from '@/service/debug' import { Generator } from '@/app/components/base/icons/src/vender/other' +import ModelIcon from '@/app/components/header/account-setting/model-provider-page/model-icon' +import ModelName from '@/app/components/header/account-setting/model-provider-page/model-name' +import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' export type IGetAutomaticResProps = { mode: AppType @@ -68,7 +72,10 @@ const GetAutomaticRes: FC = ({ onFinished, }) => { const { t } = useTranslation() - + const { + currentProvider, + currentModel, + } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration) const tryList = [ { icon: RiTerminalBoxLine, @@ -191,6 +198,19 @@ const GetAutomaticRes: FC = ({
{t('appDebug.generate.title')}
{t('appDebug.generate.description')}
+
+ + +
{t('appDebug.generate.tryIt')}
diff --git a/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx b/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx index b63e3e2693..85c522ca0f 100644 --- a/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx +++ b/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx @@ -105,6 +105,15 @@ export const GetCodeGeneratorResModal: FC = (
{t('appDebug.codegen.loading')}
) + const renderNoData = ( +
+ +
+
{t('appDebug.codegen.noDataLine1')}
+
{t('appDebug.codegen.noDataLine2')}
+
+
+ ) return ( = (
{isLoading && renderLoading} + {!isLoading && !res && renderNoData} {(!isLoading && res) && (
{t('appDebug.codegen.resTitle')}
diff --git a/web/app/components/base/file-uploader/constants.ts b/web/app/components/base/file-uploader/constants.ts index 629fe2566b..a749d73c74 100644 --- a/web/app/components/base/file-uploader/constants.ts +++ b/web/app/components/base/file-uploader/constants.ts @@ -3,5 +3,6 @@ export const IMG_SIZE_LIMIT = 10 * 1024 * 1024 export const FILE_SIZE_LIMIT = 15 * 1024 * 1024 export const AUDIO_SIZE_LIMIT = 50 * 1024 * 1024 export const VIDEO_SIZE_LIMIT = 100 * 1024 * 1024 +export const MAX_FILE_UPLOAD_LIMIT = 10 export const FILE_URL_REGEX = /^(https?|ftp):\/\// diff --git a/web/app/components/base/file-uploader/hooks.ts b/web/app/components/base/file-uploader/hooks.ts index a78c414913..c735754ffe 100644 --- a/web/app/components/base/file-uploader/hooks.ts +++ b/web/app/components/base/file-uploader/hooks.ts @@ -18,6 +18,7 @@ import { AUDIO_SIZE_LIMIT, FILE_SIZE_LIMIT, IMG_SIZE_LIMIT, + MAX_FILE_UPLOAD_LIMIT, VIDEO_SIZE_LIMIT, } from '@/app/components/base/file-uploader/constants' import { useToastContext } from '@/app/components/base/toast' @@ -33,12 +34,14 @@ export const useFileSizeLimit = (fileUploadConfig?: FileUploadConfigResponse) => const docSizeLimit = Number(fileUploadConfig?.file_size_limit) * 1024 * 1024 || FILE_SIZE_LIMIT const audioSizeLimit = Number(fileUploadConfig?.audio_file_size_limit) * 1024 * 1024 || AUDIO_SIZE_LIMIT const videoSizeLimit = Number(fileUploadConfig?.video_file_size_limit) * 1024 * 1024 || VIDEO_SIZE_LIMIT + const maxFileUploadLimit = Number(fileUploadConfig?.workflow_file_upload_limit) || MAX_FILE_UPLOAD_LIMIT return { imgSizeLimit, docSizeLimit, audioSizeLimit, videoSizeLimit, + maxFileUploadLimit, } } @@ -216,7 +219,7 @@ export const useFile = (fileConfig: FileUpload) => { handleAddFile(uploadingFile) startProgressTimer(uploadingFile.id) - uploadRemoteFileInfo(url).then((res) => { + uploadRemoteFileInfo(url, !!params.token).then((res) => { const newFile = { ...uploadingFile, type: res.mime_type, diff --git a/web/app/components/base/select/index.tsx b/web/app/components/base/select/index.tsx index ee5cee977b..c70cf24661 100644 --- a/web/app/components/base/select/index.tsx +++ b/web/app/components/base/select/index.tsx @@ -125,7 +125,7 @@ const Select: FC = ({
- {filteredItems.length > 0 && ( + {(filteredItems.length > 0 && open) && ( {filteredItems.map((item: Item) => ( { const router = useRouter() const searchParams = useSearchParams() - const pathname = usePathname() - const { getNewAccessToken } = useRefreshToken() - const consoleToken = searchParams.get('access_token') - const refreshToken = searchParams.get('refresh_token') + const consoleToken = decodeURIComponent(searchParams.get('access_token') || '') + const refreshToken = decodeURIComponent(searchParams.get('refresh_token') || '') const consoleTokenFromLocalStorage = localStorage?.getItem('console_token') const refreshTokenFromLocalStorage = localStorage?.getItem('refresh_token') + const pathname = usePathname() const [init, setInit] = useState(false) const isSetupFinished = useCallback(async () => { @@ -41,25 +39,6 @@ const SwrInitor = ({ } }, []) - const setRefreshToken = useCallback(async () => { - try { - if (!(consoleToken || refreshToken || consoleTokenFromLocalStorage || refreshTokenFromLocalStorage)) - return Promise.reject(new Error('No token found')) - - if (consoleTokenFromLocalStorage && refreshTokenFromLocalStorage) - await getNewAccessToken() - - if (consoleToken && refreshToken) { - localStorage.setItem('console_token', consoleToken) - localStorage.setItem('refresh_token', refreshToken) - await getNewAccessToken() - } - } - catch (error) { - return Promise.reject(error) - } - }, [consoleToken, refreshToken, consoleTokenFromLocalStorage, refreshTokenFromLocalStorage, getNewAccessToken]) - useEffect(() => { (async () => { try { @@ -68,9 +47,15 @@ const SwrInitor = ({ router.replace('/install') return } - await setRefreshToken() - if (searchParams.has('access_token') || searchParams.has('refresh_token')) + if (!((consoleToken && refreshToken) || (consoleTokenFromLocalStorage && refreshTokenFromLocalStorage))) { + router.replace('/signin') + return + } + if (searchParams.has('access_token') || searchParams.has('refresh_token')) { + consoleToken && localStorage.setItem('console_token', consoleToken) + refreshToken && localStorage.setItem('refresh_token', refreshToken) router.replace(pathname) + } setInit(true) } @@ -78,7 +63,7 @@ const SwrInitor = ({ router.replace('/signin') } })() - }, [isSetupFinished, setRefreshToken, router, pathname, searchParams]) + }, [isSetupFinished, router, pathname, searchParams, consoleToken, refreshToken, consoleTokenFromLocalStorage, refreshTokenFromLocalStorage]) return init ? ( diff --git a/web/app/components/workflow/constants.ts b/web/app/components/workflow/constants.ts index 794996d1bd..ffa14b347b 100644 --- a/web/app/components/workflow/constants.ts +++ b/web/app/components/workflow/constants.ts @@ -340,7 +340,9 @@ export const NODES_INITIAL_DATA = { ...ListFilterDefault.defaultValue, }, } - +export const MAX_ITERATION_PARALLEL_NUM = 10 +export const MIN_ITERATION_PARALLEL_NUM = 1 +export const DEFAULT_ITER_TIMES = 1 export const NODE_WIDTH = 240 export const X_OFFSET = 60 export const NODE_WIDTH_X_OFFSET = NODE_WIDTH + X_OFFSET diff --git a/web/app/components/workflow/hooks/use-nodes-interactions.ts b/web/app/components/workflow/hooks/use-nodes-interactions.ts index af2a1500ba..375a269377 100644 --- a/web/app/components/workflow/hooks/use-nodes-interactions.ts +++ b/web/app/components/workflow/hooks/use-nodes-interactions.ts @@ -644,6 +644,11 @@ export const useNodesInteractions = () => { newNode.data.isInIteration = true newNode.data.iteration_id = prevNode.parentId newNode.zIndex = ITERATION_CHILDREN_Z_INDEX + if (newNode.data.type === BlockEnum.Answer || newNode.data.type === BlockEnum.Tool || newNode.data.type === BlockEnum.Assigner) { + const parentIterNodeIndex = nodes.findIndex(node => node.id === prevNode.parentId) + const iterNodeData: IterationNodeType = nodes[parentIterNodeIndex].data + iterNodeData._isShowTips = true + } } const newEdge: Edge = { diff --git a/web/app/components/workflow/hooks/use-workflow-run.ts b/web/app/components/workflow/hooks/use-workflow-run.ts index 0bbb1adab8..26654ef71e 100644 --- a/web/app/components/workflow/hooks/use-workflow-run.ts +++ b/web/app/components/workflow/hooks/use-workflow-run.ts @@ -14,6 +14,7 @@ import { NodeRunningStatus, WorkflowRunningStatus, } from '../types' +import { DEFAULT_ITER_TIMES } from '../constants' import { useWorkflowUpdate } from './use-workflow-interactions' import { useStore as useAppStore } from '@/app/components/app/store' import type { IOtherOptions } from '@/service/base' @@ -170,11 +171,13 @@ export const useWorkflowRun = () => { const { workflowRunningData, setWorkflowRunningData, + setIterParallelLogMap, } = workflowStore.getState() const { edges, setEdges, } = store.getState() + setIterParallelLogMap(new Map()) setWorkflowRunningData(produce(workflowRunningData!, (draft) => { draft.task_id = task_id draft.result = { @@ -244,6 +247,8 @@ export const useWorkflowRun = () => { const { workflowRunningData, setWorkflowRunningData, + iterParallelLogMap, + setIterParallelLogMap, } = workflowStore.getState() const { getNodes, @@ -259,10 +264,21 @@ export const useWorkflowRun = () => { const tracing = draft.tracing! const iterations = tracing.find(trace => trace.node_id === node?.parentId) const currIteration = iterations?.details![node.data.iteration_index] || iterations?.details![iterations.details!.length - 1] - currIteration?.push({ - ...data, - status: NodeRunningStatus.Running, - } as any) + if (!data.parallel_run_id) { + currIteration?.push({ + ...data, + status: NodeRunningStatus.Running, + } as any) + } + else { + if (!iterParallelLogMap.has(data.parallel_run_id)) + iterParallelLogMap.set(data.parallel_run_id, [{ ...data, status: NodeRunningStatus.Running } as any]) + else + iterParallelLogMap.get(data.parallel_run_id)!.push({ ...data, status: NodeRunningStatus.Running } as any) + setIterParallelLogMap(iterParallelLogMap) + if (iterations) + iterations.details = Array.from(iterParallelLogMap.values()) + } })) } else { @@ -309,6 +325,8 @@ export const useWorkflowRun = () => { const { workflowRunningData, setWorkflowRunningData, + iterParallelLogMap, + setIterParallelLogMap, } = workflowStore.getState() const { getNodes, @@ -317,21 +335,21 @@ export const useWorkflowRun = () => { const nodes = getNodes() const nodeParentId = nodes.find(node => node.id === data.node_id)!.parentId if (nodeParentId) { - setWorkflowRunningData(produce(workflowRunningData!, (draft) => { - const tracing = draft.tracing! - const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node + if (!data.execution_metadata.parallel_mode_run_id) { + setWorkflowRunningData(produce(workflowRunningData!, (draft) => { + const tracing = draft.tracing! + const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node - if (iterations && iterations.details) { - const iterationIndex = data.execution_metadata?.iteration_index || 0 - if (!iterations.details[iterationIndex]) - iterations.details[iterationIndex] = [] + if (iterations && iterations.details) { + const iterationIndex = data.execution_metadata?.iteration_index || 0 + if (!iterations.details[iterationIndex]) + iterations.details[iterationIndex] = [] - const currIteration = iterations.details[iterationIndex] - const nodeIndex = currIteration.findIndex(node => - node.node_id === data.node_id && ( - node.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || node.parallel_id === data.execution_metadata?.parallel_id), - ) - if (data.status === NodeRunningStatus.Succeeded) { + const currIteration = iterations.details[iterationIndex] + const nodeIndex = currIteration.findIndex(node => + node.node_id === data.node_id && ( + node.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || node.parallel_id === data.execution_metadata?.parallel_id), + ) if (nodeIndex !== -1) { currIteration[nodeIndex] = { ...currIteration[nodeIndex], @@ -344,8 +362,40 @@ export const useWorkflowRun = () => { } as any) } } - } - })) + })) + } + else { + // open parallel mode + setWorkflowRunningData(produce(workflowRunningData!, (draft) => { + const tracing = draft.tracing! + const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node + + if (iterations && iterations.details) { + const iterRunID = data.execution_metadata?.parallel_mode_run_id + + const currIteration = iterParallelLogMap.get(iterRunID) + const nodeIndex = currIteration?.findIndex(node => + node.node_id === data.node_id && ( + node?.parallel_run_id === data.execution_metadata?.parallel_mode_run_id), + ) + if (currIteration) { + if (nodeIndex !== undefined && nodeIndex !== -1) { + currIteration[nodeIndex] = { + ...currIteration[nodeIndex], + ...data, + } as any + } + else { + currIteration.push({ + ...data, + } as any) + } + } + setIterParallelLogMap(iterParallelLogMap) + iterations.details = Array.from(iterParallelLogMap.values()) + } + })) + } } else { setWorkflowRunningData(produce(workflowRunningData!, (draft) => { @@ -379,6 +429,7 @@ export const useWorkflowRun = () => { const { workflowRunningData, setWorkflowRunningData, + setIterTimes, } = workflowStore.getState() const { getNodes, @@ -388,6 +439,7 @@ export const useWorkflowRun = () => { transform, } = store.getState() const nodes = getNodes() + setIterTimes(DEFAULT_ITER_TIMES) setWorkflowRunningData(produce(workflowRunningData!, (draft) => { draft.tracing!.push({ ...data, @@ -431,6 +483,8 @@ export const useWorkflowRun = () => { const { workflowRunningData, setWorkflowRunningData, + iterTimes, + setIterTimes, } = workflowStore.getState() const { data } = params @@ -445,13 +499,14 @@ export const useWorkflowRun = () => { if (iteration.details!.length >= iteration.metadata.iterator_length!) return } - iteration?.details!.push([]) + if (!data.parallel_mode_run_id) + iteration?.details!.push([]) })) const nodes = getNodes() const newNodes = produce(nodes, (draft) => { const currentNode = draft.find(node => node.id === data.node_id)! - - currentNode.data._iterationIndex = data.index > 0 ? data.index : 1 + currentNode.data._iterationIndex = iterTimes + setIterTimes(iterTimes + 1) }) setNodes(newNodes) @@ -464,6 +519,7 @@ export const useWorkflowRun = () => { const { workflowRunningData, setWorkflowRunningData, + setIterTimes, } = workflowStore.getState() const { getNodes, @@ -480,7 +536,7 @@ export const useWorkflowRun = () => { }) } })) - + setIterTimes(DEFAULT_ITER_TIMES) const newNodes = produce(nodes, (draft) => { const currentNode = draft.find(node => node.id === data.node_id)! diff --git a/web/app/components/workflow/nodes/_base/components/field.tsx b/web/app/components/workflow/nodes/_base/components/field.tsx index 6459cf8056..b2f815a325 100644 --- a/web/app/components/workflow/nodes/_base/components/field.tsx +++ b/web/app/components/workflow/nodes/_base/components/field.tsx @@ -12,15 +12,15 @@ import Tooltip from '@/app/components/base/tooltip' type Props = { className?: string title: JSX.Element | string | DefaultTFuncReturn + tooltip?: React.ReactNode isSubTitle?: boolean - tooltip?: string supportFold?: boolean children?: JSX.Element | string | null operations?: JSX.Element inline?: boolean } -const Filed: FC = ({ +const Field: FC = ({ className, title, isSubTitle, @@ -60,4 +60,4 @@ const Filed: FC = ({ ) } -export default React.memo(Filed) +export default React.memo(Field) diff --git a/web/app/components/workflow/nodes/_base/components/file-upload-setting.tsx b/web/app/components/workflow/nodes/_base/components/file-upload-setting.tsx index 82a3a906cf..42a7213f80 100644 --- a/web/app/components/workflow/nodes/_base/components/file-upload-setting.tsx +++ b/web/app/components/workflow/nodes/_base/components/file-upload-setting.tsx @@ -39,7 +39,13 @@ const FileUploadSetting: FC = ({ allowed_file_extensions, } = payload const { data: fileUploadConfigResponse } = useSWR({ url: '/files/upload' }, fetchFileUploadConfig) - const { imgSizeLimit, docSizeLimit, audioSizeLimit, videoSizeLimit } = useFileSizeLimit(fileUploadConfigResponse) + const { + imgSizeLimit, + docSizeLimit, + audioSizeLimit, + videoSizeLimit, + maxFileUploadLimit, + } = useFileSizeLimit(fileUploadConfigResponse) const handleSupportFileTypeChange = useCallback((type: SupportUploadFileTypes) => { const newPayload = produce(payload, (draft) => { @@ -156,7 +162,7 @@ const FileUploadSetting: FC = ({ diff --git a/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts b/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts index a66127875a..912025837f 100644 --- a/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts +++ b/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts @@ -106,32 +106,29 @@ const useOneStepRun = ({ const availableNodesIncludeParent = getBeforeNodesInSameBranchIncludeParent(id) const allOutputVars = toNodeOutputVars(availableNodes, isChatMode, undefined, undefined, conversationVariables) const getVar = (valueSelector: ValueSelector): Var | undefined => { - let res: Var | undefined const isSystem = valueSelector[0] === 'sys' - const targetVar = isSystem ? allOutputVars.find(item => !!item.isStartNode) : allOutputVars.find(v => v.nodeId === valueSelector[0]) + const targetVar = allOutputVars.find(item => isSystem ? !!item.isStartNode : item.nodeId === valueSelector[0]) if (!targetVar) return undefined + if (isSystem) return targetVar.vars.find(item => item.variable.split('.')[1] === valueSelector[1]) let curr: any = targetVar.vars - if (!curr) - return + for (let i = 1; i < valueSelector.length; i++) { + const key = valueSelector[i] + const isLast = i === valueSelector.length - 1 - valueSelector.slice(1).forEach((key, i) => { - const isLast = i === valueSelector.length - 2 - // conversation variable is start with 'conversation.' - curr = curr?.find((v: any) => v.variable.replace('conversation.', '') === key) - if (isLast) { - res = curr - } - else { - if (curr?.type === VarType.object || curr?.type === VarType.file) - curr = curr.children - } - }) + if (Array.isArray(curr)) + curr = curr.find((v: any) => v.variable.replace('conversation.', '') === key) - return res + if (isLast) + return curr + else if (curr?.type === VarType.object || curr?.type === VarType.file) + curr = curr.children + } + + return undefined } const checkValid = checkValidFns[data.type] diff --git a/web/app/components/workflow/nodes/_base/node.tsx b/web/app/components/workflow/nodes/_base/node.tsx index bd5921c735..e864c419e2 100644 --- a/web/app/components/workflow/nodes/_base/node.tsx +++ b/web/app/components/workflow/nodes/_base/node.tsx @@ -25,6 +25,7 @@ import { useToolIcon, } from '../../hooks' import { useNodeIterationInteractions } from '../iteration/use-interactions' +import type { IterationNodeType } from '../iteration/types' import { NodeSourceHandle, NodeTargetHandle, @@ -34,6 +35,7 @@ import NodeControl from './components/node-control' import AddVariablePopupWithPosition from './components/add-variable-popup-with-position' import cn from '@/utils/classnames' import BlockIcon from '@/app/components/workflow/block-icon' +import Tooltip from '@/app/components/base/tooltip' type BaseNodeProps = { children: ReactElement @@ -166,9 +168,27 @@ const BaseNode: FC = ({ />
- {data.title} +
+ {data.title} +
+ { + data.type === BlockEnum.Iteration && (data as IterationNodeType).is_parallel && ( + +
+ {t('workflow.nodes.iteration.parallelModeEnableTitle')} +
+ {t('workflow.nodes.iteration.parallelModeEnableDesc')} +
} + > +
+ {t('workflow.nodes.iteration.parallelModeUpper')} +
+ + ) + } { data._iterationLength && data._iterationIndex && data._runningStatus === NodeRunningStatus.Running && ( diff --git a/web/app/components/workflow/nodes/if-else/use-config.ts b/web/app/components/workflow/nodes/if-else/use-config.ts index d1210431a0..41e41f6b8b 100644 --- a/web/app/components/workflow/nodes/if-else/use-config.ts +++ b/web/app/components/workflow/nodes/if-else/use-config.ts @@ -78,24 +78,24 @@ const useConfig = (id: string, payload: IfElseNodeType) => { }) const handleAddCase = useCallback(() => { - const newInputs = produce(inputs, () => { - if (inputs.cases) { + const newInputs = produce(inputs, (draft) => { + if (draft.cases) { const case_id = uuid4() - inputs.cases.push({ + draft.cases.push({ case_id, logical_operator: LogicalOperator.and, conditions: [], }) - if (inputs._targetBranches) { - const elseCaseIndex = inputs._targetBranches.findIndex(branch => branch.id === 'false') + if (draft._targetBranches) { + const elseCaseIndex = draft._targetBranches.findIndex(branch => branch.id === 'false') if (elseCaseIndex > -1) { - inputs._targetBranches = branchNameCorrect([ - ...inputs._targetBranches.slice(0, elseCaseIndex), + draft._targetBranches = branchNameCorrect([ + ...draft._targetBranches.slice(0, elseCaseIndex), { id: case_id, name: '', }, - ...inputs._targetBranches.slice(elseCaseIndex), + ...draft._targetBranches.slice(elseCaseIndex), ]) } } diff --git a/web/app/components/workflow/nodes/iteration/default.ts b/web/app/components/workflow/nodes/iteration/default.ts index 3afa52d06e..cdef268adb 100644 --- a/web/app/components/workflow/nodes/iteration/default.ts +++ b/web/app/components/workflow/nodes/iteration/default.ts @@ -1,7 +1,10 @@ -import { BlockEnum } from '../../types' +import { BlockEnum, ErrorHandleMode } from '../../types' import type { NodeDefault } from '../../types' import type { IterationNodeType } from './types' -import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/constants' +import { + ALL_CHAT_AVAILABLE_BLOCKS, + ALL_COMPLETION_AVAILABLE_BLOCKS, +} from '@/app/components/workflow/constants' const i18nPrefix = 'workflow' const nodeDefault: NodeDefault = { @@ -10,25 +13,45 @@ const nodeDefault: NodeDefault = { iterator_selector: [], output_selector: [], _children: [], + _isShowTips: false, + is_parallel: false, + parallel_nums: 10, + error_handle_mode: ErrorHandleMode.Terminated, }, getAvailablePrevNodes(isChatMode: boolean) { const nodes = isChatMode ? ALL_CHAT_AVAILABLE_BLOCKS - : ALL_COMPLETION_AVAILABLE_BLOCKS.filter(type => type !== BlockEnum.End) + : ALL_COMPLETION_AVAILABLE_BLOCKS.filter( + type => type !== BlockEnum.End, + ) return nodes }, getAvailableNextNodes(isChatMode: boolean) { - const nodes = isChatMode ? ALL_CHAT_AVAILABLE_BLOCKS : ALL_COMPLETION_AVAILABLE_BLOCKS + const nodes = isChatMode + ? ALL_CHAT_AVAILABLE_BLOCKS + : ALL_COMPLETION_AVAILABLE_BLOCKS return nodes }, checkValid(payload: IterationNodeType, t: any) { let errorMessages = '' - if (!errorMessages && (!payload.iterator_selector || payload.iterator_selector.length === 0)) - errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.iteration.input`) }) + if ( + !errorMessages + && (!payload.iterator_selector || payload.iterator_selector.length === 0) + ) { + errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { + field: t(`${i18nPrefix}.nodes.iteration.input`), + }) + } - if (!errorMessages && (!payload.output_selector || payload.output_selector.length === 0)) - errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.iteration.output`) }) + if ( + !errorMessages + && (!payload.output_selector || payload.output_selector.length === 0) + ) { + errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { + field: t(`${i18nPrefix}.nodes.iteration.output`), + }) + } return { isValid: !errorMessages, diff --git a/web/app/components/workflow/nodes/iteration/node.tsx b/web/app/components/workflow/nodes/iteration/node.tsx index 48a005a261..fda033b87a 100644 --- a/web/app/components/workflow/nodes/iteration/node.tsx +++ b/web/app/components/workflow/nodes/iteration/node.tsx @@ -8,12 +8,16 @@ import { useNodesInitialized, useViewport, } from 'reactflow' +import { useTranslation } from 'react-i18next' import { IterationStartNodeDumb } from '../iteration-start' import { useNodeIterationInteractions } from './use-interactions' import type { IterationNodeType } from './types' import AddBlock from './add-block' import cn from '@/utils/classnames' import type { NodeProps } from '@/app/components/workflow/types' +import Toast from '@/app/components/base/toast' + +const i18nPrefix = 'workflow.nodes.iteration' const Node: FC> = ({ id, @@ -22,11 +26,20 @@ const Node: FC> = ({ const { zoom } = useViewport() const nodesInitialized = useNodesInitialized() const { handleNodeIterationRerender } = useNodeIterationInteractions() + const { t } = useTranslation() useEffect(() => { if (nodesInitialized) handleNodeIterationRerender(id) - }, [nodesInitialized, id, handleNodeIterationRerender]) + if (data.is_parallel && data._isShowTips) { + Toast.notify({ + type: 'warning', + message: t(`${i18nPrefix}.answerNodeWarningDesc`), + duration: 5000, + }) + data._isShowTips = false + } + }, [nodesInitialized, id, handleNodeIterationRerender, data, t]) return (
> = ({ data, }) => { const { t } = useTranslation() - + const responseMethod = [ + { + value: ErrorHandleMode.Terminated, + name: t(`${i18nPrefix}.ErrorMethod.operationTerminated`), + }, + { + value: ErrorHandleMode.ContinueOnError, + name: t(`${i18nPrefix}.ErrorMethod.continueOnError`), + }, + { + value: ErrorHandleMode.RemoveAbnormalOutput, + name: t(`${i18nPrefix}.ErrorMethod.removeAbnormalOutput`), + }, + ] const { readOnly, inputs, @@ -47,6 +66,9 @@ const Panel: FC> = ({ setIterator, iteratorInputKey, iterationRunResult, + changeParallel, + changeErrorResponseMode, + changeParallelNums, } = useConfig(id, data) return ( @@ -87,6 +109,39 @@ const Panel: FC> = ({ />
+
+ {t(`${i18nPrefix}.parallelPanelDesc`)}
} inline> + + + + { + inputs.is_parallel && (
+ {t(`${i18nPrefix}.MaxParallelismDesc`)}
}> +
+ { changeParallelNums(Number(e.target.value)) }} /> + +
+ + + ) + } +
+ +
+ +
+ + + +
+ {isShowSingleRun && ( { @@ -184,6 +185,25 @@ const useConfig = (id: string, payload: IterationNodeType) => { }) }, [iteratorInputKey, runInputData, setRunInputData]) + const changeParallel = useCallback((value: boolean) => { + const newInputs = produce(inputs, (draft) => { + draft.is_parallel = value + }) + setInputs(newInputs) + }, [inputs, setInputs]) + + const changeErrorResponseMode = useCallback((item: Item) => { + const newInputs = produce(inputs, (draft) => { + draft.error_handle_mode = item.value as ErrorHandleMode + }) + setInputs(newInputs) + }, [inputs, setInputs]) + const changeParallelNums = useCallback((num: number) => { + const newInputs = produce(inputs, (draft) => { + draft.parallel_nums = num + }) + setInputs(newInputs) + }, [inputs, setInputs]) return { readOnly, inputs, @@ -210,6 +230,9 @@ const useConfig = (id: string, payload: IterationNodeType) => { setIterator, iteratorInputKey, iterationRunResult, + changeParallel, + changeErrorResponseMode, + changeParallelNums, } } diff --git a/web/app/components/workflow/panel/debug-and-preview/hooks.ts b/web/app/components/workflow/panel/debug-and-preview/hooks.ts index 58a4561e2c..5d932a1ba2 100644 --- a/web/app/components/workflow/panel/debug-and-preview/hooks.ts +++ b/web/app/components/workflow/panel/debug-and-preview/hooks.ts @@ -9,6 +9,8 @@ import { produce, setAutoFreeze } from 'immer' import { uniqBy } from 'lodash-es' import { useWorkflowRun } from '../../hooks' import { NodeRunningStatus, WorkflowRunningStatus } from '../../types' +import { useWorkflowStore } from '../../store' +import { DEFAULT_ITER_TIMES } from '../../constants' import type { ChatItem, Inputs, @@ -43,6 +45,7 @@ export const useChat = ( const { notify } = useToastContext() const { handleRun } = useWorkflowRun() const hasStopResponded = useRef(false) + const workflowStore = useWorkflowStore() const conversationId = useRef('') const taskIdRef = useRef('') const [chatList, setChatList] = useState(prevChatList || []) @@ -52,6 +55,9 @@ export const useChat = ( const [suggestedQuestions, setSuggestQuestions] = useState([]) const suggestedQuestionsAbortControllerRef = useRef(null) + const { + setIterTimes, + } = workflowStore.getState() useEffect(() => { setAutoFreeze(false) return () => { @@ -102,15 +108,16 @@ export const useChat = ( handleResponding(false) if (stopChat && taskIdRef.current) stopChat(taskIdRef.current) - + setIterTimes(DEFAULT_ITER_TIMES) if (suggestedQuestionsAbortControllerRef.current) suggestedQuestionsAbortControllerRef.current.abort() - }, [handleResponding, stopChat]) + }, [handleResponding, setIterTimes, stopChat]) const handleRestart = useCallback(() => { conversationId.current = '' taskIdRef.current = '' handleStop() + setIterTimes(DEFAULT_ITER_TIMES) const newChatList = config?.opening_statement ? [{ id: `${Date.now()}`, @@ -126,6 +133,7 @@ export const useChat = ( config, handleStop, handleUpdateChatList, + setIterTimes, ]) const updateCurrentQA = useCallback(({ diff --git a/web/app/components/workflow/run/index.tsx b/web/app/components/workflow/run/index.tsx index 9e636e902b..89db43fa35 100644 --- a/web/app/components/workflow/run/index.tsx +++ b/web/app/components/workflow/run/index.tsx @@ -60,36 +60,67 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe }, [notify, getResultCallback]) const formatNodeList = useCallback((list: NodeTracing[]) => { - const allItems = list.reverse() + const allItems = [...list].reverse() const result: NodeTracing[] = [] - allItems.forEach((item) => { - const { node_type, execution_metadata } = item - if (node_type !== BlockEnum.Iteration) { - const isInIteration = !!execution_metadata?.iteration_id + const groupMap = new Map() - if (isInIteration) { - const iterationNode = result.find(node => node.node_id === execution_metadata?.iteration_id) - const iterationDetails = iterationNode?.details - const currentIterationIndex = execution_metadata?.iteration_index ?? 0 - - if (Array.isArray(iterationDetails)) { - if (iterationDetails.length === 0 || !iterationDetails[currentIterationIndex]) - iterationDetails[currentIterationIndex] = [item] - else - iterationDetails[currentIterationIndex].push(item) - } - return - } - // not in iteration - result.push(item) - - return - } + const processIterationNode = (item: NodeTracing) => { result.push({ ...item, details: [], }) + } + const updateParallelModeGroup = (runId: string, item: NodeTracing, iterationNode: NodeTracing) => { + if (!groupMap.has(runId)) + groupMap.set(runId, [item]) + else + groupMap.get(runId)!.push(item) + if (item.status === 'failed') { + iterationNode.status = 'failed' + iterationNode.error = item.error + } + + iterationNode.details = Array.from(groupMap.values()) + } + const updateSequentialModeGroup = (index: number, item: NodeTracing, iterationNode: NodeTracing) => { + const { details } = iterationNode + if (details) { + if (!details[index]) + details[index] = [item] + else + details[index].push(item) + } + + if (item.status === 'failed') { + iterationNode.status = 'failed' + iterationNode.error = item.error + } + } + const processNonIterationNode = (item: NodeTracing) => { + const { execution_metadata } = item + if (!execution_metadata?.iteration_id) { + result.push(item) + return + } + + const iterationNode = result.find(node => node.node_id === execution_metadata.iteration_id) + if (!iterationNode || !Array.isArray(iterationNode.details)) + return + + const { parallel_mode_run_id, iteration_index = 0 } = execution_metadata + + if (parallel_mode_run_id) + updateParallelModeGroup(parallel_mode_run_id, item, iterationNode) + else + updateSequentialModeGroup(iteration_index, item, iterationNode) + } + + allItems.forEach((item) => { + item.node_type === BlockEnum.Iteration + ? processIterationNode(item) + : processNonIterationNode(item) }) + return result }, []) diff --git a/web/app/components/workflow/run/iteration-result-panel.tsx b/web/app/components/workflow/run/iteration-result-panel.tsx index 7e2f6cbc00..c4cd909f2e 100644 --- a/web/app/components/workflow/run/iteration-result-panel.tsx +++ b/web/app/components/workflow/run/iteration-result-panel.tsx @@ -5,6 +5,7 @@ import { useTranslation } from 'react-i18next' import { RiArrowRightSLine, RiCloseLine, + RiErrorWarningLine, } from '@remixicon/react' import { ArrowNarrowLeft } from '../../base/icons/src/vender/line/arrows' import TracingPanel from './tracing-panel' @@ -27,7 +28,7 @@ const IterationResultPanel: FC = ({ noWrap, }) => { const { t } = useTranslation() - const [expandedIterations, setExpandedIterations] = useState>([]) + const [expandedIterations, setExpandedIterations] = useState>({}) const toggleIteration = useCallback((index: number) => { setExpandedIterations(prev => ({ @@ -71,10 +72,19 @@ const IterationResultPanel: FC = ({ {t(`${i18nPrefix}.iteration`)} {index + 1} - + { + iteration.some(item => item.status === 'failed') + ? ( + + ) + : (< RiArrowRightSLine className={ + cn( + 'w-4 h-4 text-text-tertiary transition-transform duration-200 flex-shrink-0', + expandedIterations[index] && 'transform rotate-90', + )} /> + ) + } + {expandedIterations[index] &&
= ({ return iteration_length } + const getErrorCount = (details: NodeTracing[][] | undefined) => { + if (!details || details.length === 0) + return 0 + return details.reduce((acc, iteration) => { + if (iteration.some(item => item.status === 'failed')) + acc++ + return acc + }, 0) + } useEffect(() => { setCollapseState(!nodeInfo.expand) }, [nodeInfo.expand, setCollapseState]) @@ -136,7 +145,12 @@ const NodePanel: FC = ({ onClick={handleOnShowIterationDetail} > -
{t('workflow.nodes.iteration.iteration', { count: getCount(nodeInfo.details?.length, nodeInfo.metadata?.iterator_length) })}
+
{t('workflow.nodes.iteration.iteration', { count: getCount(nodeInfo.details?.length, nodeInfo.metadata?.iterator_length) })}{getErrorCount(nodeInfo.details) > 0 && ( + <> + {t('workflow.nodes.iteration.comma')} + {t('workflow.nodes.iteration.error', { count: getErrorCount(nodeInfo.details) })} + + )}
{justShowIterationNavArrow ? ( diff --git a/web/app/components/workflow/store.ts b/web/app/components/workflow/store.ts index 480a7eb57c..da028834f7 100644 --- a/web/app/components/workflow/store.ts +++ b/web/app/components/workflow/store.ts @@ -21,6 +21,7 @@ import type { WorkflowRunningData, } from './types' import { WorkflowContext } from './context' +import type { NodeTracing } from '@/types/workflow' // #TODO chatVar# // const MOCK_DATA = [ @@ -166,6 +167,10 @@ type Shape = { setShowImportDSLModal: (showImportDSLModal: boolean) => void showTips: string setShowTips: (showTips: string) => void + iterTimes: number + setIterTimes: (iterTimes: number) => void + iterParallelLogMap: Map + setIterParallelLogMap: (iterParallelLogMap: Map) => void } export const createWorkflowStore = () => { @@ -281,6 +286,11 @@ export const createWorkflowStore = () => { setShowImportDSLModal: showImportDSLModal => set(() => ({ showImportDSLModal })), showTips: '', setShowTips: showTips => set(() => ({ showTips })), + iterTimes: 1, + setIterTimes: iterTimes => set(() => ({ iterTimes })), + iterParallelLogMap: new Map(), + setIterParallelLogMap: iterParallelLogMap => set(() => ({ iterParallelLogMap })), + })) } diff --git a/web/app/components/workflow/types.ts b/web/app/components/workflow/types.ts index 811ec0d70c..255b8fea3b 100644 --- a/web/app/components/workflow/types.ts +++ b/web/app/components/workflow/types.ts @@ -36,8 +36,12 @@ export enum ControlMode { Pointer = 'pointer', Hand = 'hand', } - -export interface Branch { +export enum ErrorHandleMode { + Terminated = 'terminated', + ContinueOnError = 'continue-on-error', + RemoveAbnormalOutput = 'remove-abnormal-output', +} +export type Branch = { id: string name: string } diff --git a/web/app/components/workflow/utils.ts b/web/app/components/workflow/utils.ts index 91656e3bbc..aaf333f4d7 100644 --- a/web/app/components/workflow/utils.ts +++ b/web/app/components/workflow/utils.ts @@ -19,7 +19,7 @@ import type { ToolWithProvider, ValueSelector, } from './types' -import { BlockEnum } from './types' +import { BlockEnum, ErrorHandleMode } from './types' import { CUSTOM_NODE, ITERATION_CHILDREN_Z_INDEX, @@ -267,8 +267,13 @@ export const initialNodes = (originNodes: Node[], originEdges: Edge[]) => { }) } - if (node.data.type === BlockEnum.Iteration) - node.data._children = iterationNodeMap[node.id] || [] + if (node.data.type === BlockEnum.Iteration) { + const iterationNodeData = node.data as IterationNodeType + iterationNodeData._children = iterationNodeMap[node.id] || [] + iterationNodeData.is_parallel = iterationNodeData.is_parallel || false + iterationNodeData.parallel_nums = iterationNodeData.parallel_nums || 10 + iterationNodeData.error_handle_mode = iterationNodeData.error_handle_mode || ErrorHandleMode.Terminated + } return node }) diff --git a/web/app/signin/normalForm.tsx b/web/app/signin/normalForm.tsx index c0f2d89b37..f4f46c68ba 100644 --- a/web/app/signin/normalForm.tsx +++ b/web/app/signin/normalForm.tsx @@ -12,11 +12,9 @@ import cn from '@/utils/classnames' import { getSystemFeatures, invitationCheck } from '@/service/common' import { defaultSystemFeatures } from '@/types/feature' import Toast from '@/app/components/base/toast' -import useRefreshToken from '@/hooks/use-refresh-token' import { IS_CE_EDITION } from '@/config' const NormalForm = () => { - const { getNewAccessToken } = useRefreshToken() const { t } = useTranslation() const router = useRouter() const searchParams = useSearchParams() @@ -38,7 +36,6 @@ const NormalForm = () => { if (consoleToken && refreshToken) { localStorage.setItem('console_token', consoleToken) localStorage.setItem('refresh_token', refreshToken) - getNewAccessToken() router.replace('/apps') return } @@ -71,7 +68,7 @@ const NormalForm = () => { setSystemFeatures(defaultSystemFeatures) } finally { setIsLoading(false) } - }, [consoleToken, refreshToken, message, router, invite_token, isInviteLink, getNewAccessToken]) + }, [consoleToken, refreshToken, message, router, invite_token, isInviteLink]) useEffect(() => { init() }, [init]) diff --git a/web/hooks/use-refresh-token.ts b/web/hooks/use-refresh-token.ts deleted file mode 100644 index 53dc4faf00..0000000000 --- a/web/hooks/use-refresh-token.ts +++ /dev/null @@ -1,99 +0,0 @@ -'use client' -import { useCallback, useEffect, useRef } from 'react' -import { jwtDecode } from 'jwt-decode' -import dayjs from 'dayjs' -import utc from 'dayjs/plugin/utc' -import { useRouter } from 'next/navigation' -import type { CommonResponse } from '@/models/common' -import { fetchNewToken } from '@/service/common' -import { fetchWithRetry } from '@/utils' - -dayjs.extend(utc) - -const useRefreshToken = () => { - const router = useRouter() - const timer = useRef() - const advanceTime = useRef(5 * 60 * 1000) - - const getExpireTime = useCallback((token: string) => { - if (!token) - return 0 - const decoded = jwtDecode(token) - return (decoded.exp || 0) * 1000 - }, []) - - const getCurrentTimeStamp = useCallback(() => { - return dayjs.utc().valueOf() - }, []) - - const handleError = useCallback(() => { - localStorage?.removeItem('is_refreshing') - localStorage?.removeItem('console_token') - localStorage?.removeItem('refresh_token') - router.replace('/signin') - }, []) - - const getNewAccessToken = useCallback(async () => { - const currentAccessToken = localStorage?.getItem('console_token') - const currentRefreshToken = localStorage?.getItem('refresh_token') - if (!currentAccessToken || !currentRefreshToken) { - handleError() - return new Error('No access token or refresh token found') - } - if (localStorage?.getItem('is_refreshing') === '1') { - clearTimeout(timer.current) - timer.current = setTimeout(() => { - getNewAccessToken() - }, 1000) - return null - } - const currentTokenExpireTime = getExpireTime(currentAccessToken) - if (getCurrentTimeStamp() + advanceTime.current > currentTokenExpireTime) { - localStorage?.setItem('is_refreshing', '1') - const [e, res] = await fetchWithRetry(fetchNewToken({ - body: { refresh_token: currentRefreshToken }, - }) as Promise) - if (e) { - handleError() - return e - } - const { access_token, refresh_token } = res.data - localStorage?.setItem('is_refreshing', '0') - localStorage?.setItem('console_token', access_token) - localStorage?.setItem('refresh_token', refresh_token) - const newTokenExpireTime = getExpireTime(access_token) - clearTimeout(timer.current) - timer.current = setTimeout(() => { - getNewAccessToken() - }, newTokenExpireTime - advanceTime.current - getCurrentTimeStamp()) - } - else { - const newTokenExpireTime = getExpireTime(currentAccessToken) - clearTimeout(timer.current) - timer.current = setTimeout(() => { - getNewAccessToken() - }, newTokenExpireTime - advanceTime.current - getCurrentTimeStamp()) - } - return null - }, [getExpireTime, getCurrentTimeStamp, handleError]) - - const handleVisibilityChange = useCallback(() => { - if (document.visibilityState === 'visible') - getNewAccessToken() - }, []) - - useEffect(() => { - window.addEventListener('visibilitychange', handleVisibilityChange) - return () => { - window.removeEventListener('visibilitychange', handleVisibilityChange) - clearTimeout(timer.current) - localStorage?.removeItem('is_refreshing') - } - }, []) - - return { - getNewAccessToken, - } -} - -export default useRefreshToken diff --git a/web/i18n/de-DE/workflow.ts b/web/i18n/de-DE/workflow.ts index bde0250fcc..d05070c308 100644 --- a/web/i18n/de-DE/workflow.ts +++ b/web/i18n/de-DE/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Iteration', iteration_other: '{{count}} Iterationen', currentIteration: 'Aktuelle Iteration', + ErrorMethod: { + operationTerminated: 'beendet', + removeAbnormalOutput: 'remove-abnormale_ausgabe', + continueOnError: 'Fehler "Fortfahren bei"', + }, + MaxParallelismTitle: 'Maximale Parallelität', + parallelMode: 'Paralleler Modus', + errorResponseMethod: 'Methode der Fehlerantwort', + error_one: '{{Anzahl}} Fehler', + error_other: '{{Anzahl}} Irrtümer', + MaxParallelismDesc: 'Die maximale Parallelität wird verwendet, um die Anzahl der Aufgaben zu steuern, die gleichzeitig in einer einzigen Iteration ausgeführt werden.', + parallelPanelDesc: 'Im parallelen Modus unterstützen Aufgaben in der Iteration die parallele Ausführung.', + parallelModeEnableDesc: 'Im parallelen Modus unterstützen Aufgaben innerhalb von Iterationen die parallele Ausführung. Sie können dies im Eigenschaftenbereich auf der rechten Seite konfigurieren.', + answerNodeWarningDesc: 'Warnung im parallelen Modus: Antwortknoten, Zuweisungen von Konversationsvariablen und persistente Lese-/Schreibvorgänge innerhalb von Iterationen können Ausnahmen verursachen.', + parallelModeEnableTitle: 'Paralleler Modus aktiviert', + parallelModeUpper: 'PARALLELER MODUS', + comma: ',', }, note: { editor: { diff --git a/web/i18n/en-US/app-debug.ts b/web/i18n/en-US/app-debug.ts index b2144262f6..e17afc38bf 100644 --- a/web/i18n/en-US/app-debug.ts +++ b/web/i18n/en-US/app-debug.ts @@ -224,6 +224,8 @@ const translation = { description: 'The Code Generator uses configured models to generate high-quality code based on your instructions. Please provide clear and detailed instructions.', instruction: 'Instructions', instructionPlaceholder: 'Enter detailed description of the code you want to generate.', + noDataLine1: 'Describe your use case on the left,', + noDataLine2: 'the code preview will show here.', generate: 'Generate', generatedCodeTitle: 'Generated Code', loading: 'Generating code...', diff --git a/web/i18n/en-US/workflow.ts b/web/i18n/en-US/workflow.ts index 8b5f96453c..7f237b1a49 100644 --- a/web/i18n/en-US/workflow.ts +++ b/web/i18n/en-US/workflow.ts @@ -556,6 +556,23 @@ const translation = { iteration_one: '{{count}} Iteration', iteration_other: '{{count}} Iterations', currentIteration: 'Current Iteration', + comma: ', ', + error_one: '{{count}} Error', + error_other: '{{count}} Errors', + parallelMode: 'Parallel Mode', + parallelModeUpper: 'PARALLEL MODE', + parallelModeEnableTitle: 'Parallel Mode Enabled', + parallelModeEnableDesc: 'In parallel mode, tasks within iterations support parallel execution. You can configure this in the properties panel on the right.', + parallelPanelDesc: 'In parallel mode, tasks in the iteration support parallel execution.', + MaxParallelismTitle: 'Maximum parallelism', + MaxParallelismDesc: 'The maximum parallelism is used to control the number of tasks executed simultaneously in a single iteration.', + errorResponseMethod: 'Error response method', + ErrorMethod: { + operationTerminated: 'terminated', + continueOnError: 'continue-on-error', + removeAbnormalOutput: 'remove-abnormal-output', + }, + answerNodeWarningDesc: 'Parallel mode warning: Answer nodes, conversation variable assignments, and persistent read/write operations within iterations may cause exceptions.', }, note: { addNote: 'Add Note', diff --git a/web/i18n/es-ES/workflow.ts b/web/i18n/es-ES/workflow.ts index 59a330e7f4..6c9af49c4d 100644 --- a/web/i18n/es-ES/workflow.ts +++ b/web/i18n/es-ES/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Iteración', iteration_other: '{{count}} Iteraciones', currentIteration: 'Iteración actual', + ErrorMethod: { + operationTerminated: 'Terminado', + continueOnError: 'Continuar en el error', + removeAbnormalOutput: 'eliminar-salida-anormal', + }, + comma: ',', + errorResponseMethod: 'Método de respuesta a errores', + error_one: '{{conteo}} Error', + parallelPanelDesc: 'En el modo paralelo, las tareas de la iteración admiten la ejecución en paralelo.', + MaxParallelismTitle: 'Máximo paralelismo', + error_other: '{{conteo}} Errores', + parallelMode: 'Modo paralelo', + parallelModeEnableDesc: 'En el modo paralelo, las tareas dentro de las iteraciones admiten la ejecución en paralelo. Puede configurar esto en el panel de propiedades a la derecha.', + parallelModeUpper: 'MODO PARALELO', + MaxParallelismDesc: 'El paralelismo máximo se utiliza para controlar el número de tareas ejecutadas simultáneamente en una sola iteración.', + answerNodeWarningDesc: 'Advertencia de modo paralelo: Los nodos de respuesta, las asignaciones de variables de conversación y las operaciones de lectura/escritura persistentes dentro de las iteraciones pueden provocar excepciones.', + parallelModeEnableTitle: 'Modo paralelo habilitado', }, note: { addNote: 'Agregar nota', diff --git a/web/i18n/fa-IR/workflow.ts b/web/i18n/fa-IR/workflow.ts index b1f9384159..4b00390663 100644 --- a/web/i18n/fa-IR/workflow.ts +++ b/web/i18n/fa-IR/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} تکرار', iteration_other: '{{count}} تکرارها', currentIteration: 'تکرار فعلی', + ErrorMethod: { + continueOnError: 'ادامه در خطا', + operationTerminated: 'فسخ', + removeAbnormalOutput: 'حذف خروجی غیرطبیعی', + }, + error_one: '{{تعداد}} خطا', + error_other: '{{تعداد}} خطاهای', + parallelMode: 'حالت موازی', + errorResponseMethod: 'روش پاسخ به خطا', + parallelModeEnableTitle: 'حالت موازی فعال است', + parallelModeUpper: 'حالت موازی', + comma: ',', + parallelModeEnableDesc: 'در حالت موازی، وظایف درون تکرارها از اجرای موازی پشتیبانی می کنند. می توانید این را در پانل ویژگی ها در سمت راست پیکربندی کنید.', + MaxParallelismTitle: 'حداکثر موازی سازی', + parallelPanelDesc: 'در حالت موازی، وظایف در تکرار از اجرای موازی پشتیبانی می کنند.', + MaxParallelismDesc: 'حداکثر موازی سازی برای کنترل تعداد وظایف اجرا شده به طور همزمان در یک تکرار واحد استفاده می شود.', + answerNodeWarningDesc: 'هشدار حالت موازی: گره های پاسخ، تکالیف متغیر مکالمه و عملیات خواندن/نوشتن مداوم در تکرارها ممکن است باعث استثنائات شود.', }, note: { addNote: 'افزودن یادداشت', diff --git a/web/i18n/fr-FR/workflow.ts b/web/i18n/fr-FR/workflow.ts index e56932455f..e736e2cb07 100644 --- a/web/i18n/fr-FR/workflow.ts +++ b/web/i18n/fr-FR/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Itération', iteration_other: '{{count}} Itérations', currentIteration: 'Itération actuelle', + ErrorMethod: { + operationTerminated: 'Terminé', + removeAbnormalOutput: 'remove-abnormal-output', + continueOnError: 'continuer sur l’erreur', + }, + comma: ',', + error_one: '{{compte}} Erreur', + error_other: '{{compte}} Erreurs', + parallelModeEnableDesc: 'En mode parallèle, les tâches au sein des itérations prennent en charge l’exécution parallèle. Vous pouvez le configurer dans le panneau des propriétés à droite.', + parallelModeUpper: 'MODE PARALLÈLE', + parallelPanelDesc: 'En mode parallèle, les tâches de l’itération prennent en charge l’exécution parallèle.', + MaxParallelismDesc: 'Le parallélisme maximal est utilisé pour contrôler le nombre de tâches exécutées simultanément en une seule itération.', + errorResponseMethod: 'Méthode de réponse aux erreurs', + MaxParallelismTitle: 'Parallélisme maximal', + answerNodeWarningDesc: 'Avertissement en mode parallèle : les nœuds de réponse, les affectations de variables de conversation et les opérations de lecture/écriture persistantes au sein des itérations peuvent provoquer des exceptions.', + parallelModeEnableTitle: 'Mode parallèle activé', + parallelMode: 'Mode parallèle', }, note: { addNote: 'Ajouter note', diff --git a/web/i18n/hi-IN/workflow.ts b/web/i18n/hi-IN/workflow.ts index 1473f78ccd..4112643488 100644 --- a/web/i18n/hi-IN/workflow.ts +++ b/web/i18n/hi-IN/workflow.ts @@ -577,6 +577,23 @@ const translation = { iteration_one: '{{count}} इटरेशन', iteration_other: '{{count}} इटरेशन्स', currentIteration: 'वर्तमान इटरेशन', + ErrorMethod: { + operationTerminated: 'समाप्त', + continueOnError: 'जारी रखें-पर-त्रुटि', + removeAbnormalOutput: 'निकालें-असामान्य-आउटपुट', + }, + comma: ',', + error_other: '{{गिनती}} त्रुटियों', + error_one: '{{गिनती}} चूक', + parallelMode: 'समानांतर मोड', + parallelModeUpper: 'समानांतर मोड', + errorResponseMethod: 'त्रुटि प्रतिक्रिया विधि', + MaxParallelismTitle: 'अधिकतम समांतरता', + parallelModeEnableTitle: 'समानांतर मोड सक्षम किया गया', + parallelModeEnableDesc: 'समानांतर मोड में, पुनरावृत्तियों के भीतर कार्य समानांतर निष्पादन का समर्थन करते हैं। आप इसे दाईं ओर गुण पैनल में कॉन्फ़िगर कर सकते हैं।', + parallelPanelDesc: 'समानांतर मोड में, पुनरावृत्ति में कार्य समानांतर निष्पादन का समर्थन करते हैं।', + MaxParallelismDesc: 'अधिकतम समांतरता का उपयोग एकल पुनरावृत्ति में एक साथ निष्पादित कार्यों की संख्या को नियंत्रित करने के लिए किया जाता है।', + answerNodeWarningDesc: 'समानांतर मोड चेतावनी: उत्तर नोड्स, वार्तालाप चर असाइनमेंट, और पुनरावृत्तियों के भीतर लगातार पढ़ने/लिखने की कार्रवाई अपवाद पैदा कर सकती है।', }, note: { addNote: 'नोट जोड़ें', diff --git a/web/i18n/it-IT/workflow.ts b/web/i18n/it-IT/workflow.ts index 19fa7bfbb5..756fb665af 100644 --- a/web/i18n/it-IT/workflow.ts +++ b/web/i18n/it-IT/workflow.ts @@ -584,6 +584,23 @@ const translation = { iteration_one: '{{count}} Iterazione', iteration_other: '{{count}} Iterazioni', currentIteration: 'Iterazione Corrente', + ErrorMethod: { + operationTerminated: 'Terminato', + continueOnError: 'continua sull\'errore', + removeAbnormalOutput: 'rimuovi-output-anomalo', + }, + error_one: '{{conteggio}} Errore', + parallelMode: 'Modalità parallela', + MaxParallelismTitle: 'Parallelismo massimo', + error_other: '{{conteggio}} Errori', + parallelModeEnableDesc: 'In modalità parallela, le attività all\'interno delle iterazioni supportano l\'esecuzione parallela. È possibile configurare questa opzione nel pannello delle proprietà a destra.', + MaxParallelismDesc: 'Il parallelismo massimo viene utilizzato per controllare il numero di attività eseguite contemporaneamente in una singola iterazione.', + errorResponseMethod: 'Metodo di risposta all\'errore', + parallelModeEnableTitle: 'Modalità parallela abilitata', + parallelModeUpper: 'MODALITÀ PARALLELA', + comma: ',', + parallelPanelDesc: 'In modalità parallela, le attività nell\'iterazione supportano l\'esecuzione parallela.', + answerNodeWarningDesc: 'Avviso in modalità parallela: i nodi di risposta, le assegnazioni di variabili di conversazione e le operazioni di lettura/scrittura persistenti all\'interno delle iterazioni possono causare eccezioni.', }, note: { addNote: 'Aggiungi Nota', diff --git a/web/i18n/ja-JP/app-debug.ts b/web/i18n/ja-JP/app-debug.ts index 620d9b2f55..05e81a2ae2 100644 --- a/web/i18n/ja-JP/app-debug.ts +++ b/web/i18n/ja-JP/app-debug.ts @@ -224,6 +224,8 @@ const translation = { description: 'コードジェネレーターは、設定されたモデルを使用して指示に基づいて高品質なコードを生成します。明確で詳細な指示を提供してください。', instruction: '指示', instructionPlaceholder: '生成したいコードの詳細な説明を入力してください。', + noDataLine1: '左側に使用例を記入してください,', + noDataLine2: 'コードのプレビューがこちらに表示されます。', generate: '生成', generatedCodeTitle: '生成されたコード', loading: 'コードを生成中...', diff --git a/web/i18n/ja-JP/app.ts b/web/i18n/ja-JP/app.ts index 76c7d1c4f4..48a35c61af 100644 --- a/web/i18n/ja-JP/app.ts +++ b/web/i18n/ja-JP/app.ts @@ -39,10 +39,10 @@ const translation = { workflowWarning: '現在ベータ版です', chatbotType: 'チャットボットのオーケストレーション方法', basic: '基本', - basicTip: '初心者向け。後で Chatflow に切り替えることができます', + basicTip: '初心者向け。後で「チャットフロー」に切り替えることができます', basicFor: '初心者向け', basicDescription: '基本オーケストレートは、組み込みのプロンプトを変更する機能がなく、簡単な設定を使用してチャットボット アプリをオーケストレートします。初心者向けです。', - advanced: 'Chatflow', + advanced: 'チャットフロー', advancedFor: '上級ユーザー向け', advancedDescription: 'ワークフロー オーケストレートは、ワークフロー形式でチャットボットをオーケストレートし、組み込みのプロンプトを編集する機能を含む高度なカスタマイズを提供します。経験豊富なユーザー向けです。', captionName: 'アプリのアイコンと名前', diff --git a/web/i18n/ja-JP/workflow.ts b/web/i18n/ja-JP/workflow.ts index b6c7786081..a82ba71e48 100644 --- a/web/i18n/ja-JP/workflow.ts +++ b/web/i18n/ja-JP/workflow.ts @@ -558,6 +558,23 @@ const translation = { iteration_one: '{{count}} イテレーション', iteration_other: '{{count}} イテレーション', currentIteration: '現在のイテレーション', + ErrorMethod: { + operationTerminated: '終了', + continueOnError: 'エラー時に続行', + removeAbnormalOutput: 'アブノーマルアウトプットの削除', + }, + comma: ',', + error_other: '{{カウント}}エラー', + error_one: '{{カウント}}エラー', + parallelModeUpper: 'パラレルモード', + parallelMode: 'パラレルモード', + MaxParallelismTitle: '最大並列処理', + errorResponseMethod: 'エラー応答方式', + parallelPanelDesc: '並列モードでは、イテレーションのタスクは並列実行をサポートします。', + parallelModeEnableDesc: '並列モードでは、イテレーション内のタスクは並列実行をサポートします。これは、右側のプロパティパネルで構成できます。', + parallelModeEnableTitle: 'パラレルモード有効', + MaxParallelismDesc: '最大並列処理は、1 回の反復で同時に実行されるタスクの数を制御するために使用されます。', + answerNodeWarningDesc: '並列モードの警告: 応答ノード、会話変数の割り当て、およびイテレーション内の永続的な読み取り/書き込み操作により、例外が発生する可能性があります。', }, note: { addNote: 'コメントを追加', diff --git a/web/i18n/ko-KR/workflow.ts b/web/i18n/ko-KR/workflow.ts index b62aff2068..589831401c 100644 --- a/web/i18n/ko-KR/workflow.ts +++ b/web/i18n/ko-KR/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} 반복', iteration_other: '{{count}} 반복', currentIteration: '현재 반복', + ErrorMethod: { + operationTerminated: '종료', + continueOnError: '오류 발생 시 계속', + removeAbnormalOutput: '비정상 출력 제거', + }, + comma: ',', + error_one: '{{개수}} 오류', + parallelMode: '병렬 모드', + errorResponseMethod: '오류 응답 방법', + parallelModeUpper: '병렬 모드', + MaxParallelismTitle: '최대 병렬 처리', + error_other: '{{개수}} 오류', + parallelModeEnableTitle: 'Parallel Mode Enabled(병렬 모드 사용)', + parallelPanelDesc: '병렬 모드에서 반복의 작업은 병렬 실행을 지원합니다.', + parallelModeEnableDesc: '병렬 모드에서는 반복 내의 작업이 병렬 실행을 지원합니다. 오른쪽의 속성 패널에서 이를 구성할 수 있습니다.', + MaxParallelismDesc: '최대 병렬 처리는 단일 반복에서 동시에 실행되는 작업 수를 제어하는 데 사용됩니다.', + answerNodeWarningDesc: '병렬 모드 경고: 응답 노드, 대화 변수 할당 및 반복 내의 지속적인 읽기/쓰기 작업으로 인해 예외가 발생할 수 있습니다.', }, note: { editor: { diff --git a/web/i18n/pl-PL/app-debug.ts b/web/i18n/pl-PL/app-debug.ts index 7cf6c77cb4..cf7232e563 100644 --- a/web/i18n/pl-PL/app-debug.ts +++ b/web/i18n/pl-PL/app-debug.ts @@ -355,7 +355,7 @@ const translation = { openingStatement: { title: 'Wstęp do rozmowy', add: 'Dodaj', - writeOpner: 'Napisz wstęp', + writeOpener: 'Napisz wstęp', placeholder: 'Tutaj napisz swoją wiadomość wprowadzającą, możesz użyć zmiennych, spróbuj wpisać {{variable}}.', openingQuestion: 'Pytania otwierające', diff --git a/web/i18n/pl-PL/workflow.ts b/web/i18n/pl-PL/workflow.ts index aace1b2642..f118f7945c 100644 --- a/web/i18n/pl-PL/workflow.ts +++ b/web/i18n/pl-PL/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Iteracja', iteration_other: '{{count}} Iteracje', currentIteration: 'Bieżąca iteracja', + ErrorMethod: { + continueOnError: 'kontynuacja w przypadku błędu', + operationTerminated: 'Zakończone', + removeAbnormalOutput: 'usuń-nieprawidłowe-wyjście', + }, + comma: ',', + parallelModeUpper: 'TRYB RÓWNOLEGŁY', + parallelModeEnableTitle: 'Włączony tryb równoległy', + MaxParallelismTitle: 'Maksymalna równoległość', + error_one: '{{liczba}} Błąd', + error_other: '{{liczba}} Błędy', + parallelPanelDesc: 'W trybie równoległym zadania w iteracji obsługują wykonywanie równoległe.', + parallelMode: 'Tryb równoległy', + MaxParallelismDesc: 'Maksymalna równoległość służy do kontrolowania liczby zadań wykonywanych jednocześnie w jednej iteracji.', + parallelModeEnableDesc: 'W trybie równoległym zadania w iteracjach obsługują wykonywanie równoległe. Możesz to skonfigurować w panelu właściwości po prawej stronie.', + answerNodeWarningDesc: 'Ostrzeżenie w trybie równoległym: węzły odpowiedzi, przypisania zmiennych konwersacji i trwałe operacje odczytu/zapisu w iteracjach mogą powodować wyjątki.', + errorResponseMethod: 'Metoda odpowiedzi na błąd', }, note: { editor: { diff --git a/web/i18n/pt-BR/workflow.ts b/web/i18n/pt-BR/workflow.ts index f0f2fec0e2..44afda5cd4 100644 --- a/web/i18n/pt-BR/workflow.ts +++ b/web/i18n/pt-BR/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Iteração', iteration_other: '{{count}} Iterações', currentIteration: 'Iteração atual', + ErrorMethod: { + continueOnError: 'continuar em erro', + removeAbnormalOutput: 'saída anormal de remoção', + operationTerminated: 'Terminada', + }, + MaxParallelismTitle: 'Paralelismo máximo', + parallelModeEnableTitle: 'Modo paralelo ativado', + errorResponseMethod: 'Método de resposta de erro', + error_other: '{{contagem}} Erros', + parallelMode: 'Modo paralelo', + parallelModeUpper: 'MODO PARALELO', + error_one: '{{contagem}} Erro', + parallelModeEnableDesc: 'No modo paralelo, as tarefas dentro das iterações dão suporte à execução paralela. Você pode configurar isso no painel de propriedades à direita.', + comma: ',', + MaxParallelismDesc: 'O paralelismo máximo é usado para controlar o número de tarefas executadas simultaneamente em uma única iteração.', + answerNodeWarningDesc: 'Aviso de modo paralelo: nós de resposta, atribuições de variáveis de conversação e operações persistentes de leitura/gravação em iterações podem causar exceções.', + parallelPanelDesc: 'No modo paralelo, as tarefas na iteração dão suporte à execução paralela.', }, note: { editor: { diff --git a/web/i18n/ro-RO/workflow.ts b/web/i18n/ro-RO/workflow.ts index ab0100d347..d8cd84f730 100644 --- a/web/i18n/ro-RO/workflow.ts +++ b/web/i18n/ro-RO/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Iterație', iteration_other: '{{count}} Iterații', currentIteration: 'Iterație curentă', + ErrorMethod: { + operationTerminated: 'Încheiată', + continueOnError: 'continuare-la-eroare', + removeAbnormalOutput: 'elimină-ieșire-anormală', + }, + parallelModeEnableTitle: 'Modul paralel activat', + errorResponseMethod: 'Metoda de răspuns la eroare', + comma: ',', + parallelModeEnableDesc: 'În modul paralel, sarcinile din iterații acceptă execuția paralelă. Puteți configura acest lucru în panoul de proprietăți din dreapta.', + parallelModeUpper: 'MOD PARALEL', + MaxParallelismTitle: 'Paralelism maxim', + parallelMode: 'Mod paralel', + error_other: '{{număr}} Erori', + error_one: '{{număr}} Eroare', + parallelPanelDesc: 'În modul paralel, activitățile din iterație acceptă execuția paralelă.', + MaxParallelismDesc: 'Paralelismul maxim este utilizat pentru a controla numărul de sarcini executate simultan într-o singură iterație.', + answerNodeWarningDesc: 'Avertisment modul paralel: Nodurile de răspuns, atribuirea variabilelor de conversație și operațiunile persistente de citire/scriere în iterații pot cauza excepții.', }, note: { editor: { diff --git a/web/i18n/ru-RU/workflow.ts b/web/i18n/ru-RU/workflow.ts index 27735fbb7d..c822f8c3e5 100644 --- a/web/i18n/ru-RU/workflow.ts +++ b/web/i18n/ru-RU/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Итерация', iteration_other: '{{count}} Итераций', currentIteration: 'Текущая итерация', + ErrorMethod: { + operationTerminated: 'Прекращено', + continueOnError: 'продолжить по ошибке', + removeAbnormalOutput: 'удалить аномальный вывод', + }, + comma: ',', + error_other: '{{Количество}} Ошибки', + errorResponseMethod: 'Метод реагирования на ошибку', + MaxParallelismTitle: 'Максимальный параллелизм', + parallelModeUpper: 'ПАРАЛЛЕЛЬНЫЙ РЕЖИМ', + error_one: '{{Количество}} Ошибка', + parallelModeEnableTitle: 'Параллельный режим включен', + parallelMode: 'Параллельный режим', + parallelPanelDesc: 'В параллельном режиме задачи в итерации поддерживают параллельное выполнение.', + parallelModeEnableDesc: 'В параллельном режиме задачи в итерациях поддерживают параллельное выполнение. Вы можете настроить это на панели свойств справа.', + MaxParallelismDesc: 'Максимальный параллелизм используется для управления количеством задач, выполняемых одновременно в одной итерации.', + answerNodeWarningDesc: 'Предупреждение о параллельном режиме: узлы ответов, присвоение переменных диалога и постоянные операции чтения и записи в итерациях могут вызывать исключения.', }, note: { addNote: 'Добавить заметку', diff --git a/web/i18n/tr-TR/workflow.ts b/web/i18n/tr-TR/workflow.ts index 82718ebc03..e6e25f6d0e 100644 --- a/web/i18n/tr-TR/workflow.ts +++ b/web/i18n/tr-TR/workflow.ts @@ -558,6 +558,23 @@ const translation = { iteration_one: '{{count}} Yineleme', iteration_other: '{{count}} Yineleme', currentIteration: 'Mevcut Yineleme', + ErrorMethod: { + operationTerminated: 'Sonlandırıldı', + continueOnError: 'Hata Üzerine Devam Et', + removeAbnormalOutput: 'anormal çıktıyı kaldır', + }, + parallelModeUpper: 'PARALEL MOD', + parallelMode: 'Paralel Mod', + MaxParallelismTitle: 'Maksimum paralellik', + error_one: '{{sayı}} Hata', + errorResponseMethod: 'Hata yanıtı yöntemi', + comma: ',', + parallelModeEnableTitle: 'Paralel Mod Etkin', + error_other: '{{sayı}} Hata', + parallelPanelDesc: 'Paralel modda, yinelemedeki görevler paralel yürütmeyi destekler.', + answerNodeWarningDesc: 'Paralel mod uyarısı: Yinelemeler içindeki yanıt düğümleri, konuşma değişkeni atamaları ve kalıcı okuma/yazma işlemleri özel durumlara neden olabilir.', + parallelModeEnableDesc: 'Paralel modda, yinelemeler içindeki görevler paralel yürütmeyi destekler. Bunu sağdaki özellikler panelinde yapılandırabilirsiniz.', + MaxParallelismDesc: 'Maksimum paralellik, tek bir yinelemede aynı anda yürütülen görevlerin sayısını kontrol etmek için kullanılır.', }, note: { addNote: 'Not Ekle', diff --git a/web/i18n/uk-UA/workflow.ts b/web/i18n/uk-UA/workflow.ts index 1828b6499f..663b5e4c13 100644 --- a/web/i18n/uk-UA/workflow.ts +++ b/web/i18n/uk-UA/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Ітерація', iteration_other: '{{count}} Ітерацій', currentIteration: 'Поточна ітерація', + ErrorMethod: { + operationTerminated: 'Припинено', + continueOnError: 'Продовжити після помилки', + removeAbnormalOutput: 'видалити-ненормальний-вивід', + }, + error_one: '{{count}} Помилка', + comma: ',', + MaxParallelismTitle: 'Максимальна паралельність', + parallelModeUpper: 'ПАРАЛЕЛЬНИЙ РЕЖИМ', + error_other: '{{count}} Помилки', + parallelMode: 'Паралельний режим', + parallelModeEnableTitle: 'Увімкнено паралельний режим', + errorResponseMethod: 'Метод реагування на помилку', + parallelPanelDesc: 'У паралельному режимі завдання в ітерації підтримують паралельне виконання.', + parallelModeEnableDesc: 'У паралельному режимі завдання всередині ітерацій підтримують паралельне виконання. Ви можете налаштувати це на панелі властивостей праворуч.', + MaxParallelismDesc: 'Максимальний паралелізм використовується для контролю числа завдань, що виконуються одночасно за одну ітерацію.', + answerNodeWarningDesc: 'Попередження в паралельному режимі: вузли відповідей, призначення змінних розмови та постійні операції читання/запису в межах ітерацій можуть спричинити винятки.', }, note: { editor: { diff --git a/web/i18n/vi-VN/workflow.ts b/web/i18n/vi-VN/workflow.ts index 2866af8a2a..1176fdd2b5 100644 --- a/web/i18n/vi-VN/workflow.ts +++ b/web/i18n/vi-VN/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Lặp', iteration_other: '{{count}} Lặp', currentIteration: 'Lặp hiện tại', + ErrorMethod: { + operationTerminated: 'Chấm dứt', + removeAbnormalOutput: 'loại bỏ-bất thường-đầu ra', + continueOnError: 'Tiếp tục lỗi', + }, + comma: ',', + error_other: '{{đếm}} Lỗi', + error_one: '{{đếm}} Lỗi', + MaxParallelismTitle: 'Song song tối đa', + parallelPanelDesc: 'Ở chế độ song song, các tác vụ trong quá trình lặp hỗ trợ thực thi song song.', + parallelMode: 'Chế độ song song', + parallelModeEnableTitle: 'Đã bật Chế độ song song', + errorResponseMethod: 'Phương pháp phản hồi lỗi', + MaxParallelismDesc: 'Tính song song tối đa được sử dụng để kiểm soát số lượng tác vụ được thực hiện đồng thời trong một lần lặp.', + answerNodeWarningDesc: 'Cảnh báo chế độ song song: Các nút trả lời, bài tập biến hội thoại và các thao tác đọc/ghi liên tục trong các lần lặp có thể gây ra ngoại lệ.', + parallelModeEnableDesc: 'Trong chế độ song song, các tác vụ trong các lần lặp hỗ trợ thực thi song song. Bạn có thể định cấu hình điều này trong bảng thuộc tính ở bên phải.', + parallelModeUpper: 'CHẾ ĐỘ SONG SONG', }, note: { editor: { diff --git a/web/i18n/zh-Hans/app-debug.ts b/web/i18n/zh-Hans/app-debug.ts index 3e801bcf62..9e21945755 100644 --- a/web/i18n/zh-Hans/app-debug.ts +++ b/web/i18n/zh-Hans/app-debug.ts @@ -224,6 +224,8 @@ const translation = { description: '代码生成器使用配置的模型根据您的指令生成高质量的代码。请提供清晰详细的说明。', instruction: '指令', instructionPlaceholder: '请输入您想要生成的代码的详细描述。', + noDataLine1: '在左侧描述您的用例,', + noDataLine2: '代码预览将在此处显示。', generate: '生成', generatedCodeTitle: '生成的代码', loading: '正在生成代码...', diff --git a/web/i18n/zh-Hans/workflow.ts b/web/i18n/zh-Hans/workflow.ts index 519f25d34e..6d574bc6f5 100644 --- a/web/i18n/zh-Hans/workflow.ts +++ b/web/i18n/zh-Hans/workflow.ts @@ -556,6 +556,23 @@ const translation = { iteration_one: '{{count}}个迭代', iteration_other: '{{count}}个迭代', currentIteration: '当前迭代', + comma: ',', + error_one: '{{count}}个失败', + error_other: '{{count}}个失败', + parallelMode: '并行模式', + parallelModeUpper: '并行模式', + parallelModeEnableTitle: '并行模式启用', + parallelModeEnableDesc: '启用并行模式时迭代内的任务支持并行执行。你可以在右侧的属性面板中进行配置。', + parallelPanelDesc: '在并行模式下,迭代中的任务支持并行执行。', + MaxParallelismTitle: '最大并行度', + MaxParallelismDesc: '最大并行度用于控制单次迭代中同时执行的任务数量。', + errorResponseMethod: '错误响应方法', + ErrorMethod: { + operationTerminated: '错误时终止', + continueOnError: '忽略错误并继续', + removeAbnormalOutput: '移除错误输出', + }, + answerNodeWarningDesc: '并行模式警告:在迭代中,回答节点、会话变量赋值和工具持久读/写操作可能会导致异常。', }, note: { addNote: '添加注释', diff --git a/web/i18n/zh-Hant/workflow.ts b/web/i18n/zh-Hant/workflow.ts index d65b3999d2..f3fbfdedc2 100644 --- a/web/i18n/zh-Hant/workflow.ts +++ b/web/i18n/zh-Hant/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}}個迭代', iteration_other: '{{count}}個迭代', currentIteration: '當前迭代', + ErrorMethod: { + operationTerminated: '終止', + removeAbnormalOutput: 'remove-abnormal-output', + continueOnError: '出錯時繼續', + }, + comma: ',', + parallelMode: '並行模式', + parallelModeEnableTitle: 'Parallel Mode 已啟用', + MaxParallelismTitle: '最大並行度', + parallelModeUpper: '並行模式', + parallelPanelDesc: '在並行模式下,反覆運算中的任務支援並行執行。', + error_one: '{{count}}錯誤', + errorResponseMethod: '錯誤回應方法', + parallelModeEnableDesc: '在並行模式下,反覆運算中的任務支援並行執行。您可以在右側的 properties 面板中進行配置。', + answerNodeWarningDesc: '並行模式警告:反覆運算中的應答節點、對話變數賦值和持久讀/寫操作可能會導致異常。', + error_other: '{{count}}錯誤', + MaxParallelismDesc: '最大並行度用於控制在單個反覆運算中同時執行的任務數。', }, note: { editor: { diff --git a/web/models/common.ts b/web/models/common.ts index 9ab27a6018..dc2b1120b9 100644 --- a/web/models/common.ts +++ b/web/models/common.ts @@ -216,7 +216,7 @@ export type FileUploadConfigResponse = { file_size_limit: number // default is 15MB audio_file_size_limit?: number // default is 50MB video_file_size_limit?: number // default is 100MB - + workflow_file_upload_limit?: number // default is 10 } export type InvitationResult = { diff --git a/web/service/base.ts b/web/service/base.ts index 8efb97cff3..6e7d90fd3a 100644 --- a/web/service/base.ts +++ b/web/service/base.ts @@ -1,4 +1,5 @@ import { API_PREFIX, IS_CE_EDITION, MARKETPLACE_API_PREFIX, PUBLIC_API_PREFIX } from '@/config' +import { refreshAccessTokenOrRelogin } from './refresh-token' import Toast from '@/app/components/base/toast' import type { AnnotationReply, MessageEnd, MessageReplace, ThoughtItem } from '@/app/components/base/chat/chat/type' import type { VisionFile } from '@/types/app' @@ -499,7 +500,9 @@ export const upload = (options: any, isPublicAPI?: boolean, url?: string, search export const ssePost = ( url: string, fetchOptions: FetchOptionType, - { + otherOptions: IOtherOptions, +) => { + const { isPublicAPI = false, onData, onCompleted, @@ -522,8 +525,7 @@ export const ssePost = ( onTextReplace, onError, getAbortController, - }: IOtherOptions, -) => { + } = otherOptions const abortController = new AbortController() const options = Object.assign({}, baseOptions, { @@ -547,21 +549,29 @@ export const ssePost = ( globalThis.fetch(urlWithPrefix, options as RequestInit) .then((res) => { if (!/^(2|3)\d{2}$/.test(String(res.status))) { - res.json().then((data: any) => { - if (isPublicAPI) { - if (data.code === 'web_sso_auth_required') - requiredWebSSOLogin() + if (res.status === 401) { + refreshAccessTokenOrRelogin(TIME_OUT).then(() => { + ssePost(url, fetchOptions, otherOptions) + }).catch(() => { + res.json().then((data: any) => { + if (isPublicAPI) { + if (data.code === 'web_sso_auth_required') + requiredWebSSOLogin() - if (data.code === 'unauthorized') { - removeAccessToken() - globalThis.location.reload() - } - if (res.status === 401) - return - } - Toast.notify({ type: 'error', message: data.message || 'Server Error' }) - }) - onError?.('Server Error') + if (data.code === 'unauthorized') { + removeAccessToken() + globalThis.location.reload() + } + } + }) + }) + } + else { + res.json().then((data) => { + Toast.notify({ type: 'error', message: data.message || 'Server Error' }) + }) + onError?.('Server Error') + } return } return handleStream(res, (str: string, isFirstMessage: boolean, moreInfo: IOnDataMoreInfo) => { @@ -584,7 +594,54 @@ export const ssePost = ( // base request export const request = (url: string, options = {}, otherOptions?: IOtherOptions) => { - return baseFetch(url, options, otherOptions || {}) + return new Promise((resolve, reject) => { + const otherOptionsForBaseFetch = otherOptions || {} + baseFetch(url, options, otherOptionsForBaseFetch).then(resolve).catch((errResp) => { + if (errResp?.status === 401) { + return refreshAccessTokenOrRelogin(TIME_OUT).then(() => { + baseFetch(url, options, otherOptionsForBaseFetch).then(resolve).catch(reject) + }).catch(() => { + const { + isPublicAPI = false, + silent, + } = otherOptionsForBaseFetch + const bodyJson = errResp.json() + if (isPublicAPI) { + return bodyJson.then((data: ResponseError) => { + if (data.code === 'web_sso_auth_required') + requiredWebSSOLogin() + + if (data.code === 'unauthorized') { + removeAccessToken() + globalThis.location.reload() + } + + return Promise.reject(data) + }) + } + const loginUrl = `${globalThis.location.origin}/signin` + bodyJson.then((data: ResponseError) => { + if (data.code === 'init_validate_failed' && IS_CE_EDITION && !silent) + Toast.notify({ type: 'error', message: data.message, duration: 4000 }) + else if (data.code === 'not_init_validated' && IS_CE_EDITION) + globalThis.location.href = `${globalThis.location.origin}/init` + else if (data.code === 'not_setup' && IS_CE_EDITION) + globalThis.location.href = `${globalThis.location.origin}/install` + else if (location.pathname !== '/signin' || !IS_CE_EDITION) + globalThis.location.href = loginUrl + else if (!silent) + Toast.notify({ type: 'error', message: data.message }) + }).catch(() => { + // Handle any other errors + globalThis.location.href = loginUrl + }) + }) + } + else { + reject(errResp) + } + }) + }) } // request methods diff --git a/web/service/common.ts b/web/service/common.ts index 9acbd75940..01b3a60991 100644 --- a/web/service/common.ts +++ b/web/service/common.ts @@ -320,8 +320,8 @@ export const verifyForgotPasswordToken: Fetcher = ({ url, body }) => post(url, { body }) -export const uploadRemoteFileInfo = (url: string) => { - return post<{ id: string; name: string; size: number; mime_type: string; url: string }>('/remote-files/upload', { body: { url } }) +export const uploadRemoteFileInfo = (url: string, isPublic?: boolean) => { + return post<{ id: string; name: string; size: number; mime_type: string; url: string }>('/remote-files/upload', { body: { url } }, { isPublicAPI: isPublic }) } export const sendEMailLoginCode = (email: string, language = 'en-US') => diff --git a/web/service/refresh-token.ts b/web/service/refresh-token.ts new file mode 100644 index 0000000000..8bd2215041 --- /dev/null +++ b/web/service/refresh-token.ts @@ -0,0 +1,75 @@ +import { apiPrefix } from '@/config' +import { fetchWithRetry } from '@/utils' + +let isRefreshing = false +function waitUntilTokenRefreshed() { + return new Promise((resolve, reject) => { + function _check() { + const isRefreshingSign = localStorage.getItem('is_refreshing') + if ((isRefreshingSign && isRefreshingSign === '1') || isRefreshing) { + setTimeout(() => { + _check() + }, 1000) + } + else { + resolve() + } + } + _check() + }) +} + +// only one request can send +async function getNewAccessToken(): Promise { + try { + const isRefreshingSign = localStorage.getItem('is_refreshing') + if ((isRefreshingSign && isRefreshingSign === '1') || isRefreshing) { + await waitUntilTokenRefreshed() + } + else { + globalThis.localStorage.setItem('is_refreshing', '1') + isRefreshing = true + const refresh_token = globalThis.localStorage.getItem('refresh_token') + + // Do not use baseFetch to refresh tokens. + // If a 401 response occurs and baseFetch itself attempts to refresh the token, + // it can lead to an infinite loop if the refresh attempt also returns 401. + // To avoid this, handle token refresh separately in a dedicated function + // that does not call baseFetch and uses a single retry mechanism. + const [error, ret] = await fetchWithRetry(globalThis.fetch(`${apiPrefix}/refresh-token`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json;utf-8', + }, + body: JSON.stringify({ refresh_token }), + })) + if (error) { + return Promise.reject(error) + } + else { + if (ret.status === 401) + return Promise.reject(ret) + + const { data } = await ret.json() + globalThis.localStorage.setItem('console_token', data.access_token) + globalThis.localStorage.setItem('refresh_token', data.refresh_token) + } + } + } + catch (error) { + console.error(error) + return Promise.reject(error) + } + finally { + isRefreshing = false + globalThis.localStorage.removeItem('is_refreshing') + } +} + +export async function refreshAccessTokenOrRelogin(timeout: number) { + return Promise.race([new Promise((resolve, reject) => setTimeout(() => { + isRefreshing = false + globalThis.localStorage.removeItem('is_refreshing') + reject(new Error('request timeout')) + }, timeout)), getNewAccessToken()]) +} diff --git a/web/types/workflow.ts b/web/types/workflow.ts index 810026b084..3c0675b605 100644 --- a/web/types/workflow.ts +++ b/web/types/workflow.ts @@ -19,6 +19,7 @@ export type NodeTracing = { process_data: any outputs?: any status: string + parallel_run_id?: string error?: string elapsed_time: number execution_metadata: { @@ -31,6 +32,7 @@ export type NodeTracing = { parallel_start_node_id?: string parent_parallel_id?: string parent_parallel_start_node_id?: string + parallel_mode_run_id?: string } metadata: { iterator_length: number @@ -121,6 +123,7 @@ export type NodeStartedResponse = { id: string node_id: string iteration_id?: string + parallel_run_id?: string node_type: string index: number predecessor_node_id?: string @@ -166,6 +169,7 @@ export type NodeFinishedResponse = { parallel_start_node_id?: string iteration_index?: number iteration_id?: string + parallel_mode_run_id: string } created_at: number files?: FileResponse[] @@ -200,6 +204,7 @@ export type IterationNextResponse = { output: any extras?: any created_at: number + parallel_mode_run_id: string execution_metadata: { parallel_id?: string }