diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index f08befefb8..76e5c04deb 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -83,9 +83,15 @@ jobs: compose-file: | docker/docker-compose.middleware.yaml services: | + db + redis sandbox ssrf_proxy + - name: setup test config + run: | + cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env + - name: Run Workflow run: uv run --project api bash dev/pytest/pytest_workflow.sh diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml index 7d0a873ebd..912267094b 100644 --- a/.github/workflows/vdb-tests.yml +++ b/.github/workflows/vdb-tests.yml @@ -84,6 +84,12 @@ jobs: elasticsearch oceanbase + - name: setup test config + run: | + echo $(pwd) + ls -lah . + cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env + - name: Check VDB Ready (TiDB) run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py diff --git a/README.md b/README.md index ec399e49ee..1dc7e2dd98 100644 --- a/README.md +++ b/README.md @@ -230,6 +230,10 @@ Deploy Dify to AWS with [CDK](https://aws.amazon.com/cdk/) Quickly deploy Dify to Alibaba cloud with [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) +#### Using Alibaba Cloud Data Management + +One-Click deploy Dify to Alibaba Cloud with [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) + ## Contributing diff --git a/README_AR.md b/README_AR.md index 5214da4894..d93bca8646 100644 --- a/README_AR.md +++ b/README_AR.md @@ -211,6 +211,11 @@ docker compose up -d #### استخدام Alibaba Cloud للنشر [بسرعة نشر Dify إلى سحابة علي بابا مع عش الحوسبة السحابية علي بابا](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + +#### استخدام Alibaba Cloud Data Management للنشر + +انشر ​​Dify على علي بابا كلاود بنقرة واحدة باستخدام [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) + ## المساهمة diff --git a/README_BN.md b/README_BN.md index 1911f186d7..3efee3684d 100644 --- a/README_BN.md +++ b/README_BN.md @@ -229,6 +229,10 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) +#### Alibaba Cloud Data Management ব্যবহার করে ডিপ্লয় + + [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) + ## Contributing diff --git a/README_CN.md b/README_CN.md index a194b01937..21e27429ec 100644 --- a/README_CN.md +++ b/README_CN.md @@ -225,6 +225,10 @@ docker compose up -d 使用 [阿里云计算巢](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) 将 Dify 一键部署到 阿里云 +#### 使用 阿里云数据管理DMS 部署 + +使用 [阿里云数据管理DMS](https://help.aliyun.com/zh/dms/dify-in-invitational-preview) 将 Dify 一键部署到 阿里云 + ## Star History diff --git a/README_DE.md b/README_DE.md index fd550a5b96..20c313035e 100644 --- a/README_DE.md +++ b/README_DE.md @@ -225,6 +225,10 @@ Bereitstellung von Dify auf AWS mit [CDK](https://aws.amazon.com/cdk/) [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) +#### Alibaba Cloud Data Management + +Ein-Klick-Bereitstellung von Dify in der Alibaba Cloud mit [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) + ## Contributing diff --git a/README_ES.md b/README_ES.md index 38dea09be1..e4b7df6686 100644 --- a/README_ES.md +++ b/README_ES.md @@ -225,6 +225,11 @@ Despliegue Dify en AWS usando [CDK](https://aws.amazon.com/cdk/) [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) +#### Alibaba Cloud Data Management + +Despliega Dify en Alibaba Cloud con un solo clic con [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) + + ## Contribuir Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). diff --git a/README_FR.md b/README_FR.md index 925918e47e..8fd17fb7c3 100644 --- a/README_FR.md +++ b/README_FR.md @@ -223,6 +223,10 @@ Déployez Dify sur AWS en utilisant [CDK](https://aws.amazon.com/cdk/) [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) +#### Alibaba Cloud Data Management + +Déployez Dify en un clic sur Alibaba Cloud avec [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) + ## Contribuer diff --git a/README_JA.md b/README_JA.md index 3f8a5b859d..a3ee81e1f2 100644 --- a/README_JA.md +++ b/README_JA.md @@ -155,7 +155,7 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ [こちら](https://dify.ai)のDify Cloudサービスを利用して、セットアップ不要で試すことができます。サンドボックスプランには、200回のGPT-4呼び出しが無料で含まれています。 - **Dify Community Editionのセルフホスティング
** -この[スタートガイド](#quick-start)を使用して、ローカル環境でDifyを簡単に実行できます。 +この[スタートガイド](#クイックスタート)を使用して、ローカル環境でDifyを簡単に実行できます。 詳しくは[ドキュメント](https://docs.dify.ai)をご覧ください。 - **企業/組織向けのDify
** @@ -223,6 +223,9 @@ docker compose up -d #### Alibaba Cloud [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) +#### Alibaba Cloud Data Management +[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) を利用して、DifyをAlibaba Cloudへワンクリックでデプロイできます + ## 貢献 diff --git a/README_KL.md b/README_KL.md index 9e562a4d73..3e5ab1a74f 100644 --- a/README_KL.md +++ b/README_KL.md @@ -223,6 +223,10 @@ wa'logh nIqHom neH ghun deployment toy'wI' [CDK](https://aws.amazon.com/cdk/) lo [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) +#### Alibaba Cloud Data Management + +[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) + ## Contributing diff --git a/README_KR.md b/README_KR.md index 683b3a86f4..3c504900e1 100644 --- a/README_KR.md +++ b/README_KR.md @@ -217,6 +217,10 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) +#### Alibaba Cloud Data Management + +[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)를 통해 원클릭으로 Dify를 Alibaba Cloud에 배포할 수 있습니다 + ## 기여 diff --git a/README_PT.md b/README_PT.md index b81127b70b..fb5f3662ae 100644 --- a/README_PT.md +++ b/README_PT.md @@ -222,6 +222,10 @@ Implante o Dify na AWS usando [CDK](https://aws.amazon.com/cdk/) [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) +#### Alibaba Cloud Data Management + +Implante o Dify na Alibaba Cloud com um clique usando o [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) + ## Contribuindo diff --git a/README_SI.md b/README_SI.md index 7034233233..647069a220 100644 --- a/README_SI.md +++ b/README_SI.md @@ -223,6 +223,10 @@ Uvedite Dify v AWS z uporabo [CDK](https://aws.amazon.com/cdk/) [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) +#### Alibaba Cloud Data Management + +Z enim klikom namestite Dify na Alibaba Cloud z [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) + ## Prispevam diff --git a/README_TR.md b/README_TR.md index 51156933d4..f52335646a 100644 --- a/README_TR.md +++ b/README_TR.md @@ -216,6 +216,10 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) +#### Alibaba Cloud Data Management + +[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) kullanarak Dify'ı tek tıkla Alibaba Cloud'a dağıtın + ## Katkıda Bulunma diff --git a/README_TW.md b/README_TW.md index 291da28825..71082ff893 100644 --- a/README_TW.md +++ b/README_TW.md @@ -228,6 +228,10 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify [阿里云](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) +#### 使用 阿里雲數據管理DMS 進行部署 + +透過 [阿里雲數據管理DMS](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/),一鍵將 Dify 部署至阿里雲 + ## 貢獻 diff --git a/README_VI.md b/README_VI.md index 51a2e9e9e6..58d8434fff 100644 --- a/README_VI.md +++ b/README_VI.md @@ -219,6 +219,10 @@ Triển khai Dify trên AWS bằng [CDK](https://aws.amazon.com/cdk/) [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) +#### Alibaba Cloud Data Management + +Triển khai Dify lên Alibaba Cloud chỉ với một cú nhấp chuột bằng [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) + ## Đóng góp diff --git a/api/.ruff.toml b/api/.ruff.toml index facb0d5419..0169613bf8 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -1,6 +1,4 @@ -exclude = [ - "migrations/*", -] +exclude = ["migrations/*"] line-length = 120 [format] @@ -9,14 +7,14 @@ quote-style = "double" [lint] preview = false select = [ - "B", # flake8-bugbear rules - "C4", # flake8-comprehensions - "E", # pycodestyle E rules - "F", # pyflakes rules - "FURB", # refurb rules - "I", # isort rules - "N", # pep8-naming - "PT", # flake8-pytest-style rules + "B", # flake8-bugbear rules + "C4", # flake8-comprehensions + "E", # pycodestyle E rules + "F", # pyflakes rules + "FURB", # refurb rules + "I", # isort rules + "N", # pep8-naming + "PT", # flake8-pytest-style rules "PLC0208", # iteration-over-set "PLC0414", # useless-import-alias "PLE0604", # invalid-all-object @@ -24,19 +22,19 @@ select = [ "PLR0402", # manual-from-import "PLR1711", # useless-return "PLR1714", # repeated-equality-comparison - "RUF013", # implicit-optional - "RUF019", # unnecessary-key-check - "RUF100", # unused-noqa - "RUF101", # redirected-noqa - "RUF200", # invalid-pyproject-toml - "RUF022", # unsorted-dunder-all - "S506", # unsafe-yaml-load - "SIM", # flake8-simplify rules - "TRY400", # error-instead-of-exception - "TRY401", # verbose-log-message - "UP", # pyupgrade rules - "W191", # tab-indentation - "W605", # invalid-escape-sequence + "RUF013", # implicit-optional + "RUF019", # unnecessary-key-check + "RUF100", # unused-noqa + "RUF101", # redirected-noqa + "RUF200", # invalid-pyproject-toml + "RUF022", # unsorted-dunder-all + "S506", # unsafe-yaml-load + "SIM", # flake8-simplify rules + "TRY400", # error-instead-of-exception + "TRY401", # verbose-log-message + "UP", # pyupgrade rules + "W191", # tab-indentation + "W605", # invalid-escape-sequence # security related linting rules # RCE proctection (sort of) "S102", # exec-builtin, disallow use of `exec` @@ -47,36 +45,37 @@ select = [ ] ignore = [ - "E402", # module-import-not-at-top-of-file - "E711", # none-comparison - "E712", # true-false-comparison - "E721", # type-comparison - "E722", # bare-except - "F821", # undefined-name - "F841", # unused-variable + "E402", # module-import-not-at-top-of-file + "E711", # none-comparison + "E712", # true-false-comparison + "E721", # type-comparison + "E722", # bare-except + "F821", # undefined-name + "F841", # unused-variable "FURB113", # repeated-append "FURB152", # math-constant - "UP007", # non-pep604-annotation - "UP032", # f-string - "UP045", # non-pep604-annotation-optional - "B005", # strip-with-multi-characters - "B006", # mutable-argument-default - "B007", # unused-loop-control-variable - "B026", # star-arg-unpacking-after-keyword-arg - "B903", # class-as-data-structure - "B904", # raise-without-from-inside-except - "B905", # zip-without-explicit-strict - "N806", # non-lowercase-variable-in-function - "N815", # mixed-case-variable-in-class-scope - "PT011", # pytest-raises-too-broad - "SIM102", # collapsible-if - "SIM103", # needless-bool - "SIM105", # suppressible-exception - "SIM107", # return-in-try-except-finally - "SIM108", # if-else-block-instead-of-if-exp - "SIM113", # enumerate-for-loop - "SIM117", # multiple-with-statements - "SIM210", # if-expr-with-true-false + "UP007", # non-pep604-annotation + "UP032", # f-string + "UP045", # non-pep604-annotation-optional + "B005", # strip-with-multi-characters + "B006", # mutable-argument-default + "B007", # unused-loop-control-variable + "B026", # star-arg-unpacking-after-keyword-arg + "B903", # class-as-data-structure + "B904", # raise-without-from-inside-except + "B905", # zip-without-explicit-strict + "N806", # non-lowercase-variable-in-function + "N815", # mixed-case-variable-in-class-scope + "PT011", # pytest-raises-too-broad + "SIM102", # collapsible-if + "SIM103", # needless-bool + "SIM105", # suppressible-exception + "SIM107", # return-in-try-except-finally + "SIM108", # if-else-block-instead-of-if-exp + "SIM113", # enumerate-for-loop + "SIM117", # multiple-with-statements + "SIM210", # if-expr-with-true-false + "UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/ ] [lint.per-file-ignores] diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 60ba272ec9..427602676f 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -223,6 +223,10 @@ class CeleryConfig(DatabaseConfig): default=None, ) + CELERY_SENTINEL_PASSWORD: Optional[str] = Field( + description="Password of the Redis Sentinel master.", + default=None, + ) CELERY_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field( description="Timeout for Redis Sentinel socket operations in seconds.", default=0.1, diff --git a/api/configs/packaging/__init__.py b/api/configs/packaging/__init__.py index 0107df22c5..dddf71c094 100644 --- a/api/configs/packaging/__init__.py +++ b/api/configs/packaging/__init__.py @@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings): CURRENT_VERSION: str = Field( description="Dify version", - default="1.4.3", + default="1.5.0", ) COMMIT_SHA: str = Field( diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index a974c63e35..dbdcdc46ce 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -63,6 +63,7 @@ from .app import ( statistic, workflow, workflow_app_log, + workflow_draft_variable, workflow_run, workflow_statistic, ) diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index cbbdd324ba..a9f088a276 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -1,5 +1,6 @@ import json import logging +from collections.abc import Sequence from typing import cast from flask import abort, request @@ -18,10 +19,12 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom +from core.file.models import File from extensions.ext_database import db -from factories import variable_factory +from factories import file_factory, variable_factory from fields.workflow_fields import workflow_fields, workflow_pagination_fields from fields.workflow_run_fields import workflow_run_node_execution_fields from libs import helper @@ -30,6 +33,7 @@ from libs.login import current_user, login_required from models import App from models.account import Account from models.model import AppMode +from models.workflow import Workflow from services.app_generate_service import AppGenerateService from services.errors.app import WorkflowHashNotEqualError from services.errors.llm import InvokeRateLimitError @@ -38,6 +42,24 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE logger = logging.getLogger(__name__) +# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing +# at the controller level rather than in the workflow logic. This would improve separation +# of concerns and make the code more maintainable. +def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence[File]: + files = files or [] + + file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) + file_objs: Sequence[File] = [] + if file_extra_config is None: + return file_objs + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=workflow.tenant_id, + config=file_extra_config, + ) + return file_objs + + class DraftWorkflowApi(Resource): @setup_required @login_required @@ -402,15 +424,30 @@ class DraftWorkflowNodeRunApi(Resource): parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("query", type=str, required=False, location="json", default="") + parser.add_argument("files", type=list, location="json", default=[]) args = parser.parse_args() - inputs = args.get("inputs") - if inputs == None: + user_inputs = args.get("inputs") + if user_inputs is None: raise ValueError("missing inputs") + workflow_srv = WorkflowService() + # fetch draft workflow by app_model + draft_workflow = workflow_srv.get_draft_workflow(app_model=app_model) + if not draft_workflow: + raise ValueError("Workflow not initialized") + files = _parse_file(draft_workflow, args.get("files")) workflow_service = WorkflowService() + workflow_node_execution = workflow_service.run_draft_workflow_node( - app_model=app_model, node_id=node_id, user_inputs=inputs, account=current_user + app_model=app_model, + draft_workflow=draft_workflow, + node_id=node_id, + user_inputs=user_inputs, + account=current_user, + query=args.get("query", ""), + files=files, ) return workflow_node_execution @@ -731,6 +768,27 @@ class WorkflowByIdApi(Resource): return None, 204 +class DraftWorkflowNodeLastRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @marshal_with(workflow_run_node_execution_fields) + def get(self, app_model: App, node_id: str): + srv = WorkflowService() + workflow = srv.get_draft_workflow(app_model) + if not workflow: + raise NotFound("Workflow not found") + node_exec = srv.get_node_last_run( + app_model=app_model, + workflow=workflow, + node_id=node_id, + ) + if node_exec is None: + raise NotFound("last run not found") + return node_exec + + api.add_resource( DraftWorkflowApi, "/apps//workflows/draft", @@ -795,3 +853,7 @@ api.add_resource( WorkflowByIdApi, "/apps//workflows/", ) +api.add_resource( + DraftWorkflowNodeLastRunApi, + "/apps//workflows/draft/nodes//last-run", +) diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py new file mode 100644 index 0000000000..00d6fa3cbf --- /dev/null +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -0,0 +1,421 @@ +import logging +from typing import Any, NoReturn + +from flask import Response +from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden + +from controllers.console import api +from controllers.console.app.error import ( + DraftWorkflowNotExist, +) +from controllers.console.app.wraps import get_app_model +from controllers.console.wraps import account_initialization_required, setup_required +from controllers.web.error import InvalidArgumentError, NotFoundError +from core.variables.segment_group import SegmentGroup +from core.variables.segments import ArrayFileSegment, FileSegment, Segment +from core.variables.types import SegmentType +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from factories.file_factory import build_from_mapping, build_from_mappings +from factories.variable_factory import build_segment_with_type +from libs.login import current_user, login_required +from models import App, AppMode, db +from models.workflow import WorkflowDraftVariable +from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService +from services.workflow_service import WorkflowService + +logger = logging.getLogger(__name__) + + +def _convert_values_to_json_serializable_object(value: Segment) -> Any: + if isinstance(value, FileSegment): + return value.value.model_dump() + elif isinstance(value, ArrayFileSegment): + return [i.model_dump() for i in value.value] + elif isinstance(value, SegmentGroup): + return [_convert_values_to_json_serializable_object(i) for i in value.value] + else: + return value.value + + +def _serialize_var_value(variable: WorkflowDraftVariable) -> Any: + value = variable.get_value() + # create a copy of the value to avoid affecting the model cache. + value = value.model_copy(deep=True) + # Refresh the url signature before returning it to client. + if isinstance(value, FileSegment): + file = value.value + file.remote_url = file.generate_url() + elif isinstance(value, ArrayFileSegment): + files = value.value + for file in files: + file.remote_url = file.generate_url() + return _convert_values_to_json_serializable_object(value) + + +def _create_pagination_parser(): + parser = reqparse.RequestParser() + parser.add_argument( + "page", + type=inputs.int_range(1, 100_000), + required=False, + default=1, + location="args", + help="the page of data requested", + ) + parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") + return parser + + +_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = { + "id": fields.String, + "type": fields.String(attribute=lambda model: model.get_variable_type()), + "name": fields.String, + "description": fields.String, + "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()), + "value_type": fields.String, + "edited": fields.Boolean(attribute=lambda model: model.edited), + "visible": fields.Boolean, +} + +_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict( + _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, + value=fields.Raw(attribute=_serialize_var_value), +) + +_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = { + "id": fields.String, + "type": fields.String(attribute=lambda _: "env"), + "name": fields.String, + "description": fields.String, + "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()), + "value_type": fields.String, + "edited": fields.Boolean(attribute=lambda model: model.edited), + "visible": fields.Boolean, +} + +_WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = { + "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)), +} + + +def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]: + return var_list.variables + + +_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = { + "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items), + "total": fields.Raw(), +} + +_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = { + "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items), +} + + +def _api_prerequisite(f): + """Common prerequisites for all draft workflow variable APIs. + + It ensures the following conditions are satisfied: + + - Dify has been property setup. + - The request user has logged in and initialized. + - The requested app is a workflow or a chat flow. + - The request user has the edit permission for the app. + """ + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def wrapper(*args, **kwargs): + if not current_user.is_editor: + raise Forbidden() + return f(*args, **kwargs) + + return wrapper + + +class WorkflowVariableCollectionApi(Resource): + @_api_prerequisite + @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) + def get(self, app_model: App): + """ + Get draft workflow + """ + parser = _create_pagination_parser() + args = parser.parse_args() + + # fetch draft workflow by app_model + workflow_service = WorkflowService() + workflow_exist = workflow_service.is_workflow_exist(app_model=app_model) + if not workflow_exist: + raise DraftWorkflowNotExist() + + # fetch draft workflow by app_model + with Session(bind=db.engine, expire_on_commit=False) as session: + draft_var_srv = WorkflowDraftVariableService( + session=session, + ) + workflow_vars = draft_var_srv.list_variables_without_values( + app_id=app_model.id, + page=args.page, + limit=args.limit, + ) + + return workflow_vars + + @_api_prerequisite + def delete(self, app_model: App): + draft_var_srv = WorkflowDraftVariableService( + session=db.session(), + ) + draft_var_srv.delete_workflow_variables(app_model.id) + db.session.commit() + return Response("", 204) + + +def validate_node_id(node_id: str) -> NoReturn | None: + if node_id in [ + CONVERSATION_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, + ]: + # NOTE(QuantumGhost): While we store the system and conversation variables as node variables + # with specific `node_id` in database, we still want to make the API separated. By disallowing + # accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`, + # we mitigate the risk that user of the API depending on the implementation detail of the API. + # + # ref: [Hyrum's Law](https://www.hyrumslaw.com/) + + raise InvalidArgumentError( + f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}", + ) + return None + + +class NodeVariableCollectionApi(Resource): + @_api_prerequisite + @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) + def get(self, app_model: App, node_id: str): + validate_node_id(node_id) + with Session(bind=db.engine, expire_on_commit=False) as session: + draft_var_srv = WorkflowDraftVariableService( + session=session, + ) + node_vars = draft_var_srv.list_node_variables(app_model.id, node_id) + + return node_vars + + @_api_prerequisite + def delete(self, app_model: App, node_id: str): + validate_node_id(node_id) + srv = WorkflowDraftVariableService(db.session()) + srv.delete_node_variables(app_model.id, node_id) + db.session.commit() + return Response("", 204) + + +class VariableApi(Resource): + _PATCH_NAME_FIELD = "name" + _PATCH_VALUE_FIELD = "value" + + @_api_prerequisite + @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) + def get(self, app_model: App, variable_id: str): + draft_var_srv = WorkflowDraftVariableService( + session=db.session(), + ) + variable = draft_var_srv.get_variable(variable_id=variable_id) + if variable is None: + raise NotFoundError(description=f"variable not found, id={variable_id}") + if variable.app_id != app_model.id: + raise NotFoundError(description=f"variable not found, id={variable_id}") + return variable + + @_api_prerequisite + @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) + def patch(self, app_model: App, variable_id: str): + # Request payload for file types: + # + # Local File: + # + # { + # "type": "image", + # "transfer_method": "local_file", + # "url": "", + # "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190" + # } + # + # Remote File: + # + # + # { + # "type": "image", + # "transfer_method": "remote_url", + # "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=", + # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4" + # } + + parser = reqparse.RequestParser() + parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json") + # Parse 'value' field as-is to maintain its original data structure + parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json") + + draft_var_srv = WorkflowDraftVariableService( + session=db.session(), + ) + args = parser.parse_args(strict=True) + + variable = draft_var_srv.get_variable(variable_id=variable_id) + if variable is None: + raise NotFoundError(description=f"variable not found, id={variable_id}") + if variable.app_id != app_model.id: + raise NotFoundError(description=f"variable not found, id={variable_id}") + + new_name = args.get(self._PATCH_NAME_FIELD, None) + raw_value = args.get(self._PATCH_VALUE_FIELD, None) + if new_name is None and raw_value is None: + return variable + + new_value = None + if raw_value is not None: + if variable.value_type == SegmentType.FILE: + if not isinstance(raw_value, dict): + raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") + raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id) + elif variable.value_type == SegmentType.ARRAY_FILE: + if not isinstance(raw_value, list): + raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") + if len(raw_value) > 0 and not isinstance(raw_value[0], dict): + raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") + raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id) + new_value = build_segment_with_type(variable.value_type, raw_value) + draft_var_srv.update_variable(variable, name=new_name, value=new_value) + db.session.commit() + return variable + + @_api_prerequisite + def delete(self, app_model: App, variable_id: str): + draft_var_srv = WorkflowDraftVariableService( + session=db.session(), + ) + variable = draft_var_srv.get_variable(variable_id=variable_id) + if variable is None: + raise NotFoundError(description=f"variable not found, id={variable_id}") + if variable.app_id != app_model.id: + raise NotFoundError(description=f"variable not found, id={variable_id}") + draft_var_srv.delete_variable(variable) + db.session.commit() + return Response("", 204) + + +class VariableResetApi(Resource): + @_api_prerequisite + def put(self, app_model: App, variable_id: str): + draft_var_srv = WorkflowDraftVariableService( + session=db.session(), + ) + + workflow_srv = WorkflowService() + draft_workflow = workflow_srv.get_draft_workflow(app_model) + if draft_workflow is None: + raise NotFoundError( + f"Draft workflow not found, app_id={app_model.id}", + ) + variable = draft_var_srv.get_variable(variable_id=variable_id) + if variable is None: + raise NotFoundError(description=f"variable not found, id={variable_id}") + if variable.app_id != app_model.id: + raise NotFoundError(description=f"variable not found, id={variable_id}") + + resetted = draft_var_srv.reset_variable(draft_workflow, variable) + db.session.commit() + if resetted is None: + return Response("", 204) + else: + return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS) + + +def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList: + with Session(bind=db.engine, expire_on_commit=False) as session: + draft_var_srv = WorkflowDraftVariableService( + session=session, + ) + if node_id == CONVERSATION_VARIABLE_NODE_ID: + draft_vars = draft_var_srv.list_conversation_variables(app_model.id) + elif node_id == SYSTEM_VARIABLE_NODE_ID: + draft_vars = draft_var_srv.list_system_variables(app_model.id) + else: + draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id) + return draft_vars + + +class ConversationVariableCollectionApi(Resource): + @_api_prerequisite + @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) + def get(self, app_model: App): + # NOTE(QuantumGhost): Prefill conversation variables into the draft variables table + # so their IDs can be returned to the caller. + workflow_srv = WorkflowService() + draft_workflow = workflow_srv.get_draft_workflow(app_model) + if draft_workflow is None: + raise NotFoundError(description=f"draft workflow not found, id={app_model.id}") + draft_var_srv = WorkflowDraftVariableService(db.session()) + draft_var_srv.prefill_conversation_variable_default_values(draft_workflow) + db.session.commit() + return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID) + + +class SystemVariableCollectionApi(Resource): + @_api_prerequisite + @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) + def get(self, app_model: App): + return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID) + + +class EnvironmentVariableCollectionApi(Resource): + @_api_prerequisite + def get(self, app_model: App): + """ + Get draft workflow + """ + # fetch draft workflow by app_model + workflow_service = WorkflowService() + workflow = workflow_service.get_draft_workflow(app_model=app_model) + if workflow is None: + raise DraftWorkflowNotExist() + + env_vars = workflow.environment_variables + env_vars_list = [] + for v in env_vars: + env_vars_list.append( + { + "id": v.id, + "type": "env", + "name": v.name, + "description": v.description, + "selector": v.selector, + "value_type": v.value_type.value, + "value": v.value, + # Do not track edited for env vars. + "edited": False, + "visible": True, + "editable": True, + } + ) + + return {"items": env_vars_list} + + +api.add_resource( + WorkflowVariableCollectionApi, + "/apps//workflows/draft/variables", +) +api.add_resource(NodeVariableCollectionApi, "/apps//workflows/draft/nodes//variables") +api.add_resource(VariableApi, "/apps//workflows/draft/variables/") +api.add_resource(VariableResetApi, "/apps//workflows/draft/variables//reset") + +api.add_resource(ConversationVariableCollectionApi, "/apps//workflows/draft/conversation-variables") +api.add_resource(SystemVariableCollectionApi, "/apps//workflows/draft/system-variables") +api.add_resource(EnvironmentVariableCollectionApi, "/apps//workflows/draft/environment-variables") diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index 9ad8c15847..03b60610aa 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -8,6 +8,15 @@ from libs.login import current_user from models import App, AppMode +def _load_app_model(app_id: str) -> Optional[App]: + app_model = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) + return app_model + + def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None): def decorator(view_func): @wraps(view_func) @@ -20,11 +29,7 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[ del kwargs["app_id"] - app_model = ( - db.session.query(App) - .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") - .first() - ) + app_model = _load_app_model(app_id) if not app_model: raise AppNotFoundError() diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 7ac60a0dc2..b2fcf3ce7b 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -5,7 +5,7 @@ from typing import cast from flask import request from flask_login import current_user -from flask_restful import Resource, fields, marshal, marshal_with, reqparse +from flask_restful import Resource, marshal, marshal_with, reqparse from sqlalchemy import asc, desc, select from werkzeug.exceptions import Forbidden, NotFound @@ -239,12 +239,10 @@ class DatasetDocumentListApi(Resource): return response - documents_and_batch_fields = {"documents": fields.List(fields.Nested(document_fields)), "batch": fields.String} - @setup_required @login_required @account_initialization_required - @marshal_with(documents_and_batch_fields) + @marshal_with(dataset_and_document_fields) @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") def post(self, dataset_id): @@ -290,6 +288,8 @@ class DatasetDocumentListApi(Resource): try: documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user) + dataset = DatasetService.get_dataset(dataset_id) + except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except QuotaExceededError: @@ -297,7 +297,7 @@ class DatasetDocumentListApi(Resource): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - return {"documents": documents, "batch": batch} + return {"dataset": dataset, "documents": documents, "batch": batch} @setup_required @login_required diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index db49da7840..48225ac90d 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -85,6 +85,7 @@ class MemberInviteEmailApi(Resource): return { "result": "success", "invitation_results": invitation_results, + "tenant_id": str(current_user.current_tenant.id), }, 201 @@ -110,7 +111,7 @@ class MemberCancelInviteApi(Resource): except Exception as e: raise ValueError(str(e)) - return {"result": "success"}, 204 + return {"result": "success", "tenant_id": str(current_user.current_tenant.id)}, 200 class MemberUpdateRoleApi(Resource): diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index 9bddbb4b4b..c0a4734828 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -13,6 +13,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.exc import PluginDaemonClientSideError from libs.login import login_required from models.account import TenantPluginPermission +from services.plugin.plugin_parameter_service import PluginParameterService from services.plugin.plugin_permission_service import PluginPermissionService from services.plugin.plugin_service import PluginService @@ -497,6 +498,42 @@ class PluginFetchPermissionApi(Resource): ) +class PluginFetchDynamicSelectOptionsApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + # check if the user is admin or owner + if not current_user.is_admin_or_owner: + raise Forbidden() + + tenant_id = current_user.current_tenant_id + user_id = current_user.id + + parser = reqparse.RequestParser() + parser.add_argument("plugin_id", type=str, required=True, location="args") + parser.add_argument("provider", type=str, required=True, location="args") + parser.add_argument("action", type=str, required=True, location="args") + parser.add_argument("parameter", type=str, required=True, location="args") + parser.add_argument("provider_type", type=str, required=True, location="args") + args = parser.parse_args() + + try: + options = PluginParameterService.get_dynamic_select_options( + tenant_id, + user_id, + args["plugin_id"], + args["provider"], + args["action"], + args["parameter"], + args["provider_type"], + ) + except PluginDaemonClientSideError as e: + raise ValueError(e) + + return jsonable_encoder({"options": options}) + + api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key") api.add_resource(PluginListApi, "/workspaces/current/plugin/list") api.add_resource(PluginListLatestVersionsApi, "/workspaces/current/plugin/list/latest-versions") @@ -521,3 +558,5 @@ api.add_resource(PluginFetchMarketplacePkgApi, "/workspaces/current/plugin/marke api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change") api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch") + +api.add_resource(PluginFetchDynamicSelectOptionsApi, "/workspaces/current/plugin/parameters/dynamic-options") diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 41063b35a5..327e9ce834 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -17,6 +17,7 @@ from core.plugin.entities.request import ( RequestInvokeApp, RequestInvokeEncrypt, RequestInvokeLLM, + RequestInvokeLLMWithStructuredOutput, RequestInvokeModeration, RequestInvokeParameterExtractorNode, RequestInvokeQuestionClassifierNode, @@ -47,6 +48,21 @@ class PluginInvokeLLMApi(Resource): return length_prefixed_response(0xF, generator()) +class PluginInvokeLLMWithStructuredOutputApi(Resource): + @setup_required + @plugin_inner_api_only + @get_user_tenant + @plugin_data(payload_type=RequestInvokeLLMWithStructuredOutput) + def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLMWithStructuredOutput): + def generator(): + response = PluginModelBackwardsInvocation.invoke_llm_with_structured_output( + user_model.id, tenant_model, payload + ) + return PluginModelBackwardsInvocation.convert_to_event_stream(response) + + return length_prefixed_response(0xF, generator()) + + class PluginInvokeTextEmbeddingApi(Resource): @setup_required @plugin_inner_api_only @@ -291,6 +307,7 @@ class PluginFetchAppInfoApi(Resource): api.add_resource(PluginInvokeLLMApi, "/invoke/llm") +api.add_resource(PluginInvokeLLMWithStructuredOutputApi, "/invoke/llm/structured-output") api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding") api.add_resource(PluginInvokeRerankApi, "/invoke/rerank") api.add_resource(PluginInvokeTTSApi, "/invoke/tts") diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index a2fc2d4675..77568b75f1 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -29,7 +29,19 @@ class EnterpriseWorkspace(Resource): tenant_was_created.send(tenant) - return {"message": "enterprise workspace created."} + resp = { + "id": tenant.id, + "name": tenant.name, + "plan": tenant.plan, + "status": tenant.status, + "created_at": tenant.created_at.isoformat() + "Z" if tenant.created_at else None, + "updated_at": tenant.updated_at.isoformat() + "Z" if tenant.updated_at else None, + } + + return { + "message": "enterprise workspace created.", + "tenant": resp, + } class EnterpriseWorkspaceNoOwnerEmail(Resource): diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 839afdb9fd..a499719fc3 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -133,6 +133,22 @@ class DatasetListApi(DatasetApiResource): parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") args = parser.parse_args() + + if args.get("embedding_model_provider"): + DatasetService.check_embedding_model_setting( + tenant_id, args.get("embedding_model_provider"), args.get("embedding_model") + ) + if ( + args.get("retrieval_model") + and args.get("retrieval_model").get("reranking_model") + and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name") + ): + DatasetService.check_reranking_model_setting( + tenant_id, + args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), + args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), + ) + try: dataset = DatasetService.create_empty_dataset( tenant_id=tenant_id, @@ -265,10 +281,20 @@ class DatasetApi(DatasetApiResource): data = request.get_json() # check embedding model setting - if data.get("indexing_technique") == "high_quality": + if data.get("indexing_technique") == "high_quality" or data.get("embedding_model_provider"): DatasetService.check_embedding_model_setting( dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model") ) + if ( + data.get("retrieval_model") + and data.get("retrieval_model").get("reranking_model") + and data.get("retrieval_model").get("reranking_model").get("reranking_provider_name") + ): + DatasetService.check_reranking_model_setting( + dataset.tenant_id, + data.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), + data.get("retrieval_model").get("reranking_model").get("reranking_model_name"), + ) # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator DatasetPermissionService.check_permission( diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index e4779f3bdf..6213fad173 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -29,7 +29,7 @@ from extensions.ext_database import db from fields.document_fields import document_fields, document_status_fields from libs.login import current_user from models.dataset import Dataset, Document, DocumentSegment -from services.dataset_service import DocumentService +from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.file_service import FileService @@ -59,6 +59,7 @@ class DocumentAddByTextApi(DatasetApiResource): parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") args = parser.parse_args() + dataset_id = str(dataset_id) tenant_id = str(tenant_id) dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() @@ -74,6 +75,21 @@ class DocumentAddByTextApi(DatasetApiResource): if text is None or name is None: raise ValueError("Both 'text' and 'name' must be non-null values.") + if args.get("embedding_model_provider"): + DatasetService.check_embedding_model_setting( + tenant_id, args.get("embedding_model_provider"), args.get("embedding_model") + ) + if ( + args.get("retrieval_model") + and args.get("retrieval_model").get("reranking_model") + and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name") + ): + DatasetService.check_reranking_model_setting( + tenant_id, + args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), + args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), + ) + upload_file = FileService.upload_text(text=str(text), text_name=str(name)) data_source = { "type": "upload_file", @@ -124,6 +140,17 @@ class DocumentUpdateByTextApi(DatasetApiResource): if not dataset: raise ValueError("Dataset does not exist.") + if ( + args.get("retrieval_model") + and args.get("retrieval_model").get("reranking_model") + and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name") + ): + DatasetService.check_reranking_model_setting( + tenant_id, + args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), + args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), + ) + # indexing_technique is already set in dataset since this is an update args["indexing_technique"] = dataset.indexing_technique @@ -188,6 +215,21 @@ class DocumentAddByFileApi(DatasetApiResource): raise ValueError("indexing_technique is required.") args["indexing_technique"] = indexing_technique + if "embedding_model_provider" in args: + DatasetService.check_embedding_model_setting( + tenant_id, args["embedding_model_provider"], args["embedding_model"] + ) + if ( + "retrieval_model" in args + and args["retrieval_model"].get("reranking_model") + and args["retrieval_model"].get("reranking_model").get("reranking_provider_name") + ): + DatasetService.check_reranking_model_setting( + tenant_id, + args["retrieval_model"].get("reranking_model").get("reranking_provider_name"), + args["retrieval_model"].get("reranking_model").get("reranking_model_name"), + ) + # save file info file = request.files["file"] # check file diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py index 4371e679db..036e11d5c5 100644 --- a/api/controllers/web/error.py +++ b/api/controllers/web/error.py @@ -139,3 +139,13 @@ class InvokeRateLimitError(BaseHTTPException): error_code = "rate_limit_error" description = "Rate Limit Error" code = 429 + + +class NotFoundError(BaseHTTPException): + error_code = "not_found" + code = 404 + + +class InvalidArgumentError(BaseHTTPException): + error_code = "invalid_param" + code = 400 diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 3f31b1c3d5..75bd2f677a 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -104,6 +104,7 @@ class VariableEntity(BaseModel): Variable Entity. """ + # `variable` records the name of the variable in user inputs. variable: str label: str description: str = "" diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index a8848b9534..61de9ec670 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -29,13 +29,14 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom from models.enums import WorkflowRunTriggeredFrom from services.conversation_service import ConversationService -from services.errors.message import MessageNotExistsError +from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService logger = logging.getLogger(__name__) @@ -116,6 +117,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): ) # parse files + # TODO(QuantumGhost): Move file parsing logic to the API controller layer + # for better separation of concerns. + # + # For implementation reference, see the `_parse_file` function and + # `DraftWorkflowNodeRunApi` class which handle this properly. files = args["files"] if args.get("files") else [] file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) if file_extra_config: @@ -261,6 +267,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) + var_loader = DraftVarLoader( + engine=db.engine, + app_id=application_generate_entity.app_config.app_id, + tenant_id=application_generate_entity.app_config.tenant_id, + ) + draft_var_srv = WorkflowDraftVariableService(db.session()) + draft_var_srv.prefill_conversation_variable_default_values(workflow) return self._generate( workflow=workflow, @@ -271,6 +284,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, conversation=None, stream=streaming, + variable_loader=var_loader, ) def single_loop_generate( @@ -336,6 +350,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) + var_loader = DraftVarLoader( + engine=db.engine, + app_id=application_generate_entity.app_config.app_id, + tenant_id=application_generate_entity.app_config.tenant_id, + ) + draft_var_srv = WorkflowDraftVariableService(db.session()) + draft_var_srv.prefill_conversation_variable_default_values(workflow) return self._generate( workflow=workflow, @@ -346,6 +367,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, conversation=None, stream=streaming, + variable_loader=var_loader, ) def _generate( @@ -359,6 +381,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_node_execution_repository: WorkflowNodeExecutionRepository, conversation: Optional[Conversation] = None, stream: bool = True, + variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: """ Generate App response. @@ -410,6 +433,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): "conversation_id": conversation.id, "message_id": message.id, "context": context, + "variable_loader": variable_loader, }, ) @@ -438,6 +462,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation_id: str, message_id: str, context: contextvars.Context, + variable_loader: VariableLoader, ) -> None: """ Generate worker in a new thread. @@ -454,8 +479,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # get conversation and message conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) - if message is None: - raise MessageNotExistsError("Message not exists") # chatbot app runner = AdvancedChatAppRunner( @@ -464,6 +487,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation=conversation, message=message, dialogue_count=self._dialogue_count, + variable_loader=variable_loader, ) runner.run() diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index d9b3833862..840a3c9d3b 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -19,6 +19,7 @@ from core.moderation.base import ModerationError from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey +from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.enums import UserFrom @@ -40,14 +41,17 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): conversation: Conversation, message: Message, dialogue_count: int, + variable_loader: VariableLoader, ) -> None: - super().__init__(queue_manager) - + super().__init__(queue_manager, variable_loader) self.application_generate_entity = application_generate_entity self.conversation = conversation self.message = message self._dialogue_count = dialogue_count + def _get_app_id(self) -> str: + return self.application_generate_entity.app_config.app_id + def run(self) -> None: app_config = self.application_generate_entity.app_config app_config = cast(AdvancedChatAppConfig, app_config) diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index a448bf8a94..edea6199d3 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -26,7 +26,6 @@ from factories import file_factory from libs.flask_utils import preserve_flask_contexts from models import Account, App, EndUser from services.conversation_service import ConversationService -from services.errors.message import MessageNotExistsError logger = logging.getLogger(__name__) @@ -124,6 +123,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): override_model_config_dict["retriever_resource"] = {"enabled": True} # parse files + # TODO(QuantumGhost): Move file parsing logic to the API controller layer + # for better separation of concerns. + # + # For implementation reference, see the `_parse_file` function and + # `DraftWorkflowNodeRunApi` class which handle this properly. files = args.get("files") or [] file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: @@ -233,8 +237,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): # get conversation and message conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) - if message is None: - raise MessageNotExistsError("Message not exists") # chatbot app runner = AgentChatAppRunner() diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index a1329cb938..a28c106ce9 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -25,7 +25,6 @@ from factories import file_factory from models.account import Account from models.model import App, EndUser from services.conversation_service import ConversationService -from services.errors.message import MessageNotExistsError logger = logging.getLogger(__name__) @@ -115,6 +114,11 @@ class ChatAppGenerator(MessageBasedAppGenerator): override_model_config_dict["retriever_resource"] = {"enabled": True} # parse files + # TODO(QuantumGhost): Move file parsing logic to the API controller layer + # for better separation of concerns. + # + # For implementation reference, see the `_parse_file` function and + # `DraftWorkflowNodeRunApi` class which handle this properly. files = args["files"] if args.get("files") else [] file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: @@ -219,8 +223,6 @@ class ChatAppGenerator(MessageBasedAppGenerator): # get conversation and message conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) - if message is None: - raise MessageNotExistsError("Message not exists") # chatbot app runner = ChatAppRunner() diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 6f524a5872..cd1d298ca2 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -48,6 +48,7 @@ from core.workflow.entities.workflow_execution import WorkflowExecution from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolNodeData +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import ( Account, CreatorUserRole, @@ -125,7 +126,7 @@ class WorkflowResponseConverter: id=workflow_execution.id_, workflow_id=workflow_execution.workflow_id, status=workflow_execution.status, - outputs=workflow_execution.outputs, + outputs=WorkflowRuntimeTypeConverter().to_json_encodable(workflow_execution.outputs), error=workflow_execution.error_message, elapsed_time=workflow_execution.elapsed_time, total_tokens=workflow_execution.total_tokens, @@ -202,6 +203,8 @@ class WorkflowResponseConverter: if not workflow_node_execution.finished_at: return None + json_converter = WorkflowRuntimeTypeConverter() + return NodeFinishStreamResponse( task_id=task_id, workflow_run_id=workflow_node_execution.workflow_execution_id, @@ -214,7 +217,7 @@ class WorkflowResponseConverter: predecessor_node_id=workflow_node_execution.predecessor_node_id, inputs=workflow_node_execution.inputs, process_data=workflow_node_execution.process_data, - outputs=workflow_node_execution.outputs, + outputs=json_converter.to_json_encodable(workflow_node_execution.outputs), status=workflow_node_execution.status, error=workflow_node_execution.error, elapsed_time=workflow_node_execution.elapsed_time, @@ -245,6 +248,8 @@ class WorkflowResponseConverter: if not workflow_node_execution.finished_at: return None + json_converter = WorkflowRuntimeTypeConverter() + return NodeRetryStreamResponse( task_id=task_id, workflow_run_id=workflow_node_execution.workflow_execution_id, @@ -257,7 +262,7 @@ class WorkflowResponseConverter: predecessor_node_id=workflow_node_execution.predecessor_node_id, inputs=workflow_node_execution.inputs, process_data=workflow_node_execution.process_data, - outputs=workflow_node_execution.outputs, + outputs=json_converter.to_json_encodable(workflow_node_execution.outputs), status=workflow_node_execution.status, error=workflow_node_execution.error, elapsed_time=workflow_node_execution.elapsed_time, @@ -376,6 +381,7 @@ class WorkflowResponseConverter: workflow_execution_id: str, event: QueueIterationCompletedEvent, ) -> IterationNodeCompletedStreamResponse: + json_converter = WorkflowRuntimeTypeConverter() return IterationNodeCompletedStreamResponse( task_id=task_id, workflow_run_id=workflow_execution_id, @@ -384,7 +390,7 @@ class WorkflowResponseConverter: node_id=event.node_id, node_type=event.node_type.value, title=event.node_data.title, - outputs=event.outputs, + outputs=json_converter.to_json_encodable(event.outputs), created_at=int(time.time()), extras={}, inputs=event.inputs or {}, @@ -463,7 +469,7 @@ class WorkflowResponseConverter: node_id=event.node_id, node_type=event.node_type.value, title=event.node_data.title, - outputs=event.outputs, + outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs), created_at=int(time.time()), extras={}, inputs=event.inputs or {}, diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index adcbaad3ec..966a6f1d66 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -101,6 +101,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator): ) # parse files + # TODO(QuantumGhost): Move file parsing logic to the API controller layer + # for better separation of concerns. + # + # For implementation reference, see the `_parse_file` function and + # `DraftWorkflowNodeRunApi` class which handle this properly. files = args["files"] if args.get("files") else [] file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: @@ -196,8 +201,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator): try: # get message message = self._get_message(message_id) - if message is None: - raise MessageNotExistsError() # chatbot app runner = CompletionAppRunner() diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 58b94f4d43..e84d59209d 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -29,6 +29,7 @@ from models.enums import CreatorUserRole from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.conversation import ConversationNotExistsError +from services.errors.message import MessageNotExistsError logger = logging.getLogger(__name__) @@ -251,7 +252,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): return introduction or "" - def _get_conversation(self, conversation_id: str): + def _get_conversation(self, conversation_id: str) -> Conversation: """ Get conversation by conversation id :param conversation_id: conversation id @@ -260,11 +261,11 @@ class MessageBasedAppGenerator(BaseAppGenerator): conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() if not conversation: - raise ConversationNotExistsError() + raise ConversationNotExistsError("Conversation not exists") return conversation - def _get_message(self, message_id: str) -> Optional[Message]: + def _get_message(self, message_id: str) -> Message: """ Get message by message id :param message_id: message id @@ -272,4 +273,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): """ message = db.session.query(Message).filter(Message.id == message_id).first() + if message is None: + raise MessageNotExistsError("Message not exists") + return message diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index fd15bd9f50..369fa0e48c 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -27,11 +27,13 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.enums import WorkflowRunTriggeredFrom +from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService logger = logging.getLogger(__name__) @@ -94,6 +96,11 @@ class WorkflowAppGenerator(BaseAppGenerator): files: Sequence[Mapping[str, Any]] = args.get("files") or [] # parse files + # TODO(QuantumGhost): Move file parsing logic to the API controller layer + # for better separation of concerns. + # + # For implementation reference, see the `_parse_file` function and + # `DraftWorkflowNodeRunApi` class which handle this properly. file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) system_files = file_factory.build_from_mappings( mappings=files, @@ -186,6 +193,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_node_execution_repository: WorkflowNodeExecutionRepository, streaming: bool = True, workflow_thread_pool_id: Optional[str] = None, + variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: """ Generate App response. @@ -219,6 +227,7 @@ class WorkflowAppGenerator(BaseAppGenerator): "queue_manager": queue_manager, "context": context, "workflow_thread_pool_id": workflow_thread_pool_id, + "variable_loader": variable_loader, }, ) @@ -303,6 +312,13 @@ class WorkflowAppGenerator(BaseAppGenerator): app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) + draft_var_srv = WorkflowDraftVariableService(db.session()) + draft_var_srv.prefill_conversation_variable_default_values(workflow) + var_loader = DraftVarLoader( + engine=db.engine, + app_id=application_generate_entity.app_config.app_id, + tenant_id=application_generate_entity.app_config.tenant_id, + ) return self._generate( app_model=app_model, @@ -313,6 +329,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, + variable_loader=var_loader, ) def single_loop_generate( @@ -379,7 +396,13 @@ class WorkflowAppGenerator(BaseAppGenerator): app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) - + draft_var_srv = WorkflowDraftVariableService(db.session()) + draft_var_srv.prefill_conversation_variable_default_values(workflow) + var_loader = DraftVarLoader( + engine=db.engine, + app_id=application_generate_entity.app_config.app_id, + tenant_id=application_generate_entity.app_config.tenant_id, + ) return self._generate( app_model=app_model, workflow=workflow, @@ -389,6 +412,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, + variable_loader=var_loader, ) def _generate_worker( @@ -397,6 +421,7 @@ class WorkflowAppGenerator(BaseAppGenerator): application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager, context: contextvars.Context, + variable_loader: VariableLoader, workflow_thread_pool_id: Optional[str] = None, ) -> None: """ @@ -415,6 +440,7 @@ class WorkflowAppGenerator(BaseAppGenerator): application_generate_entity=application_generate_entity, queue_manager=queue_manager, workflow_thread_pool_id=workflow_thread_pool_id, + variable_loader=variable_loader, ) runner.run() diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index b59e34e222..07aeb57fa3 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -12,6 +12,7 @@ from core.app.entities.app_invoke_entities import ( from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey +from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.enums import UserFrom @@ -30,6 +31,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager, + variable_loader: VariableLoader, workflow_thread_pool_id: Optional[str] = None, ) -> None: """ @@ -37,10 +39,13 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): :param queue_manager: application queue manager :param workflow_thread_pool_id: workflow thread pool id """ + super().__init__(queue_manager, variable_loader) self.application_generate_entity = application_generate_entity - self.queue_manager = queue_manager self.workflow_thread_pool_id = workflow_thread_pool_id + def _get_app_id(self) -> str: + return self.application_generate_entity.app_config.app_id + def run(self) -> None: """ Run application diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index facc24b4ca..dc6c381e86 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -1,6 +1,8 @@ from collections.abc import Mapping from typing import Any, Optional, cast +from sqlalchemy.orm import Session + from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.queue_entities import ( @@ -33,6 +35,7 @@ from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.graph_engine.entities.event import ( AgentLogEvent, + BaseNodeEvent, GraphEngineEvent, GraphRunFailedEvent, GraphRunPartialSucceededEvent, @@ -62,15 +65,23 @@ from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.graph import Graph from core.workflow.nodes import NodeType from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.model import App from models.workflow import Workflow +from services.workflow_draft_variable_service import ( + DraftVariableSaver, +) class WorkflowBasedAppRunner(AppRunner): - def __init__(self, queue_manager: AppQueueManager): + def __init__(self, queue_manager: AppQueueManager, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER) -> None: self.queue_manager = queue_manager + self._variable_loader = variable_loader + + def _get_app_id(self) -> str: + raise NotImplementedError("not implemented") def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph: """ @@ -173,6 +184,13 @@ class WorkflowBasedAppRunner(AppRunner): except NotImplementedError: variable_mapping = {} + load_into_variable_pool( + variable_loader=self._variable_loader, + variable_pool=variable_pool, + variable_mapping=variable_mapping, + user_inputs=user_inputs, + ) + WorkflowEntry.mapping_user_inputs_to_variable_pool( variable_mapping=variable_mapping, user_inputs=user_inputs, @@ -262,6 +280,12 @@ class WorkflowBasedAppRunner(AppRunner): ) except NotImplementedError: variable_mapping = {} + load_into_variable_pool( + self._variable_loader, + variable_pool=variable_pool, + variable_mapping=variable_mapping, + user_inputs=user_inputs, + ) WorkflowEntry.mapping_user_inputs_to_variable_pool( variable_mapping=variable_mapping, @@ -376,6 +400,8 @@ class WorkflowBasedAppRunner(AppRunner): in_loop_id=event.in_loop_id, ) ) + self._save_draft_var_for_event(event) + elif isinstance(event, NodeRunFailedEvent): self._publish_event( QueueNodeFailedEvent( @@ -438,6 +464,8 @@ class WorkflowBasedAppRunner(AppRunner): in_loop_id=event.in_loop_id, ) ) + self._save_draft_var_for_event(event) + elif isinstance(event, NodeInIterationFailedEvent): self._publish_event( QueueNodeInIterationFailedEvent( @@ -690,3 +718,30 @@ class WorkflowBasedAppRunner(AppRunner): def _publish_event(self, event: AppQueueEvent) -> None: self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) + + def _save_draft_var_for_event(self, event: BaseNodeEvent): + run_result = event.route_node_state.node_run_result + if run_result is None: + return + process_data = run_result.process_data + outputs = run_result.outputs + with Session(bind=db.engine) as session, session.begin(): + draft_var_saver = DraftVariableSaver( + session=session, + app_id=self._get_app_id(), + node_id=event.node_id, + node_type=event.node_type, + # FIXME(QuantumGhost): rely on private state of queue_manager is not ideal. + invoke_from=self.queue_manager._invoke_from, + node_execution_id=event.id, + enclosing_node_id=event.in_loop_id or event.in_iteration_id or None, + ) + draft_var_saver.save(process_data=process_data, outputs=outputs) + + +def _remove_first_element_from_variable_string(key: str) -> str: + """ + Remove the first element from the prefix. + """ + prefix, remaining = key.split(".", maxsplit=1) + return remaining diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index c0d99693b0..65ed267959 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -17,9 +17,24 @@ class InvokeFrom(Enum): Invoke From. """ + # SERVICE_API indicates that this invocation is from an API call to Dify app. + # + # Description of service api in Dify docs: + # https://docs.dify.ai/en/guides/application-publishing/developing-with-apis SERVICE_API = "service-api" + + # WEB_APP indicates that this invocation is from + # the web app of the workflow (or chatflow). + # + # Description of web app in Dify docs: + # https://docs.dify.ai/en/guides/application-publishing/launch-your-webapp-quickly/README WEB_APP = "web-app" + + # EXPLORE indicates that this invocation is from + # the workflow (or chatflow) explore page. EXPLORE = "explore" + # DEBUGGER indicates that this invocation is from + # the workflow (or chatflow) edit page. DEBUGGER = "debugger" @classmethod diff --git a/api/core/entities/parameter_entities.py b/api/core/entities/parameter_entities.py index 36800bc263..b071bfa5b1 100644 --- a/api/core/entities/parameter_entities.py +++ b/api/core/entities/parameter_entities.py @@ -15,6 +15,11 @@ class CommonParameterType(StrEnum): MODEL_SELECTOR = "model-selector" TOOLS_SELECTOR = "array[tools]" + # Dynamic select parameter + # Once you are not sure about the available options until authorization is done + # eg: Select a Slack channel from a Slack workspace + DYNAMIC_SELECT = "dynamic-select" + # TOOL_SELECTOR = "tool-selector" diff --git a/api/core/file/constants.py b/api/core/file/constants.py index ce1d238e93..0665ed7e0d 100644 --- a/api/core/file/constants.py +++ b/api/core/file/constants.py @@ -1 +1,11 @@ +from typing import Any + +# TODO(QuantumGhost): Refactor variable type identification. Instead of directly +# comparing `dify_model_identity` with constants throughout the codebase, extract +# this logic into a dedicated function. This would encapsulate the implementation +# details of how different variable types are identified. FILE_MODEL_IDENTITY = "__dify__file__" + + +def maybe_file_object(o: Any) -> bool: + return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 848d897779..f2fe306179 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -534,7 +534,7 @@ class IndexingRunner: # chunk nodes by chunk size indexing_start_at = time.perf_counter() tokens = 0 - if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy": # create keyword index create_keyword_thread = threading.Thread( target=self._process_keyword_index, @@ -572,7 +572,7 @@ class IndexingRunner: for future in futures: tokens += future.result() - if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy": create_keyword_thread.join() indexing_end_at = time.perf_counter() diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py new file mode 100644 index 0000000000..0aaf5abef0 --- /dev/null +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -0,0 +1,374 @@ +import json +from collections.abc import Generator, Mapping, Sequence +from copy import deepcopy +from enum import StrEnum +from typing import Any, Literal, Optional, cast, overload + +import json_repair +from pydantic import TypeAdapter, ValidationError + +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT +from core.model_manager import ModelInstance +from core.model_runtime.callbacks.base_callback import Callback +from core.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, +) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule + + +class ResponseFormat(StrEnum): + """Constants for model response formats""" + + JSON_SCHEMA = "json_schema" # model's structured output mode. some model like gemini, gpt-4o, support this mode. + JSON = "JSON" # model's json mode. some model like claude support this mode. + JSON_OBJECT = "json_object" # json mode's another alias. some model like deepseek-chat, qwen use this alias. + + +class SpecialModelType(StrEnum): + """Constants for identifying model types""" + + GEMINI = "gemini" + OLLAMA = "ollama" + + +@overload +def invoke_llm_with_structured_output( + provider: str, + model_schema: AIModelEntity, + model_instance: ModelInstance, + prompt_messages: Sequence[PromptMessage], + json_schema: Mapping[str, Any], + model_parameters: Optional[Mapping] = None, + tools: Sequence[PromptMessageTool] | None = None, + stop: Optional[list[str]] = None, + stream: Literal[True] = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, +) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... + + +@overload +def invoke_llm_with_structured_output( + provider: str, + model_schema: AIModelEntity, + model_instance: ModelInstance, + prompt_messages: Sequence[PromptMessage], + json_schema: Mapping[str, Any], + model_parameters: Optional[Mapping] = None, + tools: Sequence[PromptMessageTool] | None = None, + stop: Optional[list[str]] = None, + stream: Literal[False] = False, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, +) -> LLMResultWithStructuredOutput: ... + + +@overload +def invoke_llm_with_structured_output( + provider: str, + model_schema: AIModelEntity, + model_instance: ModelInstance, + prompt_messages: Sequence[PromptMessage], + json_schema: Mapping[str, Any], + model_parameters: Optional[Mapping] = None, + tools: Sequence[PromptMessageTool] | None = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, +) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ... + + +def invoke_llm_with_structured_output( + provider: str, + model_schema: AIModelEntity, + model_instance: ModelInstance, + prompt_messages: Sequence[PromptMessage], + json_schema: Mapping[str, Any], + model_parameters: Optional[Mapping] = None, + tools: Sequence[PromptMessageTool] | None = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, +) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: + """ + Invoke large language model with structured output + 1. This method invokes model_instance.invoke_llm with json_schema + 2. Try to parse the result as structured output + + :param prompt_messages: prompt messages + :param json_schema: json schema + :param model_parameters: model parameters + :param tools: tools for tool calling + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :param callbacks: callbacks + :return: full response or stream response chunk generator result + """ + + # handle native json schema + model_parameters_with_json_schema: dict[str, Any] = { + **(model_parameters or {}), + } + + if model_schema.support_structure_output: + model_parameters = _handle_native_json_schema( + provider, model_schema, json_schema, model_parameters_with_json_schema, model_schema.parameter_rules + ) + else: + # Set appropriate response format based on model capabilities + _set_response_format(model_parameters_with_json_schema, model_schema.parameter_rules) + + # handle prompt based schema + prompt_messages = _handle_prompt_based_schema( + prompt_messages=prompt_messages, + structured_output_schema=json_schema, + ) + + llm_result = model_instance.invoke_llm( + prompt_messages=list(prompt_messages), + model_parameters=model_parameters_with_json_schema, + tools=tools, + stop=stop, + stream=stream, + user=user, + callbacks=callbacks, + ) + + if isinstance(llm_result, LLMResult): + if not isinstance(llm_result.message.content, str): + raise OutputParserError( + f"Failed to parse structured output, LLM result is not a string: {llm_result.message.content}" + ) + + return LLMResultWithStructuredOutput( + structured_output=_parse_structured_output(llm_result.message.content), + model=llm_result.model, + message=llm_result.message, + usage=llm_result.usage, + system_fingerprint=llm_result.system_fingerprint, + prompt_messages=llm_result.prompt_messages, + ) + else: + + def generator() -> Generator[LLMResultChunkWithStructuredOutput, None, None]: + result_text: str = "" + prompt_messages: Sequence[PromptMessage] = [] + system_fingerprint: Optional[str] = None + for event in llm_result: + if isinstance(event, LLMResultChunk): + if isinstance(event.delta.message.content, str): + result_text += event.delta.message.content + prompt_messages = event.prompt_messages + system_fingerprint = event.system_fingerprint + + yield LLMResultChunkWithStructuredOutput( + model=model_schema.model, + prompt_messages=prompt_messages, + system_fingerprint=system_fingerprint, + delta=event.delta, + ) + + yield LLMResultChunkWithStructuredOutput( + structured_output=_parse_structured_output(result_text), + model=model_schema.model, + prompt_messages=prompt_messages, + system_fingerprint=system_fingerprint, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=""), + usage=None, + finish_reason=None, + ), + ) + + return generator() + + +def _handle_native_json_schema( + provider: str, + model_schema: AIModelEntity, + structured_output_schema: Mapping, + model_parameters: dict, + rules: list[ParameterRule], +) -> dict: + """ + Handle structured output for models with native JSON schema support. + + :param model_parameters: Model parameters to update + :param rules: Model parameter rules + :return: Updated model parameters with JSON schema configuration + """ + # Process schema according to model requirements + schema_json = _prepare_schema_for_model(provider, model_schema, structured_output_schema) + + # Set JSON schema in parameters + model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False) + + # Set appropriate response format if required by the model + for rule in rules: + if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value + + return model_parameters + + +def _set_response_format(model_parameters: dict, rules: list) -> None: + """ + Set the appropriate response format parameter based on model rules. + + :param model_parameters: Model parameters to update + :param rules: Model parameter rules + """ + for rule in rules: + if rule.name == "response_format": + if ResponseFormat.JSON.value in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON.value + elif ResponseFormat.JSON_OBJECT.value in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value + + +def _handle_prompt_based_schema( + prompt_messages: Sequence[PromptMessage], structured_output_schema: Mapping +) -> list[PromptMessage]: + """ + Handle structured output for models without native JSON schema support. + This function modifies the prompt messages to include schema-based output requirements. + + Args: + prompt_messages: Original sequence of prompt messages + + Returns: + list[PromptMessage]: Updated prompt messages with structured output requirements + """ + # Convert schema to string format + schema_str = json.dumps(structured_output_schema, ensure_ascii=False) + + # Find existing system prompt with schema placeholder + system_prompt = next( + (prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)), + None, + ) + structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str) + # Prepare system prompt content + system_prompt_content = ( + structured_output_prompt + "\n\n" + system_prompt.content + if system_prompt and isinstance(system_prompt.content, str) + else structured_output_prompt + ) + system_prompt = SystemPromptMessage(content=system_prompt_content) + + # Extract content from the last user message + + filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)] + updated_prompt = [system_prompt] + filtered_prompts + + return updated_prompt + + +def _parse_structured_output(result_text: str) -> Mapping[str, Any]: + structured_output: Mapping[str, Any] = {} + parsed: Mapping[str, Any] = {} + try: + parsed = TypeAdapter(Mapping).validate_json(result_text) + if not isinstance(parsed, dict): + raise OutputParserError(f"Failed to parse structured output: {result_text}") + structured_output = parsed + except ValidationError: + # if the result_text is not a valid json, try to repair it + temp_parsed = json_repair.loads(result_text) + if not isinstance(temp_parsed, dict): + # handle reasoning model like deepseek-r1 got '\n\n\n' prefix + if isinstance(temp_parsed, list): + temp_parsed = next((item for item in temp_parsed if isinstance(item, dict)), {}) + else: + raise OutputParserError(f"Failed to parse structured output: {result_text}") + structured_output = cast(dict, temp_parsed) + return structured_output + + +def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping) -> dict: + """ + Prepare JSON schema based on model requirements. + + Different models have different requirements for JSON schema formatting. + This function handles these differences. + + :param schema: The original JSON schema + :return: Processed schema compatible with the current model + """ + + # Deep copy to avoid modifying the original schema + processed_schema = dict(deepcopy(schema)) + + # Convert boolean types to string types (common requirement) + convert_boolean_to_string(processed_schema) + + # Apply model-specific transformations + if SpecialModelType.GEMINI in model_schema.model: + remove_additional_properties(processed_schema) + return processed_schema + elif SpecialModelType.OLLAMA in provider: + return processed_schema + else: + # Default format with name field + return {"schema": processed_schema, "name": "llm_response"} + + +def remove_additional_properties(schema: dict) -> None: + """ + Remove additionalProperties fields from JSON schema. + Used for models like Gemini that don't support this property. + + :param schema: JSON schema to modify in-place + """ + if not isinstance(schema, dict): + return + + # Remove additionalProperties at current level + schema.pop("additionalProperties", None) + + # Process nested structures recursively + for value in schema.values(): + if isinstance(value, dict): + remove_additional_properties(value) + elif isinstance(value, list): + for item in value: + if isinstance(item, dict): + remove_additional_properties(item) + + +def convert_boolean_to_string(schema: dict) -> None: + """ + Convert boolean type specifications to string in JSON schema. + + :param schema: JSON schema to modify in-place + """ + if not isinstance(schema, dict): + return + + # Check for boolean type at current level + if schema.get("type") == "boolean": + schema["type"] = "string" + + # Process nested dictionaries and lists recursively + for value in schema.values(): + if isinstance(value, dict): + convert_boolean_to_string(value) + elif isinstance(value, list): + for item in value: + if isinstance(item, dict): + convert_boolean_to_string(item) diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index ddfa1e7a66..ef81e38dc5 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -291,3 +291,21 @@ Your task is to convert simple user descriptions into properly formatted JSON Sc Now, generate a JSON Schema based on my description """ # noqa: E501 + +STRUCTURED_OUTPUT_PROMPT = """You’re a helpful AI assistant. You could answer questions and output in JSON format. +constraints: + - You must output in JSON format. + - Do not output boolean value, use string type instead. + - Do not output integer or float value, use number type instead. +eg: + Here is the JSON schema: + {"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"} + + Here is the user's question: + My name is John Doe and I am 30 years old. + + output: + {"name": "John Doe", "age": 30} +Here is the JSON schema: +{{schema}} +""" # noqa: E501 diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py index de5a748d4f..e52b0eba55 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/core/model_runtime/entities/llm_entities.py @@ -1,7 +1,7 @@ -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from decimal import Decimal from enum import StrEnum -from typing import Optional +from typing import Any, Optional from pydantic import BaseModel, Field @@ -101,6 +101,20 @@ class LLMResult(BaseModel): system_fingerprint: Optional[str] = None +class LLMStructuredOutput(BaseModel): + """ + Model class for llm structured output. + """ + + structured_output: Optional[Mapping[str, Any]] = None + + +class LLMResultWithStructuredOutput(LLMResult, LLMStructuredOutput): + """ + Model class for llm result with structured output. + """ + + class LLMResultChunkDelta(BaseModel): """ Model class for llm result chunk delta. @@ -123,6 +137,12 @@ class LLMResultChunk(BaseModel): delta: LLMResultChunkDelta +class LLMResultChunkWithStructuredOutput(LLMResultChunk, LLMStructuredOutput): + """ + Model class for llm result chunk with structured output. + """ + + class NumTokensResult(PriceInfo): """ Model class for number of tokens result. diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index 072644e53b..d07ab3d0c4 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -2,8 +2,15 @@ import tempfile from binascii import hexlify, unhexlify from collections.abc import Generator +from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from core.model_manager import ModelManager -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, +) from core.model_runtime.entities.message_entities import ( PromptMessage, SystemPromptMessage, @@ -12,6 +19,7 @@ from core.model_runtime.entities.message_entities import ( from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from core.plugin.entities.request import ( RequestInvokeLLM, + RequestInvokeLLMWithStructuredOutput, RequestInvokeModeration, RequestInvokeRerank, RequestInvokeSpeech2Text, @@ -81,6 +89,72 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): return handle_non_streaming(response) + @classmethod + def invoke_llm_with_structured_output( + cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLMWithStructuredOutput + ): + """ + invoke llm with structured output + """ + model_instance = ModelManager().get_model_instance( + tenant_id=tenant.id, + provider=payload.provider, + model_type=payload.model_type, + model=payload.model, + ) + + model_schema = model_instance.model_type_instance.get_model_schema(payload.model, model_instance.credentials) + + if not model_schema: + raise ValueError(f"Model schema not found for {payload.model}") + + response = invoke_llm_with_structured_output( + provider=payload.provider, + model_schema=model_schema, + model_instance=model_instance, + prompt_messages=payload.prompt_messages, + json_schema=payload.structured_output_schema, + tools=payload.tools, + stop=payload.stop, + stream=True if payload.stream is None else payload.stream, + user=user_id, + model_parameters=payload.completion_params, + ) + + if isinstance(response, Generator): + + def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]: + for chunk in response: + if chunk.delta.usage: + llm_utils.deduct_llm_quota( + tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage + ) + chunk.prompt_messages = [] + yield chunk + + return handle() + else: + if response.usage: + llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage) + + def handle_non_streaming( + response: LLMResultWithStructuredOutput, + ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: + yield LLMResultChunkWithStructuredOutput( + model=response.model, + prompt_messages=[], + system_fingerprint=response.system_fingerprint, + structured_output=response.structured_output, + delta=LLMResultChunkDelta( + index=0, + message=response.message, + usage=response.usage, + finish_reason="", + ), + ) + + return handle_non_streaming(response) + @classmethod def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding): """ diff --git a/api/core/plugin/entities/parameters.py b/api/core/plugin/entities/parameters.py index 895dd0d0fc..2b438a3c33 100644 --- a/api/core/plugin/entities/parameters.py +++ b/api/core/plugin/entities/parameters.py @@ -10,6 +10,9 @@ from core.tools.entities.common_entities import I18nObject class PluginParameterOption(BaseModel): value: str = Field(..., description="The value of the option") label: I18nObject = Field(..., description="The label of the option") + icon: Optional[str] = Field( + default=None, description="The icon of the option, can be a url or a base64 encoded image" + ) @field_validator("value", mode="before") @classmethod @@ -35,6 +38,7 @@ class PluginParameterType(enum.StrEnum): APP_SELECTOR = CommonParameterType.APP_SELECTOR.value MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value + DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT.value # deprecated, should not use. SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index e0d2857e97..592b42c0da 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from datetime import datetime from enum import StrEnum from typing import Any, Generic, Optional, TypeVar @@ -9,6 +9,7 @@ from core.agent.plugin_entities import AgentProviderEntityWithPlugin from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.entities.provider_entities import ProviderEntity from core.plugin.entities.base import BasePluginEntity +from core.plugin.entities.parameters import PluginParameterOption from core.plugin.entities.plugin import PluginDeclaration, PluginEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin @@ -186,3 +187,7 @@ class PluginOAuthCredentialsResponse(BaseModel): class PluginListResponse(BaseModel): list: list[PluginEntity] total: int + + +class PluginDynamicSelectOptionsResponse(BaseModel): + options: Sequence[PluginParameterOption] = Field(description="The options of the dynamic select.") diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 1692020ec8..f9c81ed4d5 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -82,6 +82,16 @@ class RequestInvokeLLM(BaseRequestInvokeModel): return v +class RequestInvokeLLMWithStructuredOutput(RequestInvokeLLM): + """ + Request to invoke LLM with structured output + """ + + structured_output_schema: dict[str, Any] = Field( + default_factory=dict, description="The schema of the structured output in JSON schema format" + ) + + class RequestInvokeTextEmbedding(BaseRequestInvokeModel): """ Request to invoke text embedding diff --git a/api/core/plugin/impl/dynamic_select.py b/api/core/plugin/impl/dynamic_select.py new file mode 100644 index 0000000000..f4fb051ee1 --- /dev/null +++ b/api/core/plugin/impl/dynamic_select.py @@ -0,0 +1,45 @@ +from collections.abc import Mapping +from typing import Any + +from core.plugin.entities.plugin import GenericProviderID +from core.plugin.entities.plugin_daemon import PluginDynamicSelectOptionsResponse +from core.plugin.impl.base import BasePluginClient + + +class DynamicSelectClient(BasePluginClient): + def fetch_dynamic_select_options( + self, + tenant_id: str, + user_id: str, + plugin_id: str, + provider: str, + action: str, + credentials: Mapping[str, Any], + parameter: str, + ) -> PluginDynamicSelectOptionsResponse: + """ + Fetch dynamic select options for a plugin parameter. + """ + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/dynamic_select/fetch_parameter_options", + PluginDynamicSelectOptionsResponse, + data={ + "user_id": user_id, + "data": { + "provider": GenericProviderID(provider).provider_name, + "credentials": credentials, + "provider_action": action, + "parameter": parameter, + }, + }, + headers={ + "X-Plugin-ID": plugin_id, + "Content-Type": "application/json", + }, + ) + + for options in response: + return options + + raise ValueError("Plugin service returned no options") diff --git a/api/core/plugin/impl/oauth.py b/api/core/plugin/impl/oauth.py index 91774984c8..b006bf1d4b 100644 --- a/api/core/plugin/impl/oauth.py +++ b/api/core/plugin/impl/oauth.py @@ -1,3 +1,4 @@ +import binascii from collections.abc import Mapping from typing import Any @@ -16,7 +17,7 @@ class OAuthHandler(BasePluginClient): provider: str, system_credentials: Mapping[str, Any], ) -> PluginOAuthAuthorizationUrlResponse: - return self._request_with_plugin_daemon_response( + response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url", PluginOAuthAuthorizationUrlResponse, @@ -32,6 +33,9 @@ class OAuthHandler(BasePluginClient): "Content-Type": "application/json", }, ) + for resp in response: + return resp + raise ValueError("No response received from plugin daemon for authorization URL request.") def get_credentials( self, @@ -49,7 +53,7 @@ class OAuthHandler(BasePluginClient): # encode request to raw http request raw_request_bytes = self._convert_request_to_raw_data(request) - return self._request_with_plugin_daemon_response( + response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/oauth/get_credentials", PluginOAuthCredentialsResponse, @@ -58,7 +62,8 @@ class OAuthHandler(BasePluginClient): "data": { "provider": provider, "system_credentials": system_credentials, - "raw_request_bytes": raw_request_bytes, + # for json serialization + "raw_http_request": binascii.hexlify(raw_request_bytes).decode(), }, }, headers={ @@ -66,6 +71,9 @@ class OAuthHandler(BasePluginClient): "Content-Type": "application/json", }, ) + for resp in response: + return resp + raise ValueError("No response received from plugin daemon for authorization URL request.") def _convert_request_to_raw_data(self, request: Request) -> bytes: """ @@ -79,7 +87,7 @@ class OAuthHandler(BasePluginClient): """ # Start with the request line method = request.method - path = request.path + path = request.full_path protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1") raw_data = f"{method} {path} {protocol}\r\n".encode() diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index dca84b9041..9b90bd2bb3 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -76,6 +76,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) + with_keywords = False if with_keywords: keywords_list = kwargs.get("keywords_list") keyword = Keyword(dataset) @@ -91,6 +92,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): vector.delete_by_ids(node_ids) else: vector.delete() + with_keywords = False if with_keywords: keyword = Keyword(dataset) if node_ids: diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index e30538742a..cdec92aee7 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -16,6 +16,7 @@ from core.workflow.entities.workflow_execution import ( WorkflowType, ) from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import ( Account, CreatorUserRole, @@ -152,7 +153,11 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): db_model.version = domain_model.workflow_version db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None - db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None + db_model.outputs = ( + json.dumps(WorkflowRuntimeTypeConverter().to_json_encodable(domain_model.outputs)) + if domain_model.outputs + else None + ) db_model.status = domain_model.status db_model.error = domain_model.error_message if domain_model.error_message else None db_model.total_tokens = domain_model.total_tokens diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 2f27442616..797cce9354 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -19,6 +19,7 @@ from core.workflow.entities.workflow_node_execution import ( ) from core.workflow.nodes.enums import NodeType from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import ( Account, CreatorUserRole, @@ -146,6 +147,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) if not self._creator_user_role: raise ValueError("created_by_role is required in repository constructor") + json_converter = WorkflowRuntimeTypeConverter() db_model = WorkflowNodeExecutionModel() db_model.id = domain_model.id db_model.tenant_id = self._tenant_id @@ -160,9 +162,17 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) db_model.node_id = domain_model.node_id db_model.node_type = domain_model.node_type db_model.title = domain_model.title - db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None - db_model.process_data = json.dumps(domain_model.process_data) if domain_model.process_data else None - db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None + db_model.inputs = ( + json.dumps(json_converter.to_json_encodable(domain_model.inputs)) if domain_model.inputs else None + ) + db_model.process_data = ( + json.dumps(json_converter.to_json_encodable(domain_model.process_data)) + if domain_model.process_data + else None + ) + db_model.outputs = ( + json.dumps(json_converter.to_json_encodable(domain_model.outputs)) if domain_model.outputs else None + ) db_model.status = domain_model.status db_model.error = domain_model.error db_model.elapsed_time = domain_model.elapsed_time diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 03047c0545..d2c28076ae 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -240,6 +240,7 @@ class ToolParameter(PluginParameter): FILES = PluginParameterType.FILES.value APP_SELECTOR = PluginParameterType.APP_SELECTOR.value MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value + DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value # deprecated, should not use. SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 6a5fba65bd..1f23e90351 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -86,6 +86,7 @@ class ProviderConfigEncrypter(BaseModel): cached_credentials = cache.get() if cached_credentials: return cached_credentials + data = self._deep_copy(data) # get fields need to be decrypted fields = dict[str, BasicProviderConfig]() diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 64ba16c367..6cf09e0372 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -75,6 +75,20 @@ class StringSegment(Segment): class FloatSegment(Segment): value_type: SegmentType = SegmentType.NUMBER value: float + # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. + # The following tests cannot pass. + # + # def test_float_segment_and_nan(): + # nan = float("nan") + # assert nan != nan + # + # f1 = FloatSegment(value=float("nan")) + # f2 = FloatSegment(value=float("nan")) + # assert f1 != f2 + # + # f3 = FloatSegment(value=nan) + # f4 = FloatSegment(value=nan) + # assert f3 != f4 class IntegerSegment(Segment): diff --git a/api/core/variables/types.py b/api/core/variables/types.py index 4387e9693e..68d3d82883 100644 --- a/api/core/variables/types.py +++ b/api/core/variables/types.py @@ -18,3 +18,17 @@ class SegmentType(StrEnum): NONE = "none" GROUP = "group" + + def is_array_type(self): + return self in _ARRAY_TYPES + + +_ARRAY_TYPES = frozenset( + [ + SegmentType.ARRAY_ANY, + SegmentType.ARRAY_STRING, + SegmentType.ARRAY_NUMBER, + SegmentType.ARRAY_OBJECT, + SegmentType.ARRAY_FILE, + ] +) diff --git a/api/core/variables/utils.py b/api/core/variables/utils.py index e5d222af7d..692db3502e 100644 --- a/api/core/variables/utils.py +++ b/api/core/variables/utils.py @@ -1,8 +1,26 @@ +import json from collections.abc import Iterable, Sequence +from .segment_group import SegmentGroup +from .segments import ArrayFileSegment, FileSegment, Segment + def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]: selectors = [node_id, name] if paths: selectors.extend(paths) return selectors + + +class SegmentJSONEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, ArrayFileSegment): + return [v.model_dump() for v in o.value] + elif isinstance(o, FileSegment): + return o.value.model_dump() + elif isinstance(o, SegmentGroup): + return [self.default(seg) for seg in o.value] + elif isinstance(o, Segment): + return o.value + else: + super().default(o) diff --git a/api/core/workflow/conversation_variable_updater.py b/api/core/workflow/conversation_variable_updater.py new file mode 100644 index 0000000000..84e99bb582 --- /dev/null +++ b/api/core/workflow/conversation_variable_updater.py @@ -0,0 +1,39 @@ +import abc +from typing import Protocol + +from core.variables import Variable + + +class ConversationVariableUpdater(Protocol): + """ + ConversationVariableUpdater defines an abstraction for updating conversation variable values. + + It is intended for use by `v1.VariableAssignerNode` and `v2.VariableAssignerNode` when updating + conversation variables. + + Implementations may choose to batch updates. If batching is used, the `flush` method + should be implemented to persist buffered changes, and `update` + should handle buffering accordingly. + + Note: Since implementations may buffer updates, instances of ConversationVariableUpdater + are not thread-safe. Each VariableAssignerNode should create its own instance during execution. + """ + + @abc.abstractmethod + def update(self, conversation_id: str, variable: "Variable") -> None: + """ + Updates the value of the specified conversation variable in the underlying storage. + + :param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`. + :param variable: The `Variable` instance containing the updated value. + """ + pass + + @abc.abstractmethod + def flush(self): + """ + Flushes all pending updates to the underlying storage system. + + If the implementation does not buffer updates, this method can be a no-op. + """ + pass diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index af26864c01..80dda2632d 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -7,12 +7,12 @@ from pydantic import BaseModel, Field from core.file import File, FileAttribute, file_manager from core.variables import Segment, SegmentGroup, Variable +from core.variables.consts import MIN_SELECTORS_LENGTH from core.variables.segments import FileSegment, NoneSegment +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.workflow.enums import SystemVariableKey from factories import variable_factory -from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from ..enums import SystemVariableKey - VariableValue = Union[str, int, float, dict, list, File] VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") @@ -30,9 +30,11 @@ class VariablePool(BaseModel): # TODO: This user inputs is not used for pool. user_inputs: Mapping[str, Any] = Field( description="User inputs", + default_factory=dict, ) system_variables: Mapping[SystemVariableKey, Any] = Field( description="System variables", + default_factory=dict, ) environment_variables: Sequence[Variable] = Field( description="Environment variables.", @@ -43,28 +45,7 @@ class VariablePool(BaseModel): default_factory=list, ) - def __init__( - self, - *, - system_variables: Mapping[SystemVariableKey, Any] | None = None, - user_inputs: Mapping[str, Any] | None = None, - environment_variables: Sequence[Variable] | None = None, - conversation_variables: Sequence[Variable] | None = None, - **kwargs, - ): - environment_variables = environment_variables or [] - conversation_variables = conversation_variables or [] - user_inputs = user_inputs or {} - system_variables = system_variables or {} - - super().__init__( - system_variables=system_variables, - user_inputs=user_inputs, - environment_variables=environment_variables, - conversation_variables=conversation_variables, - **kwargs, - ) - + def model_post_init(self, context: Any, /) -> None: for key, value in self.system_variables.items(): self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value) # Add environment variables to the variable pool @@ -91,12 +72,12 @@ class VariablePool(BaseModel): Returns: None """ - if len(selector) < 2: + if len(selector) < MIN_SELECTORS_LENGTH: raise ValueError("Invalid selector") if isinstance(value, Variable): variable = value - if isinstance(value, Segment): + elif isinstance(value, Segment): variable = variable_factory.segment_to_variable(segment=value, selector=selector) else: segment = variable_factory.build_segment(value) @@ -118,7 +99,7 @@ class VariablePool(BaseModel): Raises: ValueError: If the selector is invalid. """ - if len(selector) < 2: + if len(selector) < MIN_SELECTORS_LENGTH: return None hash_key = hash(tuple(selector[1:])) diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 9a4939502e..e57e9e4d64 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -66,6 +66,8 @@ class BaseNodeEvent(GraphEngineEvent): """iteration id if node is in iteration""" in_loop_id: Optional[str] = None """loop id if node is in loop""" + # The version of the node, or "1" if not specified. + node_version: str = "1" class NodeRunStartedEvent(BaseNodeEvent): diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index ee2164f22f..61a7a26652 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -53,6 +53,7 @@ from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.utils import variable_utils from libs.flask_utils import preserve_flask_contexts from models.enums import UserFrom from models.workflow import WorkflowType @@ -314,6 +315,7 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + node_version=node_instance.version(), ) raise e @@ -627,6 +629,7 @@ class GraphEngine: parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, agent_strategy=agent_strategy, + node_version=node_instance.version(), ) max_retries = node_instance.node_data.retry_config.max_retries @@ -677,6 +680,7 @@ class GraphEngine: error=run_result.error or "Unknown error", retry_index=retries, start_at=retry_start_at, + node_version=node_instance.version(), ) time.sleep(retry_interval) break @@ -712,6 +716,7 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + node_version=node_instance.version(), ) should_continue_retry = False else: @@ -726,6 +731,7 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + node_version=node_instance.version(), ) should_continue_retry = False elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: @@ -786,6 +792,7 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + node_version=node_instance.version(), ) should_continue_retry = False @@ -803,6 +810,7 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + node_version=node_instance.version(), ) elif isinstance(event, RunRetrieverResourceEvent): yield NodeRunRetrieverResourceEvent( @@ -817,6 +825,7 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + node_version=node_instance.version(), ) except GenerateTaskStoppedError: # trigger node run failed event @@ -833,6 +842,7 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + node_version=node_instance.version(), ) return except Exception as e: @@ -847,16 +857,12 @@ class GraphEngine: :param variable_value: variable value :return: """ - self.graph_runtime_state.variable_pool.add([node_id] + variable_key_list, variable_value) - - # if variable_value is a dict, then recursively append variables - if isinstance(variable_value, dict): - for key, value in variable_value.items(): - # construct new key list - new_key_list = variable_key_list + [key] - self._append_variables_recursively( - node_id=node_id, variable_key_list=new_key_list, variable_value=value - ) + variable_utils.append_variables_recursively( + self.graph_runtime_state.variable_pool, + node_id, + variable_key_list, + variable_value, + ) def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: """ diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 22c564c1fc..2f28363955 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -39,6 +39,10 @@ class AgentNode(ToolNode): _node_data_cls = AgentNodeData # type: ignore _node_type = NodeType.AGENT + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> Generator: """ Run the agent node diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index aa030870e2..38c2bcbdf5 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -18,7 +18,11 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser class AnswerNode(BaseNode[AnswerNodeData]): _node_data_cls = AnswerNodeData - _node_type: NodeType = NodeType.ANSWER + _node_type = NodeType.ANSWER + + @classmethod + def version(cls) -> str: + return "1" def _run(self) -> NodeRunResult: """ @@ -45,7 +49,10 @@ class AnswerNode(BaseNode[AnswerNodeData]): part = cast(TextGenerateRouteChunk, part) answer += part.text - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer, "files": files}) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={"answer": answer, "files": ArrayFileSegment(value=files)}, + ) @classmethod def _extract_variable_selector_to_variable_mapping( diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index ba6ba16e36..f3e4a62ade 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -109,6 +109,7 @@ class AnswerStreamProcessor(StreamProcessor): parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, from_variable_selector=[answer_node_id, "answer"], + node_version=event.node_version, ) else: route_chunk = cast(VarGenerateRouteChunk, route_chunk) @@ -134,6 +135,7 @@ class AnswerStreamProcessor(StreamProcessor): route_node_state=event.route_node_state, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, + node_version=event.node_version, ) self.route_position[answer_node_id] += 1 diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 7da0c19740..6973401429 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -1,7 +1,7 @@ import logging from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -23,7 +23,7 @@ GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData) class BaseNode(Generic[GenericNodeData]): _node_data_cls: type[GenericNodeData] - _node_type: NodeType + _node_type: ClassVar[NodeType] def __init__( self, @@ -90,8 +90,38 @@ class BaseNode(Generic[GenericNodeData]): graph_config: Mapping[str, Any], config: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping + """Extracts references variable selectors from node configuration. + + The `config` parameter represents the configuration for a specific node type and corresponds + to the `data` field in the node definition object. + + The returned mapping has the following structure: + + {'1747829548239.#1747829667553.result#': ['1747829667553', 'result']} + + For loop and iteration nodes, the mapping may look like this: + + { + "1748332301644.input_selector": ["1748332363630", "result"], + "1748332325079.1748332325079.#sys.workflow_id#": ["sys", "workflow_id"], + } + + where `1748332301644` is the ID of the loop / iteration node, + and `1748332325079` is the ID of the node inside the loop or iteration node. + + Here, the key consists of two parts: the current node ID (provided as the `node_id` + parameter to `_extract_variable_selector_to_variable_mapping`) and the variable selector, + enclosed in `#` symbols. These two parts are separated by a dot (`.`). + + The value is a list of string representing the variable selector, where the first element is the node ID + of the referenced variable, and the second element is the variable name within that node. + + The meaning of the above response is: + + The node with ID `1747829548239` references the variable `result` from the node with + ID `1747829667553`. For example, if `1747829548239` is a LLM node, its prompt may contain a + reference to the `result` output variable of node `1747829667553`. + :param graph_config: graph config :param config: node config :return: @@ -101,9 +131,10 @@ class BaseNode(Generic[GenericNodeData]): raise ValueError("Node ID is required when extracting variable selector to variable mapping.") node_data = cls._node_data_cls(**config.get("data", {})) - return cls._extract_variable_selector_to_variable_mapping( + data = cls._extract_variable_selector_to_variable_mapping( graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data) ) + return data @classmethod def _extract_variable_selector_to_variable_mapping( @@ -139,6 +170,16 @@ class BaseNode(Generic[GenericNodeData]): """ return self._node_type + @classmethod + @abstractmethod + def version(cls) -> str: + """`node_version` returns the version of current node type.""" + # NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`. + # + # If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING` + # in `api/core/workflow/nodes/__init__.py`. + raise NotImplementedError("subclasses of BaseNode must implement `version` method.") + @property def should_continue_on_error(self) -> bool: """judge if should continue on error diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 61c08a7d71..22ed9e2651 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -40,6 +40,10 @@ class CodeNode(BaseNode[CodeNodeData]): return code_provider.get_default_config() + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: # Get code language code_language = self.node_data.code_language @@ -126,6 +130,9 @@ class CodeNode(BaseNode[CodeNodeData]): prefix: str = "", depth: int = 1, ): + # TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes. + # Note that `_transform_result` may produce lists containing `None` values, + # which don't conform to the type requirements of `Array*Segment` classes. if depth > dify_config.CODE_MAX_DEPTH: raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.") diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 429fed2d04..8e6150f9cc 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -24,7 +24,7 @@ from configs import dify_config from core.file import File, FileTransferMethod, file_manager from core.helper import ssrf_proxy from core.variables import ArrayFileSegment -from core.variables.segments import FileSegment +from core.variables.segments import ArrayStringSegment, FileSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode @@ -45,6 +45,10 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): _node_data_cls = DocumentExtractorNodeData _node_type = NodeType.DOCUMENT_EXTRACTOR + @classmethod + def version(cls) -> str: + return "1" + def _run(self): variable_selector = self.node_data.variable_selector variable = self.graph_runtime_state.variable_pool.get(variable_selector) @@ -67,7 +71,7 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, process_data=process_data, - outputs={"text": extracted_text_list}, + outputs={"text": ArrayStringSegment(value=extracted_text_list)}, ) elif isinstance(value, File): extracted_text = _extract_text_from_file(value) @@ -447,7 +451,7 @@ def _extract_text_from_excel(file_content: bytes) -> str: df = df.applymap(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x) # type: ignore # Combine multi-line text in column names into a single line - df.columns = pd.Index([" ".join(col.splitlines()) for col in df.columns]) + df.columns = pd.Index([" ".join(str(col).splitlines()) for col in df.columns]) # Manually construct the Markdown table markdown_table += _construct_markdown_table(df) + "\n\n" diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 0e9756b243..17a0b3adeb 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -9,6 +9,10 @@ class EndNode(BaseNode[EndNodeData]): _node_data_cls = EndNodeData _node_type = NodeType.END + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: """ Run node diff --git a/api/core/workflow/nodes/end/end_stream_processor.py b/api/core/workflow/nodes/end/end_stream_processor.py index 3ae5af7137..a6fb2ffc18 100644 --- a/api/core/workflow/nodes/end/end_stream_processor.py +++ b/api/core/workflow/nodes/end/end_stream_processor.py @@ -139,6 +139,7 @@ class EndStreamProcessor(StreamProcessor): route_node_state=event.route_node_state, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, + node_version=event.node_version, ) self.route_position[end_node_id] += 1 diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 6b1ac57c06..971e0f73e7 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -6,6 +6,7 @@ from typing import Any, Optional from configs import dify_config from core.file import File, FileTransferMethod from core.tools.tool_file_manager import ToolFileManager +from core.variables.segments import ArrayFileSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -60,6 +61,10 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): }, } + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: process_data = {} try: @@ -92,7 +97,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={ "status_code": response.status_code, - "body": response.text if not files else "", + "body": response.text if not files.value else "", "headers": response.headers, "files": files, }, @@ -166,7 +171,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): return mapping - def extract_files(self, url: str, response: Response) -> list[File]: + def extract_files(self, url: str, response: Response) -> ArrayFileSegment: """ Extract files from response by checking both Content-Type header and URL """ @@ -178,7 +183,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): content_disposition_type = None if not is_file: - return files + return ArrayFileSegment(value=[]) if parsed_content_disposition: content_disposition_filename = parsed_content_disposition.get_filename() @@ -211,4 +216,4 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): ) files.append(file) - return files + return ArrayFileSegment(value=files) diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 976922f75d..22b748030c 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -1,4 +1,5 @@ -from typing import Literal +from collections.abc import Mapping, Sequence +from typing import Any, Literal from typing_extensions import deprecated @@ -16,6 +17,10 @@ class IfElseNode(BaseNode[IfElseNodeData]): _node_data_cls = IfElseNodeData _node_type = NodeType.IF_ELSE + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: """ Run node @@ -87,6 +92,22 @@ class IfElseNode(BaseNode[IfElseNodeData]): return data + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: IfElseNodeData, + ) -> Mapping[str, Sequence[str]]: + var_mapping: dict[str, list[str]] = {} + for case in node_data.cases or []: + for condition in case.conditions: + key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector)) + var_mapping[key] = condition.variable_selector + + return var_mapping + @deprecated("This function is deprecated. You should use the new cases structure.") def _should_not_use_old_function( diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 42b6795fb0..151efc28ec 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -11,6 +11,7 @@ from flask import Flask, current_app from configs import dify_config from core.variables import ArrayVariable, IntegerVariable, NoneVariable +from core.variables.segments import ArrayAnySegment, ArraySegment from core.workflow.entities.node_entities import ( NodeRunResult, ) @@ -37,6 +38,7 @@ from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData +from factories.variable_factory import build_segment from libs.flask_utils import preserve_flask_contexts from .exc import ( @@ -72,6 +74,10 @@ class IterationNode(BaseNode[IterationNodeData]): }, } + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: """ Run the node. @@ -85,10 +91,17 @@ class IterationNode(BaseNode[IterationNodeData]): raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") if isinstance(variable, NoneVariable) or len(variable.value) == 0: + # Try our best to preserve the type informat. + if isinstance(variable, ArraySegment): + output = variable.model_copy(update={"value": []}) + else: + output = ArrayAnySegment(value=[]) yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"output": []}, + # TODO(QuantumGhost): is it possible to compute the type of `output` + # from graph definition? + outputs={"output": output}, ) ) return @@ -231,6 +244,7 @@ class IterationNode(BaseNode[IterationNodeData]): # Flatten the list of lists if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs): outputs = [item for sublist in outputs for item in sublist] + output_segment = build_segment(outputs) yield IterationRunSucceededEvent( iteration_id=self.id, @@ -247,7 +261,7 @@ class IterationNode(BaseNode[IterationNodeData]): yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"output": outputs}, + outputs={"output": output_segment}, metadata={ WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py index bee481ebdb..9900aa225d 100644 --- a/api/core/workflow/nodes/iteration/iteration_start_node.py +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -13,6 +13,10 @@ class IterationStartNode(BaseNode[IterationStartNodeData]): _node_data_cls = IterationStartNodeData _node_type = NodeType.ITERATION_START + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: """ Run the node. diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 5cf5848d54..0b9e98f28a 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -24,6 +24,7 @@ from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.variables import StringSegment +from core.variables.segments import ArrayObjectSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.enums import NodeType @@ -70,6 +71,10 @@ class KnowledgeRetrievalNode(LLMNode): _node_data_cls = KnowledgeRetrievalNodeData # type: ignore _node_type = NodeType.KNOWLEDGE_RETRIEVAL + @classmethod + def version(cls): + return "1" + def _run(self) -> NodeRunResult: # type: ignore node_data = cast(KnowledgeRetrievalNodeData, self.node_data) # extract variables @@ -115,9 +120,12 @@ class KnowledgeRetrievalNode(LLMNode): # retrieve knowledge try: results = self._fetch_dataset_retriever(node_data=node_data, query=query) - outputs = {"result": results} + outputs = {"result": ArrayObjectSegment(value=results)} return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variables, + process_data=None, + outputs=outputs, # type: ignore ) except KnowledgeRetrievalNodeError as e: diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index e698d3f5d8..3c9ba44cf1 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -3,6 +3,7 @@ from typing import Any, Literal, Union from core.file import File from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment +from core.variables.segments import ArrayAnySegment, ArraySegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode @@ -16,6 +17,10 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): _node_data_cls = ListOperatorNodeData _node_type = NodeType.LIST_OPERATOR + @classmethod + def version(cls) -> str: + return "1" + def _run(self): inputs: dict[str, list] = {} process_data: dict[str, list] = {} @@ -30,7 +35,11 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): if not variable.value: inputs = {"variable": []} process_data = {"variable": []} - outputs = {"result": [], "first_record": None, "last_record": None} + if isinstance(variable, ArraySegment): + result = variable.model_copy(update={"value": []}) + else: + result = ArrayAnySegment(value=[]) + outputs = {"result": result, "first_record": None, "last_record": None} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, @@ -71,7 +80,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): variable = self._apply_slice(variable) outputs = { - "result": variable.value, + "result": variable, "first_record": variable.value[0] if variable.value else None, "last_record": variable.value[-1] if variable.value else None, } diff --git a/api/core/workflow/nodes/llm/file_saver.py b/api/core/workflow/nodes/llm/file_saver.py index c85baade03..a4b45ce652 100644 --- a/api/core/workflow/nodes/llm/file_saver.py +++ b/api/core/workflow/nodes/llm/file_saver.py @@ -119,9 +119,6 @@ class FileSaverImpl(LLMFileSaver): size=len(data), related_id=tool_file.id, url=url, - # TODO(QuantumGhost): how should I set the following key? - # What's the difference between `remote_url` and `url`? - # What's the purpose of `storage_key` and `dify_model_identity`? storage_key=tool_file.file_key, ) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index d27124d62c..b5225ce548 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -5,11 +5,11 @@ import logging from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Optional, cast -import json_repair - from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import FileType, file_manager from core.helper.code_executor import CodeExecutor, CodeLanguage +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities import ( @@ -18,7 +18,13 @@ from core.model_runtime.entities import ( PromptMessageContentType, TextPromptMessageContent, ) -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage +from core.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkWithStructuredOutput, + LLMStructuredOutput, + LLMUsage, +) from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessageContentUnionTypes, @@ -31,7 +37,6 @@ from core.model_runtime.entities.model_entities import ( ModelFeature, ModelPropertyKey, ModelType, - ParameterRule, ) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder @@ -62,11 +67,6 @@ from core.workflow.nodes.event import ( RunRetrieverResourceEvent, RunStreamChunkEvent, ) -from core.workflow.utils.structured_output.entities import ( - ResponseFormat, - SpecialModelType, -) -from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT from core.workflow.utils.variable_template_parser import VariableTemplateParser from . import llm_utils @@ -138,13 +138,11 @@ class LLMNode(BaseNode[LLMNodeData]): ) self._llm_file_saver = llm_file_saver - def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: - def process_structured_output(text: str) -> Optional[dict[str, Any]]: - """Process structured output if enabled""" - if not self.node_data.structured_output_enabled or not self.node_data.structured_output: - return None - return self._parse_structured_output(text) + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: node_inputs: Optional[dict[str, Any]] = None process_data = None result_text = "" @@ -240,6 +238,8 @@ class LLMNode(BaseNode[LLMNodeData]): stop=stop, ) + structured_output: LLMStructuredOutput | None = None + for event in generator: if isinstance(event, RunStreamChunkEvent): yield event @@ -250,12 +250,14 @@ class LLMNode(BaseNode[LLMNodeData]): # deduct quota llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) break + elif isinstance(event, LLMStructuredOutput): + structured_output = event + outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} - structured_output = process_structured_output(result_text) if structured_output: - outputs["structured_output"] = structured_output + outputs["structured_output"] = structured_output.structured_output if self._file_outputs is not None: - outputs["files"] = self._file_outputs + outputs["files"] = ArrayFileSegment(value=self._file_outputs) yield RunCompletedEvent( run_result=NodeRunResult( @@ -298,20 +300,40 @@ class LLMNode(BaseNode[LLMNodeData]): model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], stop: Optional[Sequence[str]] = None, - ) -> Generator[NodeEvent, None, None]: - invoke_result = model_instance.invoke_llm( - prompt_messages=list(prompt_messages), - model_parameters=node_data_model.completion_params, - stop=list(stop or []), - stream=True, - user=self.user_id, + ) -> Generator[NodeEvent | LLMStructuredOutput, None, None]: + model_schema = model_instance.model_type_instance.get_model_schema( + node_data_model.name, model_instance.credentials ) + if not model_schema: + raise ValueError(f"Model schema not found for {node_data_model.name}") + + if self.node_data.structured_output_enabled: + output_schema = self._fetch_structured_output_schema() + invoke_result = invoke_llm_with_structured_output( + provider=model_instance.provider, + model_schema=model_schema, + model_instance=model_instance, + prompt_messages=prompt_messages, + json_schema=output_schema, + model_parameters=node_data_model.completion_params, + stop=list(stop or []), + stream=True, + user=self.user_id, + ) + else: + invoke_result = model_instance.invoke_llm( + prompt_messages=list(prompt_messages), + model_parameters=node_data_model.completion_params, + stop=list(stop or []), + stream=True, + user=self.user_id, + ) return self._handle_invoke_result(invoke_result=invoke_result) def _handle_invoke_result( - self, invoke_result: LLMResult | Generator[LLMResultChunk, None, None] - ) -> Generator[NodeEvent, None, None]: + self, invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None] + ) -> Generator[NodeEvent | LLMStructuredOutput, None, None]: # For blocking mode if isinstance(invoke_result, LLMResult): event = self._handle_blocking_result(invoke_result=invoke_result) @@ -325,23 +347,32 @@ class LLMNode(BaseNode[LLMNodeData]): usage = LLMUsage.empty_usage() finish_reason = None full_text_buffer = io.StringIO() - for result in invoke_result: - contents = result.delta.message.content - for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents): - full_text_buffer.write(text_part) - yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[self.node_id, "text"]) + # Consume the invoke result and handle generator exception + try: + for result in invoke_result: + if isinstance(result, LLMResultChunkWithStructuredOutput): + yield result + if isinstance(result, LLMResultChunk): + contents = result.delta.message.content + for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents): + full_text_buffer.write(text_part) + yield RunStreamChunkEvent( + chunk_content=text_part, from_variable_selector=[self.node_id, "text"] + ) - # Update the whole metadata - if not model and result.model: - model = result.model - if len(prompt_messages) == 0: - # TODO(QuantumGhost): it seems that this update has no visable effect. - # What's the purpose of the line below? - prompt_messages = list(result.prompt_messages) - if usage.prompt_tokens == 0 and result.delta.usage: - usage = result.delta.usage - if finish_reason is None and result.delta.finish_reason: - finish_reason = result.delta.finish_reason + # Update the whole metadata + if not model and result.model: + model = result.model + if len(prompt_messages) == 0: + # TODO(QuantumGhost): it seems that this update has no visable effect. + # What's the purpose of the line below? + prompt_messages = list(result.prompt_messages) + if usage.prompt_tokens == 0 and result.delta.usage: + usage = result.delta.usage + if finish_reason is None and result.delta.finish_reason: + finish_reason = result.delta.finish_reason + except OutputParserError as e: + raise LLMNodeError(f"Failed to parse structured output: {e}") yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason) @@ -518,12 +549,6 @@ class LLMNode(BaseNode[LLMNodeData]): if not model_schema: raise ModelNotExistError(f"Model {node_data_model.name} not exist.") - if self.node_data.structured_output_enabled: - if model_schema.support_structure_output: - completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules) - else: - # Set appropriate response format based on model capabilities - self._set_response_format(completion_params, model_schema.parameter_rules) model_config_with_cred.parameters = completion_params # NOTE(-LAN-): This line modify the `self.node_data.model`, which is used in `_invoke_llm()`. node_data_model.completion_params = completion_params @@ -715,32 +740,8 @@ class LLMNode(BaseNode[LLMNodeData]): ) if not model_schema: raise ModelNotExistError(f"Model {model_config.model} not exist.") - if self.node_data.structured_output_enabled: - if not model_schema.support_structure_output: - filtered_prompt_messages = self._handle_prompt_based_schema( - prompt_messages=filtered_prompt_messages, - ) return filtered_prompt_messages, model_config.stop - def _parse_structured_output(self, result_text: str) -> dict[str, Any]: - structured_output: dict[str, Any] = {} - try: - parsed = json.loads(result_text) - if not isinstance(parsed, dict): - raise LLMNodeError(f"Failed to parse structured output: {result_text}") - structured_output = parsed - except json.JSONDecodeError as e: - # if the result_text is not a valid json, try to repair it - parsed = json_repair.loads(result_text) - if not isinstance(parsed, dict): - # handle reasoning model like deepseek-r1 got '\n\n\n' prefix - if isinstance(parsed, list): - parsed = next((item for item in parsed if isinstance(item, dict)), {}) - else: - raise LLMNodeError(f"Failed to parse structured output: {result_text}") - structured_output = parsed - return structured_output - @classmethod def _extract_variable_selector_to_variable_mapping( cls, @@ -930,104 +931,6 @@ class LLMNode(BaseNode[LLMNodeData]): self._file_outputs.append(saved_file) return saved_file - def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict: - """ - Handle structured output for models with native JSON schema support. - - :param model_parameters: Model parameters to update - :param rules: Model parameter rules - :return: Updated model parameters with JSON schema configuration - """ - # Process schema according to model requirements - schema = self._fetch_structured_output_schema() - schema_json = self._prepare_schema_for_model(schema) - - # Set JSON schema in parameters - model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False) - - # Set appropriate response format if required by the model - for rule in rules: - if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options: - model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value - - return model_parameters - - def _handle_prompt_based_schema(self, prompt_messages: Sequence[PromptMessage]) -> list[PromptMessage]: - """ - Handle structured output for models without native JSON schema support. - This function modifies the prompt messages to include schema-based output requirements. - - Args: - prompt_messages: Original sequence of prompt messages - - Returns: - list[PromptMessage]: Updated prompt messages with structured output requirements - """ - # Convert schema to string format - schema_str = json.dumps(self._fetch_structured_output_schema(), ensure_ascii=False) - - # Find existing system prompt with schema placeholder - system_prompt = next( - (prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)), - None, - ) - structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str) - # Prepare system prompt content - system_prompt_content = ( - structured_output_prompt + "\n\n" + system_prompt.content - if system_prompt and isinstance(system_prompt.content, str) - else structured_output_prompt - ) - system_prompt = SystemPromptMessage(content=system_prompt_content) - - # Extract content from the last user message - - filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)] - updated_prompt = [system_prompt] + filtered_prompts - - return updated_prompt - - def _set_response_format(self, model_parameters: dict, rules: list) -> None: - """ - Set the appropriate response format parameter based on model rules. - - :param model_parameters: Model parameters to update - :param rules: Model parameter rules - """ - for rule in rules: - if rule.name == "response_format": - if ResponseFormat.JSON.value in rule.options: - model_parameters["response_format"] = ResponseFormat.JSON.value - elif ResponseFormat.JSON_OBJECT.value in rule.options: - model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value - - def _prepare_schema_for_model(self, schema: dict) -> dict: - """ - Prepare JSON schema based on model requirements. - - Different models have different requirements for JSON schema formatting. - This function handles these differences. - - :param schema: The original JSON schema - :return: Processed schema compatible with the current model - """ - - # Deep copy to avoid modifying the original schema - processed_schema = schema.copy() - - # Convert boolean types to string types (common requirement) - convert_boolean_to_string(processed_schema) - - # Apply model-specific transformations - if SpecialModelType.GEMINI in self.node_data.model.name: - remove_additional_properties(processed_schema) - return processed_schema - elif SpecialModelType.OLLAMA in self.node_data.model.provider: - return processed_schema - else: - # Default format with name field - return {"schema": processed_schema, "name": "llm_response"} - def _fetch_model_schema(self, provider: str) -> AIModelEntity | None: """ Fetch model schema @@ -1239,49 +1142,3 @@ def _handle_completion_template( ) prompt_messages.append(prompt_message) return prompt_messages - - -def remove_additional_properties(schema: dict) -> None: - """ - Remove additionalProperties fields from JSON schema. - Used for models like Gemini that don't support this property. - - :param schema: JSON schema to modify in-place - """ - if not isinstance(schema, dict): - return - - # Remove additionalProperties at current level - schema.pop("additionalProperties", None) - - # Process nested structures recursively - for value in schema.values(): - if isinstance(value, dict): - remove_additional_properties(value) - elif isinstance(value, list): - for item in value: - if isinstance(item, dict): - remove_additional_properties(item) - - -def convert_boolean_to_string(schema: dict) -> None: - """ - Convert boolean type specifications to string in JSON schema. - - :param schema: JSON schema to modify in-place - """ - if not isinstance(schema, dict): - return - - # Check for boolean type at current level - if schema.get("type") == "boolean": - schema["type"] = "string" - - # Process nested dictionaries and lists recursively - for value in schema.values(): - if isinstance(value, dict): - convert_boolean_to_string(value) - elif isinstance(value, list): - for item in value: - if isinstance(item, dict): - convert_boolean_to_string(item) diff --git a/api/core/workflow/nodes/loop/loop_end_node.py b/api/core/workflow/nodes/loop/loop_end_node.py index 327b9e234b..b144021bab 100644 --- a/api/core/workflow/nodes/loop/loop_end_node.py +++ b/api/core/workflow/nodes/loop/loop_end_node.py @@ -13,6 +13,10 @@ class LoopEndNode(BaseNode[LoopEndNodeData]): _node_data_cls = LoopEndNodeData _node_type = NodeType.LOOP_END + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: """ Run the node. diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index fafa205386..368d662a75 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -54,6 +54,10 @@ class LoopNode(BaseNode[LoopNodeData]): _node_data_cls = LoopNodeData _node_type = NodeType.LOOP + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: """Run the node.""" # Get inputs @@ -482,6 +486,13 @@ class LoopNode(BaseNode[LoopNodeData]): variable_mapping.update(sub_node_variable_mapping) + for loop_variable in node_data.loop_variables or []: + if loop_variable.value_type == "variable": + assert loop_variable.value is not None, "Loop variable value must be provided for variable type" + # add loop variable to variable mapping + selector = loop_variable.value + variable_mapping[f"{node_id}.{loop_variable.label}"] = selector + # remove variable out from loop variable_mapping = { key: value for key, value in variable_mapping.items() if value[0] not in loop_graph.node_ids diff --git a/api/core/workflow/nodes/loop/loop_start_node.py b/api/core/workflow/nodes/loop/loop_start_node.py index 5a15f36044..f5e38b7516 100644 --- a/api/core/workflow/nodes/loop/loop_start_node.py +++ b/api/core/workflow/nodes/loop/loop_start_node.py @@ -13,6 +13,10 @@ class LoopStartNode(BaseNode[LoopStartNodeData]): _node_data_cls = LoopStartNodeData _node_type = NodeType.LOOP_START + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: """ Run the node. diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py index 1f1be59542..67cc884f20 100644 --- a/api/core/workflow/nodes/node_mapping.py +++ b/api/core/workflow/nodes/node_mapping.py @@ -25,6 +25,11 @@ from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as Var LATEST_VERSION = "latest" +# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode. +# Specifically, if you have introduced new node types, you should add them here. +# +# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__` +# hook. Try to avoid duplication of node information. NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = { NodeType.START: { LATEST_VERSION: StartNode, diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index 369eb13b04..916778d167 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -7,6 +7,10 @@ from core.workflow.nodes.base import BaseNodeData from core.workflow.nodes.llm import ModelConfig, VisionConfig +class _ParameterConfigError(Exception): + pass + + class ParameterConfig(BaseModel): """ Parameter Config. @@ -27,6 +31,19 @@ class ParameterConfig(BaseModel): raise ValueError("Invalid parameter name, __reason and __is_success are reserved") return str(value) + def is_array_type(self) -> bool: + return self.type in ("array[string]", "array[number]", "array[object]") + + def element_type(self) -> Literal["string", "number", "object"]: + if self.type == "array[number]": + return "number" + elif self.type == "array[string]": + return "string" + elif self.type == "array[object]": + return "object" + else: + raise _ParameterConfigError(f"{self.type} is not array type.") + class ParameterExtractorNodeData(BaseNodeData): """ diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 2552784762..8d6c2d0a5c 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -25,6 +25,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.variables.types import SegmentType from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -32,6 +33,7 @@ from core.workflow.nodes.base.node import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.llm import ModelConfig, llm_utils from core.workflow.utils import variable_template_parser +from factories.variable_factory import build_segment_with_type from .entities import ParameterExtractorNodeData from .exc import ( @@ -109,6 +111,10 @@ class ParameterExtractorNode(BaseNode): } } + @classmethod + def version(cls) -> str: + return "1" + def _run(self): """ Run the node. @@ -584,28 +590,30 @@ class ParameterExtractorNode(BaseNode): elif parameter.type in {"string", "select"}: if isinstance(result[parameter.name], str): transformed_result[parameter.name] = result[parameter.name] - elif parameter.type.startswith("array"): + elif parameter.is_array_type(): if isinstance(result[parameter.name], list): - nested_type = parameter.type[6:-1] - transformed_result[parameter.name] = [] + nested_type = parameter.element_type() + assert nested_type is not None + segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[]) + transformed_result[parameter.name] = segment_value for item in result[parameter.name]: if nested_type == "number": if isinstance(item, int | float): - transformed_result[parameter.name].append(item) + segment_value.value.append(item) elif isinstance(item, str): try: if "." in item: - transformed_result[parameter.name].append(float(item)) + segment_value.value.append(float(item)) else: - transformed_result[parameter.name].append(int(item)) + segment_value.value.append(int(item)) except ValueError: pass elif nested_type == "string": if isinstance(item, str): - transformed_result[parameter.name].append(item) + segment_value.value.append(item) elif nested_type == "object": if isinstance(item, dict): - transformed_result[parameter.name].append(item) + segment_value.value.append(item) if parameter.name not in transformed_result: if parameter.type == "number": @@ -615,7 +623,9 @@ class ParameterExtractorNode(BaseNode): elif parameter.type in {"string", "select"}: transformed_result[parameter.name] = "" elif parameter.type.startswith("array"): - transformed_result[parameter.name] = [] + transformed_result[parameter.name] = build_segment_with_type( + segment_type=SegmentType(parameter.type), value=[] + ) return transformed_result diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 1f50700c7e..a518167cc6 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -40,6 +40,10 @@ class QuestionClassifierNode(LLMNode): _node_data_cls = QuestionClassifierNodeData # type: ignore _node_type = NodeType.QUESTION_CLASSIFIER + @classmethod + def version(cls): + return "1" + def _run(self): node_data = cast(QuestionClassifierNodeData, self.node_data) variable_pool = self.graph_runtime_state.variable_pool diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 8839aec9d6..5ee9bc331f 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -10,6 +10,10 @@ class StartNode(BaseNode[StartNodeData]): _node_data_cls = StartNodeData _node_type = NodeType.START + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) system_inputs = self.graph_runtime_state.variable_pool.system_variables @@ -18,5 +22,6 @@ class StartNode(BaseNode[StartNodeData]): # Set system variables as node outputs. for var in system_inputs: node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] + outputs = dict(node_inputs) - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=node_inputs) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs) diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 476cf7eee4..ba573074c3 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -28,6 +28,10 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]): "config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"}, } + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: # Get variables variables = {} diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index aaecc7b989..aa15d69931 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -12,7 +12,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.errors import ToolInvokeError from core.tools.tool_engine import ToolEngine from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.variables.segments import ArrayAnySegment +from core.variables.segments import ArrayAnySegment, ArrayFileSegment from core.variables.variables import ArrayAnyVariable from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool @@ -44,6 +44,10 @@ class ToolNode(BaseNode[ToolNodeData]): _node_data_cls = ToolNodeData _node_type = NodeType.TOOL + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> Generator: """ Run the tool node @@ -300,6 +304,7 @@ class ToolNode(BaseNode[ToolNodeData]): variables[variable_name] = variable_value elif message.type == ToolInvokeMessage.MessageType.FILE: assert message.meta is not None + assert isinstance(message.meta, File) files.append(message.meta["file"]) elif message.type == ToolInvokeMessage.MessageType.LOG: assert isinstance(message.message, ToolInvokeMessage.LogMessage) @@ -363,7 +368,7 @@ class ToolNode(BaseNode[ToolNodeData]): yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"text": text, "files": files, "json": json, **variables}, + outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json, **variables}, metadata={ **agent_execution_metadata, WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index db3e25b015..96bb3e793a 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,3 +1,6 @@ +from collections.abc import Mapping + +from core.variables.segments import Segment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode @@ -9,16 +12,20 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): _node_data_cls = VariableAssignerNodeData _node_type = NodeType.VARIABLE_AGGREGATOR + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: # Get variables - outputs = {} + outputs: dict[str, Segment | Mapping[str, Segment]] = {} inputs = {} if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled: for selector in self.node_data.variables: variable = self.graph_runtime_state.variable_pool.get(selector) if variable is not None: - outputs = {"output": variable.to_object()} + outputs = {"output": variable} inputs = {".".join(selector[1:]): variable.to_object()} break @@ -28,7 +35,7 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): variable = self.graph_runtime_state.variable_pool.get(selector) if variable is not None: - outputs[group.group_name] = {"output": variable.to_object()} + outputs[group.group_name] = {"output": variable} inputs[".".join(selector[1:])] = variable.to_object() break diff --git a/api/core/workflow/nodes/variable_assigner/common/helpers.py b/api/core/workflow/nodes/variable_assigner/common/helpers.py index 8031b57fa8..0d2822233e 100644 --- a/api/core/workflow/nodes/variable_assigner/common/helpers.py +++ b/api/core/workflow/nodes/variable_assigner/common/helpers.py @@ -1,19 +1,55 @@ -from sqlalchemy import select -from sqlalchemy.orm import Session +from collections.abc import Mapping, MutableMapping, Sequence +from typing import Any, TypeVar -from core.variables import Variable -from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from extensions.ext_database import db -from models import ConversationVariable +from pydantic import BaseModel + +from core.variables import Segment +from core.variables.consts import MIN_SELECTORS_LENGTH +from core.variables.types import SegmentType + +# Use double underscore (`__`) prefix for internal variables +# to minimize risk of collision with user-defined variable names. +_UPDATED_VARIABLES_KEY = "__updated_variables" -def update_conversation_variable(conversation_id: str, variable: Variable): - stmt = select(ConversationVariable).where( - ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id +class UpdatedVariable(BaseModel): + name: str + selector: Sequence[str] + value_type: SegmentType + new_value: Any + + +_T = TypeVar("_T", bound=MutableMapping[str, Any]) + + +def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable: + if len(selector) < MIN_SELECTORS_LENGTH: + raise Exception("selector too short") + node_id, var_name = selector[:2] + return UpdatedVariable( + name=var_name, + selector=list(selector[:2]), + value_type=seg.value_type, + new_value=seg.value, ) - with Session(db.engine) as session: - row = session.scalar(stmt) - if not row: - raise VariableOperatorNodeError("conversation variable not found in the database") - row.data = variable.model_dump_json() - session.commit() + + +def set_updated_variables(m: _T, updates: Sequence[UpdatedVariable]) -> _T: + m[_UPDATED_VARIABLES_KEY] = updates + return m + + +def get_updated_variables(m: Mapping[str, Any]) -> Sequence[UpdatedVariable] | None: + updated_values = m.get(_UPDATED_VARIABLES_KEY, None) + if updated_values is None: + return None + result = [] + for items in updated_values: + if isinstance(items, UpdatedVariable): + result.append(items) + elif isinstance(items, dict): + items = UpdatedVariable.model_validate(items) + result.append(items) + else: + raise TypeError(f"Invalid updated variable: {items}, type={type(items)}") + return result diff --git a/api/core/workflow/nodes/variable_assigner/common/impl.py b/api/core/workflow/nodes/variable_assigner/common/impl.py new file mode 100644 index 0000000000..8f7a44bb62 --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/common/impl.py @@ -0,0 +1,38 @@ +from sqlalchemy import Engine, select +from sqlalchemy.orm import Session + +from core.variables.variables import Variable +from models.engine import db +from models.workflow import ConversationVariable + +from .exc import VariableOperatorNodeError + + +class ConversationVariableUpdaterImpl: + _engine: Engine | None + + def __init__(self, engine: Engine | None = None) -> None: + self._engine = engine + + def _get_engine(self) -> Engine: + if self._engine: + return self._engine + return db.engine + + def update(self, conversation_id: str, variable: Variable): + stmt = select(ConversationVariable).where( + ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id + ) + with Session(self._get_engine()) as session: + row = session.scalar(stmt) + if not row: + raise VariableOperatorNodeError("conversation variable not found in the database") + row.data = variable.model_dump_json() + session.commit() + + def flush(self): + pass + + +def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl: + return ConversationVariableUpdaterImpl() diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index 835e1d77b5..be5083c9c1 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -1,4 +1,9 @@ +from collections.abc import Callable, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Optional, TypeAlias + from core.variables import SegmentType, Variable +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID +from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode @@ -7,16 +12,71 @@ from core.workflow.nodes.variable_assigner.common import helpers as common_helpe from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError from factories import variable_factory +from ..common.impl import conversation_variable_updater_factory from .node_data import VariableAssignerData, WriteMode +if TYPE_CHECKING: + from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState + + +_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] + class VariableAssignerNode(BaseNode[VariableAssignerData]): _node_data_cls = VariableAssignerData _node_type = NodeType.VARIABLE_ASSIGNER + _conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY + + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph: "Graph", + graph_runtime_state: "GraphRuntimeState", + previous_node_id: Optional[str] = None, + thread_pool_id: Optional[str] = None, + conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph=graph, + graph_runtime_state=graph_runtime_state, + previous_node_id=previous_node_id, + thread_pool_id=thread_pool_id, + ) + self._conv_var_updater_factory = conv_var_updater_factory + + @classmethod + def version(cls) -> str: + return "1" + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: VariableAssignerData, + ) -> Mapping[str, Sequence[str]]: + mapping = {} + assigned_variable_node_id = node_data.assigned_variable_selector[0] + if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID: + selector_key = ".".join(node_data.assigned_variable_selector) + key = f"{node_id}.#{selector_key}#" + mapping[key] = node_data.assigned_variable_selector + + selector_key = ".".join(node_data.input_variable_selector) + key = f"{node_id}.#{selector_key}#" + mapping[key] = node_data.input_variable_selector + return mapping def _run(self) -> NodeRunResult: + assigned_variable_selector = self.node_data.assigned_variable_selector # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject - original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector) + original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector) if not isinstance(original_variable, Variable): raise VariableOperatorNodeError("assigned variable not found") @@ -44,20 +104,28 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]): raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}") # Over write the variable. - self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable) + self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable) # TODO: Move database operation to the pipeline. # Update conversation variable. conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"]) if not conversation_id: raise VariableOperatorNodeError("conversation_id not found") - common_helpers.update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) + conv_var_updater = self._conv_var_updater_factory() + conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable) + conv_var_updater.flush() + updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)] return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={ "value": income_value.to_object(), }, + # NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`, + # we still set `output_variables` as a list to ensure the schema of output is + # compatible with `v2.VariableAssignerNode`. + process_data=common_helpers.set_updated_variables({}, updated_variables), + outputs={}, ) diff --git a/api/core/workflow/nodes/variable_assigner/v2/entities.py b/api/core/workflow/nodes/variable_assigner/v2/entities.py index 01df33b6d4..d93affcd15 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/entities.py +++ b/api/core/workflow/nodes/variable_assigner/v2/entities.py @@ -12,6 +12,12 @@ class VariableOperationItem(BaseModel): variable_selector: Sequence[str] input_type: InputType operation: Operation + # NOTE(QuantumGhost): The `value` field serves multiple purposes depending on context: + # + # 1. For CONSTANT input_type: Contains the literal value to be used in the operation. + # 2. For VARIABLE input_type: Initially contains the selector of the source variable. + # 3. During the variable updating procedure: The `value` field is reassigned to hold + # the resolved actual value that will be applied to the target variable. value: Any | None = None diff --git a/api/core/workflow/nodes/variable_assigner/v2/exc.py b/api/core/workflow/nodes/variable_assigner/v2/exc.py index b67af6d73c..fd6c304a9a 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/exc.py +++ b/api/core/workflow/nodes/variable_assigner/v2/exc.py @@ -29,3 +29,8 @@ class InvalidInputValueError(VariableOperatorNodeError): class ConversationIDNotFoundError(VariableOperatorNodeError): def __init__(self): super().__init__("conversation_id not found") + + +class InvalidDataError(VariableOperatorNodeError): + def __init__(self, message: str) -> None: + super().__init__(message) diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index 8759a55b34..9292da6f1c 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -1,34 +1,84 @@ import json -from collections.abc import Sequence -from typing import Any, cast +from collections.abc import Callable, Mapping, MutableMapping, Sequence +from typing import Any, TypeAlias, cast from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import SegmentType, Variable +from core.variables.consts import MIN_SELECTORS_LENGTH from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID +from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError +from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory from . import helpers from .constants import EMPTY_VALUE_MAPPING -from .entities import VariableAssignerNodeData +from .entities import VariableAssignerNodeData, VariableOperationItem from .enums import InputType, Operation from .exc import ( ConversationIDNotFoundError, InputTypeNotSupportedError, + InvalidDataError, InvalidInputValueError, OperationNotSupportedError, VariableNotFoundError, ) +_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] + + +def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): + selector_node_id = item.variable_selector[0] + if selector_node_id != CONVERSATION_VARIABLE_NODE_ID: + return + selector_str = ".".join(item.variable_selector) + key = f"{node_id}.#{selector_str}#" + mapping[key] = item.variable_selector + + +def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): + # Keep this in sync with the logic in _run methods... + if item.input_type != InputType.VARIABLE: + return + selector = item.value + if not isinstance(selector, list): + raise InvalidDataError(f"selector is not a list, {node_id=}, {item=}") + if len(selector) < MIN_SELECTORS_LENGTH: + raise InvalidDataError(f"selector too short, {node_id=}, {item=}") + selector_str = ".".join(selector) + key = f"{node_id}.#{selector_str}#" + mapping[key] = selector + class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): _node_data_cls = VariableAssignerNodeData _node_type = NodeType.VARIABLE_ASSIGNER + def _conv_var_updater_factory(self) -> ConversationVariableUpdater: + return conversation_variable_updater_factory() + + @classmethod + def version(cls) -> str: + return "2" + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: VariableAssignerNodeData, + ) -> Mapping[str, Sequence[str]]: + var_mapping: dict[str, Sequence[str]] = {} + for item in node_data.items: + _target_mapping_from_item(var_mapping, node_id, item) + _source_mapping_from_item(var_mapping, node_id, item) + return var_mapping + def _run(self) -> NodeRunResult: inputs = self.node_data.model_dump() process_data: dict[str, Any] = {} @@ -114,6 +164,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): # remove the duplicated items first. updated_variable_selectors = list(set(map(tuple, updated_variable_selectors))) + conv_var_updater = self._conv_var_updater_factory() # Update variables for selector in updated_variable_selectors: variable = self.graph_runtime_state.variable_pool.get(selector) @@ -128,15 +179,23 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): raise ConversationIDNotFoundError else: conversation_id = conversation_id.value - common_helpers.update_conversation_variable( + conv_var_updater.update( conversation_id=cast(str, conversation_id), variable=variable, ) + conv_var_updater.flush() + updated_variables = [ + common_helpers.variable_to_processed_data(selector, seg) + for selector in updated_variable_selectors + if (seg := self.graph_runtime_state.variable_pool.get(selector)) is not None + ] + process_data = common_helpers.set_updated_variables(process_data, updated_variables) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, process_data=process_data, + outputs={}, ) def _handle_item( diff --git a/api/core/workflow/utils/structured_output/entities.py b/api/core/workflow/utils/structured_output/entities.py deleted file mode 100644 index 6491042bfe..0000000000 --- a/api/core/workflow/utils/structured_output/entities.py +++ /dev/null @@ -1,16 +0,0 @@ -from enum import StrEnum - - -class ResponseFormat(StrEnum): - """Constants for model response formats""" - - JSON_SCHEMA = "json_schema" # model's structured output mode. some model like gemini, gpt-4o, support this mode. - JSON = "JSON" # model's json mode. some model like claude support this mode. - JSON_OBJECT = "json_object" # json mode's another alias. some model like deepseek-chat, qwen use this alias. - - -class SpecialModelType(StrEnum): - """Constants for identifying model types""" - - GEMINI = "gemini" - OLLAMA = "ollama" diff --git a/api/core/workflow/utils/structured_output/prompt.py b/api/core/workflow/utils/structured_output/prompt.py deleted file mode 100644 index 06d9b2056e..0000000000 --- a/api/core/workflow/utils/structured_output/prompt.py +++ /dev/null @@ -1,17 +0,0 @@ -STRUCTURED_OUTPUT_PROMPT = """You’re a helpful AI assistant. You could answer questions and output in JSON format. -constraints: - - You must output in JSON format. - - Do not output boolean value, use string type instead. - - Do not output integer or float value, use number type instead. -eg: - Here is the JSON schema: - {"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"} - - Here is the user's question: - My name is John Doe and I am 30 years old. - - output: - {"name": "John Doe", "age": 30} -Here is the JSON schema: -{{schema}} -""" # noqa: E501 diff --git a/api/core/workflow/utils/variable_utils.py b/api/core/workflow/utils/variable_utils.py new file mode 100644 index 0000000000..868868315b --- /dev/null +++ b/api/core/workflow/utils/variable_utils.py @@ -0,0 +1,29 @@ +from core.variables.segments import ObjectSegment, Segment +from core.workflow.entities.variable_pool import VariablePool, VariableValue + + +def append_variables_recursively( + pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue | Segment +): + """ + Append variables recursively + :param pool: variable pool to append variables to + :param node_id: node id + :param variable_key_list: variable key list + :param variable_value: variable value + :return: + """ + pool.add([node_id] + variable_key_list, variable_value) + + # if variable_value is a dict, then recursively append variables + if isinstance(variable_value, ObjectSegment): + variable_dict = variable_value.value + elif isinstance(variable_value, dict): + variable_dict = variable_value + else: + return + + for key, value in variable_dict.items(): + # construct new key list + new_key_list = variable_key_list + [key] + append_variables_recursively(pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value) diff --git a/api/core/workflow/variable_loader.py b/api/core/workflow/variable_loader.py new file mode 100644 index 0000000000..1e13871d0a --- /dev/null +++ b/api/core/workflow/variable_loader.py @@ -0,0 +1,84 @@ +import abc +from collections.abc import Mapping, Sequence +from typing import Any, Protocol + +from core.variables import Variable +from core.variables.consts import MIN_SELECTORS_LENGTH +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.utils import variable_utils + + +class VariableLoader(Protocol): + """Interface for loading variables based on selectors. + + A `VariableLoader` is responsible for retrieving additional variables required during the execution + of a single node, which are not provided as user inputs. + + NOTE(QuantumGhost): Typically, all variables loaded by a `VariableLoader` should belong to the same + application and share the same `app_id`. However, this interface does not enforce that constraint, + and the `app_id` parameter is intentionally omitted from `load_variables` to achieve separation of + concern and allow for flexible implementations. + + Implementations of `VariableLoader` should almost always have an `app_id` parameter in + their constructor. + + TODO(QuantumGhost): this is a temporally workaround. If we can move the creation of node instance into + `WorkflowService.single_step_run`, we may get rid of this interface. + """ + + @abc.abstractmethod + def load_variables(self, selectors: list[list[str]]) -> list[Variable]: + """Load variables based on the provided selectors. If the selectors are empty, + this method should return an empty list. + + The order of the returned variables is not guaranteed. If the caller wants to ensure + a specific order, they should sort the returned list themselves. + + :param: selectors: a list of string list, each inner list should have at least two elements: + - the first element is the node ID, + - the second element is the variable name. + :return: a list of Variable objects that match the provided selectors. + """ + pass + + +class _DummyVariableLoader(VariableLoader): + """A dummy implementation of VariableLoader that does not load any variables. + Serves as a placeholder when no variable loading is needed. + """ + + def load_variables(self, selectors: list[list[str]]) -> list[Variable]: + return [] + + +DUMMY_VARIABLE_LOADER = _DummyVariableLoader() + + +def load_into_variable_pool( + variable_loader: VariableLoader, + variable_pool: VariablePool, + variable_mapping: Mapping[str, Sequence[str]], + user_inputs: Mapping[str, Any], +): + # Loading missing variable from draft var here, and set it into + # variable_pool. + variables_to_load: list[list[str]] = [] + for key, selector in variable_mapping.items(): + # NOTE(QuantumGhost): this logic needs to be in sync with + # `WorkflowEntry.mapping_user_inputs_to_variable_pool`. + node_variable_list = key.split(".") + if len(node_variable_list) < 1: + raise ValueError(f"Invalid variable key: {key}. It should have at least one element.") + if key in user_inputs: + continue + node_variable_key = ".".join(node_variable_list[1:]) + if node_variable_key in user_inputs: + continue + if variable_pool.get(selector) is None: + variables_to_load.append(list(selector)) + loaded = variable_loader.load_variables(variables_to_load) + for var in loaded: + assert len(var.selector) >= MIN_SELECTORS_LENGTH, f"Invalid variable {var}" + variable_utils.append_variables_recursively( + variable_pool, node_id=var.selector[0], variable_key_list=list(var.selector[1:]), variable_value=var + ) diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index b88f9edd03..6ee562fc8d 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -92,7 +92,7 @@ class WorkflowCycleManager: ) -> WorkflowExecution: workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) - outputs = WorkflowEntry.handle_special_values(outputs) + # outputs = WorkflowEntry.handle_special_values(outputs) workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED workflow_execution.outputs = outputs or {} @@ -125,7 +125,7 @@ class WorkflowCycleManager: trace_manager: Optional[TraceQueueManager] = None, ) -> WorkflowExecution: execution = self._get_workflow_execution_or_raise_error(workflow_run_id) - outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) + # outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED execution.outputs = outputs or {} @@ -242,9 +242,9 @@ class WorkflowCycleManager: raise ValueError(f"Domain node execution not found: {event.node_execution_id}") # Process data - inputs = WorkflowEntry.handle_special_values(event.inputs) - process_data = WorkflowEntry.handle_special_values(event.process_data) - outputs = WorkflowEntry.handle_special_values(event.outputs) + inputs = event.inputs + process_data = event.process_data + outputs = event.outputs # Convert metadata keys to strings execution_metadata_dict = {} @@ -289,7 +289,7 @@ class WorkflowCycleManager: # Process data inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) - outputs = WorkflowEntry.handle_special_values(event.outputs) + outputs = event.outputs # Convert metadata keys to strings execution_metadata_dict = {} @@ -326,7 +326,7 @@ class WorkflowCycleManager: finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - created_at).total_seconds() inputs = WorkflowEntry.handle_special_values(event.inputs) - outputs = WorkflowEntry.handle_special_values(event.outputs) + outputs = event.outputs # Convert metadata keys to strings origin_metadata = { diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 7648947fca..c0e98db3db 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -21,6 +21,7 @@ from core.workflow.nodes import NodeType from core.workflow.nodes.base import BaseNode from core.workflow.nodes.event import NodeEvent from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from factories import file_factory from models.enums import UserFrom from models.workflow import ( @@ -119,7 +120,9 @@ class WorkflowEntry: workflow: Workflow, node_id: str, user_id: str, - user_inputs: dict, + user_inputs: Mapping[str, Any], + variable_pool: VariablePool, + variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, ) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]: """ Single step run workflow node @@ -129,29 +132,14 @@ class WorkflowEntry: :param user_inputs: user inputs :return: """ - # fetch node info from workflow graph - workflow_graph = workflow.graph_dict - if not workflow_graph: - raise ValueError("workflow graph not found") - - nodes = workflow_graph.get("nodes") - if not nodes: - raise ValueError("nodes not found in workflow graph") - - # fetch node config from node id - try: - node_config = next(filter(lambda node: node["id"] == node_id, nodes)) - except StopIteration: - raise ValueError("node id not found in workflow graph") + node_config = workflow.get_node_config_by_id(node_id) + node_config_data = node_config.get("data", {}) # Get node class - node_type = NodeType(node_config.get("data", {}).get("type")) - node_version = node_config.get("data", {}).get("version", "1") + node_type = NodeType(node_config_data.get("type")) + node_version = node_config_data.get("version", "1") node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] - # init variable pool - variable_pool = VariablePool(environment_variables=workflow.environment_variables) - # init graph graph = Graph.init(graph_config=workflow.graph_dict) @@ -182,16 +170,33 @@ class WorkflowEntry: except NotImplementedError: variable_mapping = {} + # Loading missing variable from draft var here, and set it into + # variable_pool. + load_into_variable_pool( + variable_loader=variable_loader, + variable_pool=variable_pool, + variable_mapping=variable_mapping, + user_inputs=user_inputs, + ) + cls.mapping_user_inputs_to_variable_pool( variable_mapping=variable_mapping, user_inputs=user_inputs, variable_pool=variable_pool, tenant_id=workflow.tenant_id, ) + try: # run node generator = node_instance.run() except Exception as e: + logger.exception( + "error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s", + workflow.id, + node_instance.id, + node_instance.node_type, + node_instance.version(), + ) raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) return node_instance, generator @@ -294,10 +299,20 @@ class WorkflowEntry: return node_instance, generator except Exception as e: + logger.exception( + "error while running node_instance, node_id=%s, type=%s, version=%s", + node_instance.id, + node_instance.node_type, + node_instance.version(), + ) raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) @staticmethod def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: + # NOTE(QuantumGhost): Avoid using this function in new code. + # Keep values structured as long as possible and only convert to dict + # immediately before serialization (e.g., JSON serialization) to maintain + # data integrity and type information. result = WorkflowEntry._handle_special_values(value) return result if isinstance(result, Mapping) or result is None else dict(result) @@ -324,10 +339,17 @@ class WorkflowEntry: cls, *, variable_mapping: Mapping[str, Sequence[str]], - user_inputs: dict, + user_inputs: Mapping[str, Any], variable_pool: VariablePool, tenant_id: str, ) -> None: + # NOTE(QuantumGhost): This logic should remain synchronized with + # the implementation of `load_into_variable_pool`, specifically the logic about + # variable existence checking. + + # WARNING(QuantumGhost): The semantics of this method are not clearly defined, + # and multiple parts of the codebase depend on its current behavior. + # Modify with caution. for node_variable, variable_selector in variable_mapping.items(): # fetch node id and variable key from node_variable node_variable_list = node_variable.split(".") diff --git a/api/core/workflow/workflow_type_encoder.py b/api/core/workflow/workflow_type_encoder.py new file mode 100644 index 0000000000..0123fdac18 --- /dev/null +++ b/api/core/workflow/workflow_type_encoder.py @@ -0,0 +1,49 @@ +import json +from collections.abc import Mapping +from typing import Any + +from pydantic import BaseModel + +from core.file.models import File +from core.variables import Segment + + +class WorkflowRuntimeTypeEncoder(json.JSONEncoder): + def default(self, o: Any): + if isinstance(o, Segment): + return o.value + elif isinstance(o, File): + return o.to_dict() + elif isinstance(o, BaseModel): + return o.model_dump(mode="json") + else: + return super().default(o) + + +class WorkflowRuntimeTypeConverter: + def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None: + result = self._to_json_encodable_recursive(value) + return result if isinstance(result, Mapping) or result is None else dict(result) + + def _to_json_encodable_recursive(self, value: Any) -> Any: + if value is None: + return value + if isinstance(value, (bool, int, str, float)): + return value + if isinstance(value, Segment): + return self._to_json_encodable_recursive(value.value) + if isinstance(value, File): + return value.to_dict() + if isinstance(value, BaseModel): + return value.model_dump(mode="json") + if isinstance(value, dict): + res = {} + for k, v in value.items(): + res[k] = self._to_json_encodable_recursive(v) + return res + if isinstance(value, list): + res_list = [] + for item in value: + res_list.append(self._to_json_encodable_recursive(item)) + return res_list + return value diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py index 1d6ad35333..ebc55d5ef8 100644 --- a/api/events/event_handlers/__init__.py +++ b/api/events/event_handlers/__init__.py @@ -3,8 +3,10 @@ from .clean_when_document_deleted import handle from .create_document_index import handle from .create_installed_app_when_app_created import handle from .create_site_record_when_app_created import handle -from .deduct_quota_when_message_created import handle from .delete_tool_parameters_cache_when_sync_draft_workflow import handle from .update_app_dataset_join_when_app_model_config_updated import handle from .update_app_dataset_join_when_app_published_workflow_updated import handle -from .update_provider_last_used_at_when_message_created import handle + +# Consolidated handler replaces both deduct_quota_when_message_created and +# update_provider_last_used_at_when_message_created +from .update_provider_when_message_created import handle diff --git a/api/events/event_handlers/deduct_quota_when_message_created.py b/api/events/event_handlers/deduct_quota_when_message_created.py deleted file mode 100644 index b8e7019446..0000000000 --- a/api/events/event_handlers/deduct_quota_when_message_created.py +++ /dev/null @@ -1,65 +0,0 @@ -from datetime import UTC, datetime - -from configs import dify_config -from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity -from core.entities.provider_entities import QuotaUnit -from core.plugin.entities.plugin import ModelProviderID -from events.message_event import message_was_created -from extensions.ext_database import db -from models.provider import Provider, ProviderType - - -@message_was_created.connect -def handle(sender, **kwargs): - message = sender - application_generate_entity = kwargs.get("application_generate_entity") - - if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): - return - - model_config = application_generate_entity.model_conf - provider_model_bundle = model_config.provider_model_bundle - provider_configuration = provider_model_bundle.configuration - - if provider_configuration.using_provider_type != ProviderType.SYSTEM: - return - - system_configuration = provider_configuration.system_configuration - - if not system_configuration.current_quota_type: - return - - quota_unit = None - for quota_configuration in system_configuration.quota_configurations: - if quota_configuration.quota_type == system_configuration.current_quota_type: - quota_unit = quota_configuration.quota_unit - - if quota_configuration.quota_limit == -1: - return - - break - - used_quota = None - if quota_unit: - if quota_unit == QuotaUnit.TOKENS: - used_quota = message.message_tokens + message.answer_tokens - elif quota_unit == QuotaUnit.CREDITS: - used_quota = dify_config.get_model_credits(model_config.model) - else: - used_quota = 1 - - if used_quota is not None and system_configuration.current_quota_type is not None: - db.session.query(Provider).filter( - Provider.tenant_id == application_generate_entity.app_config.tenant_id, - # TODO: Use provider name with prefix after the data migration. - Provider.provider_name == ModelProviderID(model_config.provider).provider_name, - Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == system_configuration.current_quota_type.value, - Provider.quota_limit > Provider.quota_used, - ).update( - { - "quota_used": Provider.quota_used + used_quota, - "last_used": datetime.now(tz=UTC).replace(tzinfo=None), - } - ) - db.session.commit() diff --git a/api/events/event_handlers/update_provider_last_used_at_when_message_created.py b/api/events/event_handlers/update_provider_last_used_at_when_message_created.py deleted file mode 100644 index 59412cf87c..0000000000 --- a/api/events/event_handlers/update_provider_last_used_at_when_message_created.py +++ /dev/null @@ -1,20 +0,0 @@ -from datetime import UTC, datetime - -from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity -from events.message_event import message_was_created -from extensions.ext_database import db -from models.provider import Provider - - -@message_was_created.connect -def handle(sender, **kwargs): - application_generate_entity = kwargs.get("application_generate_entity") - - if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): - return - - db.session.query(Provider).filter( - Provider.tenant_id == application_generate_entity.app_config.tenant_id, - Provider.provider_name == application_generate_entity.model_conf.provider, - ).update({"last_used": datetime.now(UTC).replace(tzinfo=None)}) - db.session.commit() diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py new file mode 100644 index 0000000000..d3943f2eda --- /dev/null +++ b/api/events/event_handlers/update_provider_when_message_created.py @@ -0,0 +1,234 @@ +import logging +import time as time_module +from datetime import datetime +from typing import Any, Optional + +from pydantic import BaseModel +from sqlalchemy import update +from sqlalchemy.orm import Session + +from configs import dify_config +from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity +from core.entities.provider_entities import QuotaUnit, SystemConfiguration +from core.plugin.entities.plugin import ModelProviderID +from events.message_event import message_was_created +from extensions.ext_database import db +from libs import datetime_utils +from models.model import Message +from models.provider import Provider, ProviderType + +logger = logging.getLogger(__name__) + + +class _ProviderUpdateFilters(BaseModel): + """Filters for identifying Provider records to update.""" + + tenant_id: str + provider_name: str + provider_type: Optional[str] = None + quota_type: Optional[str] = None + + +class _ProviderUpdateAdditionalFilters(BaseModel): + """Additional filters for Provider updates.""" + + quota_limit_check: bool = False + + +class _ProviderUpdateValues(BaseModel): + """Values to update in Provider records.""" + + last_used: Optional[datetime] = None + quota_used: Optional[Any] = None # Can be Provider.quota_used + int expression + + +class _ProviderUpdateOperation(BaseModel): + """A single Provider update operation.""" + + filters: _ProviderUpdateFilters + values: _ProviderUpdateValues + additional_filters: _ProviderUpdateAdditionalFilters = _ProviderUpdateAdditionalFilters() + description: str = "unknown" + + +@message_was_created.connect +def handle(sender: Message, **kwargs): + """ + Consolidated handler for Provider updates when a message is created. + + This handler replaces both: + - update_provider_last_used_at_when_message_created + - deduct_quota_when_message_created + + By performing all Provider updates in a single transaction, we ensure + consistency and efficiency when updating Provider records. + """ + message = sender + application_generate_entity = kwargs.get("application_generate_entity") + + if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): + return + + tenant_id = application_generate_entity.app_config.tenant_id + provider_name = application_generate_entity.model_conf.provider + current_time = datetime_utils.naive_utc_now() + + # Prepare updates for both scenarios + updates_to_perform: list[_ProviderUpdateOperation] = [] + + # 1. Always update last_used for the provider + basic_update = _ProviderUpdateOperation( + filters=_ProviderUpdateFilters( + tenant_id=tenant_id, + provider_name=provider_name, + ), + values=_ProviderUpdateValues(last_used=current_time), + description="basic_last_used_update", + ) + updates_to_perform.append(basic_update) + + # 2. Check if we need to deduct quota (system provider only) + model_config = application_generate_entity.model_conf + provider_model_bundle = model_config.provider_model_bundle + provider_configuration = provider_model_bundle.configuration + + if ( + provider_configuration.using_provider_type == ProviderType.SYSTEM + and provider_configuration.system_configuration + and provider_configuration.system_configuration.current_quota_type is not None + ): + system_configuration = provider_configuration.system_configuration + + # Calculate quota usage + used_quota = _calculate_quota_usage( + message=message, + system_configuration=system_configuration, + model_name=model_config.model, + ) + + if used_quota is not None: + quota_update = _ProviderUpdateOperation( + filters=_ProviderUpdateFilters( + tenant_id=tenant_id, + provider_name=ModelProviderID(model_config.provider).provider_name, + provider_type=ProviderType.SYSTEM.value, + quota_type=provider_configuration.system_configuration.current_quota_type.value, + ), + values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time), + additional_filters=_ProviderUpdateAdditionalFilters( + quota_limit_check=True # Provider.quota_limit > Provider.quota_used + ), + description="quota_deduction_update", + ) + updates_to_perform.append(quota_update) + + # Execute all updates + start_time = time_module.perf_counter() + try: + _execute_provider_updates(updates_to_perform) + + # Log successful completion with timing + duration = time_module.perf_counter() - start_time + + logger.info( + f"Provider updates completed successfully. " + f"Updates: {len(updates_to_perform)}, Duration: {duration:.3f}s, " + f"Tenant: {tenant_id}, Provider: {provider_name}" + ) + + except Exception as e: + # Log failure with timing and context + duration = time_module.perf_counter() - start_time + + logger.exception( + f"Provider updates failed after {duration:.3f}s. " + f"Updates: {len(updates_to_perform)}, Tenant: {tenant_id}, " + f"Provider: {provider_name}" + ) + raise + + +def _calculate_quota_usage( + *, message: Message, system_configuration: SystemConfiguration, model_name: str +) -> Optional[int]: + """Calculate quota usage based on message tokens and quota type.""" + quota_unit = None + for quota_configuration in system_configuration.quota_configurations: + if quota_configuration.quota_type == system_configuration.current_quota_type: + quota_unit = quota_configuration.quota_unit + if quota_configuration.quota_limit == -1: + return None + break + if quota_unit is None: + return None + + try: + if quota_unit == QuotaUnit.TOKENS: + tokens = message.message_tokens + message.answer_tokens + return tokens + if quota_unit == QuotaUnit.CREDITS: + tokens = dify_config.get_model_credits(model_name) + return tokens + elif quota_unit == QuotaUnit.TIMES: + return 1 + return None + except Exception as e: + logger.exception("Failed to calculate quota usage") + return None + + +def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]): + """Execute all Provider updates in a single transaction.""" + if not updates_to_perform: + return + + # Use SQLAlchemy's context manager for transaction management + # This automatically handles commit/rollback + with Session(db.engine) as session: + # Use a single transaction for all updates + for update_operation in updates_to_perform: + filters = update_operation.filters + values = update_operation.values + additional_filters = update_operation.additional_filters + description = update_operation.description + + # Build the where conditions + where_conditions = [ + Provider.tenant_id == filters.tenant_id, + Provider.provider_name == filters.provider_name, + ] + + # Add additional filters if specified + if filters.provider_type is not None: + where_conditions.append(Provider.provider_type == filters.provider_type) + if filters.quota_type is not None: + where_conditions.append(Provider.quota_type == filters.quota_type) + if additional_filters.quota_limit_check: + where_conditions.append(Provider.quota_limit > Provider.quota_used) + + # Prepare values dict for SQLAlchemy update + update_values = {} + if values.last_used is not None: + update_values["last_used"] = values.last_used + if values.quota_used is not None: + update_values["quota_used"] = values.quota_used + + # Build and execute the update statement + stmt = update(Provider).where(*where_conditions).values(**update_values) + result = session.execute(stmt) + rows_affected = result.rowcount + + logger.debug( + f"Provider update ({description}): {rows_affected} rows affected. " + f"Filters: {filters.model_dump()}, Values: {update_values}" + ) + + # If no rows were affected for quota updates, log a warning + if rows_affected == 0 and description == "quota_deduction_update": + logger.warning( + f"No Provider rows updated for quota deduction. " + f"This may indicate quota limit exceeded or provider not found. " + f"Filters: {filters.model_dump()}" + ) + + logger.debug(f"Successfully processed {len(updates_to_perform)} Provider updates") diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index a837552007..6279b1ad36 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -21,6 +21,7 @@ def init_app(app: DifyApp) -> Celery: "master_name": dify_config.CELERY_SENTINEL_MASTER_NAME, "sentinel_kwargs": { "socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT, + "password": dify_config.CELERY_SENTINEL_PASSWORD, }, } diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 52f119936f..25d1390492 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -5,6 +5,7 @@ from typing import Any, cast import httpx from sqlalchemy import select +from sqlalchemy.orm import Session from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers @@ -91,6 +92,8 @@ def build_from_mappings( tenant_id: str, strict_type_validation: bool = False, ) -> Sequence[File]: + # TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query. + # Implement batch processing to reduce database load when handling multiple files. files = [ build_from_mapping( mapping=mapping, @@ -377,3 +380,75 @@ def _get_file_type_by_mimetype(mime_type: str) -> FileType | None: def get_file_type_by_mime_type(mime_type: str) -> FileType: return _get_file_type_by_mimetype(mime_type) or FileType.CUSTOM + + +class StorageKeyLoader: + """FileKeyLoader load the storage key from database for a list of files. + This loader is batched, the database query count is constant regardless of the input size. + """ + + def __init__(self, session: Session, tenant_id: str) -> None: + self._session = session + self._tenant_id = tenant_id + + def _load_upload_files(self, upload_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, UploadFile]: + stmt = select(UploadFile).where( + UploadFile.id.in_(upload_file_ids), + UploadFile.tenant_id == self._tenant_id, + ) + + return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)} + + def _load_tool_files(self, tool_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, ToolFile]: + stmt = select(ToolFile).where( + ToolFile.id.in_(tool_file_ids), + ToolFile.tenant_id == self._tenant_id, + ) + return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)} + + def load_storage_keys(self, files: Sequence[File]): + """Loads storage keys for a sequence of files by retrieving the corresponding + `UploadFile` or `ToolFile` records from the database based on their transfer method. + + This method doesn't modify the input sequence structure but updates the `_storage_key` + property of each file object by extracting the relevant key from its database record. + + Performance note: This is a batched operation where database query count remains constant + regardless of input size. However, for optimal performance, input sequences should contain + fewer than 1000 files. For larger collections, split into smaller batches and process each + batch separately. + """ + + upload_file_ids: list[uuid.UUID] = [] + tool_file_ids: list[uuid.UUID] = [] + for file in files: + related_model_id = file.related_id + if file.related_id is None: + raise ValueError("file id should not be None.") + if file.tenant_id != self._tenant_id: + err_msg = ( + f"invalid file, expected tenant_id={self._tenant_id}, " + f"got tenant_id={file.tenant_id}, file_id={file.id}, related_model_id={related_model_id}" + ) + raise ValueError(err_msg) + model_id = uuid.UUID(related_model_id) + + if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): + upload_file_ids.append(model_id) + elif file.transfer_method == FileTransferMethod.TOOL_FILE: + tool_file_ids.append(model_id) + + tool_files = self._load_tool_files(tool_file_ids) + upload_files = self._load_upload_files(upload_file_ids) + for file in files: + model_id = uuid.UUID(file.related_id) + if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): + upload_file_row = upload_files.get(model_id) + if upload_file_row is None: + raise ValueError(f"Upload file not found for id: {model_id}") + file._storage_key = upload_file_row.key + elif file.transfer_method == FileTransferMethod.TOOL_FILE: + tool_file_row = tool_files.get(model_id) + if tool_file_row is None: + raise ValueError(f"Tool file not found for id: {model_id}") + file._storage_key = tool_file_row.file_key diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index a41ef4ae4e..250ee4695e 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -43,6 +43,10 @@ class UnsupportedSegmentTypeError(Exception): pass +class TypeMismatchError(Exception): + pass + + # Define the constant SEGMENT_TO_VARIABLE_MAP = { StringSegment: StringVariable, @@ -110,6 +114,10 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen return cast(Variable, result) +def infer_segment_type_from_value(value: Any, /) -> SegmentType: + return build_segment(value).value_type + + def build_segment(value: Any, /) -> Segment: if value is None: return NoneSegment() @@ -140,10 +148,80 @@ def build_segment(value: Any, /) -> Segment: case SegmentType.NONE: return ArrayAnySegment(value=value) case _: + # This should be unreachable. raise ValueError(f"not supported value {value}") raise ValueError(f"not supported value {value}") +def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: + """ + Build a segment with explicit type checking. + + This function creates a segment from a value while enforcing type compatibility + with the specified segment_type. It provides stricter type validation compared + to the standard build_segment function. + + Args: + segment_type: The expected SegmentType for the resulting segment + value: The value to be converted into a segment + + Returns: + Segment: A segment instance of the appropriate type + + Raises: + TypeMismatchError: If the value type doesn't match the expected segment_type + + Special Cases: + - For empty list [] values, if segment_type is array[*], returns the corresponding array type + - Type validation is performed before segment creation + + Examples: + >>> build_segment_with_type(SegmentType.STRING, "hello") + StringSegment(value="hello") + + >>> build_segment_with_type(SegmentType.ARRAY_STRING, []) + ArrayStringSegment(value=[]) + + >>> build_segment_with_type(SegmentType.STRING, 123) + # Raises TypeMismatchError + """ + # Handle None values + if value is None: + if segment_type == SegmentType.NONE: + return NoneSegment() + else: + raise TypeMismatchError(f"Expected {segment_type}, but got None") + + # Handle empty list special case for array types + if isinstance(value, list) and len(value) == 0: + if segment_type == SegmentType.ARRAY_ANY: + return ArrayAnySegment(value=value) + elif segment_type == SegmentType.ARRAY_STRING: + return ArrayStringSegment(value=value) + elif segment_type == SegmentType.ARRAY_NUMBER: + return ArrayNumberSegment(value=value) + elif segment_type == SegmentType.ARRAY_OBJECT: + return ArrayObjectSegment(value=value) + elif segment_type == SegmentType.ARRAY_FILE: + return ArrayFileSegment(value=value) + else: + raise TypeMismatchError(f"Expected {segment_type}, but got empty list") + + # Build segment using existing logic to infer actual type + inferred_segment = build_segment(value) + inferred_type = inferred_segment.value_type + + # Type compatibility checking + if inferred_type == segment_type: + return inferred_segment + + # Type mismatch - raise error with descriptive message + raise TypeMismatchError( + f"Type mismatch: expected {segment_type}, but value '{value}' " + f"(type: {type(value).__name__}) corresponds to {inferred_type}" + ) + + def segment_to_variable( *, segment: Segment, diff --git a/api/libs/datetime_utils.py b/api/libs/datetime_utils.py new file mode 100644 index 0000000000..e576a34629 --- /dev/null +++ b/api/libs/datetime_utils.py @@ -0,0 +1,22 @@ +import abc +import datetime +from typing import Protocol + + +class _NowFunction(Protocol): + @abc.abstractmethod + def __call__(self, tz: datetime.timezone | None) -> datetime.datetime: + pass + + +# _now_func is a callable with the _NowFunction signature. +# Its sole purpose is to abstract time retrieval, enabling +# developers to mock this behavior in tests and time-dependent scenarios. +_now_func: _NowFunction = datetime.datetime.now + + +def naive_utc_now() -> datetime.datetime: + """Return a naive datetime object (without timezone information) + representing current UTC time. + """ + return _now_func(datetime.UTC).replace(tzinfo=None) diff --git a/api/libs/jsonutil.py b/api/libs/jsonutil.py new file mode 100644 index 0000000000..fa29671034 --- /dev/null +++ b/api/libs/jsonutil.py @@ -0,0 +1,11 @@ +import json + +from pydantic import BaseModel + + +class PydanticModelEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, BaseModel): + return o.model_dump() + else: + super().default(o) diff --git a/api/libs/smtp.py b/api/libs/smtp.py index 35561f071c..b94386660e 100644 --- a/api/libs/smtp.py +++ b/api/libs/smtp.py @@ -22,7 +22,11 @@ class SMTPClient: if self.use_tls: if self.opportunistic_tls: smtp = smtplib.SMTP(self.server, self.port, timeout=10) + # Send EHLO command with the HELO domain name as the server address + smtp.ehlo(self.server) smtp.starttls() + # Resend EHLO command to identify the TLS session + smtp.ehlo(self.server) else: smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10) else: diff --git a/api/models/_workflow_exc.py b/api/models/_workflow_exc.py new file mode 100644 index 0000000000..f6271bda47 --- /dev/null +++ b/api/models/_workflow_exc.py @@ -0,0 +1,20 @@ +"""All these exceptions are not meant to be caught by callers.""" + + +class WorkflowDataError(Exception): + """Base class for all workflow data related exceptions. + + This should be used to indicate issues with workflow data integrity, such as + no `graph` configuration, missing `nodes` field in `graph` configuration, or + similar issues. + """ + + pass + + +class NodeNotFoundError(WorkflowDataError): + """Raised when a node with the specified ID is not found in the workflow.""" + + def __init__(self, node_id: str): + super().__init__(f"Node with ID '{node_id}' not found in the workflow.") + self.node_id = node_id diff --git a/api/models/model.py b/api/models/model.py index fa83baa9cf..ce5f449f87 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -611,6 +611,14 @@ class InstalledApp(Base): return tenant +class ConversationSource(StrEnum): + """This enumeration is designed for use with `Conversation.from_source`.""" + + # NOTE(QuantumGhost): The enumeration members may not cover all possible cases. + API = "api" + CONSOLE = "console" + + class Conversation(Base): __tablename__ = "conversations" __table_args__ = ( @@ -632,7 +640,14 @@ class Conversation(Base): system_instruction = db.Column(db.Text) system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) status = db.Column(db.String(255), nullable=False) + + # The `invoke_from` records how the conversation is created. + # + # Its value corresponds to the members of `InvokeFrom`. + # (api/core/app/entities/app_invoke_entities.py) invoke_from = db.Column(db.String(255), nullable=True) + + # ref: ConversationSource. from_source = db.Column(db.String(255), nullable=False) from_end_user_id = db.Column(StringUUID) from_account_id = db.Column(StringUUID) @@ -703,7 +718,6 @@ class Conversation(Base): if "model" in override_model_configs: app_model_config = AppModelConfig() app_model_config = app_model_config.from_model_config_dict(override_model_configs) - assert app_model_config is not None, "app model config not found" model_config = app_model_config.to_dict() else: model_config["configs"] = override_model_configs @@ -817,7 +831,12 @@ class Conversation(Base): @property def first_message(self): - return db.session.query(Message).filter(Message.conversation_id == self.id).first() + return ( + db.session.query(Message) + .filter(Message.conversation_id == self.id) + .order_by(Message.created_at.asc()) + .first() + ) @property def app(self): @@ -894,11 +913,11 @@ class Message(Base): _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) query: Mapped[str] = db.Column(db.Text, nullable=False) message = db.Column(db.JSON, nullable=False) - message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + message_tokens: Mapped[int] = db.Column(db.Integer, nullable=False, server_default=db.text("0")) message_unit_price = db.Column(db.Numeric(10, 4), nullable=False) message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) answer: Mapped[str] = db.Column(db.Text, nullable=False) - answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + answer_tokens: Mapped[int] = db.Column(db.Integer, nullable=False, server_default=db.text("0")) answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) parent_message_id = db.Column(StringUUID, nullable=True) diff --git a/api/models/workflow.py b/api/models/workflow.py index 1733dec0fc..7f01135af3 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -7,10 +7,16 @@ from typing import TYPE_CHECKING, Any, Optional, Union from uuid import uuid4 from flask_login import current_user +from sqlalchemy import orm +from core.file.constants import maybe_file_object +from core.file.models import File from core.variables import utils as variable_utils from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from factories.variable_factory import build_segment +from core.workflow.nodes.enums import NodeType +from factories.variable_factory import TypeMismatchError, build_segment_with_type + +from ._workflow_exc import NodeNotFoundError, WorkflowDataError if TYPE_CHECKING: from models.model import AppMode @@ -72,6 +78,10 @@ class WorkflowType(Enum): return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT +class _InvalidGraphDefinitionError(Exception): + pass + + class Workflow(Base): """ Workflow, for `Workflow App` and `Chat App workflow mode`. @@ -136,6 +146,8 @@ class Workflow(Base): "conversation_variables", db.Text, nullable=False, server_default="{}" ) + VERSION_DRAFT = "draft" + @classmethod def new( cls, @@ -179,8 +191,72 @@ class Workflow(Base): @property def graph_dict(self) -> Mapping[str, Any]: + # TODO(QuantumGhost): Consider caching `graph_dict` to avoid repeated JSON decoding. + # + # Using `functools.cached_property` could help, but some code in the codebase may + # modify the returned dict, which can cause issues elsewhere. + # + # For example, changing this property to a cached property led to errors like the + # following when single stepping an `Iteration` node: + # + # Root node id 1748401971780start not found in the graph + # + # There is currently no standard way to make a dict deeply immutable in Python, + # and tracking modifications to the returned dict is difficult. For now, we leave + # the code as-is to avoid these issues. + # + # Currently, the following functions / methods would mutate the returned dict: + # + # - `_get_graph_and_variable_pool_of_single_iteration`. + # - `_get_graph_and_variable_pool_of_single_loop`. return json.loads(self.graph) if self.graph else {} + def get_node_config_by_id(self, node_id: str) -> Mapping[str, Any]: + """Extract a node configuration from the workflow graph by node ID. + A node configuration is a dictionary containing the node's properties, including + the node's id, title, and its data as a dict. + """ + workflow_graph = self.graph_dict + + if not workflow_graph: + raise WorkflowDataError(f"workflow graph not found, workflow_id={self.id}") + + nodes = workflow_graph.get("nodes") + if not nodes: + raise WorkflowDataError("nodes not found in workflow graph") + + try: + node_config = next(filter(lambda node: node["id"] == node_id, nodes)) + except StopIteration: + raise NodeNotFoundError(node_id) + assert isinstance(node_config, dict) + return node_config + + @staticmethod + def get_node_type_from_node_config(node_config: Mapping[str, Any]) -> NodeType: + """Extract type of a node from the node configuration returned by `get_node_config_by_id`.""" + node_config_data = node_config.get("data", {}) + # Get node class + node_type = NodeType(node_config_data.get("type")) + return node_type + + @staticmethod + def get_enclosing_node_type_and_id(node_config: Mapping[str, Any]) -> tuple[NodeType, str] | None: + in_loop = node_config.get("isInLoop", False) + in_iteration = node_config.get("isInIteration", False) + if in_loop: + loop_id = node_config.get("loop_id") + if loop_id is None: + raise _InvalidGraphDefinitionError("invalid graph") + return NodeType.LOOP, loop_id + elif in_iteration: + iteration_id = node_config.get("iteration_id") + if iteration_id is None: + raise _InvalidGraphDefinitionError("invalid graph") + return NodeType.ITERATION, iteration_id + else: + return None + @property def features(self) -> str: """ @@ -376,6 +452,10 @@ class Workflow(Base): ensure_ascii=False, ) + @staticmethod + def version_from_datetime(d: datetime) -> str: + return str(d) + class WorkflowRun(Base): """ @@ -835,8 +915,18 @@ def _naive_utc_datetime(): class WorkflowDraftVariable(Base): + """`WorkflowDraftVariable` record variables and outputs generated during + debugging worfklow or chatflow. + + IMPORTANT: This model maintains multiple invariant rules that must be preserved. + Do not instantiate this class directly with the constructor. + + Instead, use the factory methods (`new_conversation_variable`, `new_sys_variable`, + `new_node_variable`) defined below to ensure all invariants are properly maintained. + """ + @staticmethod - def unique_columns() -> list[str]: + def unique_app_id_node_id_name() -> list[str]: return [ "app_id", "node_id", @@ -844,7 +934,9 @@ class WorkflowDraftVariable(Base): ] __tablename__ = "workflow_draft_variables" - __table_args__ = (UniqueConstraint(*unique_columns()),) + __table_args__ = (UniqueConstraint(*unique_app_id_node_id_name()),) + # Required for instance variable annotation. + __allow_unmapped__ = True # id is the unique identifier of a draft variable. id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) @@ -925,6 +1017,36 @@ class WorkflowDraftVariable(Base): default=None, ) + # Cache for deserialized value + # + # NOTE(QuantumGhost): This field serves two purposes: + # + # 1. Caches deserialized values to reduce repeated parsing costs + # 2. Allows modification of the deserialized value after retrieval, + # particularly important for `File`` variables which require database + # lookups to obtain storage_key and other metadata + # + # Use double underscore prefix for better encapsulation, + # making this attribute harder to access from outside the class. + __value: Segment | None + + def __init__(self, *args, **kwargs): + """ + The constructor of `WorkflowDraftVariable` is not intended for + direct use outside this file. Its solo purpose is setup private state + used by the model instance. + + Please use the factory methods + (`new_conversation_variable`, `new_sys_variable`, `new_node_variable`) + defined below to create instances of this class. + """ + super().__init__(*args, **kwargs) + self.__value = None + + @orm.reconstructor + def _init_on_load(self): + self.__value = None + def get_selector(self) -> list[str]: selector = json.loads(self.selector) if not isinstance(selector, list): @@ -939,15 +1061,92 @@ class WorkflowDraftVariable(Base): def _set_selector(self, value: list[str]): self.selector = json.dumps(value) - def get_value(self) -> Segment | None: - return build_segment(json.loads(self.value)) + def _loads_value(self) -> Segment: + value = json.loads(self.value) + return self.build_segment_with_type(self.value_type, value) + + @staticmethod + def rebuild_file_types(value: Any) -> Any: + # NOTE(QuantumGhost): Temporary workaround for structured data handling. + # By this point, `output` has been converted to dict by + # `WorkflowEntry.handle_special_values`, so we need to + # reconstruct File objects from their serialized form + # to maintain proper variable saving behavior. + # + # Ideally, we should work with structured data objects directly + # rather than their serialized forms. + # However, multiple components in the codebase depend on + # `WorkflowEntry.handle_special_values`, making a comprehensive migration challenging. + if isinstance(value, dict): + if not maybe_file_object(value): + return value + return File.model_validate(value) + elif isinstance(value, list) and value: + first = value[0] + if not maybe_file_object(first): + return value + return [File.model_validate(i) for i in value] + else: + return value + + @classmethod + def build_segment_with_type(cls, segment_type: SegmentType, value: Any) -> Segment: + # Extends `variable_factory.build_segment_with_type` functionality by + # reconstructing `FileSegment`` or `ArrayFileSegment`` objects from + # their serialized dictionary or list representations, respectively. + if segment_type == SegmentType.FILE: + if isinstance(value, File): + return build_segment_with_type(segment_type, value) + elif isinstance(value, dict): + file = cls.rebuild_file_types(value) + return build_segment_with_type(segment_type, file) + else: + raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}") + if segment_type == SegmentType.ARRAY_FILE: + if not isinstance(value, list): + raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}") + file_list = cls.rebuild_file_types(value) + return build_segment_with_type(segment_type=segment_type, value=file_list) + + return build_segment_with_type(segment_type=segment_type, value=value) + + def get_value(self) -> Segment: + """Decode the serialized value into its corresponding `Segment` object. + + This method caches the result, so repeated calls will return the same + object instance without re-parsing the serialized data. + + If you need to modify the returned `Segment`, use `value.model_copy()` + to create a copy first to avoid affecting the cached instance. + + For more information about the caching mechanism, see the documentation + of the `__value` field. + + Returns: + Segment: The deserialized value as a Segment object. + """ + + if self.__value is not None: + return self.__value + value = self._loads_value() + self.__value = value + return value def set_name(self, name: str): self.name = name self._set_selector([self.node_id, name]) def set_value(self, value: Segment): - self.value = json.dumps(value.value) + """Updates the `value` and corresponding `value_type` fields in the database model. + + This method also stores the provided Segment object in the deserialized cache + without creating a copy, allowing for efficient value access. + + Args: + value: The Segment object to store as the variable's value. + """ + self.__value = value + self.value = json.dumps(value, cls=variable_utils.SegmentJSONEncoder) self.value_type = value.value_type def get_node_id(self) -> str | None: @@ -973,6 +1172,7 @@ class WorkflowDraftVariable(Base): node_id: str, name: str, value: Segment, + node_execution_id: str | None, description: str = "", ) -> "WorkflowDraftVariable": variable = WorkflowDraftVariable() @@ -984,6 +1184,7 @@ class WorkflowDraftVariable(Base): variable.name = name variable.set_value(value) variable._set_selector(list(variable_utils.to_selector(node_id, name))) + variable.node_execution_id = node_execution_id return variable @classmethod @@ -993,13 +1194,17 @@ class WorkflowDraftVariable(Base): app_id: str, name: str, value: Segment, + description: str = "", ) -> "WorkflowDraftVariable": variable = cls._new( app_id=app_id, node_id=CONVERSATION_VARIABLE_NODE_ID, name=name, value=value, + description=description, + node_execution_id=None, ) + variable.editable = True return variable @classmethod @@ -1009,9 +1214,16 @@ class WorkflowDraftVariable(Base): app_id: str, name: str, value: Segment, + node_execution_id: str, editable: bool = False, ) -> "WorkflowDraftVariable": - variable = cls._new(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name, value=value) + variable = cls._new( + app_id=app_id, + node_id=SYSTEM_VARIABLE_NODE_ID, + name=name, + node_execution_id=node_execution_id, + value=value, + ) variable.editable = editable return variable @@ -1023,11 +1235,19 @@ class WorkflowDraftVariable(Base): node_id: str, name: str, value: Segment, + node_execution_id: str, visible: bool = True, + editable: bool = True, ) -> "WorkflowDraftVariable": - variable = cls._new(app_id=app_id, node_id=node_id, name=name, value=value) + variable = cls._new( + app_id=app_id, + node_id=node_id, + name=name, + node_execution_id=node_execution_id, + value=value, + ) variable.visible = visible - variable.editable = True + variable.editable = editable return variable @property diff --git a/api/pyproject.toml b/api/pyproject.toml index 38cc9ae75d..4251008053 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -149,11 +149,13 @@ dev = [ "types-ujson>=5.10.0", "boto3-stubs>=1.38.20", "types-jmespath>=1.0.2.20240106", + "hypothesis>=6.131.15", "types_pyOpenSSL>=24.1.0", "types_cffi>=1.17.0", "types_setuptools>=80.9.0", "pandas-stubs~=2.2.3", "scipy-stubs>=1.15.3.0", + "types-python-http-client>=3.3.7.20240910", ] ############################################################ @@ -196,7 +198,7 @@ vdb = [ "pymochow==1.3.1", "pyobvector~=0.1.6", "qdrant-client==1.9.0", - "tablestore==6.1.0", + "tablestore==6.2.0", "tcvectordb~=1.6.4", "tidb-vector==0.0.9", "upstash-vector==0.6.0", diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 1b026acfd6..20257fa345 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -32,6 +32,7 @@ from models import Account, App, AppMode from models.model import AppModelConfig from models.workflow import Workflow from services.plugin.dependencies_analysis import DependenciesAnalysisService +from services.workflow_draft_variable_service import WorkflowDraftVariableService from services.workflow_service import WorkflowService logger = logging.getLogger(__name__) @@ -292,6 +293,8 @@ class AppDslService: dependencies=check_dependencies_pending_data, ) + draft_var_srv = WorkflowDraftVariableService(session=self._session) + draft_var_srv.delete_workflow_variables(app_id=app.id) return Import( id=import_id, status=status, diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 49ca98624a..e42b5ace75 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -278,176 +278,351 @@ class DatasetService: except ProviderTokenNotInitError as ex: raise ValueError(ex.description) + @staticmethod + def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str): + try: + model_manager = ModelManager() + model_manager.get_model_instance( + tenant_id=tenant_id, + provider=reranking_model_provider, + model_type=ModelType.RERANK, + model=reranking_model, + ) + except LLMBadRequestError: + raise ValueError( + "No Rerank Model available. Please configure a valid provider in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(ex.description) + @staticmethod def update_dataset(dataset_id, data, user): + """ + Update dataset configuration and settings. + + Args: + dataset_id: The unique identifier of the dataset to update + data: Dictionary containing the update data + user: The user performing the update operation + + Returns: + Dataset: The updated dataset object + + Raises: + ValueError: If dataset not found or validation fails + NoPermissionError: If user lacks permission to update the dataset + """ + # Retrieve and validate dataset existence dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise ValueError("Dataset not found") + # Verify user has permission to update this dataset DatasetService.check_dataset_permission(dataset, user) + + # Handle external dataset updates if dataset.provider == "external": - external_retrieval_model = data.get("external_retrieval_model", None) - if external_retrieval_model: - dataset.retrieval_model = external_retrieval_model - dataset.name = data.get("name", dataset.name) - dataset.description = data.get("description", "") - permission = data.get("permission") - if permission: - dataset.permission = permission - external_knowledge_id = data.get("external_knowledge_id", None) - db.session.add(dataset) - if not external_knowledge_id: - raise ValueError("External knowledge id is required.") - external_knowledge_api_id = data.get("external_knowledge_api_id", None) - if not external_knowledge_api_id: - raise ValueError("External knowledge api id is required.") - - with Session(db.engine) as session: - external_knowledge_binding = ( - session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first() - ) - - if not external_knowledge_binding: - raise ValueError("External knowledge binding not found.") - - if ( - external_knowledge_binding.external_knowledge_id != external_knowledge_id - or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id - ): - external_knowledge_binding.external_knowledge_id = external_knowledge_id - external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id - db.session.add(external_knowledge_binding) - db.session.commit() + return DatasetService._update_external_dataset(dataset, data, user) else: - data.pop("partial_member_list", None) - data.pop("external_knowledge_api_id", None) - data.pop("external_knowledge_id", None) - data.pop("external_retrieval_model", None) - filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"} - action = None - if dataset.indexing_technique != data["indexing_technique"]: - # if update indexing_technique - if data["indexing_technique"] == "economy": - action = "remove" - filtered_data["embedding_model"] = None - filtered_data["embedding_model_provider"] = None - filtered_data["collection_binding_id"] = None - elif data["indexing_technique"] == "high_quality": - action = "add" - # get embedding model setting - try: - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, - provider=data["embedding_model_provider"], - model_type=ModelType.TEXT_EMBEDDING, - model=data["embedding_model"], - ) - filtered_data["embedding_model"] = embedding_model.model - filtered_data["embedding_model_provider"] = embedding_model.provider - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, embedding_model.model - ) - filtered_data["collection_binding_id"] = dataset_collection_binding.id - except LLMBadRequestError: - raise ValueError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider." - ) - except ProviderTokenNotInitError as ex: - raise ValueError(ex.description) - else: - # add default plugin id to both setting sets, to make sure the plugin model provider is consistent - # Skip embedding model checks if not provided in the update request - if ( - "embedding_model_provider" not in data - or "embedding_model" not in data - or not data.get("embedding_model_provider") - or not data.get("embedding_model") - ): - # If the dataset already has embedding model settings, use those - if dataset.embedding_model_provider and dataset.embedding_model: - # Keep existing values - filtered_data["embedding_model_provider"] = dataset.embedding_model_provider - filtered_data["embedding_model"] = dataset.embedding_model - # If collection_binding_id exists, keep it too - if dataset.collection_binding_id: - filtered_data["collection_binding_id"] = dataset.collection_binding_id - # Otherwise, don't try to update embedding model settings at all - # Remove these fields from filtered_data if they exist but are None/empty - if "embedding_model_provider" in filtered_data and not filtered_data["embedding_model_provider"]: - del filtered_data["embedding_model_provider"] - if "embedding_model" in filtered_data and not filtered_data["embedding_model"]: - del filtered_data["embedding_model"] - else: - skip_embedding_update = False - try: - # Handle existing model provider - plugin_model_provider = dataset.embedding_model_provider - plugin_model_provider_str = None - if plugin_model_provider: - plugin_model_provider_str = str(ModelProviderID(plugin_model_provider)) + return DatasetService._update_internal_dataset(dataset, data, user) - # Handle new model provider from request - new_plugin_model_provider = data["embedding_model_provider"] - new_plugin_model_provider_str = None - if new_plugin_model_provider: - new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider)) + @staticmethod + def _update_external_dataset(dataset, data, user): + """ + Update external dataset configuration. - # Only update embedding model if both values are provided and different from current - if ( - plugin_model_provider_str != new_plugin_model_provider_str - or data["embedding_model"] != dataset.embedding_model - ): - action = "update" - model_manager = ModelManager() - try: - embedding_model = model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, - provider=data["embedding_model_provider"], - model_type=ModelType.TEXT_EMBEDDING, - model=data["embedding_model"], - ) - except ProviderTokenNotInitError: - # If we can't get the embedding model, skip updating it - # and keep the existing settings if available - if dataset.embedding_model_provider and dataset.embedding_model: - filtered_data["embedding_model_provider"] = dataset.embedding_model_provider - filtered_data["embedding_model"] = dataset.embedding_model - if dataset.collection_binding_id: - filtered_data["collection_binding_id"] = dataset.collection_binding_id - # Skip the rest of the embedding model update - skip_embedding_update = True - if not skip_embedding_update: - filtered_data["embedding_model"] = embedding_model.model - filtered_data["embedding_model_provider"] = embedding_model.provider - dataset_collection_binding = ( - DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, embedding_model.model - ) - ) - filtered_data["collection_binding_id"] = dataset_collection_binding.id - except LLMBadRequestError: - raise ValueError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider." - ) - except ProviderTokenNotInitError as ex: - raise ValueError(ex.description) + Args: + dataset: The dataset object to update + data: Update data dictionary + user: User performing the update - filtered_data["updated_by"] = user.id - filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + Returns: + Dataset: Updated dataset object + """ + # Update retrieval model if provided + external_retrieval_model = data.get("external_retrieval_model", None) + if external_retrieval_model: + dataset.retrieval_model = external_retrieval_model - # update Retrieval model - filtered_data["retrieval_model"] = data["retrieval_model"] + # Update basic dataset properties + dataset.name = data.get("name", dataset.name) + dataset.description = data.get("description", dataset.description) - db.session.query(Dataset).filter_by(id=dataset_id).update(filtered_data) + # Update permission if provided + permission = data.get("permission") + if permission: + dataset.permission = permission + + # Validate and update external knowledge configuration + external_knowledge_id = data.get("external_knowledge_id", None) + external_knowledge_api_id = data.get("external_knowledge_api_id", None) + + if not external_knowledge_id: + raise ValueError("External knowledge id is required.") + if not external_knowledge_api_id: + raise ValueError("External knowledge api id is required.") + # Update metadata fields + dataset.updated_by = user.id if user else None + dataset.updated_at = datetime.datetime.utcnow() + db.session.add(dataset) + + # Update external knowledge binding + DatasetService._update_external_knowledge_binding(dataset.id, external_knowledge_id, external_knowledge_api_id) + + # Commit changes to database + db.session.commit() - db.session.commit() - if action: - deal_dataset_vector_index_task.delay(dataset_id, action) return dataset + @staticmethod + def _update_external_knowledge_binding(dataset_id, external_knowledge_id, external_knowledge_api_id): + """ + Update external knowledge binding configuration. + + Args: + dataset_id: Dataset identifier + external_knowledge_id: External knowledge identifier + external_knowledge_api_id: External knowledge API identifier + """ + with Session(db.engine) as session: + external_knowledge_binding = ( + session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first() + ) + + if not external_knowledge_binding: + raise ValueError("External knowledge binding not found.") + + # Update binding if values have changed + if ( + external_knowledge_binding.external_knowledge_id != external_knowledge_id + or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id + ): + external_knowledge_binding.external_knowledge_id = external_knowledge_id + external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id + db.session.add(external_knowledge_binding) + + @staticmethod + def _update_internal_dataset(dataset, data, user): + """ + Update internal dataset configuration. + + Args: + dataset: The dataset object to update + data: Update data dictionary + user: User performing the update + + Returns: + Dataset: Updated dataset object + """ + # Remove external-specific fields from update data + data.pop("partial_member_list", None) + data.pop("external_knowledge_api_id", None) + data.pop("external_knowledge_id", None) + data.pop("external_retrieval_model", None) + + # Filter out None values except for description field + filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"} + + # Handle indexing technique changes and embedding model updates + action = DatasetService._handle_indexing_technique_change(dataset, data, filtered_data) + + # Add metadata fields + filtered_data["updated_by"] = user.id + filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + # update Retrieval model + filtered_data["retrieval_model"] = data["retrieval_model"] + + # Update dataset in database + db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data) + db.session.commit() + + # Trigger vector index task if indexing technique changed + if action: + deal_dataset_vector_index_task.delay(dataset.id, action) + + return dataset + + @staticmethod + def _handle_indexing_technique_change(dataset, data, filtered_data): + """ + Handle changes in indexing technique and configure embedding models accordingly. + + Args: + dataset: Current dataset object + data: Update data dictionary + filtered_data: Filtered update data + + Returns: + str: Action to perform ('add', 'remove', 'update', or None) + """ + if dataset.indexing_technique != data["indexing_technique"]: + if data["indexing_technique"] == "economy": + # Remove embedding model configuration for economy mode + filtered_data["embedding_model"] = None + filtered_data["embedding_model_provider"] = None + filtered_data["collection_binding_id"] = None + return "remove" + elif data["indexing_technique"] == "high_quality": + # Configure embedding model for high quality mode + DatasetService._configure_embedding_model_for_high_quality(data, filtered_data) + return "add" + else: + # Handle embedding model updates when indexing technique remains the same + return DatasetService._handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data) + return None + + @staticmethod + def _configure_embedding_model_for_high_quality(data, filtered_data): + """ + Configure embedding model settings for high quality indexing. + + Args: + data: Update data dictionary + filtered_data: Filtered update data to modify + """ + try: + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=data["embedding_model_provider"], + model_type=ModelType.TEXT_EMBEDDING, + model=data["embedding_model"], + ) + filtered_data["embedding_model"] = embedding_model.model + filtered_data["embedding_model_provider"] = embedding_model.provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.provider, embedding_model.model + ) + filtered_data["collection_binding_id"] = dataset_collection_binding.id + except LLMBadRequestError: + raise ValueError( + "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(ex.description) + + @staticmethod + def _handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data): + """ + Handle embedding model updates when indexing technique remains the same. + + Args: + dataset: Current dataset object + data: Update data dictionary + filtered_data: Filtered update data to modify + + Returns: + str: Action to perform ('update' or None) + """ + # Skip embedding model checks if not provided in the update request + if ( + "embedding_model_provider" not in data + or "embedding_model" not in data + or not data.get("embedding_model_provider") + or not data.get("embedding_model") + ): + DatasetService._preserve_existing_embedding_settings(dataset, filtered_data) + return None + else: + return DatasetService._update_embedding_model_settings(dataset, data, filtered_data) + + @staticmethod + def _preserve_existing_embedding_settings(dataset, filtered_data): + """ + Preserve existing embedding model settings when not provided in update. + + Args: + dataset: Current dataset object + filtered_data: Filtered update data to modify + """ + # If the dataset already has embedding model settings, use those + if dataset.embedding_model_provider and dataset.embedding_model: + filtered_data["embedding_model_provider"] = dataset.embedding_model_provider + filtered_data["embedding_model"] = dataset.embedding_model + # If collection_binding_id exists, keep it too + if dataset.collection_binding_id: + filtered_data["collection_binding_id"] = dataset.collection_binding_id + # Otherwise, don't try to update embedding model settings at all + # Remove these fields from filtered_data if they exist but are None/empty + if "embedding_model_provider" in filtered_data and not filtered_data["embedding_model_provider"]: + del filtered_data["embedding_model_provider"] + if "embedding_model" in filtered_data and not filtered_data["embedding_model"]: + del filtered_data["embedding_model"] + + @staticmethod + def _update_embedding_model_settings(dataset, data, filtered_data): + """ + Update embedding model settings with new values. + + Args: + dataset: Current dataset object + data: Update data dictionary + filtered_data: Filtered update data to modify + + Returns: + str: Action to perform ('update' or None) + """ + try: + # Compare current and new model provider settings + current_provider_str = ( + str(ModelProviderID(dataset.embedding_model_provider)) if dataset.embedding_model_provider else None + ) + new_provider_str = ( + str(ModelProviderID(data["embedding_model_provider"])) if data["embedding_model_provider"] else None + ) + + # Only update if values are different + if current_provider_str != new_provider_str or data["embedding_model"] != dataset.embedding_model: + DatasetService._apply_new_embedding_settings(dataset, data, filtered_data) + return "update" + except LLMBadRequestError: + raise ValueError( + "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(ex.description) + return None + + @staticmethod + def _apply_new_embedding_settings(dataset, data, filtered_data): + """ + Apply new embedding model settings to the dataset. + + Args: + dataset: Current dataset object + data: Update data dictionary + filtered_data: Filtered update data to modify + """ + model_manager = ModelManager() + try: + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=data["embedding_model_provider"], + model_type=ModelType.TEXT_EMBEDDING, + model=data["embedding_model"], + ) + except ProviderTokenNotInitError: + # If we can't get the embedding model, preserve existing settings + logging.warning( + f"Failed to initialize embedding model {data['embedding_model_provider']}/{data['embedding_model']}, " + f"preserving existing settings" + ) + if dataset.embedding_model_provider and dataset.embedding_model: + filtered_data["embedding_model_provider"] = dataset.embedding_model_provider + filtered_data["embedding_model"] = dataset.embedding_model + if dataset.collection_binding_id: + filtered_data["collection_binding_id"] = dataset.collection_binding_id + # Skip the rest of the embedding model update + return + + # Apply new embedding model settings + filtered_data["embedding_model"] = embedding_model.model + filtered_data["embedding_model_provider"] = embedding_model.provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.provider, embedding_model.model + ) + filtered_data["collection_binding_id"] = dataset_collection_binding.id + @staticmethod def delete_dataset(dataset_id, user): dataset = DatasetService.get_dataset(dataset_id) @@ -2049,6 +2224,7 @@ class SegmentService: # calc embedding use tokens if document.doc_form == "qa_model": + segment.answer = args.answer tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0] else: tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0] diff --git a/api/services/errors/app.py b/api/services/errors/app.py index 87e9e9247d..5d348c61be 100644 --- a/api/services/errors/app.py +++ b/api/services/errors/app.py @@ -4,3 +4,7 @@ class MoreLikeThisDisabledError(Exception): class WorkflowHashNotEqualError(Exception): pass + + +class IsDraftWorkflowError(Exception): + pass diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index 26d6d4ce18..cfcb121153 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -19,6 +19,10 @@ from services.entities.knowledge_entities.knowledge_entities import ( class MetadataService: @staticmethod def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata: + # check if metadata name is too long + if len(metadata_args.name) > 255: + raise ValueError("Metadata name cannot exceed 255 characters.") + # check if metadata name already exists if ( db.session.query(DatasetMetadata) @@ -42,6 +46,10 @@ class MetadataService: @staticmethod def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore + # check if metadata name is too long + if len(name) > 255: + raise ValueError("Metadata name cannot exceed 255 characters.") + lock_key = f"dataset_metadata_lock_{dataset_id}" # check if metadata name already exists if ( diff --git a/api/services/moderation_service.py b/api/services/moderation_service.py deleted file mode 100644 index 082afeed89..0000000000 --- a/api/services/moderation_service.py +++ /dev/null @@ -1,23 +0,0 @@ -from typing import Optional - -from core.moderation.factory import ModerationFactory, ModerationOutputsResult -from extensions.ext_database import db -from models.model import App, AppModelConfig - - -class ModerationService: - def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult: - app_model_config: Optional[AppModelConfig] = None - - app_model_config = ( - db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() - ) - - if not app_model_config: - raise ValueError("app model config not found") - - name = app_model_config.sensitive_word_avoidance_dict["type"] - config = app_model_config.sensitive_word_avoidance_dict["config"] - - moderation = ModerationFactory(name, app_id, app_model.tenant_id, config) - return moderation.moderation_for_outputs(text) diff --git a/api/services/plugin/oauth_service.py b/api/services/plugin/oauth_service.py index 461247419b..b84dd0afc5 100644 --- a/api/services/plugin/oauth_service.py +++ b/api/services/plugin/oauth_service.py @@ -1,7 +1,53 @@ +import json +import uuid + from core.plugin.impl.base import BasePluginClient +from extensions.ext_redis import redis_client -class OAuthService(BasePluginClient): - @classmethod - def get_authorization_url(cls, tenant_id: str, user_id: str, provider_name: str) -> str: - return "1234567890" +class OAuthProxyService(BasePluginClient): + # Default max age for proxy context parameter in seconds + __MAX_AGE__ = 5 * 60 # 5 minutes + __KEY_PREFIX__ = "oauth_proxy_context:" + + @staticmethod + def create_proxy_context(user_id: str, tenant_id: str, plugin_id: str, provider: str): + """ + Create a proxy context for an OAuth 2.0 authorization request. + + This parameter is a crucial security measure to prevent Cross-Site Request + Forgery (CSRF) attacks. It works by generating a unique nonce and storing it + in a distributed cache (Redis) along with the user's session context. + + The returned nonce should be included as the 'proxy_context' parameter in the + authorization URL. Upon callback, the `use_proxy_context` method + is used to verify the state, ensuring the request's integrity and authenticity, + and mitigating replay attacks. + """ + context_id = str(uuid.uuid4()) + data = { + "user_id": user_id, + "plugin_id": plugin_id, + "tenant_id": tenant_id, + "provider": provider, + } + redis_client.setex( + f"{OAuthProxyService.__KEY_PREFIX__}{context_id}", + OAuthProxyService.__MAX_AGE__, + json.dumps(data), + ) + return context_id + + @staticmethod + def use_proxy_context(context_id: str): + """ + Validate the proxy context parameter. + This checks if the context_id is valid and not expired. + """ + if not context_id: + raise ValueError("context_id is required") + # get data from redis + data = redis_client.getdel(f"{OAuthProxyService.__KEY_PREFIX__}{context_id}") + if not data: + raise ValueError("context_id is invalid") + return json.loads(data) diff --git a/api/services/plugin/plugin_parameter_service.py b/api/services/plugin/plugin_parameter_service.py new file mode 100644 index 0000000000..393213c0e2 --- /dev/null +++ b/api/services/plugin/plugin_parameter_service.py @@ -0,0 +1,74 @@ +from collections.abc import Mapping, Sequence +from typing import Any, Literal + +from sqlalchemy.orm import Session + +from core.plugin.entities.parameters import PluginParameterOption +from core.plugin.impl.dynamic_select import DynamicSelectClient +from core.tools.tool_manager import ToolManager +from core.tools.utils.configuration import ProviderConfigEncrypter +from extensions.ext_database import db +from models.tools import BuiltinToolProvider + + +class PluginParameterService: + @staticmethod + def get_dynamic_select_options( + tenant_id: str, + user_id: str, + plugin_id: str, + provider: str, + action: str, + parameter: str, + provider_type: Literal["tool"], + ) -> Sequence[PluginParameterOption]: + """ + Get dynamic select options for a plugin parameter. + + Args: + tenant_id: The tenant ID. + plugin_id: The plugin ID. + provider: The provider name. + action: The action name. + parameter: The parameter name. + """ + credentials: Mapping[str, Any] = {} + + match provider_type: + case "tool": + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + # init tool configuration + tool_configuration = ProviderConfigEncrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], + provider_type=provider_controller.provider_type.value, + provider_identity=provider_controller.entity.identity.name, + ) + + # check if credentials are required + if not provider_controller.need_credentials: + credentials = {} + else: + # fetch credentials from db + with Session(db.engine) as session: + db_record = ( + session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + ) + .first() + ) + + if db_record is None: + raise ValueError(f"Builtin provider {provider} not found when fetching credentials") + + credentials = tool_configuration.decrypt(db_record.credentials) + case _: + raise ValueError(f"Invalid provider type: {provider_type}") + + return ( + DynamicSelectClient() + .fetch_dynamic_select_options(tenant_id, user_id, plugin_id, provider, action, credentials, parameter) + .options + ) diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 19e37f4ee3..9165139193 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -97,16 +97,16 @@ class VectorService: vector = Vector(dataset=dataset) vector.delete_by_ids([segment.index_node_id]) vector.add_texts([document], duplicate_check=True) - - # update keyword index - keyword = Keyword(dataset) - keyword.delete_by_ids([segment.index_node_id]) - - # save keyword index - if keywords and len(keywords) > 0: - keyword.add_texts([document], keywords_list=[keywords]) else: - keyword.add_texts([document]) + # update keyword index + keyword = Keyword(dataset) + keyword.delete_by_ids([segment.index_node_id]) + + # save keyword index + if keywords and len(keywords) > 0: + keyword.add_texts([document], keywords_list=[keywords]) + else: + keyword.add_texts([document]) @classmethod def generate_child_chunks( diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py new file mode 100644 index 0000000000..164693c2e1 --- /dev/null +++ b/api/services/workflow_draft_variable_service.py @@ -0,0 +1,722 @@ +import dataclasses +import datetime +import logging +from collections.abc import Mapping, Sequence +from enum import StrEnum +from typing import Any, ClassVar + +from sqlalchemy import Engine, orm, select +from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.orm import Session +from sqlalchemy.sql.expression import and_, or_ + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.file.models import File +from core.variables import Segment, StringSegment, Variable +from core.variables.consts import MIN_SELECTORS_LENGTH +from core.variables.segments import ArrayFileSegment, FileSegment +from core.variables.types import SegmentType +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.workflow.enums import SystemVariableKey +from core.workflow.nodes import NodeType +from core.workflow.nodes.variable_assigner.common.helpers import get_updated_variables +from core.workflow.variable_loader import VariableLoader +from factories.file_factory import StorageKeyLoader +from factories.variable_factory import build_segment, segment_to_variable +from models import App, Conversation +from models.enums import DraftVariableType +from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable + +_logger = logging.getLogger(__name__) + + +@dataclasses.dataclass(frozen=True) +class WorkflowDraftVariableList: + variables: list[WorkflowDraftVariable] + total: int | None = None + + +class WorkflowDraftVariableError(Exception): + pass + + +class VariableResetError(WorkflowDraftVariableError): + pass + + +class UpdateNotSupportedError(WorkflowDraftVariableError): + pass + + +class DraftVarLoader(VariableLoader): + # This implements the VariableLoader interface for loading draft variables. + # + # ref: core.workflow.variable_loader.VariableLoader + + # Database engine used for loading variables. + _engine: Engine + # Application ID for which variables are being loaded. + _app_id: str + _tenant_id: str + _fallback_variables: Sequence[Variable] + + def __init__( + self, + engine: Engine, + app_id: str, + tenant_id: str, + fallback_variables: Sequence[Variable] | None = None, + ) -> None: + self._engine = engine + self._app_id = app_id + self._tenant_id = tenant_id + self._fallback_variables = fallback_variables or [] + + def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]: + return (selector[0], selector[1]) + + def load_variables(self, selectors: list[list[str]]) -> list[Variable]: + if not selectors: + return [] + + # Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding Variable instance. + variable_by_selector: dict[tuple[str, str], Variable] = {} + + with Session(bind=self._engine, expire_on_commit=False) as session: + srv = WorkflowDraftVariableService(session) + draft_vars = srv.get_draft_variables_by_selectors(self._app_id, selectors) + + for draft_var in draft_vars: + segment = draft_var.get_value() + variable = segment_to_variable( + segment=segment, + selector=draft_var.get_selector(), + id=draft_var.id, + name=draft_var.name, + description=draft_var.description, + ) + selector_tuple = self._selector_to_tuple(variable.selector) + variable_by_selector[selector_tuple] = variable + + # Important: + files: list[File] = [] + for draft_var in draft_vars: + value = draft_var.get_value() + if isinstance(value, FileSegment): + files.append(value.value) + elif isinstance(value, ArrayFileSegment): + files.extend(value.value) + with Session(bind=self._engine) as session: + storage_key_loader = StorageKeyLoader(session, tenant_id=self._tenant_id) + storage_key_loader.load_storage_keys(files) + + return list(variable_by_selector.values()) + + +class WorkflowDraftVariableService: + _session: Session + + def __init__(self, session: Session) -> None: + self._session = session + + def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None: + return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first() + + def get_draft_variables_by_selectors( + self, + app_id: str, + selectors: Sequence[list[str]], + ) -> list[WorkflowDraftVariable]: + ors = [] + for selector in selectors: + assert len(selector) >= MIN_SELECTORS_LENGTH, f"Invalid selector to get: {selector}" + node_id, name = selector[:2] + ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name)) + + # NOTE(QuantumGhost): Although the number of `or` expressions may be large, as long as + # each expression includes conditions on both `node_id` and `name` (which are covered by the unique index), + # PostgreSQL can efficiently retrieve the results using a bitmap index scan. + # + # Alternatively, a `SELECT` statement could be constructed for each selector and + # combined using `UNION` to fetch all rows. + # Benchmarking indicates that both approaches yield comparable performance. + variables = ( + self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app_id, or_(*ors)).all() + ) + return variables + + def list_variables_without_values(self, app_id: str, page: int, limit: int) -> WorkflowDraftVariableList: + criteria = WorkflowDraftVariable.app_id == app_id + total = None + query = self._session.query(WorkflowDraftVariable).filter(criteria) + if page == 1: + total = query.count() + variables = ( + # Do not load the `value` field. + query.options(orm.defer(WorkflowDraftVariable.value)) + .order_by(WorkflowDraftVariable.id.desc()) + .limit(limit) + .offset((page - 1) * limit) + .all() + ) + + return WorkflowDraftVariableList(variables=variables, total=total) + + def _list_node_variables(self, app_id: str, node_id: str) -> WorkflowDraftVariableList: + criteria = ( + WorkflowDraftVariable.app_id == app_id, + WorkflowDraftVariable.node_id == node_id, + ) + query = self._session.query(WorkflowDraftVariable).filter(*criteria) + variables = query.order_by(WorkflowDraftVariable.id.desc()).all() + return WorkflowDraftVariableList(variables=variables) + + def list_node_variables(self, app_id: str, node_id: str) -> WorkflowDraftVariableList: + return self._list_node_variables(app_id, node_id) + + def list_conversation_variables(self, app_id: str) -> WorkflowDraftVariableList: + return self._list_node_variables(app_id, CONVERSATION_VARIABLE_NODE_ID) + + def list_system_variables(self, app_id: str) -> WorkflowDraftVariableList: + return self._list_node_variables(app_id, SYSTEM_VARIABLE_NODE_ID) + + def get_conversation_variable(self, app_id: str, name: str) -> WorkflowDraftVariable | None: + return self._get_variable(app_id=app_id, node_id=CONVERSATION_VARIABLE_NODE_ID, name=name) + + def get_system_variable(self, app_id: str, name: str) -> WorkflowDraftVariable | None: + return self._get_variable(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name) + + def get_node_variable(self, app_id: str, node_id: str, name: str) -> WorkflowDraftVariable | None: + return self._get_variable(app_id, node_id, name) + + def _get_variable(self, app_id: str, node_id: str, name: str) -> WorkflowDraftVariable | None: + variable = ( + self._session.query(WorkflowDraftVariable) + .where( + WorkflowDraftVariable.app_id == app_id, + WorkflowDraftVariable.node_id == node_id, + WorkflowDraftVariable.name == name, + ) + .first() + ) + return variable + + def update_variable( + self, + variable: WorkflowDraftVariable, + name: str | None = None, + value: Segment | None = None, + ) -> WorkflowDraftVariable: + if not variable.editable: + raise UpdateNotSupportedError(f"variable not support updating, id={variable.id}") + if name is not None: + variable.set_name(name) + if value is not None: + variable.set_value(value) + variable.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + self._session.flush() + return variable + + def _reset_conv_var(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None: + conv_var_by_name = {i.name: i for i in workflow.conversation_variables} + conv_var = conv_var_by_name.get(variable.name) + + if conv_var is None: + self._session.delete(instance=variable) + self._session.flush() + _logger.warning( + "Conversation variable not found for draft variable, id=%s, name=%s", variable.id, variable.name + ) + return None + + variable.set_value(conv_var) + variable.last_edited_at = None + self._session.add(variable) + self._session.flush() + return variable + + def _reset_node_var(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None: + # If a variable does not allow updating, it makes no sence to resetting it. + if not variable.editable: + return variable + # No execution record for this variable, delete the variable instead. + if variable.node_execution_id is None: + self._session.delete(instance=variable) + self._session.flush() + _logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name) + return None + + query = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == variable.node_execution_id) + node_exec = self._session.scalars(query).first() + if node_exec is None: + _logger.warning( + "Node exectution not found for draft variable, id=%s, name=%s, node_execution_id=%s", + variable.id, + variable.name, + variable.node_execution_id, + ) + self._session.delete(instance=variable) + self._session.flush() + return None + + # Get node type for proper value extraction + node_config = workflow.get_node_config_by_id(variable.node_id) + node_type = workflow.get_node_type_from_node_config(node_config) + + outputs_dict = node_exec.outputs_dict or {} + + # Note: Based on the implementation in `_build_from_variable_assigner_mapping`, + # VariableAssignerNode (both v1 and v2) can only create conversation draft variables. + # For consistency, we should simply return when processing VARIABLE_ASSIGNER nodes. + # + # This implementation must remain synchronized with the `_build_from_variable_assigner_mapping` + # and `save` methods. + if node_type == NodeType.VARIABLE_ASSIGNER: + return variable + + if variable.name not in outputs_dict: + # If variable not found in execution data, delete the variable + self._session.delete(instance=variable) + self._session.flush() + return None + value = outputs_dict[variable.name] + value_seg = WorkflowDraftVariable.build_segment_with_type(variable.value_type, value) + # Extract variable value using unified logic + variable.set_value(value_seg) + variable.last_edited_at = None # Reset to indicate this is a reset operation + self._session.flush() + return variable + + def reset_variable(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None: + variable_type = variable.get_variable_type() + if variable_type == DraftVariableType.CONVERSATION: + return self._reset_conv_var(workflow, variable) + elif variable_type == DraftVariableType.NODE: + return self._reset_node_var(workflow, variable) + else: + raise VariableResetError(f"cannot reset system variable, variable_id={variable.id}") + + def delete_variable(self, variable: WorkflowDraftVariable): + self._session.delete(variable) + + def delete_workflow_variables(self, app_id: str): + ( + self._session.query(WorkflowDraftVariable) + .filter(WorkflowDraftVariable.app_id == app_id) + .delete(synchronize_session=False) + ) + + def delete_node_variables(self, app_id: str, node_id: str): + return self._delete_node_variables(app_id, node_id) + + def _delete_node_variables(self, app_id: str, node_id: str): + self._session.query(WorkflowDraftVariable).where( + WorkflowDraftVariable.app_id == app_id, + WorkflowDraftVariable.node_id == node_id, + ).delete() + + def _get_conversation_id_from_draft_variable(self, app_id: str) -> str | None: + draft_var = self._get_variable( + app_id=app_id, + node_id=SYSTEM_VARIABLE_NODE_ID, + name=str(SystemVariableKey.CONVERSATION_ID), + ) + if draft_var is None: + return None + segment = draft_var.get_value() + if not isinstance(segment, StringSegment): + _logger.warning( + "sys.conversation_id variable is not a string: app_id=%s, id=%s", + app_id, + draft_var.id, + ) + return None + return segment.value + + def get_or_create_conversation( + self, + account_id: str, + app: App, + workflow: Workflow, + ) -> str: + """ + get_or_create_conversation creates and returns the ID of a conversation for debugging. + + If a conversation already exists, as determined by the following criteria, its ID is returned: + - The system variable `sys.conversation_id` exists in the draft variable table, and + - A corresponding conversation record is found in the database. + + If no such conversation exists, a new conversation is created and its ID is returned. + """ + conv_id = self._get_conversation_id_from_draft_variable(workflow.app_id) + + if conv_id is not None: + conversation = ( + self._session.query(Conversation) + .filter( + Conversation.id == conv_id, + Conversation.app_id == workflow.app_id, + ) + .first() + ) + # Only return the conversation ID if it exists and is valid (has a correspond conversation record in DB). + if conversation is not None: + return conv_id + conversation = Conversation( + app_id=workflow.app_id, + app_model_config_id=app.app_model_config_id, + model_provider=None, + model_id="", + override_model_configs=None, + mode=app.mode, + name="Draft Debugging Conversation", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + invoke_from=InvokeFrom.DEBUGGER.value, + from_source="console", + from_end_user_id=None, + from_account_id=account_id, + ) + + self._session.add(conversation) + self._session.flush() + return conversation.id + + def prefill_conversation_variable_default_values(self, workflow: Workflow): + """""" + draft_conv_vars: list[WorkflowDraftVariable] = [] + for conv_var in workflow.conversation_variables: + draft_var = WorkflowDraftVariable.new_conversation_variable( + app_id=workflow.app_id, + name=conv_var.name, + value=conv_var, + description=conv_var.description, + ) + draft_conv_vars.append(draft_var) + _batch_upsert_draft_varaible( + self._session, + draft_conv_vars, + policy=_UpsertPolicy.IGNORE, + ) + + +class _UpsertPolicy(StrEnum): + IGNORE = "ignore" + OVERWRITE = "overwrite" + + +def _batch_upsert_draft_varaible( + session: Session, + draft_vars: Sequence[WorkflowDraftVariable], + policy: _UpsertPolicy = _UpsertPolicy.OVERWRITE, +) -> None: + if not draft_vars: + return None + # Although we could use SQLAlchemy ORM operations here, we choose not to for several reasons: + # + # 1. The variable saving process involves writing multiple rows to the + # `workflow_draft_variables` table. Batch insertion significantly improves performance. + # 2. Using the ORM would require either: + # + # a. Checking for the existence of each variable before insertion, + # resulting in 2n SQL statements for n variables and potential concurrency issues. + # b. Attempting insertion first, then updating if a unique index violation occurs, + # which still results in n to 2n SQL statements. + # + # Both approaches are inefficient and suboptimal. + # 3. We do not need to retrieve the results of the SQL execution or populate ORM + # model instances with the returned values. + # 4. Batch insertion with `ON CONFLICT DO UPDATE` allows us to insert or update all + # variables in a single SQL statement, avoiding the issues above. + # + # For these reasons, we use the SQLAlchemy query builder and rely on dialect-specific + # insert operations instead of the ORM layer. + stmt = insert(WorkflowDraftVariable).values([_model_to_insertion_dict(v) for v in draft_vars]) + if policy == _UpsertPolicy.OVERWRITE: + stmt = stmt.on_conflict_do_update( + index_elements=WorkflowDraftVariable.unique_app_id_node_id_name(), + set_={ + "updated_at": stmt.excluded.updated_at, + "last_edited_at": stmt.excluded.last_edited_at, + "description": stmt.excluded.description, + "value_type": stmt.excluded.value_type, + "value": stmt.excluded.value, + "visible": stmt.excluded.visible, + "editable": stmt.excluded.editable, + "node_execution_id": stmt.excluded.node_execution_id, + }, + ) + elif _UpsertPolicy.IGNORE: + stmt = stmt.on_conflict_do_nothing(index_elements=WorkflowDraftVariable.unique_app_id_node_id_name()) + else: + raise Exception("Invalid value for update policy.") + session.execute(stmt) + + +def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]: + d: dict[str, Any] = { + "app_id": model.app_id, + "last_edited_at": None, + "node_id": model.node_id, + "name": model.name, + "selector": model.selector, + "value_type": model.value_type, + "value": model.value, + "node_execution_id": model.node_execution_id, + } + if model.visible is not None: + d["visible"] = model.visible + if model.editable is not None: + d["editable"] = model.editable + if model.created_at is not None: + d["created_at"] = model.created_at + if model.updated_at is not None: + d["updated_at"] = model.updated_at + if model.description is not None: + d["description"] = model.description + return d + + +def _build_segment_for_serialized_values(v: Any) -> Segment: + """ + Reconstructs Segment objects from serialized values, with special handling + for FileSegment and ArrayFileSegment types. + + This function should only be used when: + 1. No explicit type information is available + 2. The input value is in serialized form (dict or list) + + It detects potential file objects in the serialized data and properly rebuilds the + appropriate segment type. + """ + return build_segment(WorkflowDraftVariable.rebuild_file_types(v)) + + +class DraftVariableSaver: + # _DUMMY_OUTPUT_IDENTITY is a placeholder output for workflow nodes. + # Its sole possible value is `None`. + # + # This is used to signal the execution of a workflow node when it has no other outputs. + _DUMMY_OUTPUT_IDENTITY: ClassVar[str] = "__dummy__" + _DUMMY_OUTPUT_VALUE: ClassVar[None] = None + + # _EXCLUDE_VARIABLE_NAMES_MAPPING maps node types and versions to variable names that + # should be excluded when saving draft variables. This prevents certain internal or + # technical variables from being exposed in the draft environment, particularly those + # that aren't meant to be directly edited or viewed by users. + _EXCLUDE_VARIABLE_NAMES_MAPPING: dict[NodeType, frozenset[str]] = { + NodeType.LLM: frozenset(["finish_reason"]), + NodeType.LOOP: frozenset(["loop_round"]), + } + + # Database session used for persisting draft variables. + _session: Session + + # The application ID associated with the draft variables. + # This should match the `Workflow.app_id` of the workflow to which the current node belongs. + _app_id: str + + # The ID of the node for which DraftVariableSaver is saving output variables. + _node_id: str + + # The type of the current node (see NodeType). + _node_type: NodeType + + # Indicates how the workflow execution was triggered (see InvokeFrom). + _invoke_from: InvokeFrom + + # + _node_execution_id: str + + # _enclosing_node_id identifies the container node that the current node belongs to. + # For example, if the current node is an LLM node inside an Iteration node + # or Loop node, then `_enclosing_node_id` refers to the ID of + # the containing Iteration or Loop node. + # + # If the current node is not nested within another node, `_enclosing_node_id` is + # `None`. + _enclosing_node_id: str | None + + def __init__( + self, + session: Session, + app_id: str, + node_id: str, + node_type: NodeType, + invoke_from: InvokeFrom, + node_execution_id: str, + enclosing_node_id: str | None = None, + ): + self._session = session + self._app_id = app_id + self._node_id = node_id + self._node_type = node_type + self._invoke_from = invoke_from + self._node_execution_id = node_execution_id + self._enclosing_node_id = enclosing_node_id + + def _create_dummy_output_variable(self): + return WorkflowDraftVariable.new_node_variable( + app_id=self._app_id, + node_id=self._node_id, + name=self._DUMMY_OUTPUT_IDENTITY, + node_execution_id=self._node_execution_id, + value=build_segment(self._DUMMY_OUTPUT_VALUE), + visible=False, + editable=False, + ) + + def _should_save_output_variables_for_draft(self) -> bool: + # Only save output variables for debugging execution of workflow. + if self._invoke_from != InvokeFrom.DEBUGGER: + return False + if self._enclosing_node_id is not None and self._node_type != NodeType.VARIABLE_ASSIGNER: + # Currently we do not save output variables for nodes inside loop or iteration. + return False + return True + + def _build_from_variable_assigner_mapping(self, process_data: Mapping[str, Any]) -> list[WorkflowDraftVariable]: + draft_vars: list[WorkflowDraftVariable] = [] + updated_variables = get_updated_variables(process_data) or [] + + for item in updated_variables: + selector = item.selector + if len(selector) < MIN_SELECTORS_LENGTH: + raise Exception("selector too short") + # NOTE(QuantumGhost): only the following two kinds of variable could be updated by + # VariableAssigner: ConversationVariable and iteration variable. + # We only save conversation variable here. + if selector[0] != CONVERSATION_VARIABLE_NODE_ID: + continue + segment = WorkflowDraftVariable.build_segment_with_type(segment_type=item.value_type, value=item.new_value) + draft_vars.append( + WorkflowDraftVariable.new_conversation_variable( + app_id=self._app_id, + name=item.name, + value=segment, + ) + ) + # Add a dummy output variable to indicate that this node is executed. + draft_vars.append(self._create_dummy_output_variable()) + return draft_vars + + def _build_variables_from_start_mapping(self, output: Mapping[str, Any]) -> list[WorkflowDraftVariable]: + draft_vars = [] + has_non_sys_variables = False + for name, value in output.items(): + value_seg = _build_segment_for_serialized_values(value) + node_id, name = self._normalize_variable_for_start_node(name) + # If node_id is not `sys`, it means that the variable is a user-defined input field + # in `Start` node. + if node_id != SYSTEM_VARIABLE_NODE_ID: + draft_vars.append( + WorkflowDraftVariable.new_node_variable( + app_id=self._app_id, + node_id=self._node_id, + name=name, + node_execution_id=self._node_execution_id, + value=value_seg, + visible=True, + editable=True, + ) + ) + has_non_sys_variables = True + else: + if name == SystemVariableKey.FILES: + # Here we know the type of variable must be `array[file]`, we + # just build files from the value. + files = [File.model_validate(v) for v in value] + if files: + value_seg = WorkflowDraftVariable.build_segment_with_type(SegmentType.ARRAY_FILE, files) + else: + value_seg = ArrayFileSegment(value=[]) + + draft_vars.append( + WorkflowDraftVariable.new_sys_variable( + app_id=self._app_id, + name=name, + node_execution_id=self._node_execution_id, + value=value_seg, + editable=self._should_variable_be_editable(node_id, name), + ) + ) + if not has_non_sys_variables: + draft_vars.append(self._create_dummy_output_variable()) + return draft_vars + + def _normalize_variable_for_start_node(self, name: str) -> tuple[str, str]: + if not name.startswith(f"{SYSTEM_VARIABLE_NODE_ID}."): + return self._node_id, name + _, name_ = name.split(".", maxsplit=1) + return SYSTEM_VARIABLE_NODE_ID, name_ + + def _build_variables_from_mapping(self, output: Mapping[str, Any]) -> list[WorkflowDraftVariable]: + draft_vars = [] + for name, value in output.items(): + if not self._should_variable_be_saved(name): + _logger.debug( + "Skip saving variable as it has been excluded by its node_type, name=%s, node_type=%s", + name, + self._node_type, + ) + continue + if isinstance(value, Segment): + value_seg = value + else: + value_seg = _build_segment_for_serialized_values(value) + draft_vars.append( + WorkflowDraftVariable.new_node_variable( + app_id=self._app_id, + node_id=self._node_id, + name=name, + node_execution_id=self._node_execution_id, + value=value_seg, + visible=self._should_variable_be_visible(self._node_id, self._node_type, name), + ) + ) + return draft_vars + + def save( + self, + process_data: Mapping[str, Any] | None = None, + outputs: Mapping[str, Any] | None = None, + ): + draft_vars: list[WorkflowDraftVariable] = [] + if outputs is None: + outputs = {} + if process_data is None: + process_data = {} + if not self._should_save_output_variables_for_draft(): + return + if self._node_type == NodeType.VARIABLE_ASSIGNER: + draft_vars = self._build_from_variable_assigner_mapping(process_data=process_data) + elif self._node_type == NodeType.START: + draft_vars = self._build_variables_from_start_mapping(outputs) + else: + draft_vars = self._build_variables_from_mapping(outputs) + _batch_upsert_draft_varaible(self._session, draft_vars) + + @staticmethod + def _should_variable_be_editable(node_id: str, name: str) -> bool: + if node_id in (CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID): + return False + if node_id == SYSTEM_VARIABLE_NODE_ID and not is_system_variable_editable(name): + return False + return True + + @staticmethod + def _should_variable_be_visible(node_id: str, node_type: NodeType, name: str) -> bool: + if node_type in NodeType.IF_ELSE: + return False + if node_id == SYSTEM_VARIABLE_NODE_ID and not is_system_variable_editable(name): + return False + return True + + def _should_variable_be_saved(self, name: str) -> bool: + exclude_var_names = self._EXCLUDE_VARIABLE_NAMES_MAPPING.get(self._node_type) + if exclude_var_names is None: + return True + return name not in exclude_var_names diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index bc213ccce6..0fd94ac86e 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1,6 +1,7 @@ import json import time -from collections.abc import Callable, Generator, Sequence +import uuid +from collections.abc import Callable, Generator, Mapping, Sequence from datetime import UTC, datetime from typing import Any, Optional from uuid import uuid4 @@ -8,12 +9,17 @@ from uuid import uuid4 from sqlalchemy import select from sqlalchemy.orm import Session +from core.app.app_config.entities import VariableEntityType from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.app.entities.app_invoke_entities import InvokeFrom +from core.file import File from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.variables import Variable from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus +from core.workflow.enums import SystemVariableKey from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes import NodeType @@ -22,9 +28,11 @@ from core.workflow.nodes.enums import ErrorStrategy from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.event.types import NodeEvent from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING +from core.workflow.nodes.start.entities import StartNodeData from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db +from factories.file_factory import build_from_mapping, build_from_mappings from models.account import Account from models.model import App, AppMode from models.tools import WorkflowToolProvider @@ -34,10 +42,15 @@ from models.workflow import ( WorkflowNodeExecutionTriggeredFrom, WorkflowType, ) -from services.errors.app import WorkflowHashNotEqualError +from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError from services.workflow.workflow_converter import WorkflowConverter from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError +from .workflow_draft_variable_service import ( + DraftVariableSaver, + DraftVarLoader, + WorkflowDraftVariableService, +) class WorkflowService: @@ -45,6 +58,33 @@ class WorkflowService: Workflow Service """ + def get_node_last_run(self, app_model: App, workflow: Workflow, node_id: str) -> WorkflowNodeExecutionModel | None: + # TODO(QuantumGhost): This query is not fully covered by index. + criteria = ( + WorkflowNodeExecutionModel.tenant_id == app_model.tenant_id, + WorkflowNodeExecutionModel.app_id == app_model.id, + WorkflowNodeExecutionModel.workflow_id == workflow.id, + WorkflowNodeExecutionModel.node_id == node_id, + ) + node_exec = ( + db.session.query(WorkflowNodeExecutionModel) + .filter(*criteria) + .order_by(WorkflowNodeExecutionModel.created_at.desc()) + .first() + ) + return node_exec + + def is_workflow_exist(self, app_model: App) -> bool: + return ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.version == Workflow.VERSION_DRAFT, + ) + .count() + ) > 0 + def get_draft_workflow(self, app_model: App) -> Optional[Workflow]: """ Get draft workflow @@ -61,6 +101,23 @@ class WorkflowService: # return draft workflow return workflow + def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Optional[Workflow]: + # fetch published workflow by workflow_id + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.id == workflow_id, + ) + .first() + ) + if not workflow: + return None + if workflow.version == Workflow.VERSION_DRAFT: + raise IsDraftWorkflowError(f"Workflow is draft version, id={workflow_id}") + return workflow + def get_published_workflow(self, app_model: App) -> Optional[Workflow]: """ Get published workflow @@ -199,7 +256,7 @@ class WorkflowService: tenant_id=app_model.tenant_id, app_id=app_model.id, type=draft_workflow.type, - version=str(datetime.now(UTC).replace(tzinfo=None)), + version=Workflow.version_from_datetime(datetime.now(UTC).replace(tzinfo=None)), graph=draft_workflow.graph, features=draft_workflow.features, created_by=account.id, @@ -253,26 +310,85 @@ class WorkflowService: return default_config def run_draft_workflow_node( - self, app_model: App, node_id: str, user_inputs: dict, account: Account + self, + app_model: App, + draft_workflow: Workflow, + node_id: str, + user_inputs: Mapping[str, Any], + account: Account, + query: str = "", + files: Sequence[File] | None = None, ) -> WorkflowNodeExecutionModel: """ Run draft workflow node """ - # fetch draft workflow by app_model - draft_workflow = self.get_draft_workflow(app_model=app_model) - if not draft_workflow: - raise ValueError("Workflow not initialized") + files = files or [] + + with Session(bind=db.engine, expire_on_commit=False) as session, session.begin(): + draft_var_srv = WorkflowDraftVariableService(session) + draft_var_srv.prefill_conversation_variable_default_values(draft_workflow) + + node_config = draft_workflow.get_node_config_by_id(node_id) + node_type = Workflow.get_node_type_from_node_config(node_config) + node_data = node_config.get("data", {}) + if node_type == NodeType.START: + with Session(bind=db.engine) as session, session.begin(): + draft_var_srv = WorkflowDraftVariableService(session) + conversation_id = draft_var_srv.get_or_create_conversation( + account_id=account.id, + app=app_model, + workflow=draft_workflow, + ) + start_data = StartNodeData.model_validate(node_data) + user_inputs = _rebuild_file_for_user_inputs_in_start_node( + tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs + ) + # init variable pool + variable_pool = _setup_variable_pool( + query=query, + files=files or [], + user_id=account.id, + user_inputs=user_inputs, + workflow=draft_workflow, + # NOTE(QuantumGhost): We rely on `DraftVarLoader` to load conversation variables. + conversation_variables=[], + node_type=node_type, + conversation_id=conversation_id, + ) + + else: + variable_pool = VariablePool( + system_variables={}, + user_inputs=user_inputs, + environment_variables=draft_workflow.environment_variables, + conversation_variables=[], + ) + + variable_loader = DraftVarLoader( + engine=db.engine, + app_id=app_model.id, + tenant_id=app_model.tenant_id, + ) + + eclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config) + if eclosing_node_type_and_id: + _, enclosing_node_id = eclosing_node_type_and_id + else: + enclosing_node_id = None + + run = WorkflowEntry.single_step_run( + workflow=draft_workflow, + node_id=node_id, + user_inputs=user_inputs, + user_id=account.id, + variable_pool=variable_pool, + variable_loader=variable_loader, + ) # run draft workflow node start_at = time.perf_counter() - node_execution = self._handle_node_run_result( - invoke_node_fn=lambda: WorkflowEntry.single_step_run( - workflow=draft_workflow, - node_id=node_id, - user_inputs=user_inputs, - user_id=account.id, - ), + invoke_node_fn=lambda: run, start_at=start_at, node_id=node_id, ) @@ -292,6 +408,18 @@ class WorkflowService: # Convert node_execution to WorkflowNodeExecution after save workflow_node_execution = repository.to_db_model(node_execution) + with Session(bind=db.engine) as session, session.begin(): + draft_var_saver = DraftVariableSaver( + session=session, + app_id=app_model.id, + node_id=workflow_node_execution.node_id, + node_type=NodeType(workflow_node_execution.node_type), + invoke_from=InvokeFrom.DEBUGGER, + enclosing_node_id=enclosing_node_id, + node_execution_id=node_execution.id, + ) + draft_var_saver.save(process_data=node_execution.process_data, outputs=node_execution.outputs) + session.commit() return workflow_node_execution def run_free_workflow_node( @@ -332,7 +460,7 @@ class WorkflowService: node_run_result = event.run_result # sign output files - node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) + # node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) break if not node_run_result: @@ -394,7 +522,7 @@ class WorkflowService: if node_run_result.process_data else None ) - outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None + outputs = node_run_result.outputs node_execution.inputs = inputs node_execution.process_data = process_data @@ -531,3 +659,83 @@ class WorkflowService: session.delete(workflow) return True + + +def _setup_variable_pool( + query: str, + files: Sequence[File], + user_id: str, + user_inputs: Mapping[str, Any], + workflow: Workflow, + node_type: NodeType, + conversation_id: str, + conversation_variables: list[Variable], +): + # Only inject system variables for START node type. + if node_type == NodeType.START: + # Create a variable pool. + system_inputs: dict[SystemVariableKey, Any] = { + # From inputs: + SystemVariableKey.FILES: files, + SystemVariableKey.USER_ID: user_id, + # From workflow model + SystemVariableKey.APP_ID: workflow.app_id, + SystemVariableKey.WORKFLOW_ID: workflow.id, + # Randomly generated. + SystemVariableKey.WORKFLOW_EXECUTION_ID: str(uuid.uuid4()), + } + + # Only add chatflow-specific variables for non-workflow types + if workflow.type != WorkflowType.WORKFLOW.value: + system_inputs.update( + { + SystemVariableKey.QUERY: query, + SystemVariableKey.CONVERSATION_ID: conversation_id, + SystemVariableKey.DIALOGUE_COUNT: 0, + } + ) + else: + system_inputs = {} + + # init variable pool + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=user_inputs, + environment_variables=workflow.environment_variables, + conversation_variables=conversation_variables, + ) + + return variable_pool + + +def _rebuild_file_for_user_inputs_in_start_node( + tenant_id: str, start_node_data: StartNodeData, user_inputs: Mapping[str, Any] +) -> Mapping[str, Any]: + inputs_copy = dict(user_inputs) + + for variable in start_node_data.variables: + if variable.type not in (VariableEntityType.FILE, VariableEntityType.FILE_LIST): + continue + if variable.variable not in user_inputs: + continue + value = user_inputs[variable.variable] + file = _rebuild_single_file(tenant_id=tenant_id, value=value, variable_entity_type=variable.type) + inputs_copy[variable.variable] = file + return inputs_copy + + +def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: VariableEntityType) -> File | Sequence[File]: + if variable_entity_type == VariableEntityType.FILE: + if not isinstance(value, dict): + raise ValueError(f"expected dict for file object, got {type(value)}") + return build_from_mapping(mapping=value, tenant_id=tenant_id) + elif variable_entity_type == VariableEntityType.FILE_LIST: + if not isinstance(value, list): + raise ValueError(f"expected list for file list object, got {type(value)}") + if len(value) == 0: + return [] + if not isinstance(value[0], dict): + raise ValueError(f"expected dict for first element in the file list, got {type(value)}") + return build_from_mappings(mappings=value, tenant_id=tenant_id) + else: + raise Exception("unreachable") diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index 9e40a8494d..4046096c27 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -1,107 +1,217 @@ -# OpenAI API Key -OPENAI_API_KEY= +FLASK_APP=app.py +FLASK_DEBUG=0 +SECRET_KEY='uhySf6a3aZuvRNfAlcr47paOw9TRYBY6j8ZHXpVw1yx5RP27Yj3w2uvI' -# Azure OpenAI API Base Endpoint & API Key -AZURE_OPENAI_API_BASE= -AZURE_OPENAI_API_KEY= +CONSOLE_API_URL=http://127.0.0.1:5001 +CONSOLE_WEB_URL=http://127.0.0.1:3000 -# Anthropic API Key -ANTHROPIC_API_KEY= +# Service API base URL +SERVICE_API_URL=http://127.0.0.1:5001 -# Replicate API Key -REPLICATE_API_KEY= +# Web APP base URL +APP_WEB_URL=http://127.0.0.1:3000 -# Hugging Face API Key -HUGGINGFACE_API_KEY= -HUGGINGFACE_TEXT_GEN_ENDPOINT_URL= -HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL= -HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL= +# Files URL +FILES_URL=http://127.0.0.1:5001 -# Minimax Credentials -MINIMAX_API_KEY= -MINIMAX_GROUP_ID= +# The time in seconds after the signature is rejected +FILES_ACCESS_TIMEOUT=300 -# Spark Credentials -SPARK_APP_ID= -SPARK_API_KEY= -SPARK_API_SECRET= +# Access token expiration time in minutes +ACCESS_TOKEN_EXPIRE_MINUTES=60 -# Tongyi Credentials -TONGYI_DASHSCOPE_API_KEY= +# Refresh token expiration time in days +REFRESH_TOKEN_EXPIRE_DAYS=30 -# Wenxin Credentials -WENXIN_API_KEY= -WENXIN_SECRET_KEY= +# celery configuration +CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1 -# ZhipuAI Credentials -ZHIPUAI_API_KEY= +# redis configuration +REDIS_HOST=localhost +REDIS_PORT=6379 +REDIS_USERNAME= +REDIS_PASSWORD=difyai123456 +REDIS_USE_SSL=false +REDIS_DB=0 -# Baichuan Credentials -BAICHUAN_API_KEY= -BAICHUAN_SECRET_KEY= +# PostgreSQL database configuration +DB_USERNAME=postgres +DB_PASSWORD=difyai123456 +DB_HOST=localhost +DB_PORT=5432 +DB_DATABASE=dify -# ChatGLM Credentials -CHATGLM_API_BASE= +# Storage configuration +# use for store upload files, private keys... +# storage type: opendal, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase +STORAGE_TYPE=opendal -# Xinference Credentials -XINFERENCE_SERVER_URL= -XINFERENCE_GENERATION_MODEL_UID= -XINFERENCE_CHAT_MODEL_UID= -XINFERENCE_EMBEDDINGS_MODEL_UID= -XINFERENCE_RERANK_MODEL_UID= +# Apache OpenDAL storage configuration, refer to https://github.com/apache/opendal +OPENDAL_SCHEME=fs +OPENDAL_FS_ROOT=storage -# OpenLLM Credentials -OPENLLM_SERVER_URL= +# CORS configuration +WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* +CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* -# LocalAI Credentials -LOCALAI_SERVER_URL= +# Vector database configuration +# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase +VECTOR_STORE=weaviate +# Weaviate configuration +WEAVIATE_ENDPOINT=http://localhost:8080 +WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih +WEAVIATE_GRPC_ENABLED=false +WEAVIATE_BATCH_SIZE=100 -# Cohere Credentials -COHERE_API_KEY= -# Jina Credentials -JINA_API_KEY= +# Upload configuration +UPLOAD_FILE_SIZE_LIMIT=15 +UPLOAD_FILE_BATCH_LIMIT=5 +UPLOAD_IMAGE_FILE_SIZE_LIMIT=10 +UPLOAD_VIDEO_FILE_SIZE_LIMIT=100 +UPLOAD_AUDIO_FILE_SIZE_LIMIT=50 -# Ollama Credentials -OLLAMA_BASE_URL= +# Model configuration +MULTIMODAL_SEND_FORMAT=base64 +PROMPT_GENERATION_MAX_TOKENS=4096 +CODE_GENERATION_MAX_TOKENS=1024 -# Together API Key -TOGETHER_API_KEY= +# Mail configuration, support: resend, smtp +MAIL_TYPE= +MAIL_DEFAULT_SEND_FROM=no-reply +RESEND_API_KEY= +RESEND_API_URL=https://api.resend.com +# smtp configuration +SMTP_SERVER=smtp.example.com +SMTP_PORT=465 +SMTP_USERNAME=123 +SMTP_PASSWORD=abc +SMTP_USE_TLS=true +SMTP_OPPORTUNISTIC_TLS=false -# Mock Switch -MOCK_SWITCH=false +# Sentry configuration +SENTRY_DSN= + +# DEBUG +DEBUG=false +SQLALCHEMY_ECHO=false + +# Notion import configuration, support public and internal +NOTION_INTEGRATION_TYPE=public +NOTION_CLIENT_SECRET=you-client-secret +NOTION_CLIENT_ID=you-client-id +NOTION_INTERNAL_SECRET=you-internal-secret + +ETL_TYPE=dify +UNSTRUCTURED_API_URL= +UNSTRUCTURED_API_KEY= +SCARF_NO_ANALYTICS=false + +#ssrf +SSRF_PROXY_HTTP_URL= +SSRF_PROXY_HTTPS_URL= +SSRF_DEFAULT_MAX_RETRIES=3 +SSRF_DEFAULT_TIME_OUT=5 +SSRF_DEFAULT_CONNECT_TIME_OUT=5 +SSRF_DEFAULT_READ_TIME_OUT=5 +SSRF_DEFAULT_WRITE_TIME_OUT=5 + +BATCH_UPLOAD_LIMIT=10 +KEYWORD_DATA_SOURCE_TYPE=database + +# Workflow file upload limit +WORKFLOW_FILE_UPLOAD_LIMIT=10 # CODE EXECUTION CONFIGURATION -CODE_EXECUTION_ENDPOINT= -CODE_EXECUTION_API_KEY= +CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194 +CODE_EXECUTION_API_KEY=dify-sandbox +CODE_MAX_NUMBER=9223372036854775807 +CODE_MIN_NUMBER=-9223372036854775808 +CODE_MAX_STRING_LENGTH=80000 +TEMPLATE_TRANSFORM_MAX_LENGTH=80000 +CODE_MAX_STRING_ARRAY_LENGTH=30 +CODE_MAX_OBJECT_ARRAY_LENGTH=30 +CODE_MAX_NUMBER_ARRAY_LENGTH=1000 -# Volcengine MaaS Credentials -VOLC_API_KEY= -VOLC_SECRET_KEY= -VOLC_MODEL_ENDPOINT_ID= -VOLC_EMBEDDING_ENDPOINT_ID= +# API Tool configuration +API_TOOL_DEFAULT_CONNECT_TIMEOUT=10 +API_TOOL_DEFAULT_READ_TIMEOUT=60 -# 360 AI Credentials -ZHINAO_API_KEY= +# HTTP Node configuration +HTTP_REQUEST_MAX_CONNECT_TIMEOUT=300 +HTTP_REQUEST_MAX_READ_TIMEOUT=600 +HTTP_REQUEST_MAX_WRITE_TIMEOUT=600 +HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 +HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 + +# Respect X-* headers to redirect clients +RESPECT_XFORWARD_HEADERS_ENABLED=false + +# Log file path +LOG_FILE= +# Log file max size, the unit is MB +LOG_FILE_MAX_SIZE=20 +# Log file max backup count +LOG_FILE_BACKUP_COUNT=5 +# Log dateformat +LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S +# Log Timezone +LOG_TZ=UTC +# Log format +LOG_FORMAT=%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s + +# Indexing configuration +INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000 + +# Workflow runtime configuration +WORKFLOW_MAX_EXECUTION_STEPS=500 +WORKFLOW_MAX_EXECUTION_TIME=1200 +WORKFLOW_CALL_MAX_DEPTH=5 +WORKFLOW_PARALLEL_DEPTH_LIMIT=3 +MAX_VARIABLE_SIZE=204800 + +# App configuration +APP_MAX_EXECUTION_TIME=1200 +APP_MAX_ACTIVE_REQUESTS=0 + +# Celery beat configuration +CELERY_BEAT_SCHEDULER_TIME=1 + +# Position configuration +POSITION_TOOL_PINS= +POSITION_TOOL_INCLUDES= +POSITION_TOOL_EXCLUDES= + +POSITION_PROVIDER_PINS= +POSITION_PROVIDER_INCLUDES= +POSITION_PROVIDER_EXCLUDES= # Plugin configuration -PLUGIN_DAEMON_KEY= -PLUGIN_DAEMON_URL= +PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi +PLUGIN_DAEMON_URL=http://127.0.0.1:5002 +PLUGIN_REMOTE_INSTALL_PORT=5003 +PLUGIN_REMOTE_INSTALL_HOST=localhost +PLUGIN_MAX_PACKAGE_SIZE=15728640 +INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1 # Marketplace configuration -MARKETPLACE_API_URL= -# VESSL AI Credentials -VESSL_AI_MODEL_NAME= -VESSL_AI_API_KEY= -VESSL_AI_ENDPOINT_URL= +MARKETPLACE_ENABLED=true +MARKETPLACE_API_URL=https://marketplace.dify.ai -# GPUStack Credentials -GPUSTACK_SERVER_URL= -GPUSTACK_API_KEY= +# Endpoint configuration +ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id} -# Gitee AI Credentials -GITEE_AI_API_KEY= +# Reset password token expiry minutes +RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 -# xAI Credentials -XAI_API_KEY= -XAI_API_BASE= +CREATE_TIDB_SERVICE_JOB_ENABLED=false + +# Maximum number of submitted thread count in a ThreadPool for parallel node execution +MAX_SUBMIT_COUNT=100 +# Lockout duration in seconds +LOGIN_LOCKOUT_DURATION=86400 + +HTTP_PROXY='http://127.0.0.1:1092' +HTTPS_PROXY='http://127.0.0.1:1092' +NO_PROXY='localhost,127.0.0.1' +LOG_LEVEL=INFO diff --git a/api/tests/integration_tests/conftest.py b/api/tests/integration_tests/conftest.py index 6e3ab4b74b..d9f90f992e 100644 --- a/api/tests/integration_tests/conftest.py +++ b/api/tests/integration_tests/conftest.py @@ -1,19 +1,91 @@ -import os +import pathlib +import random +import secrets +from collections.abc import Generator -# Getting the absolute path of the current file's directory -ABS_PATH = os.path.dirname(os.path.abspath(__file__)) +import pytest +from flask import Flask +from flask.testing import FlaskClient +from sqlalchemy.orm import Session -# Getting the absolute path of the project's root directory -PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir)) +from app_factory import create_app +from models import Account, DifySetup, Tenant, TenantAccountJoin, db +from services.account_service import AccountService, RegisterService # Loading the .env file if it exists def _load_env() -> None: - dotenv_path = os.path.join(PROJECT_DIR, "tests", "integration_tests", ".env") - if os.path.exists(dotenv_path): + current_file_path = pathlib.Path(__file__).absolute() + # Items later in the list have higher precedence. + files_to_load = [".env", "vdb.env"] + + env_file_paths = [current_file_path.parent / i for i in files_to_load] + for path in env_file_paths: + if not path.exists(): + continue + from dotenv import load_dotenv - load_dotenv(dotenv_path) + # Set `override=True` to ensure values from `vdb.env` take priority over values from `.env` + load_dotenv(str(path), override=True) _load_env() + +_CACHED_APP = create_app() + + +@pytest.fixture +def flask_app() -> Flask: + return _CACHED_APP + + +@pytest.fixture(scope="session") +def setup_account(request) -> Generator[Account, None, None]: + """`dify_setup` completes the setup process for the Dify application. + + It creates `Account` and `Tenant`, and inserts a `DifySetup` record into the database. + + Most tests in the `controllers` package may require dify has been successfully setup. + """ + with _CACHED_APP.test_request_context(): + rand_suffix = random.randint(int(1e6), int(1e7)) # noqa + name = f"test-user-{rand_suffix}" + email = f"{name}@example.com" + RegisterService.setup( + email=email, + name=name, + password=secrets.token_hex(16), + ip_address="localhost", + ) + + with _CACHED_APP.test_request_context(): + with Session(bind=db.engine, expire_on_commit=False) as session: + account = session.query(Account).filter_by(email=email).one() + + yield account + + with _CACHED_APP.test_request_context(): + db.session.query(DifySetup).delete() + db.session.query(TenantAccountJoin).delete() + db.session.query(Account).delete() + db.session.query(Tenant).delete() + db.session.commit() + + +@pytest.fixture +def flask_req_ctx(): + with _CACHED_APP.test_request_context(): + yield + + +@pytest.fixture +def auth_header(setup_account) -> dict[str, str]: + token = AccountService.get_account_jwt_token(setup_account) + return {"Authorization": f"Bearer {token}"} + + +@pytest.fixture +def test_client() -> Generator[FlaskClient, None, None]: + with _CACHED_APP.test_client() as client: + yield client diff --git a/api/tests/integration_tests/controllers/app_fixture.py b/api/tests/integration_tests/controllers/app_fixture.py deleted file mode 100644 index 32e8c11d19..0000000000 --- a/api/tests/integration_tests/controllers/app_fixture.py +++ /dev/null @@ -1,25 +0,0 @@ -import pytest - -from app_factory import create_app -from configs import dify_config - -mock_user = type( - "MockUser", - (object,), - { - "is_authenticated": True, - "id": "123", - "is_editor": True, - "is_dataset_editor": True, - "status": "active", - "get_id": "123", - "current_tenant_id": "9d2074fc-6f86-45a9-b09d-6ecc63b9056b", - }, -) - - -@pytest.fixture -def app(): - app = create_app() - dify_config.LOGIN_DISABLED = True - return app diff --git a/api/tests/integration_tests/controllers/console/__init__.py b/api/tests/integration_tests/controllers/console/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/controllers/console/app/__init__.py b/api/tests/integration_tests/controllers/console/app/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/integration_tests/controllers/console/app/test_workflow_draft_variable.py new file mode 100644 index 0000000000..038f37af5f --- /dev/null +++ b/api/tests/integration_tests/controllers/console/app/test_workflow_draft_variable.py @@ -0,0 +1,47 @@ +import uuid +from unittest import mock + +from controllers.console.app import workflow_draft_variable as draft_variable_api +from controllers.console.app import wraps +from factories.variable_factory import build_segment +from models import App, AppMode +from models.workflow import WorkflowDraftVariable +from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService + + +def _get_mock_srv_class() -> type[WorkflowDraftVariableService]: + return mock.create_autospec(WorkflowDraftVariableService) + + +class TestWorkflowDraftNodeVariableListApi: + def test_get(self, test_client, auth_header, monkeypatch): + srv_class = _get_mock_srv_class() + mock_app_model: App = App() + mock_app_model.id = str(uuid.uuid4()) + test_node_id = "test_node_id" + mock_app_model.mode = AppMode.ADVANCED_CHAT + mock_load_app_model = mock.Mock(return_value=mock_app_model) + + monkeypatch.setattr(draft_variable_api, "WorkflowDraftVariableService", srv_class) + monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model) + + var1 = WorkflowDraftVariable.new_node_variable( + app_id="test_app_1", + node_id="test_node_1", + name="str_var", + value=build_segment("str_value"), + node_execution_id=str(uuid.uuid4()), + ) + srv_instance = mock.create_autospec(WorkflowDraftVariableService, instance=True) + srv_class.return_value = srv_instance + srv_instance.list_node_variables.return_value = WorkflowDraftVariableList(variables=[var1]) + + response = test_client.get( + f"/console/api/apps/{mock_app_model.id}/workflows/draft/nodes/{test_node_id}/variables", + headers=auth_header, + ) + assert response.status_code == 200 + response_dict = response.json + assert isinstance(response_dict, dict) + assert "items" in response_dict + assert len(response_dict["items"]) == 1 diff --git a/api/tests/integration_tests/controllers/test_controllers.py b/api/tests/integration_tests/controllers/test_controllers.py deleted file mode 100644 index 276ad3a7ed..0000000000 --- a/api/tests/integration_tests/controllers/test_controllers.py +++ /dev/null @@ -1,9 +0,0 @@ -from unittest.mock import patch - -from app_fixture import mock_user # type: ignore - - -def test_post_requires_login(app): - with app.test_client() as client, patch("flask_login.utils._get_user", mock_user): - response = client.get("/console/api/data-source/integrates") - assert response.status_code == 200 diff --git a/api/tests/integration_tests/factories/__init__.py b/api/tests/integration_tests/factories/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/factories/test_storage_key_loader.py b/api/tests/integration_tests/factories/test_storage_key_loader.py new file mode 100644 index 0000000000..fecb3f6d95 --- /dev/null +++ b/api/tests/integration_tests/factories/test_storage_key_loader.py @@ -0,0 +1,371 @@ +import unittest +from datetime import UTC, datetime +from typing import Optional +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session + +from core.file import File, FileTransferMethod, FileType +from extensions.ext_database import db +from factories.file_factory import StorageKeyLoader +from models import ToolFile, UploadFile +from models.enums import CreatorUserRole + + +@pytest.mark.usefixtures("flask_req_ctx") +class TestStorageKeyLoader(unittest.TestCase): + """ + Integration tests for StorageKeyLoader class. + + Tests the batched loading of storage keys from the database for files + with different transfer methods: LOCAL_FILE, REMOTE_URL, and TOOL_FILE. + """ + + def setUp(self): + """Set up test data before each test method.""" + self.session = db.session() + self.tenant_id = str(uuid4()) + self.user_id = str(uuid4()) + self.conversation_id = str(uuid4()) + + # Create test data that will be cleaned up after each test + self.test_upload_files = [] + self.test_tool_files = [] + + # Create StorageKeyLoader instance + self.loader = StorageKeyLoader(self.session, self.tenant_id) + + def tearDown(self): + """Clean up test data after each test method.""" + self.session.rollback() + + def _create_upload_file( + self, file_id: Optional[str] = None, storage_key: Optional[str] = None, tenant_id: Optional[str] = None + ) -> UploadFile: + """Helper method to create an UploadFile record for testing.""" + if file_id is None: + file_id = str(uuid4()) + if storage_key is None: + storage_key = f"test_storage_key_{uuid4()}" + if tenant_id is None: + tenant_id = self.tenant_id + + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type="local", + key=storage_key, + name="test_file.txt", + size=1024, + extension=".txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=self.user_id, + created_at=datetime.now(UTC), + used=False, + ) + upload_file.id = file_id + + self.session.add(upload_file) + self.session.flush() + self.test_upload_files.append(upload_file) + + return upload_file + + def _create_tool_file( + self, file_id: Optional[str] = None, file_key: Optional[str] = None, tenant_id: Optional[str] = None + ) -> ToolFile: + """Helper method to create a ToolFile record for testing.""" + if file_id is None: + file_id = str(uuid4()) + if file_key is None: + file_key = f"test_file_key_{uuid4()}" + if tenant_id is None: + tenant_id = self.tenant_id + + tool_file = ToolFile() + tool_file.id = file_id + tool_file.user_id = self.user_id + tool_file.tenant_id = tenant_id + tool_file.conversation_id = self.conversation_id + tool_file.file_key = file_key + tool_file.mimetype = "text/plain" + tool_file.original_url = "http://example.com/file.txt" + tool_file.name = "test_tool_file.txt" + tool_file.size = 2048 + + self.session.add(tool_file) + self.session.flush() + self.test_tool_files.append(tool_file) + + return tool_file + + def _create_file( + self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None + ) -> File: + """Helper method to create a File object for testing.""" + if tenant_id is None: + tenant_id = self.tenant_id + + # Set related_id for LOCAL_FILE and TOOL_FILE transfer methods + file_related_id = None + remote_url = None + + if transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE): + file_related_id = related_id + elif transfer_method == FileTransferMethod.REMOTE_URL: + remote_url = "https://example.com/test_file.txt" + file_related_id = related_id + + return File( + id=str(uuid4()), # Generate new UUID for File.id + tenant_id=tenant_id, + type=FileType.DOCUMENT, + transfer_method=transfer_method, + related_id=file_related_id, + remote_url=remote_url, + filename="test_file.txt", + extension=".txt", + mime_type="text/plain", + size=1024, + storage_key="initial_key", + ) + + def test_load_storage_keys_local_file(self): + """Test loading storage keys for LOCAL_FILE transfer method.""" + # Create test data + upload_file = self._create_upload_file() + file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + + # Load storage keys + self.loader.load_storage_keys([file]) + + # Verify storage key was loaded correctly + assert file._storage_key == upload_file.key + + def test_load_storage_keys_remote_url(self): + """Test loading storage keys for REMOTE_URL transfer method.""" + # Create test data + upload_file = self._create_upload_file() + file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL) + + # Load storage keys + self.loader.load_storage_keys([file]) + + # Verify storage key was loaded correctly + assert file._storage_key == upload_file.key + + def test_load_storage_keys_tool_file(self): + """Test loading storage keys for TOOL_FILE transfer method.""" + # Create test data + tool_file = self._create_tool_file() + file = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE) + + # Load storage keys + self.loader.load_storage_keys([file]) + + # Verify storage key was loaded correctly + assert file._storage_key == tool_file.file_key + + def test_load_storage_keys_mixed_methods(self): + """Test batch loading with mixed transfer methods.""" + # Create test data for different transfer methods + upload_file1 = self._create_upload_file() + upload_file2 = self._create_upload_file() + tool_file = self._create_tool_file() + + file1 = self._create_file(related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE) + file2 = self._create_file(related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL) + file3 = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE) + + files = [file1, file2, file3] + + # Load storage keys + self.loader.load_storage_keys(files) + + # Verify all storage keys were loaded correctly + assert file1._storage_key == upload_file1.key + assert file2._storage_key == upload_file2.key + assert file3._storage_key == tool_file.file_key + + def test_load_storage_keys_empty_list(self): + """Test with empty file list.""" + # Should not raise any exceptions + self.loader.load_storage_keys([]) + + def test_load_storage_keys_tenant_mismatch(self): + """Test tenant_id validation.""" + # Create file with different tenant_id + upload_file = self._create_upload_file() + file = self._create_file( + related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4()) + ) + + # Should raise ValueError for tenant mismatch + with pytest.raises(ValueError) as context: + self.loader.load_storage_keys([file]) + + assert "invalid file, expected tenant_id" in str(context.value) + + def test_load_storage_keys_missing_file_id(self): + """Test with None file.related_id.""" + # Create a file with valid parameters first, then manually set related_id to None + file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE) + file.related_id = None + + # Should raise ValueError for None file related_id + with pytest.raises(ValueError) as context: + self.loader.load_storage_keys([file]) + + assert str(context.value) == "file id should not be None." + + def test_load_storage_keys_nonexistent_upload_file_records(self): + """Test with missing UploadFile database records.""" + # Create file with non-existent upload file id + non_existent_id = str(uuid4()) + file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.LOCAL_FILE) + + # Should raise ValueError for missing record + with pytest.raises(ValueError): + self.loader.load_storage_keys([file]) + + def test_load_storage_keys_nonexistent_tool_file_records(self): + """Test with missing ToolFile database records.""" + # Create file with non-existent tool file id + non_existent_id = str(uuid4()) + file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.TOOL_FILE) + + # Should raise ValueError for missing record + with pytest.raises(ValueError): + self.loader.load_storage_keys([file]) + + def test_load_storage_keys_invalid_uuid(self): + """Test with invalid UUID format.""" + # Create a file with valid parameters first, then manually set invalid related_id + file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE) + file.related_id = "invalid-uuid-format" + + # Should raise ValueError for invalid UUID + with pytest.raises(ValueError): + self.loader.load_storage_keys([file]) + + def test_load_storage_keys_batch_efficiency(self): + """Test batched operations use efficient queries.""" + # Create multiple files of different types + upload_files = [self._create_upload_file() for _ in range(3)] + tool_files = [self._create_tool_file() for _ in range(2)] + + files = [] + files.extend( + [self._create_file(related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE) for uf in upload_files] + ) + files.extend( + [self._create_file(related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE) for tf in tool_files] + ) + + # Mock the session to count queries + with patch.object(self.session, "scalars", wraps=self.session.scalars) as mock_scalars: + self.loader.load_storage_keys(files) + + # Should make exactly 2 queries (one for upload_files, one for tool_files) + assert mock_scalars.call_count == 2 + + # Verify all storage keys were loaded correctly + for i, file in enumerate(files[:3]): + assert file._storage_key == upload_files[i].key + for i, file in enumerate(files[3:]): + assert file._storage_key == tool_files[i].file_key + + def test_load_storage_keys_tenant_isolation(self): + """Test that tenant isolation works correctly.""" + # Create files for different tenants + other_tenant_id = str(uuid4()) + + # Create upload file for current tenant + upload_file_current = self._create_upload_file() + file_current = self._create_file( + related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE + ) + + # Create upload file for other tenant (but don't add to cleanup list) + upload_file_other = UploadFile( + tenant_id=other_tenant_id, + storage_type="local", + key="other_tenant_key", + name="other_file.txt", + size=1024, + extension=".txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=self.user_id, + created_at=datetime.now(UTC), + used=False, + ) + upload_file_other.id = str(uuid4()) + self.session.add(upload_file_other) + self.session.flush() + + # Create file for other tenant but try to load with current tenant's loader + file_other = self._create_file( + related_id=upload_file_other.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id + ) + + # Should raise ValueError due to tenant mismatch + with pytest.raises(ValueError) as context: + self.loader.load_storage_keys([file_other]) + + assert "invalid file, expected tenant_id" in str(context.value) + + # Current tenant's file should still work + self.loader.load_storage_keys([file_current]) + assert file_current._storage_key == upload_file_current.key + + def test_load_storage_keys_mixed_tenant_batch(self): + """Test batch with mixed tenant files (should fail on first mismatch).""" + # Create files for current tenant + upload_file_current = self._create_upload_file() + file_current = self._create_file( + related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE + ) + + # Create file for different tenant + other_tenant_id = str(uuid4()) + file_other = self._create_file( + related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id + ) + + # Should raise ValueError on tenant mismatch + with pytest.raises(ValueError) as context: + self.loader.load_storage_keys([file_current, file_other]) + + assert "invalid file, expected tenant_id" in str(context.value) + + def test_load_storage_keys_duplicate_file_ids(self): + """Test handling of duplicate file IDs in the batch.""" + # Create upload file + upload_file = self._create_upload_file() + + # Create two File objects with same related_id + file1 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + file2 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + + # Should handle duplicates gracefully + self.loader.load_storage_keys([file1, file2]) + + # Both files should have the same storage key + assert file1._storage_key == upload_file.key + assert file2._storage_key == upload_file.key + + def test_load_storage_keys_session_isolation(self): + """Test that the loader uses the provided session correctly.""" + # Create test data + upload_file = self._create_upload_file() + file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + + # Create loader with different session (same underlying connection) + + with Session(bind=db.engine) as other_session: + other_loader = StorageKeyLoader(other_session, self.tenant_id) + with pytest.raises(ValueError): + other_loader.load_storage_keys([file]) diff --git a/api/tests/integration_tests/services/__init__.py b/api/tests/integration_tests/services/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py new file mode 100644 index 0000000000..30cd2e60cb --- /dev/null +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -0,0 +1,501 @@ +import json +import unittest +import uuid + +import pytest +from sqlalchemy.orm import Session + +from core.variables.variables import StringVariable +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.workflow.nodes import NodeType +from factories.variable_factory import build_segment +from libs import datetime_utils +from models import db +from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel +from services.workflow_draft_variable_service import DraftVarLoader, VariableResetError, WorkflowDraftVariableService + + +@pytest.mark.usefixtures("flask_req_ctx") +class TestWorkflowDraftVariableService(unittest.TestCase): + _test_app_id: str + _session: Session + _node1_id = "test_node_1" + _node2_id = "test_node_2" + _node_exec_id = str(uuid.uuid4()) + + def setUp(self): + self._test_app_id = str(uuid.uuid4()) + self._session: Session = db.session() + sys_var = WorkflowDraftVariable.new_sys_variable( + app_id=self._test_app_id, + name="sys_var", + value=build_segment("sys_value"), + node_execution_id=self._node_exec_id, + ) + conv_var = WorkflowDraftVariable.new_conversation_variable( + app_id=self._test_app_id, + name="conv_var", + value=build_segment("conv_value"), + ) + node2_vars = [ + WorkflowDraftVariable.new_node_variable( + app_id=self._test_app_id, + node_id=self._node2_id, + name="int_var", + value=build_segment(1), + visible=False, + node_execution_id=self._node_exec_id, + ), + WorkflowDraftVariable.new_node_variable( + app_id=self._test_app_id, + node_id=self._node2_id, + name="str_var", + value=build_segment("str_value"), + visible=True, + node_execution_id=self._node_exec_id, + ), + ] + node1_var = WorkflowDraftVariable.new_node_variable( + app_id=self._test_app_id, + node_id=self._node1_id, + name="str_var", + value=build_segment("str_value"), + visible=True, + node_execution_id=self._node_exec_id, + ) + _variables = list(node2_vars) + _variables.extend( + [ + node1_var, + sys_var, + conv_var, + ] + ) + + db.session.add_all(_variables) + db.session.flush() + self._variable_ids = [v.id for v in _variables] + self._node1_str_var_id = node1_var.id + self._sys_var_id = sys_var.id + self._conv_var_id = conv_var.id + self._node2_var_ids = [v.id for v in node2_vars] + + def _get_test_srv(self) -> WorkflowDraftVariableService: + return WorkflowDraftVariableService(session=self._session) + + def tearDown(self): + self._session.rollback() + + def test_list_variables(self): + srv = self._get_test_srv() + var_list = srv.list_variables_without_values(self._test_app_id, page=1, limit=2) + assert var_list.total == 5 + assert len(var_list.variables) == 2 + page1_var_ids = {v.id for v in var_list.variables} + assert page1_var_ids.issubset(self._variable_ids) + + var_list_2 = srv.list_variables_without_values(self._test_app_id, page=2, limit=2) + assert var_list_2.total is None + assert len(var_list_2.variables) == 2 + page2_var_ids = {v.id for v in var_list_2.variables} + assert page2_var_ids.isdisjoint(page1_var_ids) + assert page2_var_ids.issubset(self._variable_ids) + + def test_get_node_variable(self): + srv = self._get_test_srv() + node_var = srv.get_node_variable(self._test_app_id, self._node1_id, "str_var") + assert node_var is not None + assert node_var.id == self._node1_str_var_id + assert node_var.name == "str_var" + assert node_var.get_value() == build_segment("str_value") + + def test_get_system_variable(self): + srv = self._get_test_srv() + sys_var = srv.get_system_variable(self._test_app_id, "sys_var") + assert sys_var is not None + assert sys_var.id == self._sys_var_id + assert sys_var.name == "sys_var" + assert sys_var.get_value() == build_segment("sys_value") + + def test_get_conversation_variable(self): + srv = self._get_test_srv() + conv_var = srv.get_conversation_variable(self._test_app_id, "conv_var") + assert conv_var is not None + assert conv_var.id == self._conv_var_id + assert conv_var.name == "conv_var" + assert conv_var.get_value() == build_segment("conv_value") + + def test_delete_node_variables(self): + srv = self._get_test_srv() + srv.delete_node_variables(self._test_app_id, self._node2_id) + node2_var_count = ( + self._session.query(WorkflowDraftVariable) + .where( + WorkflowDraftVariable.app_id == self._test_app_id, + WorkflowDraftVariable.node_id == self._node2_id, + ) + .count() + ) + assert node2_var_count == 0 + + def test_delete_variable(self): + srv = self._get_test_srv() + node_1_var = ( + self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).one() + ) + srv.delete_variable(node_1_var) + exists = bool( + self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).first() + ) + assert exists is False + + def test__list_node_variables(self): + srv = self._get_test_srv() + node_vars = srv._list_node_variables(self._test_app_id, self._node2_id) + assert len(node_vars.variables) == 2 + assert {v.id for v in node_vars.variables} == set(self._node2_var_ids) + + def test_get_draft_variables_by_selectors(self): + srv = self._get_test_srv() + selectors = [ + [self._node1_id, "str_var"], + [self._node2_id, "str_var"], + [self._node2_id, "int_var"], + ] + variables = srv.get_draft_variables_by_selectors(self._test_app_id, selectors) + assert len(variables) == 3 + assert {v.id for v in variables} == {self._node1_str_var_id} | set(self._node2_var_ids) + + +@pytest.mark.usefixtures("flask_req_ctx") +class TestDraftVariableLoader(unittest.TestCase): + _test_app_id: str + _test_tenant_id: str + + _node1_id = "test_loader_node_1" + _node_exec_id = str(uuid.uuid4()) + + def setUp(self): + self._test_app_id = str(uuid.uuid4()) + self._test_tenant_id = str(uuid.uuid4()) + sys_var = WorkflowDraftVariable.new_sys_variable( + app_id=self._test_app_id, + name="sys_var", + value=build_segment("sys_value"), + node_execution_id=self._node_exec_id, + ) + conv_var = WorkflowDraftVariable.new_conversation_variable( + app_id=self._test_app_id, + name="conv_var", + value=build_segment("conv_value"), + ) + node_var = WorkflowDraftVariable.new_node_variable( + app_id=self._test_app_id, + node_id=self._node1_id, + name="str_var", + value=build_segment("str_value"), + visible=True, + node_execution_id=self._node_exec_id, + ) + _variables = [ + node_var, + sys_var, + conv_var, + ] + + with Session(bind=db.engine, expire_on_commit=False) as session: + session.add_all(_variables) + session.flush() + session.commit() + self._variable_ids = [v.id for v in _variables] + self._node_var_id = node_var.id + self._sys_var_id = sys_var.id + self._conv_var_id = conv_var.id + + def tearDown(self): + with Session(bind=db.engine, expire_on_commit=False) as session: + session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.app_id == self._test_app_id).delete( + synchronize_session=False + ) + session.commit() + + def test_variable_loader_with_empty_selector(self): + var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id) + variables = var_loader.load_variables([]) + assert len(variables) == 0 + + def test_variable_loader_with_non_empty_selector(self): + var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id) + variables = var_loader.load_variables( + [ + [SYSTEM_VARIABLE_NODE_ID, "sys_var"], + [CONVERSATION_VARIABLE_NODE_ID, "conv_var"], + [self._node1_id, "str_var"], + ] + ) + assert len(variables) == 3 + conv_var = next(v for v in variables if v.selector[0] == CONVERSATION_VARIABLE_NODE_ID) + assert conv_var.id == self._conv_var_id + sys_var = next(v for v in variables if v.selector[0] == SYSTEM_VARIABLE_NODE_ID) + assert sys_var.id == self._sys_var_id + node1_var = next(v for v in variables if v.selector[0] == self._node1_id) + assert node1_var.id == self._node_var_id + + +@pytest.mark.usefixtures("flask_req_ctx") +class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase): + """Integration tests for reset_variable functionality using real database""" + + _test_app_id: str + _test_tenant_id: str + _test_workflow_id: str + _session: Session + _node_id = "test_reset_node" + _node_exec_id: str + _workflow_node_exec_id: str + + def setUp(self): + self._test_app_id = str(uuid.uuid4()) + self._test_tenant_id = str(uuid.uuid4()) + self._test_workflow_id = str(uuid.uuid4()) + self._node_exec_id = str(uuid.uuid4()) + self._workflow_node_exec_id = str(uuid.uuid4()) + self._session: Session = db.session() + + # Create a workflow node execution record with outputs + # Note: The WorkflowNodeExecutionModel.id should match the node_execution_id in WorkflowDraftVariable + self._workflow_node_execution = WorkflowNodeExecutionModel( + id=self._node_exec_id, # This should match the node_execution_id in the variable + tenant_id=self._test_tenant_id, + app_id=self._test_app_id, + workflow_id=self._test_workflow_id, + triggered_from="workflow-run", + workflow_run_id=str(uuid.uuid4()), + index=1, + node_execution_id=self._node_exec_id, + node_id=self._node_id, + node_type=NodeType.LLM.value, + title="Test Node", + inputs='{"input": "test input"}', + process_data='{"test_var": "process_value", "other_var": "other_process"}', + outputs='{"test_var": "output_value", "other_var": "other_output"}', + status="succeeded", + elapsed_time=1.5, + created_by_role="account", + created_by=str(uuid.uuid4()), + ) + + # Create conversation variables for the workflow + self._conv_variables = [ + StringVariable( + id=str(uuid.uuid4()), + name="conv_var_1", + description="Test conversation variable 1", + value="default_value_1", + ), + StringVariable( + id=str(uuid.uuid4()), + name="conv_var_2", + description="Test conversation variable 2", + value="default_value_2", + ), + ] + + # Create test variables + self._node_var_with_exec = WorkflowDraftVariable.new_node_variable( + app_id=self._test_app_id, + node_id=self._node_id, + name="test_var", + value=build_segment("old_value"), + node_execution_id=self._node_exec_id, + ) + self._node_var_with_exec.last_edited_at = datetime_utils.naive_utc_now() + + self._node_var_without_exec = WorkflowDraftVariable.new_node_variable( + app_id=self._test_app_id, + node_id=self._node_id, + name="no_exec_var", + value=build_segment("some_value"), + node_execution_id="temp_exec_id", + ) + # Manually set node_execution_id to None after creation + self._node_var_without_exec.node_execution_id = None + + self._node_var_missing_exec = WorkflowDraftVariable.new_node_variable( + app_id=self._test_app_id, + node_id=self._node_id, + name="missing_exec_var", + value=build_segment("some_value"), + node_execution_id=str(uuid.uuid4()), # Use a valid UUID that doesn't exist in database + ) + + self._conv_var = WorkflowDraftVariable.new_conversation_variable( + app_id=self._test_app_id, + name="conv_var_1", + value=build_segment("old_conv_value"), + ) + self._conv_var.last_edited_at = datetime_utils.naive_utc_now() + + # Add all to database + db.session.add_all( + [ + self._workflow_node_execution, + self._node_var_with_exec, + self._node_var_without_exec, + self._node_var_missing_exec, + self._conv_var, + ] + ) + db.session.flush() + + # Store IDs for assertions + self._node_var_with_exec_id = self._node_var_with_exec.id + self._node_var_without_exec_id = self._node_var_without_exec.id + self._node_var_missing_exec_id = self._node_var_missing_exec.id + self._conv_var_id = self._conv_var.id + + def _get_test_srv(self) -> WorkflowDraftVariableService: + return WorkflowDraftVariableService(session=self._session) + + def _create_mock_workflow(self) -> Workflow: + """Create a real workflow with conversation variables and graph""" + conversation_vars = self._conv_variables + + # Create a simple graph with the test node + graph = { + "nodes": [{"id": "test_reset_node", "type": "llm", "title": "Test Node", "data": {"type": "llm"}}], + "edges": [], + } + + workflow = Workflow.new( + tenant_id=str(uuid.uuid4()), + app_id=self._test_app_id, + type="workflow", + version="1.0", + graph=json.dumps(graph), + features="{}", + created_by=str(uuid.uuid4()), + environment_variables=[], + conversation_variables=conversation_vars, + ) + return workflow + + def tearDown(self): + self._session.rollback() + + def test_reset_node_variable_with_valid_execution_record(self): + """Test resetting a node variable with valid execution record - should restore from execution""" + srv = self._get_test_srv() + mock_workflow = self._create_mock_workflow() + + # Get the variable before reset + variable = srv.get_variable(self._node_var_with_exec_id) + assert variable is not None + assert variable.get_value().value == "old_value" + assert variable.last_edited_at is not None + + # Reset the variable + result = srv.reset_variable(mock_workflow, variable) + + # Should return the updated variable + assert result is not None + assert result.id == self._node_var_with_exec_id + assert result.node_execution_id == self._workflow_node_execution.id + assert result.last_edited_at is None # Should be reset to None + + # The returned variable should have the updated value from execution record + assert result.get_value().value == "output_value" + + # Verify the variable was updated in database + updated_variable = srv.get_variable(self._node_var_with_exec_id) + assert updated_variable is not None + # The value should be updated from the execution record's outputs + assert updated_variable.get_value().value == "output_value" + assert updated_variable.last_edited_at is None + assert updated_variable.node_execution_id == self._workflow_node_execution.id + + def test_reset_node_variable_with_no_execution_id(self): + """Test resetting a node variable with no execution ID - should delete variable""" + srv = self._get_test_srv() + mock_workflow = self._create_mock_workflow() + + # Get the variable before reset + variable = srv.get_variable(self._node_var_without_exec_id) + assert variable is not None + + # Reset the variable + result = srv.reset_variable(mock_workflow, variable) + + # Should return None (variable deleted) + assert result is None + + # Verify the variable was deleted + deleted_variable = srv.get_variable(self._node_var_without_exec_id) + assert deleted_variable is None + + def test_reset_node_variable_with_missing_execution_record(self): + """Test resetting a node variable when execution record doesn't exist""" + srv = self._get_test_srv() + mock_workflow = self._create_mock_workflow() + + # Get the variable before reset + variable = srv.get_variable(self._node_var_missing_exec_id) + assert variable is not None + + # Reset the variable + result = srv.reset_variable(mock_workflow, variable) + + # Should return None (variable deleted) + assert result is None + + # Verify the variable was deleted + deleted_variable = srv.get_variable(self._node_var_missing_exec_id) + assert deleted_variable is None + + def test_reset_conversation_variable(self): + """Test resetting a conversation variable""" + srv = self._get_test_srv() + mock_workflow = self._create_mock_workflow() + + # Get the variable before reset + variable = srv.get_variable(self._conv_var_id) + assert variable is not None + assert variable.get_value().value == "old_conv_value" + assert variable.last_edited_at is not None + + # Reset the variable + result = srv.reset_variable(mock_workflow, variable) + + # Should return the updated variable + assert result is not None + assert result.id == self._conv_var_id + assert result.last_edited_at is None # Should be reset to None + + # Verify the variable was updated with default value from workflow + updated_variable = srv.get_variable(self._conv_var_id) + assert updated_variable is not None + # The value should be updated from the workflow's conversation variable default + assert updated_variable.get_value().value == "default_value_1" + assert updated_variable.last_edited_at is None + + def test_reset_system_variable_raises_error(self): + """Test that resetting a system variable raises an error""" + srv = self._get_test_srv() + mock_workflow = self._create_mock_workflow() + + # Create a system variable + sys_var = WorkflowDraftVariable.new_sys_variable( + app_id=self._test_app_id, + name="sys_var", + value=build_segment("sys_value"), + node_execution_id=self._node_exec_id, + ) + db.session.add(sys_var) + db.session.flush() + + # Attempt to reset the system variable + with pytest.raises(VariableResetError) as exc_info: + srv.reset_variable(mock_workflow, sys_var) + + assert "cannot reset system variable" in str(exc_info.value) + assert sys_var.id in str(exc_info.value) diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 6aa48b1cbb..389d1071f3 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -8,9 +8,8 @@ from unittest.mock import MagicMock, patch import pytest -from app_factory import create_app -from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom +from core.llm_generator.output_parser.structured_output import _parse_structured_output from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.message_entities import AssistantPromptMessage from core.workflow.entities.variable_pool import VariablePool @@ -30,21 +29,6 @@ from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_mod from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock -@pytest.fixture(scope="session") -def app(): - # Set up storage configuration - os.environ["STORAGE_TYPE"] = "opendal" - os.environ["OPENDAL_SCHEME"] = "fs" - os.environ["OPENDAL_FS_ROOT"] = "storage" - - # Ensure storage directory exists - os.makedirs("storage", exist_ok=True) - - app = create_app() - dify_config.LOGIN_DISABLED = True - return app - - def init_llm_node(config: dict) -> LLMNode: graph_config = { "edges": [ @@ -102,200 +86,101 @@ def init_llm_node(config: dict) -> LLMNode: return node -def test_execute_llm(app): - with app.app_context(): - node = init_llm_node( - config={ - "id": "llm", - "data": { - "title": "123", - "type": "llm", - "model": { - "provider": "langgenius/openai/openai", - "name": "gpt-3.5-turbo", - "mode": "chat", - "completion_params": {}, - }, - "prompt_template": [ - { - "role": "system", - "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.", - }, - {"role": "user", "text": "{{#sys.query#}}"}, - ], - "memory": None, - "context": {"enabled": False}, - "vision": {"enabled": False}, +def test_execute_llm(flask_req_ctx): + node = init_llm_node( + config={ + "id": "llm", + "data": { + "title": "123", + "type": "llm", + "model": { + "provider": "langgenius/openai/openai", + "name": "gpt-3.5-turbo", + "mode": "chat", + "completion_params": {}, }, + "prompt_template": [ + { + "role": "system", + "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.", + }, + {"role": "user", "text": "{{#sys.query#}}"}, + ], + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, }, - ) + }, + ) - credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} + credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} - # Create a proper LLM result with real entities - mock_usage = LLMUsage( - prompt_tokens=30, - prompt_unit_price=Decimal("0.001"), - prompt_price_unit=Decimal("1000"), - prompt_price=Decimal("0.00003"), - completion_tokens=20, - completion_unit_price=Decimal("0.002"), - completion_price_unit=Decimal("1000"), - completion_price=Decimal("0.00004"), - total_tokens=50, - total_price=Decimal("0.00007"), - currency="USD", - latency=0.5, - ) + # Create a proper LLM result with real entities + mock_usage = LLMUsage( + prompt_tokens=30, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("1000"), + prompt_price=Decimal("0.00003"), + completion_tokens=20, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("1000"), + completion_price=Decimal("0.00004"), + total_tokens=50, + total_price=Decimal("0.00007"), + currency="USD", + latency=0.5, + ) - mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.") + mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.") - mock_llm_result = LLMResult( - model="gpt-3.5-turbo", - prompt_messages=[], - message=mock_message, - usage=mock_usage, - ) + mock_llm_result = LLMResult( + model="gpt-3.5-turbo", + prompt_messages=[], + message=mock_message, + usage=mock_usage, + ) - # Create a simple mock model instance that doesn't call real providers - mock_model_instance = MagicMock() - mock_model_instance.invoke_llm.return_value = mock_llm_result + # Create a simple mock model instance that doesn't call real providers + mock_model_instance = MagicMock() + mock_model_instance.invoke_llm.return_value = mock_llm_result - # Create a simple mock model config with required attributes - mock_model_config = MagicMock() - mock_model_config.mode = "chat" - mock_model_config.provider = "langgenius/openai/openai" - mock_model_config.model = "gpt-3.5-turbo" - mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" + # Create a simple mock model config with required attributes + mock_model_config = MagicMock() + mock_model_config.mode = "chat" + mock_model_config.provider = "langgenius/openai/openai" + mock_model_config.model = "gpt-3.5-turbo" + mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" - # Mock the _fetch_model_config method - def mock_fetch_model_config_func(_node_data_model): - return mock_model_instance, mock_model_config + # Mock the _fetch_model_config method + def mock_fetch_model_config_func(_node_data_model): + return mock_model_instance, mock_model_config - # Also mock ModelManager.get_model_instance to avoid database calls - def mock_get_model_instance(_self, **kwargs): - return mock_model_instance + # Also mock ModelManager.get_model_instance to avoid database calls + def mock_get_model_instance(_self, **kwargs): + return mock_model_instance - with ( - patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), - patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), - ): - # execute node - result = node._run() - assert isinstance(result, Generator) + with ( + patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), + patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), + ): + # execute node + result = node._run() + assert isinstance(result, Generator) - for item in result: - if isinstance(item, RunCompletedEvent): - assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.process_data is not None - assert item.run_result.outputs is not None - assert item.run_result.outputs.get("text") is not None - assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 + for item in result: + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.process_data is not None + assert item.run_result.outputs is not None + assert item.run_result.outputs.get("text") is not None + assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) -def test_execute_llm_with_jinja2(app, setup_code_executor_mock): +def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock): """ Test execute LLM node with jinja2 """ - with app.app_context(): - node = init_llm_node( - config={ - "id": "llm", - "data": { - "title": "123", - "type": "llm", - "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, - "prompt_config": { - "jinja2_variables": [ - {"variable": "sys_query", "value_selector": ["sys", "query"]}, - {"variable": "output", "value_selector": ["abc", "output"]}, - ] - }, - "prompt_template": [ - { - "role": "system", - "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}", - "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.", - "edition_type": "jinja2", - }, - { - "role": "user", - "text": "{{#sys.query#}}", - "jinja2_text": "{{sys_query}}", - "edition_type": "basic", - }, - ], - "memory": None, - "context": {"enabled": False}, - "vision": {"enabled": False}, - }, - }, - ) - - # Mock db.session.close() - db.session.close = MagicMock() - - # Create a proper LLM result with real entities - mock_usage = LLMUsage( - prompt_tokens=30, - prompt_unit_price=Decimal("0.001"), - prompt_price_unit=Decimal("1000"), - prompt_price=Decimal("0.00003"), - completion_tokens=20, - completion_unit_price=Decimal("0.002"), - completion_price_unit=Decimal("1000"), - completion_price=Decimal("0.00004"), - total_tokens=50, - total_price=Decimal("0.00007"), - currency="USD", - latency=0.5, - ) - - mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?") - - mock_llm_result = LLMResult( - model="gpt-3.5-turbo", - prompt_messages=[], - message=mock_message, - usage=mock_usage, - ) - - # Create a simple mock model instance that doesn't call real providers - mock_model_instance = MagicMock() - mock_model_instance.invoke_llm.return_value = mock_llm_result - - # Create a simple mock model config with required attributes - mock_model_config = MagicMock() - mock_model_config.mode = "chat" - mock_model_config.provider = "openai" - mock_model_config.model = "gpt-3.5-turbo" - mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" - - # Mock the _fetch_model_config method - def mock_fetch_model_config_func(_node_data_model): - return mock_model_instance, mock_model_config - - # Also mock ModelManager.get_model_instance to avoid database calls - def mock_get_model_instance(_self, **kwargs): - return mock_model_instance - - with ( - patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), - patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), - ): - # execute node - result = node._run() - - for item in result: - if isinstance(item, RunCompletedEvent): - assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.process_data is not None - assert "sunny" in json.dumps(item.run_result.process_data) - assert "what's the weather today?" in json.dumps(item.run_result.process_data) - - -def test_extract_json(): node = init_llm_node( config={ "id": "llm", @@ -304,21 +189,95 @@ def test_extract_json(): "type": "llm", "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, "prompt_config": { - "structured_output": { - "enabled": True, - "schema": { - "type": "object", - "properties": {"name": {"type": "string"}, "age": {"type": "number"}}, - }, - } + "jinja2_variables": [ + {"variable": "sys_query", "value_selector": ["sys", "query"]}, + {"variable": "output", "value_selector": ["abc", "output"]}, + ] }, - "prompt_template": [{"role": "user", "text": "{{#sys.query#}}"}], + "prompt_template": [ + { + "role": "system", + "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}", + "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.", + "edition_type": "jinja2", + }, + { + "role": "user", + "text": "{{#sys.query#}}", + "jinja2_text": "{{sys_query}}", + "edition_type": "basic", + }, + ], "memory": None, "context": {"enabled": False}, "vision": {"enabled": False}, }, }, ) + + # Mock db.session.close() + db.session.close = MagicMock() + + # Create a proper LLM result with real entities + mock_usage = LLMUsage( + prompt_tokens=30, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("1000"), + prompt_price=Decimal("0.00003"), + completion_tokens=20, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("1000"), + completion_price=Decimal("0.00004"), + total_tokens=50, + total_price=Decimal("0.00007"), + currency="USD", + latency=0.5, + ) + + mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?") + + mock_llm_result = LLMResult( + model="gpt-3.5-turbo", + prompt_messages=[], + message=mock_message, + usage=mock_usage, + ) + + # Create a simple mock model instance that doesn't call real providers + mock_model_instance = MagicMock() + mock_model_instance.invoke_llm.return_value = mock_llm_result + + # Create a simple mock model config with required attributes + mock_model_config = MagicMock() + mock_model_config.mode = "chat" + mock_model_config.provider = "openai" + mock_model_config.model = "gpt-3.5-turbo" + mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" + + # Mock the _fetch_model_config method + def mock_fetch_model_config_func(_node_data_model): + return mock_model_instance, mock_model_config + + # Also mock ModelManager.get_model_instance to avoid database calls + def mock_get_model_instance(_self, **kwargs): + return mock_model_instance + + with ( + patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), + patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), + ): + # execute node + result = node._run() + + for item in result: + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.process_data is not None + assert "sunny" in json.dumps(item.run_result.process_data) + assert "what's the weather today?" in json.dumps(item.run_result.process_data) + + +def test_extract_json(): llm_texts = [ '\n\n{"name": "test", "age": 123', # resoning model (deepseek-r1) '{"name":"test","age":123}', # json schema model (gpt-4o) @@ -327,4 +286,4 @@ def test_extract_json(): '{"name":"test",age:123}', # without quotes (qwen-2.5-0.5b) ] result = {"name": "test", "age": 123} - assert all(node._parse_structured_output(item) == result for item in llm_texts) + assert all(_parse_structured_output(item) == result for item in llm_texts) diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py new file mode 100644 index 0000000000..f26be6702a --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -0,0 +1,302 @@ +import datetime +import uuid +from collections import OrderedDict +from typing import Any, NamedTuple + +from flask_restful import marshal + +from controllers.console.app.workflow_draft_variable import ( + _WORKFLOW_DRAFT_VARIABLE_FIELDS, + _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS, + _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS, + _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, +) +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from factories.variable_factory import build_segment +from models.workflow import WorkflowDraftVariable +from services.workflow_draft_variable_service import WorkflowDraftVariableList + +_TEST_APP_ID = "test_app_id" +_TEST_NODE_EXEC_ID = str(uuid.uuid4()) + + +class TestWorkflowDraftVariableFields: + def test_conversation_variable(self): + conv_var = WorkflowDraftVariable.new_conversation_variable( + app_id=_TEST_APP_ID, name="conv_var", value=build_segment(1) + ) + + conv_var.id = str(uuid.uuid4()) + conv_var.visible = True + + expected_without_value: OrderedDict[str, Any] = OrderedDict( + { + "id": str(conv_var.id), + "type": conv_var.get_variable_type().value, + "name": "conv_var", + "description": "", + "selector": [CONVERSATION_VARIABLE_NODE_ID, "conv_var"], + "value_type": "number", + "edited": False, + "visible": True, + } + ) + + assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value + expected_with_value = expected_without_value.copy() + expected_with_value["value"] = 1 + assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value + + def test_create_sys_variable(self): + sys_var = WorkflowDraftVariable.new_sys_variable( + app_id=_TEST_APP_ID, + name="sys_var", + value=build_segment("a"), + editable=True, + node_execution_id=_TEST_NODE_EXEC_ID, + ) + + sys_var.id = str(uuid.uuid4()) + sys_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + sys_var.visible = True + + expected_without_value = OrderedDict( + { + "id": str(sys_var.id), + "type": sys_var.get_variable_type().value, + "name": "sys_var", + "description": "", + "selector": [SYSTEM_VARIABLE_NODE_ID, "sys_var"], + "value_type": "string", + "edited": True, + "visible": True, + } + ) + assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value + expected_with_value = expected_without_value.copy() + expected_with_value["value"] = "a" + assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value + + def test_node_variable(self): + node_var = WorkflowDraftVariable.new_node_variable( + app_id=_TEST_APP_ID, + node_id="test_node", + name="node_var", + value=build_segment([1, "a"]), + visible=False, + node_execution_id=_TEST_NODE_EXEC_ID, + ) + + node_var.id = str(uuid.uuid4()) + node_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + + expected_without_value: OrderedDict[str, Any] = OrderedDict( + { + "id": str(node_var.id), + "type": node_var.get_variable_type().value, + "name": "node_var", + "description": "", + "selector": ["test_node", "node_var"], + "value_type": "array[any]", + "edited": True, + "visible": False, + } + ) + + assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value + expected_with_value = expected_without_value.copy() + expected_with_value["value"] = [1, "a"] + assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value + + +class TestWorkflowDraftVariableList: + def test_workflow_draft_variable_list(self): + class TestCase(NamedTuple): + name: str + var_list: WorkflowDraftVariableList + expected: dict + + node_var = WorkflowDraftVariable.new_node_variable( + app_id=_TEST_APP_ID, + node_id="test_node", + name="test_var", + value=build_segment("a"), + visible=True, + node_execution_id=_TEST_NODE_EXEC_ID, + ) + node_var.id = str(uuid.uuid4()) + node_var_dict = OrderedDict( + { + "id": str(node_var.id), + "type": node_var.get_variable_type().value, + "name": "test_var", + "description": "", + "selector": ["test_node", "test_var"], + "value_type": "string", + "edited": False, + "visible": True, + } + ) + + cases = [ + TestCase( + name="empty variable list", + var_list=WorkflowDraftVariableList(variables=[]), + expected=OrderedDict( + { + "items": [], + "total": None, + } + ), + ), + TestCase( + name="empty variable list with total", + var_list=WorkflowDraftVariableList(variables=[], total=10), + expected=OrderedDict( + { + "items": [], + "total": 10, + } + ), + ), + TestCase( + name="non-empty variable list", + var_list=WorkflowDraftVariableList(variables=[node_var], total=None), + expected=OrderedDict( + { + "items": [node_var_dict], + "total": None, + } + ), + ), + TestCase( + name="non-empty variable list with total", + var_list=WorkflowDraftVariableList(variables=[node_var], total=10), + expected=OrderedDict( + { + "items": [node_var_dict], + "total": 10, + } + ), + ), + ] + + for idx, case in enumerate(cases, 1): + assert marshal(case.var_list, _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) == case.expected, ( + f"Test case {idx} failed, {case.name=}" + ) + + +def test_workflow_node_variables_fields(): + conv_var = WorkflowDraftVariable.new_conversation_variable( + app_id=_TEST_APP_ID, name="conv_var", value=build_segment(1) + ) + resp = marshal(WorkflowDraftVariableList(variables=[conv_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) + assert isinstance(resp, dict) + assert len(resp["items"]) == 1 + item_dict = resp["items"][0] + assert item_dict["name"] == "conv_var" + assert item_dict["value"] == 1 + + +def test_workflow_file_variable_with_signed_url(): + """Test that File type variables include signed URLs in API responses.""" + from core.file.enums import FileTransferMethod, FileType + from core.file.models import File + + # Create a File object with LOCAL_FILE transfer method (which generates signed URLs) + test_file = File( + id="test_file_id", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="test_upload_file_id", + filename="test.jpg", + extension=".jpg", + mime_type="image/jpeg", + size=12345, + ) + + # Create a WorkflowDraftVariable with the File + file_var = WorkflowDraftVariable.new_node_variable( + app_id=_TEST_APP_ID, + node_id="test_node", + name="file_var", + value=build_segment(test_file), + node_execution_id=_TEST_NODE_EXEC_ID, + ) + + # Marshal the variable using the API fields + resp = marshal(WorkflowDraftVariableList(variables=[file_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) + + # Verify the response structure + assert isinstance(resp, dict) + assert len(resp["items"]) == 1 + item_dict = resp["items"][0] + assert item_dict["name"] == "file_var" + + # Verify the value is a dict (File.to_dict() result) and contains expected fields + value = item_dict["value"] + assert isinstance(value, dict) + + # Verify the File fields are preserved + assert value["id"] == test_file.id + assert value["filename"] == test_file.filename + assert value["type"] == test_file.type.value + assert value["transfer_method"] == test_file.transfer_method.value + assert value["size"] == test_file.size + + # Verify the URL is present (it should be a signed URL for LOCAL_FILE transfer method) + remote_url = value["remote_url"] + assert remote_url is not None + + assert isinstance(remote_url, str) + # For LOCAL_FILE, the URL should contain signature parameters + assert "timestamp=" in remote_url + assert "nonce=" in remote_url + assert "sign=" in remote_url + + +def test_workflow_file_variable_remote_url(): + """Test that File type variables with REMOTE_URL transfer method return the remote URL.""" + from core.file.enums import FileTransferMethod, FileType + from core.file.models import File + + # Create a File object with REMOTE_URL transfer method + test_file = File( + id="test_file_id", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/test.jpg", + filename="test.jpg", + extension=".jpg", + mime_type="image/jpeg", + size=12345, + ) + + # Create a WorkflowDraftVariable with the File + file_var = WorkflowDraftVariable.new_node_variable( + app_id=_TEST_APP_ID, + node_id="test_node", + name="file_var", + value=build_segment(test_file), + node_execution_id=_TEST_NODE_EXEC_ID, + ) + + # Marshal the variable using the API fields + resp = marshal(WorkflowDraftVariableList(variables=[file_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) + + # Verify the response structure + assert isinstance(resp, dict) + assert len(resp["items"]) == 1 + item_dict = resp["items"][0] + assert item_dict["name"] == "file_var" + + # Verify the value is a dict (File.to_dict() result) and contains expected fields + value = item_dict["value"] + assert isinstance(value, dict) + remote_url = value["remote_url"] + + # For REMOTE_URL, the URL should be the original remote URL + assert remote_url == test_file.remote_url diff --git a/api/tests/unit_tests/core/app/segments/test_factory.py b/api/tests/unit_tests/core/app/segments/test_factory.py deleted file mode 100644 index e6e289c12a..0000000000 --- a/api/tests/unit_tests/core/app/segments/test_factory.py +++ /dev/null @@ -1,165 +0,0 @@ -from uuid import uuid4 - -import pytest - -from core.variables import ( - ArrayNumberVariable, - ArrayObjectVariable, - ArrayStringVariable, - FloatVariable, - IntegerVariable, - ObjectSegment, - SecretVariable, - StringVariable, -) -from core.variables.exc import VariableError -from core.variables.segments import ArrayAnySegment -from factories import variable_factory - - -def test_string_variable(): - test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"} - result = variable_factory.build_conversation_variable_from_mapping(test_data) - assert isinstance(result, StringVariable) - - -def test_integer_variable(): - test_data = {"value_type": "number", "name": "test_int", "value": 42} - result = variable_factory.build_conversation_variable_from_mapping(test_data) - assert isinstance(result, IntegerVariable) - - -def test_float_variable(): - test_data = {"value_type": "number", "name": "test_float", "value": 3.14} - result = variable_factory.build_conversation_variable_from_mapping(test_data) - assert isinstance(result, FloatVariable) - - -def test_secret_variable(): - test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"} - result = variable_factory.build_conversation_variable_from_mapping(test_data) - assert isinstance(result, SecretVariable) - - -def test_invalid_value_type(): - test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"} - with pytest.raises(VariableError): - variable_factory.build_conversation_variable_from_mapping(test_data) - - -def test_build_a_blank_string(): - result = variable_factory.build_conversation_variable_from_mapping( - { - "value_type": "string", - "name": "blank", - "value": "", - } - ) - assert isinstance(result, StringVariable) - assert result.value == "" - - -def test_build_a_object_variable_with_none_value(): - var = variable_factory.build_segment( - { - "key1": None, - } - ) - assert isinstance(var, ObjectSegment) - assert var.value["key1"] is None - - -def test_object_variable(): - mapping = { - "id": str(uuid4()), - "value_type": "object", - "name": "test_object", - "description": "Description of the variable.", - "value": { - "key1": "text", - "key2": 2, - }, - } - variable = variable_factory.build_conversation_variable_from_mapping(mapping) - assert isinstance(variable, ObjectSegment) - assert isinstance(variable.value["key1"], str) - assert isinstance(variable.value["key2"], int) - - -def test_array_string_variable(): - mapping = { - "id": str(uuid4()), - "value_type": "array[string]", - "name": "test_array", - "description": "Description of the variable.", - "value": [ - "text", - "text", - ], - } - variable = variable_factory.build_conversation_variable_from_mapping(mapping) - assert isinstance(variable, ArrayStringVariable) - assert isinstance(variable.value[0], str) - assert isinstance(variable.value[1], str) - - -def test_array_number_variable(): - mapping = { - "id": str(uuid4()), - "value_type": "array[number]", - "name": "test_array", - "description": "Description of the variable.", - "value": [ - 1, - 2.0, - ], - } - variable = variable_factory.build_conversation_variable_from_mapping(mapping) - assert isinstance(variable, ArrayNumberVariable) - assert isinstance(variable.value[0], int) - assert isinstance(variable.value[1], float) - - -def test_array_object_variable(): - mapping = { - "id": str(uuid4()), - "value_type": "array[object]", - "name": "test_array", - "description": "Description of the variable.", - "value": [ - { - "key1": "text", - "key2": 1, - }, - { - "key1": "text", - "key2": 1, - }, - ], - } - variable = variable_factory.build_conversation_variable_from_mapping(mapping) - assert isinstance(variable, ArrayObjectVariable) - assert isinstance(variable.value[0], dict) - assert isinstance(variable.value[1], dict) - assert isinstance(variable.value[0]["key1"], str) - assert isinstance(variable.value[0]["key2"], int) - assert isinstance(variable.value[1]["key1"], str) - assert isinstance(variable.value[1]["key2"], int) - - -def test_variable_cannot_large_than_200_kb(): - with pytest.raises(VariableError): - variable_factory.build_conversation_variable_from_mapping( - { - "id": str(uuid4()), - "value_type": "string", - "name": "test_text", - "value": "a" * 1024 * 201, - } - ) - - -def test_array_none_variable(): - var = variable_factory.build_segment([None, None, None, None]) - assert isinstance(var, ArrayAnySegment) - assert var.value == [None, None, None, None] diff --git a/api/tests/unit_tests/core/file/test_models.py b/api/tests/unit_tests/core/file/test_models.py new file mode 100644 index 0000000000..3ada2087c6 --- /dev/null +++ b/api/tests/unit_tests/core/file/test_models.py @@ -0,0 +1,25 @@ +from core.file import File, FileTransferMethod, FileType + + +def test_file(): + file = File( + id="test-file", + tenant_id="test-tenant-id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id="test-related-id", + filename="image.png", + extension=".png", + mime_type="image/png", + size=67, + storage_key="test-storage-key", + url="https://example.com/image.png", + ) + assert file.tenant_id == "test-tenant-id" + assert file.type == FileType.IMAGE + assert file.transfer_method == FileTransferMethod.TOOL_FILE + assert file.related_id == "test-related-id" + assert file.filename == "image.png" + assert file.extension == ".png" + assert file.mime_type == "image/png" + assert file.size == 67 diff --git a/api/tests/unit_tests/core/app/segments/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py similarity index 100% rename from api/tests/unit_tests/core/app/segments/test_segment.py rename to api/tests/unit_tests/core/variables/test_segment.py diff --git a/api/tests/unit_tests/core/app/segments/test_variables.py b/api/tests/unit_tests/core/variables/test_variables.py similarity index 100% rename from api/tests/unit_tests/core/app/segments/test_variables.py rename to api/tests/unit_tests/core/variables/test_variables.py diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py new file mode 100644 index 0000000000..8712b61a23 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -0,0 +1,36 @@ +from core.workflow.nodes.base.node import BaseNode +from core.workflow.nodes.enums import NodeType + +# Ensures that all node classes are imported. +from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING + +_ = NODE_TYPE_CLASSES_MAPPING + + +def _get_all_subclasses(root: type[BaseNode]) -> list[type[BaseNode]]: + subclasses = [] + queue = [root] + while queue: + cls = queue.pop() + + subclasses.extend(cls.__subclasses__()) + queue.extend(cls.__subclasses__()) + + return subclasses + + +def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined(): + classes = _get_all_subclasses(BaseNode) # type: ignore + type_version_set: set[tuple[NodeType, str]] = set() + + for cls in classes: + # Validate that 'version' is directly defined in the class (not inherited) by checking the class's __dict__ + assert "version" in cls.__dict__, f"class {cls} should have version method defined (NOT INHERITED.)" + node_type = cls._node_type + node_version = cls.version() + + assert isinstance(cls._node_type, NodeType) + assert isinstance(node_version, str) + node_type_and_version = (node_type, node_version) + assert node_type_and_version not in type_version_set + type_version_set.add(node_type_and_version) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py index 6d854c950d..362072a3db 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py @@ -3,6 +3,7 @@ import uuid from unittest.mock import patch from core.app.entities.app_invoke_entities import InvokeFrom +from core.variables.segments import ArrayAnySegment, ArrayStringSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -197,7 +198,7 @@ def test_run(): count += 1 if isinstance(item, RunCompletedEvent): assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])} assert count == 20 @@ -413,7 +414,7 @@ def test_run_parallel(): count += 1 if isinstance(item, RunCompletedEvent): assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])} assert count == 32 @@ -654,7 +655,7 @@ def test_iteration_run_in_parallel_mode(): parallel_arr.append(item) if isinstance(item, RunCompletedEvent): assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])} assert count == 32 for item in sequential_result: @@ -662,7 +663,7 @@ def test_iteration_run_in_parallel_mode(): count += 1 if isinstance(item, RunCompletedEvent): assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])} assert count == 64 @@ -846,7 +847,7 @@ def test_iteration_run_error_handle(): count += 1 if isinstance(item, RunCompletedEvent): assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.outputs == {"output": [None, None]} + assert item.run_result.outputs == {"output": ArrayAnySegment(value=[None, None])} assert count == 14 # execute remove abnormal output @@ -857,5 +858,5 @@ def test_iteration_run_error_handle(): count += 1 if isinstance(item, RunCompletedEvent): assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.outputs == {"output": []} + assert item.run_result.outputs == {"output": ArrayAnySegment(value=[])} assert count == 14 diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 4cb1aa93f9..66c7818adf 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -7,6 +7,7 @@ from docx.oxml.text.paragraph import CT_P from core.file import File, FileTransferMethod from core.variables import ArrayFileSegment +from core.variables.segments import ArrayStringSegment from core.variables.variables import StringVariable from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -69,7 +70,13 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s @pytest.mark.parametrize( ("mime_type", "file_content", "expected_text", "transfer_method", "extension"), [ - ("text/plain", b"Hello, world!", ["Hello, world!"], FileTransferMethod.LOCAL_FILE, ".txt"), + ( + "text/plain", + b"Hello, world!", + ["Hello, world!"], + FileTransferMethod.LOCAL_FILE, + ".txt", + ), ( "application/pdf", b"%PDF-1.5\n%Test PDF content", @@ -84,7 +91,13 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s FileTransferMethod.REMOTE_URL, "", ), - ("text/plain", b"Remote content", ["Remote content"], FileTransferMethod.REMOTE_URL, None), + ( + "text/plain", + b"Remote content", + ["Remote content"], + FileTransferMethod.REMOTE_URL, + None, + ), ], ) def test_run_extract_text( @@ -131,7 +144,7 @@ def test_run_extract_text( assert isinstance(result, NodeRunResult) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED, result.error assert result.outputs is not None - assert result.outputs["text"] == expected_text + assert result.outputs["text"] == ArrayStringSegment(value=expected_text) if transfer_method == FileTransferMethod.REMOTE_URL: mock_ssrf_proxy_get.assert_called_once_with("https://example.com/file.txt") @@ -329,3 +342,26 @@ def test_extract_text_from_excel_all_sheets_fail(mock_excel_file): assert result == "" assert mock_excel_instance.parse.call_count == 2 + + +@patch("pandas.ExcelFile") +def test_extract_text_from_excel_numeric_type_column(mock_excel_file): + """Test extracting text from Excel file with numeric column names.""" + + # Test numeric type column + data = {1: ["Test"], 1.1: ["Test"]} + + df = pd.DataFrame(data) + + # Mock ExcelFile + mock_excel_instance = Mock() + mock_excel_instance.sheet_names = ["Sheet1"] + mock_excel_instance.parse.return_value = df + mock_excel_file.return_value = mock_excel_instance + + file_content = b"fake_excel_content" + result = _extract_text_from_excel(file_content) + + expected_manual = "| 1.0 | 1.1 |\n| --- | --- |\n| Test | Test |\n\n" + + assert expected_manual == result diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index 77d42e2692..7d3a1d6a2d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -115,7 +115,7 @@ def test_filter_files_by_type(list_operator_node): }, ] assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - for expected_file, result_file in zip(expected_files, result.outputs["result"]): + for expected_file, result_file in zip(expected_files, result.outputs["result"].value): assert expected_file["filename"] == result_file.filename assert expected_file["type"] == result_file.type assert expected_file["tenant_id"] == result_file.tenant_id diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index 9793da129d..deb3e29b86 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -5,6 +5,7 @@ from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import ArrayStringVariable, StringVariable +from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph @@ -63,10 +64,11 @@ def test_overwrite_string_variable(): name="test_string_variable", value="the second value", ) + conversation_id = str(uuid.uuid4()) # construct variable pool variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id}, user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -77,6 +79,9 @@ def test_overwrite_string_variable(): input_variable, ) + mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) + mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) + node = VariableAssignerNode( id=str(uuid.uuid4()), graph_init_params=init_params, @@ -91,11 +96,20 @@ def test_overwrite_string_variable(): "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], }, }, + conv_var_updater_factory=mock_conv_var_updater_factory, ) - with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run: - list(node.run()) - mock_run.assert_called_once() + list(node.run()) + expected_var = StringVariable( + id=conversation_variable.id, + name=conversation_variable.name, + description=conversation_variable.description, + selector=conversation_variable.selector, + value_type=conversation_variable.value_type, + value=input_variable.value, + ) + mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var) + mock_conv_var_updater.flush.assert_called_once() got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None @@ -148,9 +162,10 @@ def test_append_variable_to_array(): name="test_string_variable", value="the second value", ) + conversation_id = str(uuid.uuid4()) variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id}, user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -160,6 +175,9 @@ def test_append_variable_to_array(): input_variable, ) + mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) + mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) + node = VariableAssignerNode( id=str(uuid.uuid4()), graph_init_params=init_params, @@ -174,11 +192,22 @@ def test_append_variable_to_array(): "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], }, }, + conv_var_updater_factory=mock_conv_var_updater_factory, ) - with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run: - list(node.run()) - mock_run.assert_called_once() + list(node.run()) + expected_value = list(conversation_variable.value) + expected_value.append(input_variable.value) + expected_var = ArrayStringVariable( + id=conversation_variable.id, + name=conversation_variable.name, + description=conversation_variable.description, + selector=conversation_variable.selector, + value_type=conversation_variable.value_type, + value=expected_value, + ) + mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var) + mock_conv_var_updater.flush.assert_called_once() got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None @@ -225,13 +254,17 @@ def test_clear_array(): value=["the first value"], ) + conversation_id = str(uuid.uuid4()) variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id}, user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], ) + mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) + mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) + node = VariableAssignerNode( id=str(uuid.uuid4()), graph_init_params=init_params, @@ -246,11 +279,20 @@ def test_clear_array(): "input_variable_selector": [], }, }, + conv_var_updater_factory=mock_conv_var_updater_factory, ) - with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run: - list(node.run()) - mock_run.assert_called_once() + list(node.run()) + expected_var = ArrayStringVariable( + id=conversation_variable.id, + name=conversation_variable.name, + description=conversation_variable.description, + selector=conversation_variable.selector, + value_type=conversation_variable.value_type, + value=[], + ) + mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var) + mock_conv_var_updater.flush.assert_called_once() got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index efbcdc760c..bb8d34fad5 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -1,8 +1,12 @@ import pytest +from pydantic import ValidationError from core.file import File, FileTransferMethod, FileType from core.variables import FileSegment, StringSegment +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from factories.variable_factory import build_segment, segment_to_variable @pytest.fixture @@ -44,3 +48,38 @@ def test_use_long_selector(pool): result = pool.get(("node_1", "part_1", "part_2")) assert result is not None assert result.value == "test_value" + + +class TestVariablePool: + def test_constructor(self): + pool = VariablePool() + pool = VariablePool( + variable_dictionary={}, + user_inputs={}, + system_variables={}, + environment_variables=[], + conversation_variables=[], + ) + + pool = VariablePool( + user_inputs={"key": "value"}, + system_variables={SystemVariableKey.WORKFLOW_ID: "test_workflow_id"}, + environment_variables=[ + segment_to_variable( + segment=build_segment(1), + selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_var_1"], + name="env_var_1", + ) + ], + conversation_variables=[ + segment_to_variable( + segment=build_segment("1"), + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var_1"], + name="conv_var_1", + ) + ], + ) + + def test_constructor_with_invalid_system_variable_key(self): + with pytest.raises(ValidationError): + VariablePool(system_variables={"invalid_key": "value"}) # type: ignore diff --git a/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py b/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py index 2f90afcf89..28ef05edde 100644 --- a/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py +++ b/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py @@ -1,22 +1,10 @@ -from core.variables import SecretVariable +import dataclasses + from core.workflow.entities.variable_entities import VariableSelector -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey from core.workflow.utils import variable_template_parser def test_extract_selectors_from_template(): - variable_pool = VariablePool( - system_variables={ - SystemVariableKey("user_id"): "fake-user-id", - }, - user_inputs={}, - environment_variables=[ - SecretVariable(name="secret_key", value="fake-secret-key"), - ], - conversation_variables=[], - ) - variable_pool.add(("node_id", "custom_query"), "fake-user-query") template = ( "Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}." ) @@ -26,3 +14,35 @@ def test_extract_selectors_from_template(): VariableSelector(variable="#node_id.custom_query#", value_selector=["node_id", "custom_query"]), VariableSelector(variable="#env.secret_key#", value_selector=["env", "secret_key"]), ] + + +def test_invalid_references(): + @dataclasses.dataclass + class TestCase: + name: str + template: str + + cases = [ + TestCase( + name="lack of closing brace", + template="Hello, {{#sys.user_id#", + ), + TestCase( + name="lack of opening brace", + template="Hello, #sys.user_id#}}", + ), + TestCase( + name="lack selector name", + template="Hello, {{#sys#}}", + ), + TestCase( + name="empty node name part", + template="Hello, {{#.user_id#}}", + ), + ] + for idx, c in enumerate(cases, 1): + fail_msg = f"Test case {c.name} failed, index={idx}" + selectors = variable_template_parser.extract_selectors_from_template(c.template) + assert selectors == [], fail_msg + parser = variable_template_parser.VariableTemplateParser(c.template) + assert parser.extract_variable_selectors() == [], fail_msg diff --git a/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py b/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py new file mode 100644 index 0000000000..f1cb937bb3 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py @@ -0,0 +1,148 @@ +from typing import Any + +from core.variables.segments import ObjectSegment, StringSegment +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.utils.variable_utils import append_variables_recursively + + +class TestAppendVariablesRecursively: + """Test cases for append_variables_recursively function""" + + def test_append_simple_dict_value(self): + """Test appending a simple dictionary value""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["output"] + variable_value = {"name": "John", "age": 30} + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check that the main variable is added + main_var = pool.get([node_id] + variable_key_list) + assert main_var is not None + assert main_var.value == variable_value + + # Check that nested variables are added recursively + name_var = pool.get([node_id] + variable_key_list + ["name"]) + assert name_var is not None + assert name_var.value == "John" + + age_var = pool.get([node_id] + variable_key_list + ["age"]) + assert age_var is not None + assert age_var.value == 30 + + def test_append_object_segment_value(self): + """Test appending an ObjectSegment value""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["result"] + + # Create an ObjectSegment + obj_data = {"status": "success", "code": 200} + variable_value = ObjectSegment(value=obj_data) + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check that the main variable is added + main_var = pool.get([node_id] + variable_key_list) + assert main_var is not None + assert isinstance(main_var, ObjectSegment) + assert main_var.value == obj_data + + # Check that nested variables are added recursively + status_var = pool.get([node_id] + variable_key_list + ["status"]) + assert status_var is not None + assert status_var.value == "success" + + code_var = pool.get([node_id] + variable_key_list + ["code"]) + assert code_var is not None + assert code_var.value == 200 + + def test_append_nested_dict_value(self): + """Test appending a nested dictionary value""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["data"] + + variable_value = { + "user": { + "profile": {"name": "Alice", "email": "alice@example.com"}, + "settings": {"theme": "dark", "notifications": True}, + }, + "metadata": {"version": "1.0", "timestamp": 1234567890}, + } + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check deeply nested variables + name_var = pool.get([node_id] + variable_key_list + ["user", "profile", "name"]) + assert name_var is not None + assert name_var.value == "Alice" + + email_var = pool.get([node_id] + variable_key_list + ["user", "profile", "email"]) + assert email_var is not None + assert email_var.value == "alice@example.com" + + theme_var = pool.get([node_id] + variable_key_list + ["user", "settings", "theme"]) + assert theme_var is not None + assert theme_var.value == "dark" + + notifications_var = pool.get([node_id] + variable_key_list + ["user", "settings", "notifications"]) + assert notifications_var is not None + assert notifications_var.value == 1 # Boolean True is converted to integer 1 + + version_var = pool.get([node_id] + variable_key_list + ["metadata", "version"]) + assert version_var is not None + assert version_var.value == "1.0" + + def test_append_non_dict_value(self): + """Test appending a non-dictionary value (should not recurse)""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["simple"] + variable_value = "simple_string" + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check that only the main variable is added + main_var = pool.get([node_id] + variable_key_list) + assert main_var is not None + assert main_var.value == variable_value + + # Ensure no additional variables are created + assert len(pool.variable_dictionary[node_id]) == 1 + + def test_append_segment_non_object_value(self): + """Test appending a Segment that is not ObjectSegment (should not recurse)""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["text"] + variable_value = StringSegment(value="Hello World") + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check that only the main variable is added + main_var = pool.get([node_id] + variable_key_list) + assert main_var is not None + assert isinstance(main_var, StringSegment) + assert main_var.value == "Hello World" + + # Ensure no additional variables are created + assert len(pool.variable_dictionary[node_id]) == 1 + + def test_append_empty_dict_value(self): + """Test appending an empty dictionary value""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["empty"] + variable_value: dict[str, Any] = {} + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check that the main variable is added + main_var = pool.get([node_id] + variable_key_list) + assert main_var is not None + assert main_var.value == {} + + # Ensure only the main variable is created (no recursion for empty dict) + assert len(pool.variable_dictionary[node_id]) == 1 diff --git a/api/tests/unit_tests/factories/test_variable_factory.py b/api/tests/unit_tests/factories/test_variable_factory.py new file mode 100644 index 0000000000..481fbdc91a --- /dev/null +++ b/api/tests/unit_tests/factories/test_variable_factory.py @@ -0,0 +1,865 @@ +import math +from dataclasses import dataclass +from typing import Any +from uuid import uuid4 + +import pytest +from hypothesis import given +from hypothesis import strategies as st + +from core.file import File, FileTransferMethod, FileType +from core.variables import ( + ArrayNumberVariable, + ArrayObjectVariable, + ArrayStringVariable, + FloatVariable, + IntegerVariable, + ObjectSegment, + SecretVariable, + SegmentType, + StringVariable, +) +from core.variables.exc import VariableError +from core.variables.segments import ( + ArrayAnySegment, + ArrayFileSegment, + ArrayNumberSegment, + ArrayObjectSegment, + ArrayStringSegment, + FileSegment, + FloatSegment, + IntegerSegment, + NoneSegment, + ObjectSegment, + StringSegment, +) +from core.variables.types import SegmentType +from factories import variable_factory +from factories.variable_factory import TypeMismatchError, build_segment_with_type + + +def test_string_variable(): + test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"} + result = variable_factory.build_conversation_variable_from_mapping(test_data) + assert isinstance(result, StringVariable) + + +def test_integer_variable(): + test_data = {"value_type": "number", "name": "test_int", "value": 42} + result = variable_factory.build_conversation_variable_from_mapping(test_data) + assert isinstance(result, IntegerVariable) + + +def test_float_variable(): + test_data = {"value_type": "number", "name": "test_float", "value": 3.14} + result = variable_factory.build_conversation_variable_from_mapping(test_data) + assert isinstance(result, FloatVariable) + + +def test_secret_variable(): + test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"} + result = variable_factory.build_conversation_variable_from_mapping(test_data) + assert isinstance(result, SecretVariable) + + +def test_invalid_value_type(): + test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"} + with pytest.raises(VariableError): + variable_factory.build_conversation_variable_from_mapping(test_data) + + +def test_build_a_blank_string(): + result = variable_factory.build_conversation_variable_from_mapping( + { + "value_type": "string", + "name": "blank", + "value": "", + } + ) + assert isinstance(result, StringVariable) + assert result.value == "" + + +def test_build_a_object_variable_with_none_value(): + var = variable_factory.build_segment( + { + "key1": None, + } + ) + assert isinstance(var, ObjectSegment) + assert var.value["key1"] is None + + +def test_object_variable(): + mapping = { + "id": str(uuid4()), + "value_type": "object", + "name": "test_object", + "description": "Description of the variable.", + "value": { + "key1": "text", + "key2": 2, + }, + } + variable = variable_factory.build_conversation_variable_from_mapping(mapping) + assert isinstance(variable, ObjectSegment) + assert isinstance(variable.value["key1"], str) + assert isinstance(variable.value["key2"], int) + + +def test_array_string_variable(): + mapping = { + "id": str(uuid4()), + "value_type": "array[string]", + "name": "test_array", + "description": "Description of the variable.", + "value": [ + "text", + "text", + ], + } + variable = variable_factory.build_conversation_variable_from_mapping(mapping) + assert isinstance(variable, ArrayStringVariable) + assert isinstance(variable.value[0], str) + assert isinstance(variable.value[1], str) + + +def test_array_number_variable(): + mapping = { + "id": str(uuid4()), + "value_type": "array[number]", + "name": "test_array", + "description": "Description of the variable.", + "value": [ + 1, + 2.0, + ], + } + variable = variable_factory.build_conversation_variable_from_mapping(mapping) + assert isinstance(variable, ArrayNumberVariable) + assert isinstance(variable.value[0], int) + assert isinstance(variable.value[1], float) + + +def test_array_object_variable(): + mapping = { + "id": str(uuid4()), + "value_type": "array[object]", + "name": "test_array", + "description": "Description of the variable.", + "value": [ + { + "key1": "text", + "key2": 1, + }, + { + "key1": "text", + "key2": 1, + }, + ], + } + variable = variable_factory.build_conversation_variable_from_mapping(mapping) + assert isinstance(variable, ArrayObjectVariable) + assert isinstance(variable.value[0], dict) + assert isinstance(variable.value[1], dict) + assert isinstance(variable.value[0]["key1"], str) + assert isinstance(variable.value[0]["key2"], int) + assert isinstance(variable.value[1]["key1"], str) + assert isinstance(variable.value[1]["key2"], int) + + +def test_variable_cannot_large_than_200_kb(): + with pytest.raises(VariableError): + variable_factory.build_conversation_variable_from_mapping( + { + "id": str(uuid4()), + "value_type": "string", + "name": "test_text", + "value": "a" * 1024 * 201, + } + ) + + +def test_array_none_variable(): + var = variable_factory.build_segment([None, None, None, None]) + assert isinstance(var, ArrayAnySegment) + assert var.value == [None, None, None, None] + + +def test_build_segment_none_type(): + """Test building NoneSegment from None value.""" + segment = variable_factory.build_segment(None) + assert isinstance(segment, NoneSegment) + assert segment.value is None + assert segment.value_type == SegmentType.NONE + + +def test_build_segment_none_type_properties(): + """Test NoneSegment properties and methods.""" + segment = variable_factory.build_segment(None) + assert segment.text == "" + assert segment.log == "" + assert segment.markdown == "" + assert segment.to_object() is None + + +def test_build_segment_array_file_single_file(): + """Test building ArrayFileSegment from list with single file.""" + file = File( + id="test_file_id", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://test.example.com/test-file.png", + filename="test-file", + extension=".png", + mime_type="image/png", + size=1000, + ) + segment = variable_factory.build_segment([file]) + assert isinstance(segment, ArrayFileSegment) + assert len(segment.value) == 1 + assert segment.value[0] == file + assert segment.value_type == SegmentType.ARRAY_FILE + + +def test_build_segment_array_file_multiple_files(): + """Test building ArrayFileSegment from list with multiple files.""" + file1 = File( + id="test_file_id_1", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://test.example.com/test-file1.png", + filename="test-file1", + extension=".png", + mime_type="image/png", + size=1000, + ) + file2 = File( + id="test_file_id_2", + tenant_id="test_tenant_id", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="test_relation_id", + filename="test-file2", + extension=".txt", + mime_type="text/plain", + size=500, + ) + segment = variable_factory.build_segment([file1, file2]) + assert isinstance(segment, ArrayFileSegment) + assert len(segment.value) == 2 + assert segment.value[0] == file1 + assert segment.value[1] == file2 + assert segment.value_type == SegmentType.ARRAY_FILE + + +def test_build_segment_array_file_empty_list(): + """Test building ArrayFileSegment from empty list should create ArrayAnySegment.""" + segment = variable_factory.build_segment([]) + assert isinstance(segment, ArrayAnySegment) + assert segment.value == [] + assert segment.value_type == SegmentType.ARRAY_ANY + + +def test_build_segment_array_any_mixed_types(): + """Test building ArrayAnySegment from list with mixed types.""" + mixed_values = ["string", 42, 3.14, {"key": "value"}, None] + segment = variable_factory.build_segment(mixed_values) + assert isinstance(segment, ArrayAnySegment) + assert segment.value == mixed_values + assert segment.value_type == SegmentType.ARRAY_ANY + + +def test_build_segment_array_any_with_nested_arrays(): + """Test building ArrayAnySegment from list containing arrays.""" + nested_values = [["nested", "array"], [1, 2, 3], "string"] + segment = variable_factory.build_segment(nested_values) + assert isinstance(segment, ArrayAnySegment) + assert segment.value == nested_values + assert segment.value_type == SegmentType.ARRAY_ANY + + +def test_build_segment_array_any_mixed_with_files(): + """Test building ArrayAnySegment from list with files and other types.""" + file = File( + id="test_file_id", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://test.example.com/test-file.png", + filename="test-file", + extension=".png", + mime_type="image/png", + size=1000, + ) + mixed_values = [file, "string", 42] + segment = variable_factory.build_segment(mixed_values) + assert isinstance(segment, ArrayAnySegment) + assert segment.value == mixed_values + assert segment.value_type == SegmentType.ARRAY_ANY + + +def test_build_segment_array_any_all_none_values(): + """Test building ArrayAnySegment from list with all None values.""" + none_values = [None, None, None] + segment = variable_factory.build_segment(none_values) + assert isinstance(segment, ArrayAnySegment) + assert segment.value == none_values + assert segment.value_type == SegmentType.ARRAY_ANY + + +def test_build_segment_array_file_properties(): + """Test ArrayFileSegment properties and methods.""" + file1 = File( + id="test_file_id_1", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://test.example.com/test-file1.png", + filename="test-file1", + extension=".png", + mime_type="image/png", + size=1000, + ) + file2 = File( + id="test_file_id_2", + tenant_id="test_tenant_id", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://test.example.com/test-file2.txt", + filename="test-file2", + extension=".txt", + mime_type="text/plain", + size=500, + ) + segment = variable_factory.build_segment([file1, file2]) + + # Test properties + assert segment.text == "" # ArrayFileSegment text property returns empty string + assert segment.log == "" # ArrayFileSegment log property returns empty string + assert segment.markdown == f"{file1.markdown}\n{file2.markdown}" + assert segment.to_object() == [file1, file2] + + +def test_build_segment_array_any_properties(): + """Test ArrayAnySegment properties and methods.""" + mixed_values = ["string", 42, None] + segment = variable_factory.build_segment(mixed_values) + + # Test properties + assert segment.text == str(mixed_values) + assert segment.log == str(mixed_values) + assert segment.markdown == "string\n42\nNone" + assert segment.to_object() == mixed_values + + +def test_build_segment_edge_cases(): + """Test edge cases for build_segment function.""" + # Test with complex nested structures + complex_structure = [{"nested": {"deep": [1, 2, 3]}}, [{"inner": "value"}], "mixed"] + segment = variable_factory.build_segment(complex_structure) + assert isinstance(segment, ArrayAnySegment) + assert segment.value == complex_structure + + # Test with single None in list + single_none = [None] + segment = variable_factory.build_segment(single_none) + assert isinstance(segment, ArrayAnySegment) + assert segment.value == single_none + + +def test_build_segment_file_array_with_different_file_types(): + """Test ArrayFileSegment with different file types.""" + image_file = File( + id="image_id", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://test.example.com/image.png", + filename="image", + extension=".png", + mime_type="image/png", + size=1000, + ) + + video_file = File( + id="video_id", + tenant_id="test_tenant_id", + type=FileType.VIDEO, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="video_relation_id", + filename="video", + extension=".mp4", + mime_type="video/mp4", + size=5000, + ) + + audio_file = File( + id="audio_id", + tenant_id="test_tenant_id", + type=FileType.AUDIO, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="audio_relation_id", + filename="audio", + extension=".mp3", + mime_type="audio/mpeg", + size=3000, + ) + + segment = variable_factory.build_segment([image_file, video_file, audio_file]) + assert isinstance(segment, ArrayFileSegment) + assert len(segment.value) == 3 + assert segment.value[0].type == FileType.IMAGE + assert segment.value[1].type == FileType.VIDEO + assert segment.value[2].type == FileType.AUDIO + + +@st.composite +def _generate_file(draw) -> File: + file_id = draw(st.text(min_size=1, max_size=10)) + tenant_id = draw(st.text(min_size=1, max_size=10)) + file_type, mime_type, extension = draw( + st.sampled_from( + [ + (FileType.IMAGE, "image/png", ".png"), + (FileType.VIDEO, "video/mp4", ".mp4"), + (FileType.DOCUMENT, "text/plain", ".txt"), + (FileType.AUDIO, "audio/mpeg", ".mp3"), + ] + ) + ) + filename = "test-file" + size = draw(st.integers(min_value=0, max_value=1024 * 1024)) + + transfer_method = draw(st.sampled_from(list(FileTransferMethod))) + if transfer_method == FileTransferMethod.REMOTE_URL: + url = "https://test.example.com/test-file" + file = File( + id="test_file_id", + tenant_id="test_tenant_id", + type=file_type, + transfer_method=transfer_method, + remote_url=url, + related_id=None, + filename=filename, + extension=extension, + mime_type=mime_type, + size=size, + ) + else: + relation_id = draw(st.uuids(version=4)) + + file = File( + id="test_file_id", + tenant_id="test_tenant_id", + type=file_type, + transfer_method=transfer_method, + related_id=str(relation_id), + filename=filename, + extension=extension, + mime_type=mime_type, + size=size, + ) + return file + + +def _scalar_value() -> st.SearchStrategy[int | float | str | File | None]: + return st.one_of( + st.none(), + st.integers(), + st.floats(), + st.text(), + _generate_file(), + ) + + +@given(_scalar_value()) +def test_build_segment_and_extract_values_for_scalar_types(value): + seg = variable_factory.build_segment(value) + # nan == nan yields false, so we need to use `math.isnan` to check `seg.value` here. + if isinstance(value, float) and math.isnan(value): + assert math.isnan(seg.value) + else: + assert seg.value == value + + +@given(st.lists(_scalar_value())) +def test_build_segment_and_extract_values_for_array_types(values): + seg = variable_factory.build_segment(values) + assert seg.value == values + + +def test_build_segment_type_for_scalar(): + @dataclass(frozen=True) + class TestCase: + value: int | float | str | File + expected_type: SegmentType + + file = File( + id="test_file_id", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://test.example.com/test-file.png", + filename="test-file", + extension=".png", + mime_type="image/png", + size=1000, + ) + cases = [ + TestCase(0, SegmentType.NUMBER), + TestCase(0.0, SegmentType.NUMBER), + TestCase("", SegmentType.STRING), + TestCase(file, SegmentType.FILE), + ] + + for idx, c in enumerate(cases, 1): + segment = variable_factory.build_segment(c.value) + assert segment.value_type == c.expected_type, f"test case {idx} failed." + + +class TestBuildSegmentWithType: + """Test cases for build_segment_with_type function.""" + + def test_string_type(self): + """Test building a string segment with correct type.""" + result = build_segment_with_type(SegmentType.STRING, "hello") + assert isinstance(result, StringSegment) + assert result.value == "hello" + assert result.value_type == SegmentType.STRING + + def test_number_type_integer(self): + """Test building a number segment with integer value.""" + result = build_segment_with_type(SegmentType.NUMBER, 42) + assert isinstance(result, IntegerSegment) + assert result.value == 42 + assert result.value_type == SegmentType.NUMBER + + def test_number_type_float(self): + """Test building a number segment with float value.""" + result = build_segment_with_type(SegmentType.NUMBER, 3.14) + assert isinstance(result, FloatSegment) + assert result.value == 3.14 + assert result.value_type == SegmentType.NUMBER + + def test_object_type(self): + """Test building an object segment with correct type.""" + test_obj = {"key": "value", "nested": {"inner": 123}} + result = build_segment_with_type(SegmentType.OBJECT, test_obj) + assert isinstance(result, ObjectSegment) + assert result.value == test_obj + assert result.value_type == SegmentType.OBJECT + + def test_file_type(self): + """Test building a file segment with correct type.""" + test_file = File( + id="test_file_id", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://test.example.com/test-file.png", + filename="test-file", + extension=".png", + mime_type="image/png", + size=1000, + storage_key="test_storage_key", + ) + result = build_segment_with_type(SegmentType.FILE, test_file) + assert isinstance(result, FileSegment) + assert result.value == test_file + assert result.value_type == SegmentType.FILE + + def test_none_type(self): + """Test building a none segment with None value.""" + result = build_segment_with_type(SegmentType.NONE, None) + assert isinstance(result, NoneSegment) + assert result.value is None + assert result.value_type == SegmentType.NONE + + def test_empty_array_string(self): + """Test building an empty array[string] segment.""" + result = build_segment_with_type(SegmentType.ARRAY_STRING, []) + assert isinstance(result, ArrayStringSegment) + assert result.value == [] + assert result.value_type == SegmentType.ARRAY_STRING + + def test_empty_array_number(self): + """Test building an empty array[number] segment.""" + result = build_segment_with_type(SegmentType.ARRAY_NUMBER, []) + assert isinstance(result, ArrayNumberSegment) + assert result.value == [] + assert result.value_type == SegmentType.ARRAY_NUMBER + + def test_empty_array_object(self): + """Test building an empty array[object] segment.""" + result = build_segment_with_type(SegmentType.ARRAY_OBJECT, []) + assert isinstance(result, ArrayObjectSegment) + assert result.value == [] + assert result.value_type == SegmentType.ARRAY_OBJECT + + def test_empty_array_file(self): + """Test building an empty array[file] segment.""" + result = build_segment_with_type(SegmentType.ARRAY_FILE, []) + assert isinstance(result, ArrayFileSegment) + assert result.value == [] + assert result.value_type == SegmentType.ARRAY_FILE + + def test_empty_array_any(self): + """Test building an empty array[any] segment.""" + result = build_segment_with_type(SegmentType.ARRAY_ANY, []) + assert isinstance(result, ArrayAnySegment) + assert result.value == [] + assert result.value_type == SegmentType.ARRAY_ANY + + def test_array_with_values(self): + """Test building array segments with actual values.""" + # Array of strings + result = build_segment_with_type(SegmentType.ARRAY_STRING, ["hello", "world"]) + assert isinstance(result, ArrayStringSegment) + assert result.value == ["hello", "world"] + assert result.value_type == SegmentType.ARRAY_STRING + + # Array of numbers + result = build_segment_with_type(SegmentType.ARRAY_NUMBER, [1, 2, 3.14]) + assert isinstance(result, ArrayNumberSegment) + assert result.value == [1, 2, 3.14] + assert result.value_type == SegmentType.ARRAY_NUMBER + + # Array of objects + result = build_segment_with_type(SegmentType.ARRAY_OBJECT, [{"a": 1}, {"b": 2}]) + assert isinstance(result, ArrayObjectSegment) + assert result.value == [{"a": 1}, {"b": 2}] + assert result.value_type == SegmentType.ARRAY_OBJECT + + def test_type_mismatch_string_to_number(self): + """Test type mismatch when expecting number but getting string.""" + with pytest.raises(TypeMismatchError) as exc_info: + build_segment_with_type(SegmentType.NUMBER, "not_a_number") + + assert "Type mismatch" in str(exc_info.value) + assert "expected number" in str(exc_info.value) + assert "str" in str(exc_info.value) + + def test_type_mismatch_number_to_string(self): + """Test type mismatch when expecting string but getting number.""" + with pytest.raises(TypeMismatchError) as exc_info: + build_segment_with_type(SegmentType.STRING, 123) + + assert "Type mismatch" in str(exc_info.value) + assert "expected string" in str(exc_info.value) + assert "int" in str(exc_info.value) + + def test_type_mismatch_none_to_string(self): + """Test type mismatch when expecting string but getting None.""" + with pytest.raises(TypeMismatchError) as exc_info: + build_segment_with_type(SegmentType.STRING, None) + + assert "Expected string, but got None" in str(exc_info.value) + + def test_type_mismatch_empty_list_to_non_array(self): + """Test type mismatch when expecting non-array type but getting empty list.""" + with pytest.raises(TypeMismatchError) as exc_info: + build_segment_with_type(SegmentType.STRING, []) + + assert "Expected string, but got empty list" in str(exc_info.value) + + def test_type_mismatch_object_to_array(self): + """Test type mismatch when expecting array but getting object.""" + with pytest.raises(TypeMismatchError) as exc_info: + build_segment_with_type(SegmentType.ARRAY_STRING, {"key": "value"}) + + assert "Type mismatch" in str(exc_info.value) + assert "expected array[string]" in str(exc_info.value) + + def test_compatible_number_types(self): + """Test that int and float are both compatible with NUMBER type.""" + # Integer should work + result_int = build_segment_with_type(SegmentType.NUMBER, 42) + assert isinstance(result_int, IntegerSegment) + assert result_int.value_type == SegmentType.NUMBER + + # Float should work + result_float = build_segment_with_type(SegmentType.NUMBER, 3.14) + assert isinstance(result_float, FloatSegment) + assert result_float.value_type == SegmentType.NUMBER + + @pytest.mark.parametrize( + ("segment_type", "value", "expected_class"), + [ + (SegmentType.STRING, "test", StringSegment), + (SegmentType.NUMBER, 42, IntegerSegment), + (SegmentType.NUMBER, 3.14, FloatSegment), + (SegmentType.OBJECT, {}, ObjectSegment), + (SegmentType.NONE, None, NoneSegment), + (SegmentType.ARRAY_STRING, [], ArrayStringSegment), + (SegmentType.ARRAY_NUMBER, [], ArrayNumberSegment), + (SegmentType.ARRAY_OBJECT, [], ArrayObjectSegment), + (SegmentType.ARRAY_ANY, [], ArrayAnySegment), + ], + ) + def test_parametrized_valid_types(self, segment_type, value, expected_class): + """Parametrized test for valid type combinations.""" + result = build_segment_with_type(segment_type, value) + assert isinstance(result, expected_class) + assert result.value == value + assert result.value_type == segment_type + + @pytest.mark.parametrize( + ("segment_type", "value"), + [ + (SegmentType.STRING, 123), + (SegmentType.NUMBER, "not_a_number"), + (SegmentType.OBJECT, "not_an_object"), + (SegmentType.ARRAY_STRING, "not_an_array"), + (SegmentType.STRING, None), + (SegmentType.NUMBER, None), + ], + ) + def test_parametrized_type_mismatches(self, segment_type, value): + """Parametrized test for type mismatches that should raise TypeMismatchError.""" + with pytest.raises(TypeMismatchError): + build_segment_with_type(segment_type, value) + + +# Test cases for ValueError scenarios in build_segment function +class TestBuildSegmentValueErrors: + """Test cases for ValueError scenarios in the build_segment function.""" + + @dataclass(frozen=True) + class ValueErrorTestCase: + """Test case data for ValueError scenarios.""" + + name: str + description: str + test_value: Any + + def _get_test_cases(self): + """Get all test cases for ValueError scenarios.""" + + # Define inline classes for complex test cases + class CustomType: + pass + + def unsupported_function(): + return "test" + + def gen(): + yield 1 + yield 2 + + return [ + self.ValueErrorTestCase( + name="unsupported_custom_type", + description="custom class that doesn't match any supported type", + test_value=CustomType(), + ), + self.ValueErrorTestCase( + name="unsupported_set_type", + description="set (unsupported collection type)", + test_value={1, 2, 3}, + ), + self.ValueErrorTestCase( + name="unsupported_tuple_type", description="tuple (unsupported type)", test_value=(1, 2, 3) + ), + self.ValueErrorTestCase( + name="unsupported_bytes_type", + description="bytes (unsupported type)", + test_value=b"hello world", + ), + self.ValueErrorTestCase( + name="unsupported_function_type", + description="function (unsupported type)", + test_value=unsupported_function, + ), + self.ValueErrorTestCase( + name="unsupported_module_type", description="module (unsupported type)", test_value=math + ), + self.ValueErrorTestCase( + name="array_with_unsupported_element_types", + description="array with unsupported element types", + test_value=[CustomType()], + ), + self.ValueErrorTestCase( + name="mixed_array_with_unsupported_types", + description="array with mix of supported and unsupported types", + test_value=["valid_string", 42, CustomType()], + ), + self.ValueErrorTestCase( + name="nested_unsupported_types", + description="nested structures containing unsupported types", + test_value=[{"valid": "data"}, CustomType()], + ), + self.ValueErrorTestCase( + name="complex_number_type", + description="complex number (unsupported type)", + test_value=3 + 4j, + ), + self.ValueErrorTestCase( + name="range_type", description="range object (unsupported type)", test_value=range(10) + ), + self.ValueErrorTestCase( + name="generator_type", + description="generator (unsupported type)", + test_value=gen(), + ), + self.ValueErrorTestCase( + name="exception_message_contains_value", + description="set to verify error message contains the actual unsupported value", + test_value={1, 2, 3}, + ), + self.ValueErrorTestCase( + name="array_with_mixed_unsupported_segment_types", + description="array processing with unsupported segment types in match", + test_value=[CustomType()], + ), + self.ValueErrorTestCase( + name="frozenset_type", + description="frozenset (unsupported type)", + test_value=frozenset([1, 2, 3]), + ), + self.ValueErrorTestCase( + name="memoryview_type", + description="memoryview (unsupported type)", + test_value=memoryview(b"hello"), + ), + self.ValueErrorTestCase( + name="slice_type", description="slice object (unsupported type)", test_value=slice(1, 10, 2) + ), + self.ValueErrorTestCase(name="type_object", description="type object (unsupported type)", test_value=type), + self.ValueErrorTestCase( + name="generic_object", description="generic object (unsupported type)", test_value=object() + ), + ] + + def test_build_segment_unsupported_types(self): + """Table-driven test for all ValueError scenarios in build_segment function.""" + test_cases = self._get_test_cases() + + for index, test_case in enumerate(test_cases, 1): + # Use test value directly + test_value = test_case.test_value + + with pytest.raises(ValueError) as exc_info: # noqa: PT012 + segment = variable_factory.build_segment(test_value) + pytest.fail(f"Test case {index} ({test_case.name}) should raise ValueError but not, result={segment}") + + error_message = str(exc_info.value) + assert "not supported value" in error_message, ( + f"Test case {index} ({test_case.name}): Expected 'not supported value' in error message, " + f"but got: {error_message}" + ) + + def test_build_segment_boolean_type_note(self): + """Note: Boolean values are actually handled as integers in Python, so they don't raise ValueError.""" + # Boolean values in Python are subclasses of int, so they get processed as integers + # True becomes IntegerSegment(value=1) and False becomes IntegerSegment(value=0) + true_segment = variable_factory.build_segment(True) + false_segment = variable_factory.build_segment(False) + + # Verify they are processed as integers, not as errors + assert true_segment.value == 1, "Test case 1 (boolean_true): Expected True to be processed as integer 1" + assert false_segment.value == 0, "Test case 2 (boolean_false): Expected False to be processed as integer 0" + assert true_segment.value_type == SegmentType.NUMBER + assert false_segment.value_type == SegmentType.NUMBER diff --git a/api/tests/unit_tests/libs/test_datetime_utils.py b/api/tests/unit_tests/libs/test_datetime_utils.py new file mode 100644 index 0000000000..e7781a5821 --- /dev/null +++ b/api/tests/unit_tests/libs/test_datetime_utils.py @@ -0,0 +1,20 @@ +import datetime + +from libs.datetime_utils import naive_utc_now + + +def test_naive_utc_now(monkeypatch): + tz_aware_utc_now = datetime.datetime.now(tz=datetime.UTC) + + def _now_func(tz: datetime.timezone | None) -> datetime.datetime: + return tz_aware_utc_now.astimezone(tz) + + monkeypatch.setattr("libs.datetime_utils._now_func", _now_func) + + naive_datetime = naive_utc_now() + + assert naive_datetime.tzinfo is None + assert naive_datetime.date() == tz_aware_utc_now.date() + naive_time = naive_datetime.time() + utc_time = tz_aware_utc_now.time() + assert naive_time == utc_time diff --git a/api/tests/unit_tests/models/__init__.py b/api/tests/unit_tests/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index b79e95c7ed..69163d48bd 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -1,10 +1,15 @@ +import dataclasses import json from unittest import mock from uuid import uuid4 from constants import HIDDEN_VALUE +from core.file.enums import FileTransferMethod, FileType +from core.file.models import File from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable -from models.workflow import Workflow, WorkflowNodeExecutionModel +from core.variables.segments import IntegerSegment, Segment +from factories.variable_factory import build_segment +from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable def test_environment_variables(): @@ -163,3 +168,147 @@ class TestWorkflowNodeExecution: original = {"a": 1, "b": ["2"]} node_exec.execution_metadata = json.dumps(original) assert node_exec.execution_metadata_dict == original + + +class TestIsSystemVariableEditable: + def test_is_system_variable(self): + cases = [ + ("query", True), + ("files", True), + ("dialogue_count", False), + ("conversation_id", False), + ("user_id", False), + ("app_id", False), + ("workflow_id", False), + ("workflow_run_id", False), + ] + for name, editable in cases: + assert editable == is_system_variable_editable(name) + + assert is_system_variable_editable("invalid_or_new_system_variable") == False + + +class TestWorkflowDraftVariableGetValue: + def test_get_value_by_case(self): + @dataclasses.dataclass + class TestCase: + name: str + value: Segment + + tenant_id = "test_tenant_id" + + test_file = File( + tenant_id=tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/example.jpg", + filename="example.jpg", + extension=".jpg", + mime_type="image/jpeg", + size=100, + ) + cases: list[TestCase] = [ + TestCase( + name="number/int", + value=build_segment(1), + ), + TestCase( + name="number/float", + value=build_segment(1.0), + ), + TestCase( + name="string", + value=build_segment("a"), + ), + TestCase( + name="object", + value=build_segment({}), + ), + TestCase( + name="file", + value=build_segment(test_file), + ), + TestCase( + name="array[any]", + value=build_segment([1, "a"]), + ), + TestCase( + name="array[string]", + value=build_segment(["a", "b"]), + ), + TestCase( + name="array[number]/int", + value=build_segment([1, 2]), + ), + TestCase( + name="array[number]/float", + value=build_segment([1.0, 2.0]), + ), + TestCase( + name="array[number]/mixed", + value=build_segment([1, 2.0]), + ), + TestCase( + name="array[object]", + value=build_segment([{}, {"a": 1}]), + ), + TestCase( + name="none", + value=build_segment(None), + ), + ] + + for idx, c in enumerate(cases, 1): + fail_msg = f"test case {c.name} failed, index={idx}" + draft_var = WorkflowDraftVariable() + draft_var.set_value(c.value) + assert c.value == draft_var.get_value(), fail_msg + + def test_file_variable_preserves_all_fields(self): + """Test that File type variables preserve all fields during encoding/decoding.""" + tenant_id = "test_tenant_id" + + # Create a File with specific field values + test_file = File( + id="test_file_id", + tenant_id=tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/test.jpg", + filename="test.jpg", + extension=".jpg", + mime_type="image/jpeg", + size=12345, # Specific size to test preservation + storage_key="test_storage_key", + ) + + # Create a FileSegment and WorkflowDraftVariable + file_segment = build_segment(test_file) + draft_var = WorkflowDraftVariable() + draft_var.set_value(file_segment) + + # Retrieve the value and verify all fields are preserved + retrieved_segment = draft_var.get_value() + retrieved_file = retrieved_segment.value + + # Verify all important fields are preserved + assert retrieved_file.id == test_file.id + assert retrieved_file.tenant_id == test_file.tenant_id + assert retrieved_file.type == test_file.type + assert retrieved_file.transfer_method == test_file.transfer_method + assert retrieved_file.remote_url == test_file.remote_url + assert retrieved_file.filename == test_file.filename + assert retrieved_file.extension == test_file.extension + assert retrieved_file.mime_type == test_file.mime_type + assert retrieved_file.size == test_file.size # This was the main issue being fixed + # Note: storage_key is not serialized in model_dump() so it won't be preserved + + # Verify the segments have the same type and the important fields match + assert file_segment.value_type == retrieved_segment.value_type + + def test_get_and_set_value(self): + draft_var = WorkflowDraftVariable() + int_var = IntegerSegment(value=1) + draft_var.set_value(int_var) + value = draft_var.get_value() + assert value == int_var diff --git a/api/tests/unit_tests/services/test_dataset_permission.py b/api/tests/unit_tests/services/test_dataset_permission.py index 066f541c1b..a67252e856 100644 --- a/api/tests/unit_tests/services/test_dataset_permission.py +++ b/api/tests/unit_tests/services/test_dataset_permission.py @@ -8,151 +8,298 @@ from services.dataset_service import DatasetService from services.errors.account import NoPermissionError +class DatasetPermissionTestDataFactory: + """Factory class for creating test data and mock objects for dataset permission tests.""" + + @staticmethod + def create_dataset_mock( + dataset_id: str = "dataset-123", + tenant_id: str = "test-tenant-123", + created_by: str = "creator-456", + permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME, + **kwargs, + ) -> Mock: + """Create a mock dataset with specified attributes.""" + dataset = Mock(spec=Dataset) + dataset.id = dataset_id + dataset.tenant_id = tenant_id + dataset.created_by = created_by + dataset.permission = permission + for key, value in kwargs.items(): + setattr(dataset, key, value) + return dataset + + @staticmethod + def create_user_mock( + user_id: str = "user-789", + tenant_id: str = "test-tenant-123", + role: TenantAccountRole = TenantAccountRole.NORMAL, + **kwargs, + ) -> Mock: + """Create a mock user with specified attributes.""" + user = Mock(spec=Account) + user.id = user_id + user.current_tenant_id = tenant_id + user.current_role = role + for key, value in kwargs.items(): + setattr(user, key, value) + return user + + @staticmethod + def create_dataset_permission_mock( + dataset_id: str = "dataset-123", + account_id: str = "user-789", + **kwargs, + ) -> Mock: + """Create a mock dataset permission record.""" + permission = Mock(spec=DatasetPermission) + permission.dataset_id = dataset_id + permission.account_id = account_id + for key, value in kwargs.items(): + setattr(permission, key, value) + return permission + + class TestDatasetPermissionService: - """Test cases for dataset permission checking functionality""" + """ + Comprehensive unit tests for DatasetService.check_dataset_permission method. - def setup_method(self): - """Set up test fixtures""" - # Mock tenant and user - self.tenant_id = "test-tenant-123" - self.creator_id = "creator-456" - self.normal_user_id = "normal-789" - self.owner_user_id = "owner-999" + This test suite covers all permission scenarios including: + - Cross-tenant access restrictions + - Owner privilege checks + - Different permission levels (ONLY_ME, ALL_TEAM, PARTIAL_TEAM) + - Explicit permission checks for PARTIAL_TEAM + - Error conditions and logging + """ - # Mock dataset - self.dataset = Mock(spec=Dataset) - self.dataset.id = "dataset-123" - self.dataset.tenant_id = self.tenant_id - self.dataset.created_by = self.creator_id + @pytest.fixture + def mock_dataset_service_dependencies(self): + """Common mock setup for dataset service dependencies.""" + with patch("services.dataset_service.db.session") as mock_session: + yield { + "db_session": mock_session, + } - # Mock users - self.creator_user = Mock(spec=Account) - self.creator_user.id = self.creator_id - self.creator_user.current_tenant_id = self.tenant_id - self.creator_user.current_role = TenantAccountRole.EDITOR - - self.normal_user = Mock(spec=Account) - self.normal_user.id = self.normal_user_id - self.normal_user.current_tenant_id = self.tenant_id - self.normal_user.current_role = TenantAccountRole.NORMAL - - self.owner_user = Mock(spec=Account) - self.owner_user.id = self.owner_user_id - self.owner_user.current_tenant_id = self.tenant_id - self.owner_user.current_role = TenantAccountRole.OWNER - - def test_permission_check_different_tenant_should_fail(self): - """Test that users from different tenants cannot access dataset""" - self.normal_user.current_tenant_id = "different-tenant" - - with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset."): - DatasetService.check_dataset_permission(self.dataset, self.normal_user) - - def test_owner_can_access_any_dataset(self): - """Test that tenant owners can access any dataset regardless of permission""" - self.dataset.permission = DatasetPermissionEnum.ONLY_ME + @pytest.fixture + def mock_logging_dependencies(self): + """Mock setup for logging tests.""" + with patch("services.dataset_service.logging") as mock_logging: + yield { + "logging": mock_logging, + } + def _assert_permission_check_passes(self, dataset: Mock, user: Mock): + """Helper method to verify that permission check passes without raising exceptions.""" # Should not raise any exception - DatasetService.check_dataset_permission(self.dataset, self.owner_user) + DatasetService.check_dataset_permission(dataset, user) - def test_only_me_permission_creator_can_access(self): - """Test ONLY_ME permission allows only creator to access""" - self.dataset.permission = DatasetPermissionEnum.ONLY_ME + def _assert_permission_check_fails( + self, dataset: Mock, user: Mock, expected_message: str = "You do not have permission to access this dataset." + ): + """Helper method to verify that permission check fails with expected error.""" + with pytest.raises(NoPermissionError, match=expected_message): + DatasetService.check_dataset_permission(dataset, user) - # Creator should be able to access - DatasetService.check_dataset_permission(self.dataset, self.creator_user) + def _assert_database_query_called(self, mock_session: Mock, dataset_id: str, account_id: str): + """Helper method to verify database query calls for permission checks.""" + mock_session.query().filter_by.assert_called_with(dataset_id=dataset_id, account_id=account_id) - def test_only_me_permission_others_cannot_access(self): - """Test ONLY_ME permission denies access to non-creators""" - self.dataset.permission = DatasetPermissionEnum.ONLY_ME - - with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset."): - DatasetService.check_dataset_permission(self.dataset, self.normal_user) - - def test_all_team_permission_allows_access(self): - """Test ALL_TEAM permission allows any team member to access""" - self.dataset.permission = DatasetPermissionEnum.ALL_TEAM - - # Should not raise any exception for team members - DatasetService.check_dataset_permission(self.dataset, self.normal_user) - DatasetService.check_dataset_permission(self.dataset, self.creator_user) - - @patch("services.dataset_service.db.session") - def test_partial_team_permission_creator_can_access(self, mock_session): - """Test PARTIAL_TEAM permission allows creator to access""" - self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM - - # Should not raise any exception for creator - DatasetService.check_dataset_permission(self.dataset, self.creator_user) - - # Should not query database for creator + def _assert_database_query_not_called(self, mock_session: Mock): + """Helper method to verify that database query was not called.""" mock_session.query.assert_not_called() - @patch("services.dataset_service.db.session") - def test_partial_team_permission_with_explicit_permission(self, mock_session): - """Test PARTIAL_TEAM permission allows users with explicit permission""" - self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM + # ==================== Cross-Tenant Access Tests ==================== + + def test_permission_check_different_tenant_should_fail(self): + """Test that users from different tenants cannot access dataset regardless of other permissions.""" + # Create dataset and user from different tenants + dataset = DatasetPermissionTestDataFactory.create_dataset_mock( + tenant_id="tenant-123", permission=DatasetPermissionEnum.ALL_TEAM + ) + user = DatasetPermissionTestDataFactory.create_user_mock( + user_id="user-789", tenant_id="different-tenant-456", role=TenantAccountRole.EDITOR + ) + + # Should fail due to different tenant + self._assert_permission_check_fails(dataset, user) + + # ==================== Owner Privilege Tests ==================== + + def test_owner_can_access_any_dataset(self): + """Test that tenant owners can access any dataset regardless of permission level.""" + # Create dataset with restrictive permission + dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ONLY_ME) + + # Create owner user + owner_user = DatasetPermissionTestDataFactory.create_user_mock( + user_id="owner-999", role=TenantAccountRole.OWNER + ) + + # Owner should have access regardless of dataset permission + self._assert_permission_check_passes(dataset, owner_user) + + # ==================== ONLY_ME Permission Tests ==================== + + def test_only_me_permission_creator_can_access(self): + """Test ONLY_ME permission allows only the dataset creator to access.""" + # Create dataset with ONLY_ME permission + dataset = DatasetPermissionTestDataFactory.create_dataset_mock( + created_by="creator-456", permission=DatasetPermissionEnum.ONLY_ME + ) + + # Create creator user + creator_user = DatasetPermissionTestDataFactory.create_user_mock( + user_id="creator-456", role=TenantAccountRole.EDITOR + ) + + # Creator should be able to access + self._assert_permission_check_passes(dataset, creator_user) + + def test_only_me_permission_others_cannot_access(self): + """Test ONLY_ME permission denies access to non-creators.""" + # Create dataset with ONLY_ME permission + dataset = DatasetPermissionTestDataFactory.create_dataset_mock( + created_by="creator-456", permission=DatasetPermissionEnum.ONLY_ME + ) + + # Create normal user (not the creator) + normal_user = DatasetPermissionTestDataFactory.create_user_mock( + user_id="normal-789", role=TenantAccountRole.NORMAL + ) + + # Non-creator should be denied access + self._assert_permission_check_fails(dataset, normal_user) + + # ==================== ALL_TEAM Permission Tests ==================== + + def test_all_team_permission_allows_access(self): + """Test ALL_TEAM permission allows any team member to access the dataset.""" + # Create dataset with ALL_TEAM permission + dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ALL_TEAM) + + # Create different types of team members + normal_user = DatasetPermissionTestDataFactory.create_user_mock( + user_id="normal-789", role=TenantAccountRole.NORMAL + ) + editor_user = DatasetPermissionTestDataFactory.create_user_mock( + user_id="editor-456", role=TenantAccountRole.EDITOR + ) + + # All team members should have access + self._assert_permission_check_passes(dataset, normal_user) + self._assert_permission_check_passes(dataset, editor_user) + + # ==================== PARTIAL_TEAM Permission Tests ==================== + + def test_partial_team_permission_creator_can_access(self, mock_dataset_service_dependencies): + """Test PARTIAL_TEAM permission allows creator to access without database query.""" + # Create dataset with PARTIAL_TEAM permission + dataset = DatasetPermissionTestDataFactory.create_dataset_mock( + created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM + ) + + # Create creator user + creator_user = DatasetPermissionTestDataFactory.create_user_mock( + user_id="creator-456", role=TenantAccountRole.EDITOR + ) + + # Creator should have access without database query + self._assert_permission_check_passes(dataset, creator_user) + self._assert_database_query_not_called(mock_dataset_service_dependencies["db_session"]) + + def test_partial_team_permission_with_explicit_permission(self, mock_dataset_service_dependencies): + """Test PARTIAL_TEAM permission allows users with explicit permission records.""" + # Create dataset with PARTIAL_TEAM permission + dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM) + + # Create normal user (not the creator) + normal_user = DatasetPermissionTestDataFactory.create_user_mock( + user_id="normal-789", role=TenantAccountRole.NORMAL + ) # Mock database query to return a permission record - mock_permission = Mock(spec=DatasetPermission) - mock_session.query().filter_by().first.return_value = mock_permission + mock_permission = DatasetPermissionTestDataFactory.create_dataset_permission_mock( + dataset_id=dataset.id, account_id=normal_user.id + ) + mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = mock_permission - # Should not raise any exception - DatasetService.check_dataset_permission(self.dataset, self.normal_user) + # User with explicit permission should have access + self._assert_permission_check_passes(dataset, normal_user) + self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, normal_user.id) - # Verify database was queried correctly - mock_session.query().filter_by.assert_called_with(dataset_id=self.dataset.id, account_id=self.normal_user.id) + def test_partial_team_permission_without_explicit_permission(self, mock_dataset_service_dependencies): + """Test PARTIAL_TEAM permission denies users without explicit permission records.""" + # Create dataset with PARTIAL_TEAM permission + dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM) - @patch("services.dataset_service.db.session") - def test_partial_team_permission_without_explicit_permission(self, mock_session): - """Test PARTIAL_TEAM permission denies users without explicit permission""" - self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM + # Create normal user (not the creator) + normal_user = DatasetPermissionTestDataFactory.create_user_mock( + user_id="normal-789", role=TenantAccountRole.NORMAL + ) # Mock database query to return None (no permission record) - mock_session.query().filter_by().first.return_value = None + mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None - with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset."): - DatasetService.check_dataset_permission(self.dataset, self.normal_user) + # User without explicit permission should be denied access + self._assert_permission_check_fails(dataset, normal_user) + self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, normal_user.id) - # Verify database was queried correctly - mock_session.query().filter_by.assert_called_with(dataset_id=self.dataset.id, account_id=self.normal_user.id) - - @patch("services.dataset_service.db.session") - def test_partial_team_permission_non_creator_without_permission_fails(self, mock_session): - """Test that non-creators without explicit permission are denied access""" - self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM + def test_partial_team_permission_non_creator_without_permission_fails(self, mock_dataset_service_dependencies): + """Test that non-creators without explicit permission are denied access to PARTIAL_TEAM datasets.""" + # Create dataset with PARTIAL_TEAM permission + dataset = DatasetPermissionTestDataFactory.create_dataset_mock( + created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM + ) # Create a different user (not the creator) - other_user = Mock(spec=Account) - other_user.id = "other-user-123" - other_user.current_tenant_id = self.tenant_id - other_user.current_role = TenantAccountRole.NORMAL + other_user = DatasetPermissionTestDataFactory.create_user_mock( + user_id="other-user-123", role=TenantAccountRole.NORMAL + ) # Mock database query to return None (no permission record) - mock_session.query().filter_by().first.return_value = None + mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None - with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset."): - DatasetService.check_dataset_permission(self.dataset, other_user) + # Non-creator without explicit permission should be denied access + self._assert_permission_check_fails(dataset, other_user) + self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, other_user.id) + + # ==================== Enum Usage Tests ==================== def test_partial_team_permission_uses_correct_enum(self): - """Test that the method correctly uses DatasetPermissionEnum.PARTIAL_TEAM""" - # This test ensures we're using the enum instead of string literals - self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM - - # Creator should always have access - DatasetService.check_dataset_permission(self.dataset, self.creator_user) - - @patch("services.dataset_service.logging") - @patch("services.dataset_service.db.session") - def test_permission_denied_logs_debug_message(self, mock_session, mock_logging): - """Test that permission denied events are logged""" - self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM - mock_session.query().filter_by().first.return_value = None - - with pytest.raises(NoPermissionError): - DatasetService.check_dataset_permission(self.dataset, self.normal_user) - - # Verify debug message was logged - mock_logging.debug.assert_called_with( - f"User {self.normal_user.id} does not have permission to access dataset {self.dataset.id}" + """Test that the method correctly uses DatasetPermissionEnum.PARTIAL_TEAM instead of string literals.""" + # Create dataset with PARTIAL_TEAM permission using enum + dataset = DatasetPermissionTestDataFactory.create_dataset_mock( + created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM + ) + + # Create creator user + creator_user = DatasetPermissionTestDataFactory.create_user_mock( + user_id="creator-456", role=TenantAccountRole.EDITOR + ) + + # Creator should always have access regardless of permission level + self._assert_permission_check_passes(dataset, creator_user) + + # ==================== Logging Tests ==================== + + def test_permission_denied_logs_debug_message(self, mock_dataset_service_dependencies, mock_logging_dependencies): + """Test that permission denied events are properly logged for debugging purposes.""" + # Create dataset with PARTIAL_TEAM permission + dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM) + + # Create normal user (not the creator) + normal_user = DatasetPermissionTestDataFactory.create_user_mock( + user_id="normal-789", role=TenantAccountRole.NORMAL + ) + + # Mock database query to return None (no permission record) + mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None + + # Attempt permission check (should fail) + with pytest.raises(NoPermissionError): + DatasetService.check_dataset_permission(dataset, normal_user) + + # Verify debug message was logged with correct user and dataset information + mock_logging_dependencies["logging"].debug.assert_called_with( + f"User {normal_user.id} does not have permission to access dataset {dataset.id}" ) diff --git a/api/tests/unit_tests/services/test_dataset_service.py b/api/tests/unit_tests/services/test_dataset_service.py deleted file mode 100644 index f22500cfe4..0000000000 --- a/api/tests/unit_tests/services/test_dataset_service.py +++ /dev/null @@ -1,1238 +0,0 @@ -import datetime -import unittest - -# Mock redis_client before importing dataset_service -from unittest.mock import Mock, call, patch - -import pytest - -from models.dataset import Dataset, Document -from services.dataset_service import DocumentService -from services.errors.document import DocumentIndexingError -from tests.unit_tests.conftest import redis_mock - - -class TestDatasetServiceBatchUpdateDocumentStatus(unittest.TestCase): - """ - Comprehensive unit tests for DocumentService.batch_update_document_status method. - - This test suite covers all supported actions (enable, disable, archive, un_archive), - error conditions, edge cases, and validates proper interaction with Redis cache, - database operations, and async task triggers. - """ - - @patch("extensions.ext_database.db.session") - @patch("services.dataset_service.add_document_to_index_task") - @patch("services.dataset_service.DocumentService.get_document") - @patch("services.dataset_service.datetime") - def test_batch_update_enable_documents_success(self, mock_datetime, mock_get_doc, mock_add_task, mock_db): - """ - Test successful enabling of disabled documents. - - Verifies that: - 1. Only disabled documents are processed (already enabled documents are skipped) - 2. Document attributes are updated correctly (enabled=True, metadata cleared) - 3. Database changes are committed for each document - 4. Redis cache keys are set to prevent concurrent indexing - 5. Async indexing task is triggered for each enabled document - 6. Timestamp fields are properly updated - """ - # Create mock dataset - mock_dataset = Mock(spec=Dataset) - mock_dataset.id = "dataset-123" - mock_dataset.tenant_id = "tenant-456" - - # Create mock user - mock_user = Mock() - mock_user.id = "user-789" - - # Create mock disabled document - mock_disabled_doc_1 = Mock(spec=Document) - mock_disabled_doc_1.id = "doc-1" - mock_disabled_doc_1.name = "disabled_document.pdf" - mock_disabled_doc_1.enabled = False - mock_disabled_doc_1.archived = False - mock_disabled_doc_1.indexing_status = "completed" - mock_disabled_doc_1.completed_at = datetime.datetime.now() - - mock_disabled_doc_2 = Mock(spec=Document) - mock_disabled_doc_2.id = "doc-2" - mock_disabled_doc_2.name = "disabled_document.pdf" - mock_disabled_doc_2.enabled = False - mock_disabled_doc_2.archived = False - mock_disabled_doc_2.indexing_status = "completed" - mock_disabled_doc_2.completed_at = datetime.datetime.now() - - # Set up mock return values - current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) - mock_datetime.datetime.now.return_value = current_time - mock_datetime.UTC = datetime.UTC - - # Mock document retrieval to return disabled documents - mock_get_doc.side_effect = [mock_disabled_doc_1, mock_disabled_doc_2] - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Call the method to enable documents - DocumentService.batch_update_document_status( - dataset=mock_dataset, document_ids=["doc-1", "doc-2"], action="enable", user=mock_user - ) - - # Verify document attributes were updated correctly - for mock_doc in [mock_disabled_doc_1, mock_disabled_doc_2]: - # Check that document was enabled - assert mock_doc.enabled == True - # Check that disable metadata was cleared - assert mock_doc.disabled_at is None - assert mock_doc.disabled_by is None - # Check that update timestamp was set - assert mock_doc.updated_at == current_time.replace(tzinfo=None) - - # Verify Redis cache operations - expected_cache_calls = [call("document_doc-1_indexing"), call("document_doc-2_indexing")] - redis_mock.get.assert_has_calls(expected_cache_calls) - - # Verify Redis cache was set to prevent concurrent indexing (600 seconds) - expected_setex_calls = [call("document_doc-1_indexing", 600, 1), call("document_doc-2_indexing", 600, 1)] - redis_mock.setex.assert_has_calls(expected_setex_calls) - - # Verify async tasks were triggered for indexing - expected_task_calls = [call("doc-1"), call("doc-2")] - mock_add_task.delay.assert_has_calls(expected_task_calls) - - # Verify database add counts (one add for one document) - assert mock_db.add.call_count == 2 - # Verify database commits (one commit for the batch operation) - assert mock_db.commit.call_count == 1 - - @patch("extensions.ext_database.db.session") - @patch("services.dataset_service.remove_document_from_index_task") - @patch("services.dataset_service.DocumentService.get_document") - @patch("services.dataset_service.datetime") - def test_batch_update_disable_documents_success(self, mock_datetime, mock_get_doc, mock_remove_task, mock_db): - """ - Test successful disabling of enabled and completed documents. - - Verifies that: - 1. Only completed and enabled documents can be disabled - 2. Document attributes are updated correctly (enabled=False, disable metadata set) - 3. User ID is recorded in disabled_by field - 4. Database changes are committed for each document - 5. Redis cache keys are set to prevent concurrent indexing - 6. Async task is triggered to remove documents from index - """ - # Create mock dataset - mock_dataset = Mock(spec=Dataset) - mock_dataset.id = "dataset-123" - mock_dataset.tenant_id = "tenant-456" - - # Create mock user - mock_user = Mock() - mock_user.id = "user-789" - - # Create mock enabled document - mock_enabled_doc_1 = Mock(spec=Document) - mock_enabled_doc_1.id = "doc-1" - mock_enabled_doc_1.name = "enabled_document.pdf" - mock_enabled_doc_1.enabled = True - mock_enabled_doc_1.archived = False - mock_enabled_doc_1.indexing_status = "completed" - mock_enabled_doc_1.completed_at = datetime.datetime.now() - - mock_enabled_doc_2 = Mock(spec=Document) - mock_enabled_doc_2.id = "doc-2" - mock_enabled_doc_2.name = "enabled_document.pdf" - mock_enabled_doc_2.enabled = True - mock_enabled_doc_2.archived = False - mock_enabled_doc_2.indexing_status = "completed" - mock_enabled_doc_2.completed_at = datetime.datetime.now() - - # Set up mock return values - current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) - mock_datetime.datetime.now.return_value = current_time - mock_datetime.UTC = datetime.UTC - - # Mock document retrieval to return enabled, completed documents - mock_get_doc.side_effect = [mock_enabled_doc_1, mock_enabled_doc_2] - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Call the method to disable documents - DocumentService.batch_update_document_status( - dataset=mock_dataset, document_ids=["doc-1", "doc-2"], action="disable", user=mock_user - ) - - # Verify document attributes were updated correctly - for mock_doc in [mock_enabled_doc_1, mock_enabled_doc_2]: - # Check that document was disabled - assert mock_doc.enabled == False - # Check that disable metadata was set correctly - assert mock_doc.disabled_at == current_time.replace(tzinfo=None) - assert mock_doc.disabled_by == mock_user.id - # Check that update timestamp was set - assert mock_doc.updated_at == current_time.replace(tzinfo=None) - - # Verify Redis cache operations for indexing prevention - expected_setex_calls = [call("document_doc-1_indexing", 600, 1), call("document_doc-2_indexing", 600, 1)] - redis_mock.setex.assert_has_calls(expected_setex_calls) - - # Verify async tasks were triggered to remove from index - expected_task_calls = [call("doc-1"), call("doc-2")] - mock_remove_task.delay.assert_has_calls(expected_task_calls) - - # Verify database add counts (one add for one document) - assert mock_db.add.call_count == 2 - # Verify database commits (totally 1 for any batch operation) - assert mock_db.commit.call_count == 1 - - @patch("extensions.ext_database.db.session") - @patch("services.dataset_service.remove_document_from_index_task") - @patch("services.dataset_service.DocumentService.get_document") - @patch("services.dataset_service.datetime") - def test_batch_update_archive_documents_success(self, mock_datetime, mock_get_doc, mock_remove_task, mock_db): - """ - Test successful archiving of unarchived documents. - - Verifies that: - 1. Only unarchived documents are processed (already archived are skipped) - 2. Document attributes are updated correctly (archived=True, archive metadata set) - 3. User ID is recorded in archived_by field - 4. If documents are enabled, they are removed from the index - 5. Redis cache keys are set only for enabled documents being archived - 6. Database changes are committed for each document - """ - # Create mock dataset - mock_dataset = Mock(spec=Dataset) - mock_dataset.id = "dataset-123" - mock_dataset.tenant_id = "tenant-456" - - # Create mock user - mock_user = Mock() - mock_user.id = "user-789" - - # Create unarchived enabled document - unarchived_doc = Mock(spec=Document) - # Manually set attributes to ensure they can be modified - unarchived_doc.id = "doc-1" - unarchived_doc.name = "unarchived_document.pdf" - unarchived_doc.enabled = True - unarchived_doc.archived = False - - # Set up mock return values - current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) - mock_datetime.datetime.now.return_value = current_time - mock_datetime.UTC = datetime.UTC - - mock_get_doc.return_value = unarchived_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Call the method to archive documents - DocumentService.batch_update_document_status( - dataset=mock_dataset, document_ids=["doc-1"], action="archive", user=mock_user - ) - - # Verify document attributes were updated correctly - assert unarchived_doc.archived == True - assert unarchived_doc.archived_at == current_time.replace(tzinfo=None) - assert unarchived_doc.archived_by == mock_user.id - assert unarchived_doc.updated_at == current_time.replace(tzinfo=None) - - # Verify Redis cache was set (because document was enabled) - redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) - - # Verify async task was triggered to remove from index (because enabled) - mock_remove_task.delay.assert_called_once_with("doc-1") - - # Verify database add - mock_db.add.assert_called_once() - # Verify database commit - mock_db.commit.assert_called_once() - - @patch("extensions.ext_database.db.session") - @patch("services.dataset_service.add_document_to_index_task") - @patch("services.dataset_service.DocumentService.get_document") - @patch("services.dataset_service.datetime") - def test_batch_update_unarchive_documents_success(self, mock_datetime, mock_get_doc, mock_add_task, mock_db): - """ - Test successful unarchiving of archived documents. - - Verifies that: - 1. Only archived documents are processed (already unarchived are skipped) - 2. Document attributes are updated correctly (archived=False, archive metadata cleared) - 3. If documents are enabled, they are added back to the index - 4. Redis cache keys are set only for enabled documents being unarchived - 5. Database changes are committed for each document - """ - # Create mock dataset - mock_dataset = Mock(spec=Dataset) - mock_dataset.id = "dataset-123" - mock_dataset.tenant_id = "tenant-456" - - # Create mock user - mock_user = Mock() - mock_user.id = "user-789" - - # Create mock archived document - mock_archived_doc = Mock(spec=Document) - mock_archived_doc.id = "doc-3" - mock_archived_doc.name = "archived_document.pdf" - mock_archived_doc.enabled = True - mock_archived_doc.archived = True - mock_archived_doc.indexing_status = "completed" - mock_archived_doc.completed_at = datetime.datetime.now() - - # Set up mock return values - current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) - mock_datetime.datetime.now.return_value = current_time - mock_datetime.UTC = datetime.UTC - - mock_get_doc.return_value = mock_archived_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Call the method to unarchive documents - DocumentService.batch_update_document_status( - dataset=mock_dataset, document_ids=["doc-3"], action="un_archive", user=mock_user - ) - - # Verify document attributes were updated correctly - assert mock_archived_doc.archived == False - assert mock_archived_doc.archived_at is None - assert mock_archived_doc.archived_by is None - assert mock_archived_doc.updated_at == current_time.replace(tzinfo=None) - - # Verify Redis cache was set (because document is enabled) - redis_mock.setex.assert_called_once_with("document_doc-3_indexing", 600, 1) - - # Verify async task was triggered to add back to index (because enabled) - mock_add_task.delay.assert_called_once_with("doc-3") - - # Verify database add - mock_db.add.assert_called_once() - # Verify database commit - mock_db.commit.assert_called_once() - - @patch("services.dataset_service.DocumentService.get_document") - def test_batch_update_document_indexing_error_redis_cache_hit(self, mock_get_doc): - """ - Test that DocumentIndexingError is raised when documents are currently being indexed. - - Verifies that: - 1. The method checks Redis cache for active indexing operations - 2. DocumentIndexingError is raised if any document is being indexed - 3. Error message includes the document name for user feedback - 4. No further processing occurs when indexing is detected - """ - # Create mock dataset - mock_dataset = Mock(spec=Dataset) - mock_dataset.id = "dataset-123" - mock_dataset.tenant_id = "tenant-456" - - # Create mock user - mock_user = Mock() - mock_user.id = "user-789" - - # Create mock enabled document - mock_enabled_doc = Mock(spec=Document) - mock_enabled_doc.id = "doc-1" - mock_enabled_doc.name = "enabled_document.pdf" - mock_enabled_doc.enabled = True - mock_enabled_doc.archived = False - mock_enabled_doc.indexing_status = "completed" - mock_enabled_doc.completed_at = datetime.datetime.now() - - # Set up mock to indicate document is being indexed - mock_get_doc.return_value = mock_enabled_doc - - # Reset module-level Redis mock, set to indexing status - redis_mock.reset_mock() - redis_mock.get.return_value = "indexing" - - # Verify that DocumentIndexingError is raised - with pytest.raises(DocumentIndexingError) as exc_info: - DocumentService.batch_update_document_status( - dataset=mock_dataset, document_ids=["doc-1"], action="enable", user=mock_user - ) - - # Verify error message contains document name - assert "enabled_document.pdf" in str(exc_info.value) - assert "is being indexed" in str(exc_info.value) - - # Verify Redis cache was checked - redis_mock.get.assert_called_once_with("document_doc-1_indexing") - - @patch("services.dataset_service.DocumentService.get_document") - def test_batch_update_disable_non_completed_document_error(self, mock_get_doc): - """ - Test that DocumentIndexingError is raised when trying to disable non-completed documents. - - Verifies that: - 1. Only completed documents can be disabled - 2. DocumentIndexingError is raised for non-completed documents - 3. Error message indicates the document is not completed - """ - # Create mock dataset - mock_dataset = Mock(spec=Dataset) - mock_dataset.id = "dataset-123" - mock_dataset.tenant_id = "tenant-456" - - # Create mock user - mock_user = Mock() - mock_user.id = "user-789" - - # Create a document that's not completed - non_completed_doc = Mock(spec=Document) - # Manually set attributes to ensure they can be modified - non_completed_doc.id = "doc-1" - non_completed_doc.name = "indexing_document.pdf" - non_completed_doc.enabled = True - non_completed_doc.indexing_status = "indexing" # Not completed - non_completed_doc.completed_at = None # Not completed - - mock_get_doc.return_value = non_completed_doc - - # Verify that DocumentIndexingError is raised - with pytest.raises(DocumentIndexingError) as exc_info: - DocumentService.batch_update_document_status( - dataset=mock_dataset, document_ids=["doc-1"], action="disable", user=mock_user - ) - - # Verify error message indicates document is not completed - assert "is not completed" in str(exc_info.value) - - @patch("services.dataset_service.DocumentService.get_document") - def test_batch_update_empty_document_list(self, mock_get_doc): - """ - Test batch operations with an empty document ID list. - - Verifies that: - 1. The method handles empty input gracefully - 2. No document operations are performed with empty input - 3. No errors are raised with empty input - 4. Method returns early without processing - """ - # Create mock dataset - mock_dataset = Mock(spec=Dataset) - mock_dataset.id = "dataset-123" - mock_dataset.tenant_id = "tenant-456" - - # Create mock user - mock_user = Mock() - mock_user.id = "user-789" - - # Call method with empty document list - result = DocumentService.batch_update_document_status( - dataset=mock_dataset, document_ids=[], action="enable", user=mock_user - ) - - # Verify no document lookups were performed - mock_get_doc.assert_not_called() - - # Verify method returns None (early return) - assert result is None - - @patch("services.dataset_service.DocumentService.get_document") - def test_batch_update_document_not_found_skipped(self, mock_get_doc): - """ - Test behavior when some documents don't exist in the database. - - Verifies that: - 1. Non-existent documents are gracefully skipped - 2. Processing continues for existing documents - 3. No errors are raised for missing document IDs - 4. Method completes successfully despite missing documents - """ - # Create mock dataset - mock_dataset = Mock(spec=Dataset) - mock_dataset.id = "dataset-123" - mock_dataset.tenant_id = "tenant-456" - - # Create mock user - mock_user = Mock() - mock_user.id = "user-789" - - # Mock document service to return None (document not found) - mock_get_doc.return_value = None - - # Call method with non-existent document ID - # This should not raise an error, just skip the missing document - try: - DocumentService.batch_update_document_status( - dataset=mock_dataset, document_ids=["non-existent-doc"], action="enable", user=mock_user - ) - except Exception as e: - pytest.fail(f"Method should not raise exception for missing documents: {e}") - - # Verify document lookup was attempted - mock_get_doc.assert_called_once_with(mock_dataset.id, "non-existent-doc") - - @patch("extensions.ext_database.db.session") - @patch("services.dataset_service.DocumentService.get_document") - def test_batch_update_enable_already_enabled_document_skipped(self, mock_get_doc, mock_db): - """ - Test enabling documents that are already enabled. - - Verifies that: - 1. Already enabled documents are skipped (no unnecessary operations) - 2. No database commits occur for already enabled documents - 3. No Redis cache operations occur for skipped documents - 4. No async tasks are triggered for skipped documents - 5. Method completes successfully - """ - # Create mock dataset - mock_dataset = Mock(spec=Dataset) - mock_dataset.id = "dataset-123" - mock_dataset.tenant_id = "tenant-456" - - # Create mock user - mock_user = Mock() - mock_user.id = "user-789" - - # Create mock enabled document - mock_enabled_doc = Mock(spec=Document) - mock_enabled_doc.id = "doc-1" - mock_enabled_doc.name = "enabled_document.pdf" - mock_enabled_doc.enabled = True - mock_enabled_doc.archived = False - mock_enabled_doc.indexing_status = "completed" - mock_enabled_doc.completed_at = datetime.datetime.now() - - # Mock document that is already enabled - mock_get_doc.return_value = mock_enabled_doc # Already enabled - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Attempt to enable already enabled document - DocumentService.batch_update_document_status( - dataset=mock_dataset, document_ids=["doc-1"], action="enable", user=mock_user - ) - - # Verify no database operations occurred (document was skipped) - mock_db.commit.assert_not_called() - - # Verify no Redis setex operations occurred (document was skipped) - redis_mock.setex.assert_not_called() - - @patch("extensions.ext_database.db.session") - @patch("services.dataset_service.DocumentService.get_document") - def test_batch_update_archive_already_archived_document_skipped(self, mock_get_doc, mock_db): - """ - Test archiving documents that are already archived. - - Verifies that: - 1. Already archived documents are skipped (no unnecessary operations) - 2. No database commits occur for already archived documents - 3. No Redis cache operations occur for skipped documents - 4. No async tasks are triggered for skipped documents - 5. Method completes successfully - """ - # Create mock dataset - mock_dataset = Mock(spec=Dataset) - mock_dataset.id = "dataset-123" - mock_dataset.tenant_id = "tenant-456" - - # Create mock user - mock_user = Mock() - mock_user.id = "user-789" - - # Create mock archived document - mock_archived_doc = Mock(spec=Document) - mock_archived_doc.id = "doc-3" - mock_archived_doc.name = "archived_document.pdf" - mock_archived_doc.enabled = True - mock_archived_doc.archived = True - mock_archived_doc.indexing_status = "completed" - mock_archived_doc.completed_at = datetime.datetime.now() - - # Mock document that is already archived - mock_get_doc.return_value = mock_archived_doc # Already archived - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Attempt to archive already archived document - DocumentService.batch_update_document_status( - dataset=mock_dataset, document_ids=["doc-3"], action="archive", user=mock_user - ) - - # Verify no database operations occurred (document was skipped) - mock_db.commit.assert_not_called() - - # Verify no Redis setex operations occurred (document was skipped) - redis_mock.setex.assert_not_called() - - @patch("extensions.ext_database.db.session") - @patch("services.dataset_service.add_document_to_index_task") - @patch("services.dataset_service.remove_document_from_index_task") - @patch("services.dataset_service.DocumentService.get_document") - @patch("services.dataset_service.datetime") - def test_batch_update_mixed_document_states_and_actions( - self, mock_datetime, mock_get_doc, mock_remove_task, mock_add_task, mock_db - ): - """ - Test batch operations on documents with mixed states and various scenarios. - - Verifies that: - 1. Each document is processed according to its current state - 2. Some documents may be skipped while others are processed - 3. Different async tasks are triggered based on document states - 4. Method handles mixed scenarios gracefully - 5. Database commits occur only for documents that were actually modified - """ - # Create mock dataset - mock_dataset = Mock(spec=Dataset) - mock_dataset.id = "dataset-123" - mock_dataset.tenant_id = "tenant-456" - - # Create mock user - mock_user = Mock() - mock_user.id = "user-789" - - # Create mock documents with different states - mock_disabled_doc = Mock(spec=Document) - mock_disabled_doc.id = "doc-1" - mock_disabled_doc.name = "disabled_document.pdf" - mock_disabled_doc.enabled = False - mock_disabled_doc.archived = False - mock_disabled_doc.indexing_status = "completed" - mock_disabled_doc.completed_at = datetime.datetime.now() - - mock_enabled_doc = Mock(spec=Document) - mock_enabled_doc.id = "doc-2" - mock_enabled_doc.name = "enabled_document.pdf" - mock_enabled_doc.enabled = True - mock_enabled_doc.archived = False - mock_enabled_doc.indexing_status = "completed" - mock_enabled_doc.completed_at = datetime.datetime.now() - - mock_archived_doc = Mock(spec=Document) - mock_archived_doc.id = "doc-3" - mock_archived_doc.name = "archived_document.pdf" - mock_archived_doc.enabled = True - mock_archived_doc.archived = True - mock_archived_doc.indexing_status = "completed" - mock_archived_doc.completed_at = datetime.datetime.now() - - # Set up mixed document states - current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) - mock_datetime.datetime.now.return_value = current_time - mock_datetime.UTC = datetime.UTC - - # Mix of different document states - documents = [ - mock_disabled_doc, # Will be enabled - mock_enabled_doc, # Already enabled, will be skipped - mock_archived_doc, # Archived but enabled, will be skipped for enable action - ] - - mock_get_doc.side_effect = documents - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Perform enable operation on mixed state documents - DocumentService.batch_update_document_status( - dataset=mock_dataset, document_ids=["doc-1", "doc-2", "doc-3"], action="enable", user=mock_user - ) - - # Verify only the disabled document was processed - # (enabled and archived documents should be skipped for enable action) - - # Only one add should occur (for the disabled document that was enabled) - mock_db.add.assert_called_once() - # Only one commit should occur - mock_db.commit.assert_called_once() - - # Only one Redis setex should occur (for the document that was enabled) - redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) - - # Only one async task should be triggered (for the document that was enabled) - mock_add_task.delay.assert_called_once_with("doc-1") - - @patch("extensions.ext_database.db.session") - @patch("services.dataset_service.remove_document_from_index_task") - @patch("services.dataset_service.DocumentService.get_document") - @patch("services.dataset_service.datetime") - def test_batch_update_archive_disabled_document_no_index_removal( - self, mock_datetime, mock_get_doc, mock_remove_task, mock_db - ): - """ - Test archiving disabled documents (should not trigger index removal). - - Verifies that: - 1. Disabled documents can be archived - 2. Archive metadata is set correctly - 3. No index removal task is triggered (because document is disabled) - 4. No Redis cache key is set (because document is disabled) - 5. Database commit still occurs - """ - # Create mock dataset - mock_dataset = Mock(spec=Dataset) - mock_dataset.id = "dataset-123" - mock_dataset.tenant_id = "tenant-456" - - # Create mock user - mock_user = Mock() - mock_user.id = "user-789" - - # Set up disabled, unarchived document - current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) - mock_datetime.datetime.now.return_value = current_time - mock_datetime.UTC = datetime.UTC - - disabled_unarchived_doc = Mock(spec=Document) - # Manually set attributes to ensure they can be modified - disabled_unarchived_doc.id = "doc-1" - disabled_unarchived_doc.name = "disabled_document.pdf" - disabled_unarchived_doc.enabled = False # Disabled - disabled_unarchived_doc.archived = False # Not archived - - mock_get_doc.return_value = disabled_unarchived_doc - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Archive the disabled document - DocumentService.batch_update_document_status( - dataset=mock_dataset, document_ids=["doc-1"], action="archive", user=mock_user - ) - - # Verify document was archived - assert disabled_unarchived_doc.archived == True - assert disabled_unarchived_doc.archived_at == current_time.replace(tzinfo=None) - assert disabled_unarchived_doc.archived_by == mock_user.id - - # Verify no Redis cache was set (document is disabled) - redis_mock.setex.assert_not_called() - - # Verify no index removal task was triggered (document is disabled) - mock_remove_task.delay.assert_not_called() - - # Verify database add still occurred - mock_db.add.assert_called_once() - # Verify database commit still occurred - mock_db.commit.assert_called_once() - - @patch("services.dataset_service.DocumentService.get_document") - def test_batch_update_invalid_action_error(self, mock_get_doc): - """ - Test that ValueError is raised when an invalid action is provided. - - Verifies that: - 1. Invalid actions are rejected with ValueError - 2. Error message includes the invalid action name - 3. No document processing occurs with invalid actions - 4. Method fails fast on invalid input - """ - # Create mock dataset - mock_dataset = Mock(spec=Dataset) - mock_dataset.id = "dataset-123" - mock_dataset.tenant_id = "tenant-456" - - # Create mock user - mock_user = Mock() - mock_user.id = "user-789" - - # Create mock document - mock_doc = Mock(spec=Document) - mock_doc.id = "doc-1" - mock_doc.name = "test_document.pdf" - mock_doc.enabled = True - mock_doc.archived = False - - mock_get_doc.return_value = mock_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Test with invalid action - invalid_action = "invalid_action" - with pytest.raises(ValueError) as exc_info: - DocumentService.batch_update_document_status( - dataset=mock_dataset, document_ids=["doc-1"], action=invalid_action, user=mock_user - ) - - # Verify error message contains the invalid action - assert invalid_action in str(exc_info.value) - assert "Invalid action" in str(exc_info.value) - - # Verify no Redis operations occurred - redis_mock.setex.assert_not_called() - - @patch("extensions.ext_database.db.session") - @patch("services.dataset_service.add_document_to_index_task") - @patch("services.dataset_service.DocumentService.get_document") - @patch("services.dataset_service.datetime") - def test_batch_update_disable_already_disabled_document_skipped( - self, mock_datetime, mock_get_doc, mock_add_task, mock_db - ): - """ - Test disabling documents that are already disabled. - - Verifies that: - 1. Already disabled documents are skipped (no unnecessary operations) - 2. No database commits occur for already disabled documents - 3. No Redis cache operations occur for skipped documents - 4. No async tasks are triggered for skipped documents - 5. Method completes successfully - """ - # Create mock dataset - mock_dataset = Mock(spec=Dataset) - mock_dataset.id = "dataset-123" - mock_dataset.tenant_id = "tenant-456" - - # Create mock user - mock_user = Mock() - mock_user.id = "user-789" - - # Create mock disabled document - mock_disabled_doc = Mock(spec=Document) - mock_disabled_doc.id = "doc-1" - mock_disabled_doc.name = "disabled_document.pdf" - mock_disabled_doc.enabled = False # Already disabled - mock_disabled_doc.archived = False - mock_disabled_doc.indexing_status = "completed" - mock_disabled_doc.completed_at = datetime.datetime.now() - - # Mock document that is already disabled - mock_get_doc.return_value = mock_disabled_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Attempt to disable already disabled document - DocumentService.batch_update_document_status( - dataset=mock_dataset, document_ids=["doc-1"], action="disable", user=mock_user - ) - - # Verify no database operations occurred (document was skipped) - mock_db.commit.assert_not_called() - - # Verify no Redis setex operations occurred (document was skipped) - redis_mock.setex.assert_not_called() - - # Verify no async tasks were triggered (document was skipped) - mock_add_task.delay.assert_not_called() - - @patch("extensions.ext_database.db.session") - @patch("services.dataset_service.add_document_to_index_task") - @patch("services.dataset_service.DocumentService.get_document") - @patch("services.dataset_service.datetime") - def test_batch_update_unarchive_already_unarchived_document_skipped( - self, mock_datetime, mock_get_doc, mock_add_task, mock_db - ): - """ - Test unarchiving documents that are already unarchived. - - Verifies that: - 1. Already unarchived documents are skipped (no unnecessary operations) - 2. No database commits occur for already unarchived documents - 3. No Redis cache operations occur for skipped documents - 4. No async tasks are triggered for skipped documents - 5. Method completes successfully - """ - # Create mock dataset - mock_dataset = Mock(spec=Dataset) - mock_dataset.id = "dataset-123" - mock_dataset.tenant_id = "tenant-456" - - # Create mock user - mock_user = Mock() - mock_user.id = "user-789" - - # Create mock unarchived document - mock_unarchived_doc = Mock(spec=Document) - mock_unarchived_doc.id = "doc-1" - mock_unarchived_doc.name = "unarchived_document.pdf" - mock_unarchived_doc.enabled = True - mock_unarchived_doc.archived = False # Already unarchived - mock_unarchived_doc.indexing_status = "completed" - mock_unarchived_doc.completed_at = datetime.datetime.now() - - # Mock document that is already unarchived - mock_get_doc.return_value = mock_unarchived_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Attempt to unarchive already unarchived document - DocumentService.batch_update_document_status( - dataset=mock_dataset, document_ids=["doc-1"], action="un_archive", user=mock_user - ) - - # Verify no database operations occurred (document was skipped) - mock_db.commit.assert_not_called() - - # Verify no Redis setex operations occurred (document was skipped) - redis_mock.setex.assert_not_called() - - # Verify no async tasks were triggered (document was skipped) - mock_add_task.delay.assert_not_called() - - @patch("extensions.ext_database.db.session") - @patch("services.dataset_service.add_document_to_index_task") - @patch("services.dataset_service.DocumentService.get_document") - @patch("services.dataset_service.datetime") - def test_batch_update_unarchive_disabled_document_no_index_addition( - self, mock_datetime, mock_get_doc, mock_add_task, mock_db - ): - """ - Test unarchiving disabled documents (should not trigger index addition). - - Verifies that: - 1. Disabled documents can be unarchived - 2. Unarchive metadata is cleared correctly - 3. No index addition task is triggered (because document is disabled) - 4. No Redis cache key is set (because document is disabled) - 5. Database commit still occurs - """ - # Create mock dataset - mock_dataset = Mock(spec=Dataset) - mock_dataset.id = "dataset-123" - mock_dataset.tenant_id = "tenant-456" - - # Create mock user - mock_user = Mock() - mock_user.id = "user-789" - - # Create mock archived but disabled document - mock_archived_disabled_doc = Mock(spec=Document) - mock_archived_disabled_doc.id = "doc-1" - mock_archived_disabled_doc.name = "archived_disabled_document.pdf" - mock_archived_disabled_doc.enabled = False # Disabled - mock_archived_disabled_doc.archived = True # Archived - mock_archived_disabled_doc.indexing_status = "completed" - mock_archived_disabled_doc.completed_at = datetime.datetime.now() - - # Set up mock return values - current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) - mock_datetime.datetime.now.return_value = current_time - mock_datetime.UTC = datetime.UTC - - mock_get_doc.return_value = mock_archived_disabled_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Unarchive the disabled document - DocumentService.batch_update_document_status( - dataset=mock_dataset, document_ids=["doc-1"], action="un_archive", user=mock_user - ) - - # Verify document was unarchived - assert mock_archived_disabled_doc.archived == False - assert mock_archived_disabled_doc.archived_at is None - assert mock_archived_disabled_doc.archived_by is None - assert mock_archived_disabled_doc.updated_at == current_time.replace(tzinfo=None) - - # Verify no Redis cache was set (document is disabled) - redis_mock.setex.assert_not_called() - - # Verify no index addition task was triggered (document is disabled) - mock_add_task.delay.assert_not_called() - - # Verify database add still occurred - mock_db.add.assert_called_once() - # Verify database commit still occurred - mock_db.commit.assert_called_once() - - @patch("extensions.ext_database.db.session") - @patch("services.dataset_service.add_document_to_index_task") - @patch("services.dataset_service.DocumentService.get_document") - @patch("services.dataset_service.datetime") - def test_batch_update_async_task_error_handling(self, mock_datetime, mock_get_doc, mock_add_task, mock_db): - """ - Test handling of async task errors during batch operations. - - Verifies that: - 1. Async task errors are properly handled - 2. Database operations complete successfully - 3. Redis cache operations complete successfully - 4. Method continues processing despite async task errors - """ - # Create mock dataset - mock_dataset = Mock(spec=Dataset) - mock_dataset.id = "dataset-123" - mock_dataset.tenant_id = "tenant-456" - - # Create mock user - mock_user = Mock() - mock_user.id = "user-789" - - # Create mock disabled document - mock_disabled_doc = Mock(spec=Document) - mock_disabled_doc.id = "doc-1" - mock_disabled_doc.name = "disabled_document.pdf" - mock_disabled_doc.enabled = False - mock_disabled_doc.archived = False - mock_disabled_doc.indexing_status = "completed" - mock_disabled_doc.completed_at = datetime.datetime.now() - - # Set up mock return values - current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) - mock_datetime.datetime.now.return_value = current_time - mock_datetime.UTC = datetime.UTC - - mock_get_doc.return_value = mock_disabled_doc - - # Mock async task to raise an exception - mock_add_task.delay.side_effect = Exception("Celery task error") - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Verify that async task error is propagated - with pytest.raises(Exception) as exc_info: - DocumentService.batch_update_document_status( - dataset=mock_dataset, document_ids=["doc-1"], action="enable", user=mock_user - ) - - # Verify error message - assert "Celery task error" in str(exc_info.value) - - # Verify database operations completed successfully - mock_db.add.assert_called_once() - mock_db.commit.assert_called_once() - - # Verify Redis cache was set successfully - redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) - - # Verify document was updated - assert mock_disabled_doc.enabled == True - assert mock_disabled_doc.disabled_at is None - assert mock_disabled_doc.disabled_by is None - - @patch("extensions.ext_database.db.session") - @patch("services.dataset_service.add_document_to_index_task") - @patch("services.dataset_service.DocumentService.get_document") - @patch("services.dataset_service.datetime") - def test_batch_update_large_document_list_performance(self, mock_datetime, mock_get_doc, mock_add_task, mock_db): - """ - Test batch operations with a large number of documents. - - Verifies that: - 1. Method can handle large document lists efficiently - 2. All documents are processed correctly - 3. Database commits occur for each document - 4. Redis cache operations occur for each document - 5. Async tasks are triggered for each document - 6. Performance remains consistent with large inputs - """ - # Create mock dataset - mock_dataset = Mock(spec=Dataset) - mock_dataset.id = "dataset-123" - mock_dataset.tenant_id = "tenant-456" - - # Create mock user - mock_user = Mock() - mock_user.id = "user-789" - - # Create large list of document IDs - document_ids = [f"doc-{i}" for i in range(1, 101)] # 100 documents - - # Create mock documents - mock_documents = [] - for i in range(1, 101): - mock_doc = Mock(spec=Document) - mock_doc.id = f"doc-{i}" - mock_doc.name = f"document_{i}.pdf" - mock_doc.enabled = False # All disabled, will be enabled - mock_doc.archived = False - mock_doc.indexing_status = "completed" - mock_doc.completed_at = datetime.datetime.now() - mock_documents.append(mock_doc) - - # Set up mock return values - current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) - mock_datetime.datetime.now.return_value = current_time - mock_datetime.UTC = datetime.UTC - - mock_get_doc.side_effect = mock_documents - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Perform batch enable operation - DocumentService.batch_update_document_status( - dataset=mock_dataset, document_ids=document_ids, action="enable", user=mock_user - ) - - # Verify all documents were processed - assert mock_get_doc.call_count == 100 - - # Verify all documents were updated - for mock_doc in mock_documents: - assert mock_doc.enabled == True - assert mock_doc.disabled_at is None - assert mock_doc.disabled_by is None - assert mock_doc.updated_at == current_time.replace(tzinfo=None) - - # Verify database commits, one add for one document - assert mock_db.add.call_count == 100 - # Verify database commits, one commit for the batch operation - assert mock_db.commit.call_count == 1 - - # Verify Redis cache operations occurred for each document - assert redis_mock.setex.call_count == 100 - - # Verify async tasks were triggered for each document - assert mock_add_task.delay.call_count == 100 - - # Verify correct Redis cache keys were set - expected_redis_calls = [call(f"document_doc-{i}_indexing", 600, 1) for i in range(1, 101)] - redis_mock.setex.assert_has_calls(expected_redis_calls) - - # Verify correct async task calls - expected_task_calls = [call(f"doc-{i}") for i in range(1, 101)] - mock_add_task.delay.assert_has_calls(expected_task_calls) - - @patch("extensions.ext_database.db.session") - @patch("services.dataset_service.add_document_to_index_task") - @patch("services.dataset_service.DocumentService.get_document") - @patch("services.dataset_service.datetime") - def test_batch_update_mixed_document_states_complex_scenario( - self, mock_datetime, mock_get_doc, mock_add_task, mock_db - ): - """ - Test complex batch operations with documents in various states. - - Verifies that: - 1. Each document is processed according to its current state - 2. Some documents are skipped while others are processed - 3. Different actions trigger different async tasks - 4. Database commits occur only for modified documents - 5. Redis cache operations occur only for relevant documents - 6. Method handles complex mixed scenarios correctly - """ - # Create mock dataset - mock_dataset = Mock(spec=Dataset) - mock_dataset.id = "dataset-123" - mock_dataset.tenant_id = "tenant-456" - - # Create mock user - mock_user = Mock() - mock_user.id = "user-789" - - # Create documents in various states - current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) - mock_datetime.datetime.now.return_value = current_time - mock_datetime.UTC = datetime.UTC - - # Document 1: Disabled, will be enabled - doc1 = Mock(spec=Document) - doc1.id = "doc-1" - doc1.name = "disabled_doc.pdf" - doc1.enabled = False - doc1.archived = False - doc1.indexing_status = "completed" - doc1.completed_at = datetime.datetime.now() - - # Document 2: Already enabled, will be skipped - doc2 = Mock(spec=Document) - doc2.id = "doc-2" - doc2.name = "enabled_doc.pdf" - doc2.enabled = True - doc2.archived = False - doc2.indexing_status = "completed" - doc2.completed_at = datetime.datetime.now() - - # Document 3: Enabled and completed, will be disabled - doc3 = Mock(spec=Document) - doc3.id = "doc-3" - doc3.name = "enabled_completed_doc.pdf" - doc3.enabled = True - doc3.archived = False - doc3.indexing_status = "completed" - doc3.completed_at = datetime.datetime.now() - - # Document 4: Unarchived, will be archived - doc4 = Mock(spec=Document) - doc4.id = "doc-4" - doc4.name = "unarchived_doc.pdf" - doc4.enabled = True - doc4.archived = False - doc4.indexing_status = "completed" - doc4.completed_at = datetime.datetime.now() - - # Document 5: Archived, will be unarchived - doc5 = Mock(spec=Document) - doc5.id = "doc-5" - doc5.name = "archived_doc.pdf" - doc5.enabled = True - doc5.archived = True - doc5.indexing_status = "completed" - doc5.completed_at = datetime.datetime.now() - - # Document 6: Non-existent, will be skipped - doc6 = None - - mock_get_doc.side_effect = [doc1, doc2, doc3, doc4, doc5, doc6] - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Perform mixed batch operations - DocumentService.batch_update_document_status( - dataset=mock_dataset, - document_ids=["doc-1", "doc-2", "doc-3", "doc-4", "doc-5", "doc-6"], - action="enable", # This will only affect doc1 and doc3 (doc3 will be enabled then disabled) - user=mock_user, - ) - - # Verify document 1 was enabled - assert doc1.enabled == True - assert doc1.disabled_at is None - assert doc1.disabled_by is None - - # Verify document 2 was skipped (already enabled) - assert doc2.enabled == True # No change - - # Verify document 3 was skipped (already enabled) - assert doc3.enabled == True - - # Verify document 4 was skipped (not affected by enable action) - assert doc4.enabled == True # No change - - # Verify document 5 was skipped (not affected by enable action) - assert doc5.enabled == True # No change - - # Verify database commits occurred for processed documents - # Only doc1 should be added (doc2, doc3, doc4, doc5 were skipped, doc6 doesn't exist) - assert mock_db.add.call_count == 1 - assert mock_db.commit.call_count == 1 - - # Verify Redis cache operations occurred for processed documents - # Only doc1 should have Redis operations - assert redis_mock.setex.call_count == 1 - - # Verify async tasks were triggered for processed documents - # Only doc1 should trigger tasks - assert mock_add_task.delay.call_count == 1 - - # Verify correct Redis cache keys were set - expected_redis_calls = [call("document_doc-1_indexing", 600, 1)] - redis_mock.setex.assert_has_calls(expected_redis_calls) - - # Verify correct async task calls - expected_task_calls = [call("doc-1")] - mock_add_task.delay.assert_has_calls(expected_task_calls) diff --git a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py new file mode 100644 index 0000000000..dc09aca5b2 --- /dev/null +++ b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py @@ -0,0 +1,804 @@ +import datetime +from typing import Optional + +# Mock redis_client before importing dataset_service +from unittest.mock import Mock, call, patch + +import pytest + +from models.dataset import Dataset, Document +from services.dataset_service import DocumentService +from services.errors.document import DocumentIndexingError +from tests.unit_tests.conftest import redis_mock + + +class DocumentBatchUpdateTestDataFactory: + """Factory class for creating test data and mock objects for document batch update tests.""" + + @staticmethod + def create_dataset_mock(dataset_id: str = "dataset-123", tenant_id: str = "tenant-456") -> Mock: + """Create a mock dataset with specified attributes.""" + dataset = Mock(spec=Dataset) + dataset.id = dataset_id + dataset.tenant_id = tenant_id + return dataset + + @staticmethod + def create_user_mock(user_id: str = "user-789") -> Mock: + """Create a mock user.""" + user = Mock() + user.id = user_id + return user + + @staticmethod + def create_document_mock( + document_id: str = "doc-1", + name: str = "test_document.pdf", + enabled: bool = True, + archived: bool = False, + indexing_status: str = "completed", + completed_at: Optional[datetime.datetime] = None, + **kwargs, + ) -> Mock: + """Create a mock document with specified attributes.""" + document = Mock(spec=Document) + document.id = document_id + document.name = name + document.enabled = enabled + document.archived = archived + document.indexing_status = indexing_status + document.completed_at = completed_at or datetime.datetime.now() + + # Set default values for optional fields + document.disabled_at = None + document.disabled_by = None + document.archived_at = None + document.archived_by = None + document.updated_at = None + + for key, value in kwargs.items(): + setattr(document, key, value) + return document + + @staticmethod + def create_multiple_documents( + document_ids: list[str], enabled: bool = True, archived: bool = False, indexing_status: str = "completed" + ) -> list[Mock]: + """Create multiple mock documents with specified attributes.""" + documents = [] + for doc_id in document_ids: + doc = DocumentBatchUpdateTestDataFactory.create_document_mock( + document_id=doc_id, + name=f"document_{doc_id}.pdf", + enabled=enabled, + archived=archived, + indexing_status=indexing_status, + ) + documents.append(doc) + return documents + + +class TestDatasetServiceBatchUpdateDocumentStatus: + """ + Comprehensive unit tests for DocumentService.batch_update_document_status method. + + This test suite covers all supported actions (enable, disable, archive, un_archive), + error conditions, edge cases, and validates proper interaction with Redis cache, + database operations, and async task triggers. + """ + + @pytest.fixture + def mock_document_service_dependencies(self): + """Common mock setup for document service dependencies.""" + with ( + patch("services.dataset_service.DocumentService.get_document") as mock_get_doc, + patch("extensions.ext_database.db.session") as mock_db, + patch("services.dataset_service.datetime") as mock_datetime, + ): + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + yield { + "get_document": mock_get_doc, + "db_session": mock_db, + "datetime": mock_datetime, + "current_time": current_time, + } + + @pytest.fixture + def mock_async_task_dependencies(self): + """Mock setup for async task dependencies.""" + with ( + patch("services.dataset_service.add_document_to_index_task") as mock_add_task, + patch("services.dataset_service.remove_document_from_index_task") as mock_remove_task, + ): + yield {"add_task": mock_add_task, "remove_task": mock_remove_task} + + def _assert_document_enabled(self, document: Mock, user_id: str, current_time: datetime.datetime): + """Helper method to verify document was enabled correctly.""" + assert document.enabled == True + assert document.disabled_at is None + assert document.disabled_by is None + assert document.updated_at == current_time.replace(tzinfo=None) + + def _assert_document_disabled(self, document: Mock, user_id: str, current_time: datetime.datetime): + """Helper method to verify document was disabled correctly.""" + assert document.enabled == False + assert document.disabled_at == current_time.replace(tzinfo=None) + assert document.disabled_by == user_id + assert document.updated_at == current_time.replace(tzinfo=None) + + def _assert_document_archived(self, document: Mock, user_id: str, current_time: datetime.datetime): + """Helper method to verify document was archived correctly.""" + assert document.archived == True + assert document.archived_at == current_time.replace(tzinfo=None) + assert document.archived_by == user_id + assert document.updated_at == current_time.replace(tzinfo=None) + + def _assert_document_unarchived(self, document: Mock): + """Helper method to verify document was unarchived correctly.""" + assert document.archived == False + assert document.archived_at is None + assert document.archived_by is None + + def _assert_redis_cache_operations(self, document_ids: list[str], action: str = "setex"): + """Helper method to verify Redis cache operations.""" + if action == "setex": + expected_calls = [call(f"document_{doc_id}_indexing", 600, 1) for doc_id in document_ids] + redis_mock.setex.assert_has_calls(expected_calls) + elif action == "get": + expected_calls = [call(f"document_{doc_id}_indexing") for doc_id in document_ids] + redis_mock.get.assert_has_calls(expected_calls) + + def _assert_async_task_calls(self, mock_task, document_ids: list[str], task_type: str): + """Helper method to verify async task calls.""" + expected_calls = [call(doc_id) for doc_id in document_ids] + if task_type in {"add", "remove"}: + mock_task.delay.assert_has_calls(expected_calls) + + # ==================== Enable Document Tests ==================== + + def test_batch_update_enable_documents_success( + self, mock_document_service_dependencies, mock_async_task_dependencies + ): + """Test successful enabling of disabled documents.""" + dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() + user = DocumentBatchUpdateTestDataFactory.create_user_mock() + + # Create disabled documents + disabled_docs = DocumentBatchUpdateTestDataFactory.create_multiple_documents(["doc-1", "doc-2"], enabled=False) + mock_document_service_dependencies["get_document"].side_effect = disabled_docs + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Call the method to enable documents + DocumentService.batch_update_document_status( + dataset=dataset, document_ids=["doc-1", "doc-2"], action="enable", user=user + ) + + # Verify document attributes were updated correctly + for doc in disabled_docs: + self._assert_document_enabled(doc, user.id, mock_document_service_dependencies["current_time"]) + + # Verify Redis cache operations + self._assert_redis_cache_operations(["doc-1", "doc-2"], "get") + self._assert_redis_cache_operations(["doc-1", "doc-2"], "setex") + + # Verify async tasks were triggered for indexing + self._assert_async_task_calls(mock_async_task_dependencies["add_task"], ["doc-1", "doc-2"], "add") + + # Verify database operations + mock_db = mock_document_service_dependencies["db_session"] + assert mock_db.add.call_count == 2 + assert mock_db.commit.call_count == 1 + + def test_batch_update_enable_already_enabled_document_skipped(self, mock_document_service_dependencies): + """Test enabling documents that are already enabled.""" + dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() + user = DocumentBatchUpdateTestDataFactory.create_user_mock() + + # Create already enabled document + enabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True) + mock_document_service_dependencies["get_document"].return_value = enabled_doc + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Attempt to enable already enabled document + DocumentService.batch_update_document_status( + dataset=dataset, document_ids=["doc-1"], action="enable", user=user + ) + + # Verify no database operations occurred (document was skipped) + mock_db = mock_document_service_dependencies["db_session"] + mock_db.commit.assert_not_called() + + # Verify no Redis setex operations occurred (document was skipped) + redis_mock.setex.assert_not_called() + + # ==================== Disable Document Tests ==================== + + def test_batch_update_disable_documents_success( + self, mock_document_service_dependencies, mock_async_task_dependencies + ): + """Test successful disabling of enabled and completed documents.""" + dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() + user = DocumentBatchUpdateTestDataFactory.create_user_mock() + + # Create enabled documents + enabled_docs = DocumentBatchUpdateTestDataFactory.create_multiple_documents(["doc-1", "doc-2"], enabled=True) + mock_document_service_dependencies["get_document"].side_effect = enabled_docs + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Call the method to disable documents + DocumentService.batch_update_document_status( + dataset=dataset, document_ids=["doc-1", "doc-2"], action="disable", user=user + ) + + # Verify document attributes were updated correctly + for doc in enabled_docs: + self._assert_document_disabled(doc, user.id, mock_document_service_dependencies["current_time"]) + + # Verify Redis cache operations for indexing prevention + self._assert_redis_cache_operations(["doc-1", "doc-2"], "setex") + + # Verify async tasks were triggered to remove from index + self._assert_async_task_calls(mock_async_task_dependencies["remove_task"], ["doc-1", "doc-2"], "remove") + + # Verify database operations + mock_db = mock_document_service_dependencies["db_session"] + assert mock_db.add.call_count == 2 + assert mock_db.commit.call_count == 1 + + def test_batch_update_disable_already_disabled_document_skipped( + self, mock_document_service_dependencies, mock_async_task_dependencies + ): + """Test disabling documents that are already disabled.""" + dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() + user = DocumentBatchUpdateTestDataFactory.create_user_mock() + + # Create already disabled document + disabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=False) + mock_document_service_dependencies["get_document"].return_value = disabled_doc + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Attempt to disable already disabled document + DocumentService.batch_update_document_status( + dataset=dataset, document_ids=["doc-1"], action="disable", user=user + ) + + # Verify no database operations occurred (document was skipped) + mock_db = mock_document_service_dependencies["db_session"] + mock_db.commit.assert_not_called() + + # Verify no Redis setex operations occurred (document was skipped) + redis_mock.setex.assert_not_called() + + # Verify no async tasks were triggered (document was skipped) + mock_async_task_dependencies["add_task"].delay.assert_not_called() + + def test_batch_update_disable_non_completed_document_error(self, mock_document_service_dependencies): + """Test that DocumentIndexingError is raised when trying to disable non-completed documents.""" + dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() + user = DocumentBatchUpdateTestDataFactory.create_user_mock() + + # Create a document that's not completed + non_completed_doc = DocumentBatchUpdateTestDataFactory.create_document_mock( + enabled=True, + indexing_status="indexing", # Not completed + completed_at=None, # Not completed + ) + mock_document_service_dependencies["get_document"].return_value = non_completed_doc + + # Verify that DocumentIndexingError is raised + with pytest.raises(DocumentIndexingError) as exc_info: + DocumentService.batch_update_document_status( + dataset=dataset, document_ids=["doc-1"], action="disable", user=user + ) + + # Verify error message indicates document is not completed + assert "is not completed" in str(exc_info.value) + + # ==================== Archive Document Tests ==================== + + def test_batch_update_archive_documents_success( + self, mock_document_service_dependencies, mock_async_task_dependencies + ): + """Test successful archiving of unarchived documents.""" + dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() + user = DocumentBatchUpdateTestDataFactory.create_user_mock() + + # Create unarchived enabled document + unarchived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True, archived=False) + mock_document_service_dependencies["get_document"].return_value = unarchived_doc + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Call the method to archive documents + DocumentService.batch_update_document_status( + dataset=dataset, document_ids=["doc-1"], action="archive", user=user + ) + + # Verify document attributes were updated correctly + self._assert_document_archived(unarchived_doc, user.id, mock_document_service_dependencies["current_time"]) + + # Verify Redis cache was set (because document was enabled) + redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) + + # Verify async task was triggered to remove from index (because enabled) + mock_async_task_dependencies["remove_task"].delay.assert_called_once_with("doc-1") + + # Verify database operations + mock_db = mock_document_service_dependencies["db_session"] + mock_db.add.assert_called_once() + mock_db.commit.assert_called_once() + + def test_batch_update_archive_already_archived_document_skipped(self, mock_document_service_dependencies): + """Test archiving documents that are already archived.""" + dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() + user = DocumentBatchUpdateTestDataFactory.create_user_mock() + + # Create already archived document + archived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True, archived=True) + mock_document_service_dependencies["get_document"].return_value = archived_doc + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Attempt to archive already archived document + DocumentService.batch_update_document_status( + dataset=dataset, document_ids=["doc-3"], action="archive", user=user + ) + + # Verify no database operations occurred (document was skipped) + mock_db = mock_document_service_dependencies["db_session"] + mock_db.commit.assert_not_called() + + # Verify no Redis setex operations occurred (document was skipped) + redis_mock.setex.assert_not_called() + + def test_batch_update_archive_disabled_document_no_index_removal( + self, mock_document_service_dependencies, mock_async_task_dependencies + ): + """Test archiving disabled documents (should not trigger index removal).""" + dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() + user = DocumentBatchUpdateTestDataFactory.create_user_mock() + + # Set up disabled, unarchived document + disabled_unarchived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=False, archived=False) + mock_document_service_dependencies["get_document"].return_value = disabled_unarchived_doc + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Archive the disabled document + DocumentService.batch_update_document_status( + dataset=dataset, document_ids=["doc-1"], action="archive", user=user + ) + + # Verify document was archived + self._assert_document_archived( + disabled_unarchived_doc, user.id, mock_document_service_dependencies["current_time"] + ) + + # Verify no Redis cache was set (document is disabled) + redis_mock.setex.assert_not_called() + + # Verify no index removal task was triggered (document is disabled) + mock_async_task_dependencies["remove_task"].delay.assert_not_called() + + # Verify database operations still occurred + mock_db = mock_document_service_dependencies["db_session"] + mock_db.add.assert_called_once() + mock_db.commit.assert_called_once() + + # ==================== Unarchive Document Tests ==================== + + def test_batch_update_unarchive_documents_success( + self, mock_document_service_dependencies, mock_async_task_dependencies + ): + """Test successful unarchiving of archived documents.""" + dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() + user = DocumentBatchUpdateTestDataFactory.create_user_mock() + + # Create mock archived document + archived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True, archived=True) + mock_document_service_dependencies["get_document"].return_value = archived_doc + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Call the method to unarchive documents + DocumentService.batch_update_document_status( + dataset=dataset, document_ids=["doc-1"], action="un_archive", user=user + ) + + # Verify document attributes were updated correctly + self._assert_document_unarchived(archived_doc) + assert archived_doc.updated_at == mock_document_service_dependencies["current_time"].replace(tzinfo=None) + + # Verify Redis cache was set (because document is enabled) + redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) + + # Verify async task was triggered to add back to index (because enabled) + mock_async_task_dependencies["add_task"].delay.assert_called_once_with("doc-1") + + # Verify database operations + mock_db = mock_document_service_dependencies["db_session"] + mock_db.add.assert_called_once() + mock_db.commit.assert_called_once() + + def test_batch_update_unarchive_already_unarchived_document_skipped( + self, mock_document_service_dependencies, mock_async_task_dependencies + ): + """Test unarchiving documents that are already unarchived.""" + dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() + user = DocumentBatchUpdateTestDataFactory.create_user_mock() + + # Create already unarchived document + unarchived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True, archived=False) + mock_document_service_dependencies["get_document"].return_value = unarchived_doc + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Attempt to unarchive already unarchived document + DocumentService.batch_update_document_status( + dataset=dataset, document_ids=["doc-1"], action="un_archive", user=user + ) + + # Verify no database operations occurred (document was skipped) + mock_db = mock_document_service_dependencies["db_session"] + mock_db.commit.assert_not_called() + + # Verify no Redis setex operations occurred (document was skipped) + redis_mock.setex.assert_not_called() + + # Verify no async tasks were triggered (document was skipped) + mock_async_task_dependencies["add_task"].delay.assert_not_called() + + def test_batch_update_unarchive_disabled_document_no_index_addition( + self, mock_document_service_dependencies, mock_async_task_dependencies + ): + """Test unarchiving disabled documents (should not trigger index addition).""" + dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() + user = DocumentBatchUpdateTestDataFactory.create_user_mock() + + # Create mock archived but disabled document + archived_disabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=False, archived=True) + mock_document_service_dependencies["get_document"].return_value = archived_disabled_doc + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Unarchive the disabled document + DocumentService.batch_update_document_status( + dataset=dataset, document_ids=["doc-1"], action="un_archive", user=user + ) + + # Verify document was unarchived + self._assert_document_unarchived(archived_disabled_doc) + assert archived_disabled_doc.updated_at == mock_document_service_dependencies["current_time"].replace( + tzinfo=None + ) + + # Verify no Redis cache was set (document is disabled) + redis_mock.setex.assert_not_called() + + # Verify no index addition task was triggered (document is disabled) + mock_async_task_dependencies["add_task"].delay.assert_not_called() + + # Verify database operations still occurred + mock_db = mock_document_service_dependencies["db_session"] + mock_db.add.assert_called_once() + mock_db.commit.assert_called_once() + + # ==================== Error Handling Tests ==================== + + def test_batch_update_document_indexing_error_redis_cache_hit(self, mock_document_service_dependencies): + """Test that DocumentIndexingError is raised when documents are currently being indexed.""" + dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() + user = DocumentBatchUpdateTestDataFactory.create_user_mock() + + # Create mock enabled document + enabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True) + mock_document_service_dependencies["get_document"].return_value = enabled_doc + + # Set up mock to indicate document is being indexed + redis_mock.reset_mock() + redis_mock.get.return_value = "indexing" + + # Verify that DocumentIndexingError is raised + with pytest.raises(DocumentIndexingError) as exc_info: + DocumentService.batch_update_document_status( + dataset=dataset, document_ids=["doc-1"], action="enable", user=user + ) + + # Verify error message contains document name + assert "test_document.pdf" in str(exc_info.value) + assert "is being indexed" in str(exc_info.value) + + # Verify Redis cache was checked + redis_mock.get.assert_called_once_with("document_doc-1_indexing") + + def test_batch_update_invalid_action_error(self, mock_document_service_dependencies): + """Test that ValueError is raised when an invalid action is provided.""" + dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() + user = DocumentBatchUpdateTestDataFactory.create_user_mock() + + # Create mock document + doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True) + mock_document_service_dependencies["get_document"].return_value = doc + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Test with invalid action + invalid_action = "invalid_action" + with pytest.raises(ValueError) as exc_info: + DocumentService.batch_update_document_status( + dataset=dataset, document_ids=["doc-1"], action=invalid_action, user=user + ) + + # Verify error message contains the invalid action + assert invalid_action in str(exc_info.value) + assert "Invalid action" in str(exc_info.value) + + # Verify no Redis operations occurred + redis_mock.setex.assert_not_called() + + def test_batch_update_async_task_error_handling( + self, mock_document_service_dependencies, mock_async_task_dependencies + ): + """Test handling of async task errors during batch operations.""" + dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() + user = DocumentBatchUpdateTestDataFactory.create_user_mock() + + # Create mock disabled document + disabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=False) + mock_document_service_dependencies["get_document"].return_value = disabled_doc + + # Mock async task to raise an exception + mock_async_task_dependencies["add_task"].delay.side_effect = Exception("Celery task error") + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Verify that async task error is propagated + with pytest.raises(Exception) as exc_info: + DocumentService.batch_update_document_status( + dataset=dataset, document_ids=["doc-1"], action="enable", user=user + ) + + # Verify error message + assert "Celery task error" in str(exc_info.value) + + # Verify database operations completed successfully + mock_db = mock_document_service_dependencies["db_session"] + mock_db.add.assert_called_once() + mock_db.commit.assert_called_once() + + # Verify Redis cache was set successfully + redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) + + # Verify document was updated + self._assert_document_enabled(disabled_doc, user.id, mock_document_service_dependencies["current_time"]) + + # ==================== Edge Case Tests ==================== + + def test_batch_update_empty_document_list(self, mock_document_service_dependencies): + """Test batch operations with an empty document ID list.""" + dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() + user = DocumentBatchUpdateTestDataFactory.create_user_mock() + + # Call method with empty document list + result = DocumentService.batch_update_document_status( + dataset=dataset, document_ids=[], action="enable", user=user + ) + + # Verify no document lookups were performed + mock_document_service_dependencies["get_document"].assert_not_called() + + # Verify method returns None (early return) + assert result is None + + def test_batch_update_document_not_found_skipped(self, mock_document_service_dependencies): + """Test behavior when some documents don't exist in the database.""" + dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() + user = DocumentBatchUpdateTestDataFactory.create_user_mock() + + # Mock document service to return None (document not found) + mock_document_service_dependencies["get_document"].return_value = None + + # Call method with non-existent document ID + # This should not raise an error, just skip the missing document + try: + DocumentService.batch_update_document_status( + dataset=dataset, document_ids=["non-existent-doc"], action="enable", user=user + ) + except Exception as e: + pytest.fail(f"Method should not raise exception for missing documents: {e}") + + # Verify document lookup was attempted + mock_document_service_dependencies["get_document"].assert_called_once_with(dataset.id, "non-existent-doc") + + def test_batch_update_mixed_document_states_and_actions( + self, mock_document_service_dependencies, mock_async_task_dependencies + ): + """Test batch operations on documents with mixed states and various scenarios.""" + dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() + user = DocumentBatchUpdateTestDataFactory.create_user_mock() + + # Create documents in various states + disabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock("doc-1", enabled=False) + enabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock("doc-2", enabled=True) + archived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock("doc-3", enabled=True, archived=True) + + # Mix of different document states + documents = [disabled_doc, enabled_doc, archived_doc] + mock_document_service_dependencies["get_document"].side_effect = documents + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Perform enable operation on mixed state documents + DocumentService.batch_update_document_status( + dataset=dataset, document_ids=["doc-1", "doc-2", "doc-3"], action="enable", user=user + ) + + # Verify only the disabled document was processed + # (enabled and archived documents should be skipped for enable action) + + # Only one add should occur (for the disabled document that was enabled) + mock_db = mock_document_service_dependencies["db_session"] + mock_db.add.assert_called_once() + # Only one commit should occur + mock_db.commit.assert_called_once() + + # Only one Redis setex should occur (for the document that was enabled) + redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) + + # Only one async task should be triggered (for the document that was enabled) + mock_async_task_dependencies["add_task"].delay.assert_called_once_with("doc-1") + + # ==================== Performance Tests ==================== + + def test_batch_update_large_document_list_performance( + self, mock_document_service_dependencies, mock_async_task_dependencies + ): + """Test batch operations with a large number of documents.""" + dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() + user = DocumentBatchUpdateTestDataFactory.create_user_mock() + + # Create large list of document IDs + document_ids = [f"doc-{i}" for i in range(1, 101)] # 100 documents + + # Create mock documents + mock_documents = DocumentBatchUpdateTestDataFactory.create_multiple_documents( + document_ids, + enabled=False, # All disabled, will be enabled + ) + mock_document_service_dependencies["get_document"].side_effect = mock_documents + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Perform batch enable operation + DocumentService.batch_update_document_status( + dataset=dataset, document_ids=document_ids, action="enable", user=user + ) + + # Verify all documents were processed + assert mock_document_service_dependencies["get_document"].call_count == 100 + + # Verify all documents were updated + for mock_doc in mock_documents: + self._assert_document_enabled(mock_doc, user.id, mock_document_service_dependencies["current_time"]) + + # Verify database operations + mock_db = mock_document_service_dependencies["db_session"] + assert mock_db.add.call_count == 100 + assert mock_db.commit.call_count == 1 + + # Verify Redis cache operations occurred for each document + assert redis_mock.setex.call_count == 100 + + # Verify async tasks were triggered for each document + assert mock_async_task_dependencies["add_task"].delay.call_count == 100 + + # Verify correct Redis cache keys were set + expected_redis_calls = [call(f"document_doc-{i}_indexing", 600, 1) for i in range(1, 101)] + redis_mock.setex.assert_has_calls(expected_redis_calls) + + # Verify correct async task calls + expected_task_calls = [call(f"doc-{i}") for i in range(1, 101)] + mock_async_task_dependencies["add_task"].delay.assert_has_calls(expected_task_calls) + + def test_batch_update_mixed_document_states_complex_scenario( + self, mock_document_service_dependencies, mock_async_task_dependencies + ): + """Test complex batch operations with documents in various states.""" + dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() + user = DocumentBatchUpdateTestDataFactory.create_user_mock() + + # Create documents in various states + doc1 = DocumentBatchUpdateTestDataFactory.create_document_mock("doc-1", enabled=False) # Will be enabled + doc2 = DocumentBatchUpdateTestDataFactory.create_document_mock( + "doc-2", enabled=True + ) # Already enabled, will be skipped + doc3 = DocumentBatchUpdateTestDataFactory.create_document_mock( + "doc-3", enabled=True + ) # Already enabled, will be skipped + doc4 = DocumentBatchUpdateTestDataFactory.create_document_mock( + "doc-4", enabled=True + ) # Not affected by enable action + doc5 = DocumentBatchUpdateTestDataFactory.create_document_mock( + "doc-5", enabled=True, archived=True + ) # Not affected by enable action + doc6 = None # Non-existent, will be skipped + + mock_document_service_dependencies["get_document"].side_effect = [doc1, doc2, doc3, doc4, doc5, doc6] + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Perform mixed batch operations + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=["doc-1", "doc-2", "doc-3", "doc-4", "doc-5", "doc-6"], + action="enable", # This will only affect doc1 + user=user, + ) + + # Verify document 1 was enabled + self._assert_document_enabled(doc1, user.id, mock_document_service_dependencies["current_time"]) + + # Verify other documents were skipped appropriately + assert doc2.enabled == True # No change + assert doc3.enabled == True # No change + assert doc4.enabled == True # No change + assert doc5.enabled == True # No change + + # Verify database commits occurred for processed documents + # Only doc1 should be added (others were skipped, doc6 doesn't exist) + mock_db = mock_document_service_dependencies["db_session"] + assert mock_db.add.call_count == 1 + assert mock_db.commit.call_count == 1 + + # Verify Redis cache operations occurred for processed documents + # Only doc1 should have Redis operations + assert redis_mock.setex.call_count == 1 + + # Verify async tasks were triggered for processed documents + # Only doc1 should trigger tasks + assert mock_async_task_dependencies["add_task"].delay.call_count == 1 + + # Verify correct Redis cache keys were set + expected_redis_calls = [call("document_doc-1_indexing", 600, 1)] + redis_mock.setex.assert_has_calls(expected_redis_calls) + + # Verify correct async task calls + expected_task_calls = [call("doc-1")] + mock_async_task_dependencies["add_task"].delay.assert_has_calls(expected_task_calls) diff --git a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py new file mode 100644 index 0000000000..cdbb439c85 --- /dev/null +++ b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py @@ -0,0 +1,633 @@ +import datetime +from typing import Any, Optional + +# Mock redis_client before importing dataset_service +from unittest.mock import Mock, patch + +import pytest + +from core.model_runtime.entities.model_entities import ModelType +from models.dataset import Dataset, ExternalKnowledgeBindings +from services.dataset_service import DatasetService +from services.errors.account import NoPermissionError +from tests.unit_tests.conftest import redis_mock + + +class DatasetUpdateTestDataFactory: + """Factory class for creating test data and mock objects for dataset update tests.""" + + @staticmethod + def create_dataset_mock( + dataset_id: str = "dataset-123", + provider: str = "vendor", + name: str = "old_name", + description: str = "old_description", + indexing_technique: str = "high_quality", + retrieval_model: str = "old_model", + embedding_model_provider: Optional[str] = None, + embedding_model: Optional[str] = None, + collection_binding_id: Optional[str] = None, + **kwargs, + ) -> Mock: + """Create a mock dataset with specified attributes.""" + dataset = Mock(spec=Dataset) + dataset.id = dataset_id + dataset.provider = provider + dataset.name = name + dataset.description = description + dataset.indexing_technique = indexing_technique + dataset.retrieval_model = retrieval_model + dataset.embedding_model_provider = embedding_model_provider + dataset.embedding_model = embedding_model + dataset.collection_binding_id = collection_binding_id + for key, value in kwargs.items(): + setattr(dataset, key, value) + return dataset + + @staticmethod + def create_user_mock(user_id: str = "user-789") -> Mock: + """Create a mock user.""" + user = Mock() + user.id = user_id + return user + + @staticmethod + def create_external_binding_mock( + external_knowledge_id: str = "old_knowledge_id", external_knowledge_api_id: str = "old_api_id" + ) -> Mock: + """Create a mock external knowledge binding.""" + binding = Mock(spec=ExternalKnowledgeBindings) + binding.external_knowledge_id = external_knowledge_id + binding.external_knowledge_api_id = external_knowledge_api_id + return binding + + @staticmethod + def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock: + """Create a mock embedding model.""" + embedding_model = Mock() + embedding_model.model = model + embedding_model.provider = provider + return embedding_model + + @staticmethod + def create_collection_binding_mock(binding_id: str = "binding-456") -> Mock: + """Create a mock collection binding.""" + binding = Mock() + binding.id = binding_id + return binding + + @staticmethod + def create_current_user_mock(tenant_id: str = "tenant-123") -> Mock: + """Create a mock current user.""" + current_user = Mock() + current_user.current_tenant_id = tenant_id + return current_user + + +class TestDatasetServiceUpdateDataset: + """ + Comprehensive unit tests for DatasetService.update_dataset method. + + This test suite covers all supported scenarios including: + - External dataset updates + - Internal dataset updates with different indexing techniques + - Embedding model updates + - Permission checks + - Error conditions and edge cases + """ + + @pytest.fixture + def mock_dataset_service_dependencies(self): + """Common mock setup for dataset service dependencies.""" + with ( + patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset, + patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm, + patch("extensions.ext_database.db.session") as mock_db, + patch("services.dataset_service.datetime") as mock_datetime, + ): + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + yield { + "get_dataset": mock_get_dataset, + "check_permission": mock_check_perm, + "db_session": mock_db, + "datetime": mock_datetime, + "current_time": current_time, + } + + @pytest.fixture + def mock_external_provider_dependencies(self): + """Mock setup for external provider tests.""" + with patch("services.dataset_service.Session") as mock_session: + from extensions.ext_database import db + + with patch.object(db.__class__, "engine", new_callable=Mock): + session_mock = Mock() + mock_session.return_value.__enter__.return_value = session_mock + yield session_mock + + @pytest.fixture + def mock_internal_provider_dependencies(self): + """Mock setup for internal provider tests.""" + with ( + patch("services.dataset_service.ModelManager") as mock_model_manager, + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" + ) as mock_get_binding, + patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task, + patch("services.dataset_service.current_user") as mock_current_user, + ): + mock_current_user.current_tenant_id = "tenant-123" + yield { + "model_manager": mock_model_manager, + "get_binding": mock_get_binding, + "task": mock_task, + "current_user": mock_current_user, + } + + def _assert_database_update_called(self, mock_db, dataset_id: str, expected_updates: dict[str, Any]): + """Helper method to verify database update calls.""" + mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_updates) + mock_db.commit.assert_called_once() + + def _assert_external_dataset_update(self, mock_dataset, mock_binding, update_data: dict[str, Any]): + """Helper method to verify external dataset updates.""" + assert mock_dataset.name == update_data.get("name", mock_dataset.name) + assert mock_dataset.description == update_data.get("description", mock_dataset.description) + assert mock_dataset.retrieval_model == update_data.get("external_retrieval_model", mock_dataset.retrieval_model) + + if "external_knowledge_id" in update_data: + assert mock_binding.external_knowledge_id == update_data["external_knowledge_id"] + if "external_knowledge_api_id" in update_data: + assert mock_binding.external_knowledge_api_id == update_data["external_knowledge_api_id"] + + # ==================== External Dataset Tests ==================== + + def test_update_external_dataset_success( + self, mock_dataset_service_dependencies, mock_external_provider_dependencies + ): + """Test successful update of external dataset.""" + dataset = DatasetUpdateTestDataFactory.create_dataset_mock( + provider="external", name="old_name", description="old_description", retrieval_model="old_model" + ) + mock_dataset_service_dependencies["get_dataset"].return_value = dataset + + user = DatasetUpdateTestDataFactory.create_user_mock() + binding = DatasetUpdateTestDataFactory.create_external_binding_mock() + + # Mock external knowledge binding query + mock_external_provider_dependencies.query.return_value.filter_by.return_value.first.return_value = binding + + update_data = { + "name": "new_name", + "description": "new_description", + "external_retrieval_model": "new_model", + "permission": "only_me", + "external_knowledge_id": "new_knowledge_id", + "external_knowledge_api_id": "new_api_id", + } + + result = DatasetService.update_dataset("dataset-123", update_data, user) + + # Verify permission check was called + mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) + + # Verify dataset and binding updates + self._assert_external_dataset_update(dataset, binding, update_data) + + # Verify database operations + mock_db = mock_dataset_service_dependencies["db_session"] + mock_db.add.assert_any_call(dataset) + mock_db.add.assert_any_call(binding) + mock_db.commit.assert_called_once() + + # Verify return value + assert result == dataset + + def test_update_external_dataset_missing_knowledge_id_error(self, mock_dataset_service_dependencies): + """Test error when external knowledge id is missing.""" + dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external") + mock_dataset_service_dependencies["get_dataset"].return_value = dataset + + user = DatasetUpdateTestDataFactory.create_user_mock() + update_data = {"name": "new_name", "external_knowledge_api_id": "api_id"} + + with pytest.raises(ValueError) as context: + DatasetService.update_dataset("dataset-123", update_data, user) + + assert "External knowledge id is required" in str(context.value) + + def test_update_external_dataset_missing_api_id_error(self, mock_dataset_service_dependencies): + """Test error when external knowledge api id is missing.""" + dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external") + mock_dataset_service_dependencies["get_dataset"].return_value = dataset + + user = DatasetUpdateTestDataFactory.create_user_mock() + update_data = {"name": "new_name", "external_knowledge_id": "knowledge_id"} + + with pytest.raises(ValueError) as context: + DatasetService.update_dataset("dataset-123", update_data, user) + + assert "External knowledge api id is required" in str(context.value) + + def test_update_external_dataset_binding_not_found_error( + self, mock_dataset_service_dependencies, mock_external_provider_dependencies + ): + """Test error when external knowledge binding is not found.""" + dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external") + mock_dataset_service_dependencies["get_dataset"].return_value = dataset + + user = DatasetUpdateTestDataFactory.create_user_mock() + + # Mock external knowledge binding query returning None + mock_external_provider_dependencies.query.return_value.filter_by.return_value.first.return_value = None + + update_data = { + "name": "new_name", + "external_knowledge_id": "knowledge_id", + "external_knowledge_api_id": "api_id", + } + + with pytest.raises(ValueError) as context: + DatasetService.update_dataset("dataset-123", update_data, user) + + assert "External knowledge binding not found" in str(context.value) + + # ==================== Internal Dataset Basic Tests ==================== + + def test_update_internal_dataset_basic_success(self, mock_dataset_service_dependencies): + """Test successful update of internal dataset with basic fields.""" + dataset = DatasetUpdateTestDataFactory.create_dataset_mock( + provider="vendor", + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embedding-ada-002", + collection_binding_id="binding-123", + ) + mock_dataset_service_dependencies["get_dataset"].return_value = dataset + + user = DatasetUpdateTestDataFactory.create_user_mock() + + update_data = { + "name": "new_name", + "description": "new_description", + "indexing_technique": "high_quality", + "retrieval_model": "new_model", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + } + + result = DatasetService.update_dataset("dataset-123", update_data, user) + + # Verify permission check was called + mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) + + # Verify database update was called with correct filtered data + expected_filtered_data = { + "name": "new_name", + "description": "new_description", + "indexing_technique": "high_quality", + "retrieval_model": "new_model", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + "updated_by": user.id, + "updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None), + } + + self._assert_database_update_called( + mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data + ) + + # Verify return value + assert result == dataset + + def test_update_internal_dataset_filter_none_values(self, mock_dataset_service_dependencies): + """Test that None values are filtered out except for description field.""" + dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="high_quality") + mock_dataset_service_dependencies["get_dataset"].return_value = dataset + + user = DatasetUpdateTestDataFactory.create_user_mock() + + update_data = { + "name": "new_name", + "description": None, # Should be included + "indexing_technique": "high_quality", + "retrieval_model": "new_model", + "embedding_model_provider": None, # Should be filtered out + "embedding_model": None, # Should be filtered out + } + + result = DatasetService.update_dataset("dataset-123", update_data, user) + + # Verify database update was called with filtered data + expected_filtered_data = { + "name": "new_name", + "description": None, # Description should be included even if None + "indexing_technique": "high_quality", + "retrieval_model": "new_model", + "updated_by": user.id, + "updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None), + } + + actual_call_args = mock_dataset_service_dependencies[ + "db_session" + ].query.return_value.filter_by.return_value.update.call_args[0][0] + # Remove timestamp for comparison as it's dynamic + del actual_call_args["updated_at"] + del expected_filtered_data["updated_at"] + + assert actual_call_args == expected_filtered_data + + # Verify return value + assert result == dataset + + # ==================== Indexing Technique Switch Tests ==================== + + def test_update_internal_dataset_indexing_technique_to_economy( + self, mock_dataset_service_dependencies, mock_internal_provider_dependencies + ): + """Test updating internal dataset indexing technique to economy.""" + dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="high_quality") + mock_dataset_service_dependencies["get_dataset"].return_value = dataset + + user = DatasetUpdateTestDataFactory.create_user_mock() + + update_data = {"indexing_technique": "economy", "retrieval_model": "new_model"} + + result = DatasetService.update_dataset("dataset-123", update_data, user) + + # Verify database update was called with embedding model fields cleared + expected_filtered_data = { + "indexing_technique": "economy", + "embedding_model": None, + "embedding_model_provider": None, + "collection_binding_id": None, + "retrieval_model": "new_model", + "updated_by": user.id, + "updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None), + } + + self._assert_database_update_called( + mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data + ) + + # Verify return value + assert result == dataset + + def test_update_internal_dataset_indexing_technique_to_high_quality( + self, mock_dataset_service_dependencies, mock_internal_provider_dependencies + ): + """Test updating internal dataset indexing technique to high_quality.""" + dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="economy") + mock_dataset_service_dependencies["get_dataset"].return_value = dataset + + user = DatasetUpdateTestDataFactory.create_user_mock() + + # Mock embedding model + embedding_model = DatasetUpdateTestDataFactory.create_embedding_model_mock() + mock_internal_provider_dependencies[ + "model_manager" + ].return_value.get_model_instance.return_value = embedding_model + + # Mock collection binding + binding = DatasetUpdateTestDataFactory.create_collection_binding_mock() + mock_internal_provider_dependencies["get_binding"].return_value = binding + + update_data = { + "indexing_technique": "high_quality", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + "retrieval_model": "new_model", + } + + result = DatasetService.update_dataset("dataset-123", update_data, user) + + # Verify embedding model was validated + mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.assert_called_once_with( + tenant_id=mock_internal_provider_dependencies["current_user"].current_tenant_id, + provider="openai", + model_type=ModelType.TEXT_EMBEDDING, + model="text-embedding-ada-002", + ) + + # Verify collection binding was retrieved + mock_internal_provider_dependencies["get_binding"].assert_called_once_with("openai", "text-embedding-ada-002") + + # Verify database update was called with correct data + expected_filtered_data = { + "indexing_technique": "high_quality", + "embedding_model": "text-embedding-ada-002", + "embedding_model_provider": "openai", + "collection_binding_id": "binding-456", + "retrieval_model": "new_model", + "updated_by": user.id, + "updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None), + } + + self._assert_database_update_called( + mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data + ) + + # Verify vector index task was triggered + mock_internal_provider_dependencies["task"].delay.assert_called_once_with("dataset-123", "add") + + # Verify return value + assert result == dataset + + # ==================== Embedding Model Update Tests ==================== + + def test_update_internal_dataset_keep_existing_embedding_model(self, mock_dataset_service_dependencies): + """Test updating internal dataset without changing embedding model.""" + dataset = DatasetUpdateTestDataFactory.create_dataset_mock( + provider="vendor", + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embedding-ada-002", + collection_binding_id="binding-123", + ) + mock_dataset_service_dependencies["get_dataset"].return_value = dataset + + user = DatasetUpdateTestDataFactory.create_user_mock() + + update_data = {"name": "new_name", "indexing_technique": "high_quality", "retrieval_model": "new_model"} + + result = DatasetService.update_dataset("dataset-123", update_data, user) + + # Verify database update was called with existing embedding model preserved + expected_filtered_data = { + "name": "new_name", + "indexing_technique": "high_quality", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + "collection_binding_id": "binding-123", + "retrieval_model": "new_model", + "updated_by": user.id, + "updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None), + } + + self._assert_database_update_called( + mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data + ) + + # Verify return value + assert result == dataset + + def test_update_internal_dataset_embedding_model_update( + self, mock_dataset_service_dependencies, mock_internal_provider_dependencies + ): + """Test updating internal dataset with new embedding model.""" + dataset = DatasetUpdateTestDataFactory.create_dataset_mock( + provider="vendor", + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embedding-ada-002", + ) + mock_dataset_service_dependencies["get_dataset"].return_value = dataset + + user = DatasetUpdateTestDataFactory.create_user_mock() + + # Mock embedding model + embedding_model = DatasetUpdateTestDataFactory.create_embedding_model_mock("text-embedding-3-small") + mock_internal_provider_dependencies[ + "model_manager" + ].return_value.get_model_instance.return_value = embedding_model + + # Mock collection binding + binding = DatasetUpdateTestDataFactory.create_collection_binding_mock("binding-789") + mock_internal_provider_dependencies["get_binding"].return_value = binding + + update_data = { + "indexing_technique": "high_quality", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-3-small", + "retrieval_model": "new_model", + } + + result = DatasetService.update_dataset("dataset-123", update_data, user) + + # Verify embedding model was validated + mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.assert_called_once_with( + tenant_id=mock_internal_provider_dependencies["current_user"].current_tenant_id, + provider="openai", + model_type=ModelType.TEXT_EMBEDDING, + model="text-embedding-3-small", + ) + + # Verify collection binding was retrieved + mock_internal_provider_dependencies["get_binding"].assert_called_once_with("openai", "text-embedding-3-small") + + # Verify database update was called with correct data + expected_filtered_data = { + "indexing_technique": "high_quality", + "embedding_model": "text-embedding-3-small", + "embedding_model_provider": "openai", + "collection_binding_id": "binding-789", + "retrieval_model": "new_model", + "updated_by": user.id, + "updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None), + } + + self._assert_database_update_called( + mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data + ) + + # Verify vector index task was triggered + mock_internal_provider_dependencies["task"].delay.assert_called_once_with("dataset-123", "update") + + # Verify return value + assert result == dataset + + def test_update_internal_dataset_no_indexing_technique_change(self, mock_dataset_service_dependencies): + """Test updating internal dataset without changing indexing technique.""" + dataset = DatasetUpdateTestDataFactory.create_dataset_mock( + provider="vendor", + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embedding-ada-002", + collection_binding_id="binding-123", + ) + mock_dataset_service_dependencies["get_dataset"].return_value = dataset + + user = DatasetUpdateTestDataFactory.create_user_mock() + + update_data = { + "name": "new_name", + "indexing_technique": "high_quality", # Same as current + "retrieval_model": "new_model", + } + + result = DatasetService.update_dataset("dataset-123", update_data, user) + + # Verify database update was called with correct data + expected_filtered_data = { + "name": "new_name", + "indexing_technique": "high_quality", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + "collection_binding_id": "binding-123", + "retrieval_model": "new_model", + "updated_by": user.id, + "updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None), + } + + self._assert_database_update_called( + mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data + ) + + # Verify return value + assert result == dataset + + # ==================== Error Handling Tests ==================== + + def test_update_dataset_not_found_error(self, mock_dataset_service_dependencies): + """Test error when dataset is not found.""" + mock_dataset_service_dependencies["get_dataset"].return_value = None + + user = DatasetUpdateTestDataFactory.create_user_mock() + update_data = {"name": "new_name"} + + with pytest.raises(ValueError) as context: + DatasetService.update_dataset("dataset-123", update_data, user) + + assert "Dataset not found" in str(context.value) + + def test_update_dataset_permission_error(self, mock_dataset_service_dependencies): + """Test error when user doesn't have permission.""" + dataset = DatasetUpdateTestDataFactory.create_dataset_mock() + mock_dataset_service_dependencies["get_dataset"].return_value = dataset + + user = DatasetUpdateTestDataFactory.create_user_mock() + mock_dataset_service_dependencies["check_permission"].side_effect = NoPermissionError("No permission") + + update_data = {"name": "new_name"} + + with pytest.raises(NoPermissionError): + DatasetService.update_dataset("dataset-123", update_data, user) + + def test_update_internal_dataset_embedding_model_error( + self, mock_dataset_service_dependencies, mock_internal_provider_dependencies + ): + """Test error when embedding model is not available.""" + dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="economy") + mock_dataset_service_dependencies["get_dataset"].return_value = dataset + + user = DatasetUpdateTestDataFactory.create_user_mock() + + # Mock model manager to raise error + mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.side_effect = Exception( + "No Embedding Model available" + ) + + update_data = { + "indexing_technique": "high_quality", + "embedding_model_provider": "invalid_provider", + "embedding_model": "invalid_model", + "retrieval_model": "new_model", + } + + with pytest.raises(Exception) as context: + DatasetService.update_dataset("dataset-123", update_data, user) + + assert "No Embedding Model available".lower() in str(context.value).lower() diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py new file mode 100644 index 0000000000..8ae69c8d64 --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py @@ -0,0 +1,222 @@ +import dataclasses +import secrets +from unittest import mock +from unittest.mock import Mock, patch + +import pytest +from sqlalchemy.orm import Session + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.variables.types import SegmentType +from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID +from core.workflow.nodes import NodeType +from models.enums import DraftVariableType +from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel +from services.workflow_draft_variable_service import ( + DraftVariableSaver, + VariableResetError, + WorkflowDraftVariableService, +) + + +class TestDraftVariableSaver: + def _get_test_app_id(self): + suffix = secrets.token_hex(6) + return f"test_app_id_{suffix}" + + def test__should_variable_be_visible(self): + mock_session = mock.MagicMock(spec=Session) + test_app_id = self._get_test_app_id() + saver = DraftVariableSaver( + session=mock_session, + app_id=test_app_id, + node_id="test_node_id", + node_type=NodeType.START, + invoke_from=InvokeFrom.DEBUGGER, + node_execution_id="test_execution_id", + ) + assert saver._should_variable_be_visible("123_456", NodeType.IF_ELSE, "output") == False + assert saver._should_variable_be_visible("123", NodeType.START, "output") == True + + def test__normalize_variable_for_start_node(self): + @dataclasses.dataclass(frozen=True) + class TestCase: + name: str + input_node_id: str + input_name: str + expected_node_id: str + expected_name: str + + _NODE_ID = "1747228642872" + cases = [ + TestCase( + name="name with `sys.` prefix should return the system node_id", + input_node_id=_NODE_ID, + input_name="sys.workflow_id", + expected_node_id=SYSTEM_VARIABLE_NODE_ID, + expected_name="workflow_id", + ), + TestCase( + name="name without `sys.` prefix should return the original input node_id", + input_node_id=_NODE_ID, + input_name="start_input", + expected_node_id=_NODE_ID, + expected_name="start_input", + ), + TestCase( + name="dummy_variable should return the original input node_id", + input_node_id=_NODE_ID, + input_name="__dummy__", + expected_node_id=_NODE_ID, + expected_name="__dummy__", + ), + ] + + mock_session = mock.MagicMock(spec=Session) + test_app_id = self._get_test_app_id() + saver = DraftVariableSaver( + session=mock_session, + app_id=test_app_id, + node_id=_NODE_ID, + node_type=NodeType.START, + invoke_from=InvokeFrom.DEBUGGER, + node_execution_id="test_execution_id", + ) + for idx, c in enumerate(cases, 1): + fail_msg = f"Test case {c.name} failed, index={idx}" + node_id, name = saver._normalize_variable_for_start_node(c.input_name) + assert node_id == c.expected_node_id, fail_msg + assert name == c.expected_name, fail_msg + + +class TestWorkflowDraftVariableService: + def _get_test_app_id(self): + suffix = secrets.token_hex(6) + return f"test_app_id_{suffix}" + + def test_reset_conversation_variable(self): + """Test resetting a conversation variable""" + mock_session = Mock(spec=Session) + service = WorkflowDraftVariableService(mock_session) + mock_workflow = Mock(spec=Workflow) + mock_workflow.app_id = self._get_test_app_id() + + # Create mock variable + mock_variable = Mock(spec=WorkflowDraftVariable) + mock_variable.get_variable_type.return_value = DraftVariableType.CONVERSATION + mock_variable.id = "var-id" + mock_variable.name = "test_var" + + # Mock the _reset_conv_var method + expected_result = Mock(spec=WorkflowDraftVariable) + with patch.object(service, "_reset_conv_var", return_value=expected_result) as mock_reset_conv: + result = service.reset_variable(mock_workflow, mock_variable) + + mock_reset_conv.assert_called_once_with(mock_workflow, mock_variable) + assert result == expected_result + + def test_reset_node_variable_with_no_execution_id(self): + """Test resetting a node variable with no execution ID - should delete variable""" + mock_session = Mock(spec=Session) + service = WorkflowDraftVariableService(mock_session) + mock_workflow = Mock(spec=Workflow) + mock_workflow.app_id = self._get_test_app_id() + + # Create mock variable with no execution ID + mock_variable = Mock(spec=WorkflowDraftVariable) + mock_variable.get_variable_type.return_value = DraftVariableType.NODE + mock_variable.node_execution_id = None + mock_variable.id = "var-id" + mock_variable.name = "test_var" + + result = service._reset_node_var(mock_workflow, mock_variable) + + # Should delete the variable and return None + mock_session.delete.assert_called_once_with(instance=mock_variable) + mock_session.flush.assert_called_once() + assert result is None + + def test_reset_node_variable_with_missing_execution_record(self): + """Test resetting a node variable when execution record doesn't exist""" + mock_session = Mock(spec=Session) + service = WorkflowDraftVariableService(mock_session) + mock_workflow = Mock(spec=Workflow) + mock_workflow.app_id = self._get_test_app_id() + + # Create mock variable with execution ID + mock_variable = Mock(spec=WorkflowDraftVariable) + mock_variable.get_variable_type.return_value = DraftVariableType.NODE + mock_variable.node_execution_id = "exec-id" + mock_variable.id = "var-id" + mock_variable.name = "test_var" + + # Mock session.scalars to return None (no execution record found) + mock_scalars = Mock() + mock_scalars.first.return_value = None + mock_session.scalars.return_value = mock_scalars + + result = service._reset_node_var(mock_workflow, mock_variable) + + # Should delete the variable and return None + mock_session.delete.assert_called_once_with(instance=mock_variable) + mock_session.flush.assert_called_once() + assert result is None + + def test_reset_node_variable_with_valid_execution_record(self): + """Test resetting a node variable with valid execution record - should restore from execution""" + mock_session = Mock(spec=Session) + service = WorkflowDraftVariableService(mock_session) + mock_workflow = Mock(spec=Workflow) + mock_workflow.app_id = self._get_test_app_id() + + # Create mock variable with execution ID + mock_variable = Mock(spec=WorkflowDraftVariable) + mock_variable.get_variable_type.return_value = DraftVariableType.NODE + mock_variable.node_execution_id = "exec-id" + mock_variable.id = "var-id" + mock_variable.name = "test_var" + mock_variable.node_id = "node-id" + mock_variable.value_type = SegmentType.STRING + + # Create mock execution record + mock_execution = Mock(spec=WorkflowNodeExecutionModel) + mock_execution.process_data_dict = {"test_var": "process_value"} + mock_execution.outputs_dict = {"test_var": "output_value"} + + # Mock session.scalars to return the execution record + mock_scalars = Mock() + mock_scalars.first.return_value = mock_execution + mock_session.scalars.return_value = mock_scalars + + # Mock workflow methods + mock_node_config = {"type": "test_node"} + mock_workflow.get_node_config_by_id.return_value = mock_node_config + mock_workflow.get_node_type_from_node_config.return_value = NodeType.LLM + + result = service._reset_node_var(mock_workflow, mock_variable) + + # Verify variable.set_value was called with the correct value + mock_variable.set_value.assert_called_once() + # Verify last_edited_at was reset + assert mock_variable.last_edited_at is None + # Verify session.flush was called + mock_session.flush.assert_called() + + # Should return the updated variable + assert result == mock_variable + + def test_reset_system_variable_raises_error(self): + """Test that resetting a system variable raises an error""" + mock_session = Mock(spec=Session) + service = WorkflowDraftVariableService(mock_session) + mock_workflow = Mock(spec=Workflow) + mock_workflow.app_id = self._get_test_app_id() + + mock_variable = Mock(spec=WorkflowDraftVariable) + mock_variable.get_variable_type.return_value = DraftVariableType.SYS # Not a valid enum value for this test + mock_variable.id = "var-id" + + with pytest.raises(VariableResetError) as exc_info: + service.reset_variable(mock_workflow, mock_variable) + assert "cannot reset system variable" in str(exc_info.value) + assert "variable_id=var-id" in str(exc_info.value) diff --git a/api/tests/unit_tests/utils/http_parser/test_oauth_convert_request_to_raw_data.py b/api/tests/unit_tests/utils/http_parser/test_oauth_convert_request_to_raw_data.py index f788a9756b..293ac253f5 100644 --- a/api/tests/unit_tests/utils/http_parser/test_oauth_convert_request_to_raw_data.py +++ b/api/tests/unit_tests/utils/http_parser/test_oauth_convert_request_to_raw_data.py @@ -1,3 +1,5 @@ +import json + from werkzeug import Request from werkzeug.datastructures import Headers from werkzeug.test import EnvironBuilder @@ -15,6 +17,59 @@ def test_oauth_convert_request_to_raw_data(): request = Request(builder.get_environ()) raw_request_bytes = oauth_handler._convert_request_to_raw_data(request) - assert b"GET /test HTTP/1.1" in raw_request_bytes + assert b"GET /test? HTTP/1.1" in raw_request_bytes assert b"Content-Type: application/json" in raw_request_bytes assert b"\r\n\r\n" in raw_request_bytes + + +def test_oauth_convert_request_to_raw_data_with_query_params(): + oauth_handler = OAuthHandler() + builder = EnvironBuilder( + method="GET", + path="/test", + query_string="code=abc123&state=xyz789", + headers=Headers({"Content-Type": "application/json"}), + ) + request = Request(builder.get_environ()) + raw_request_bytes = oauth_handler._convert_request_to_raw_data(request) + + assert b"GET /test?code=abc123&state=xyz789 HTTP/1.1" in raw_request_bytes + assert b"Content-Type: application/json" in raw_request_bytes + assert b"\r\n\r\n" in raw_request_bytes + + +def test_oauth_convert_request_to_raw_data_with_post_body(): + oauth_handler = OAuthHandler() + builder = EnvironBuilder( + method="POST", + path="/test", + data="param1=value1¶m2=value2", + headers=Headers({"Content-Type": "application/x-www-form-urlencoded"}), + ) + request = Request(builder.get_environ()) + raw_request_bytes = oauth_handler._convert_request_to_raw_data(request) + + assert b"POST /test? HTTP/1.1" in raw_request_bytes + assert b"Content-Type: application/x-www-form-urlencoded" in raw_request_bytes + assert b"\r\n\r\n" in raw_request_bytes + assert b"param1=value1¶m2=value2" in raw_request_bytes + + +def test_oauth_convert_request_to_raw_data_with_json_body(): + oauth_handler = OAuthHandler() + json_data = {"code": "abc123", "state": "xyz789", "grant_type": "authorization_code"} + builder = EnvironBuilder( + method="POST", + path="/test", + data=json.dumps(json_data), + headers=Headers({"Content-Type": "application/json"}), + ) + request = Request(builder.get_environ()) + raw_request_bytes = oauth_handler._convert_request_to_raw_data(request) + + assert b"POST /test? HTTP/1.1" in raw_request_bytes + assert b"Content-Type: application/json" in raw_request_bytes + assert b"\r\n\r\n" in raw_request_bytes + assert b'"code": "abc123"' in raw_request_bytes + assert b'"state": "xyz789"' in raw_request_bytes + assert b'"grant_type": "authorization_code"' in raw_request_bytes diff --git a/api/uv.lock b/api/uv.lock index a03929510e..6b9deffa0e 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1284,6 +1284,7 @@ dev = [ { name = "coverage" }, { name = "dotenv-linter" }, { name = "faker" }, + { name = "hypothesis" }, { name = "lxml-stubs" }, { name = "mypy" }, { name = "pandas-stubs" }, @@ -1322,6 +1323,7 @@ dev = [ { name = "types-pymysql" }, { name = "types-pyopenssl" }, { name = "types-python-dateutil" }, + { name = "types-python-http-client" }, { name = "types-pywin32" }, { name = "types-pyyaml" }, { name = "types-regex" }, @@ -1461,6 +1463,7 @@ dev = [ { name = "coverage", specifier = "~=7.2.4" }, { name = "dotenv-linter", specifier = "~=0.5.0" }, { name = "faker", specifier = "~=32.1.0" }, + { name = "hypothesis", specifier = ">=6.131.15" }, { name = "lxml-stubs", specifier = "~=0.5.1" }, { name = "mypy", specifier = "~=1.16.0" }, { name = "pandas-stubs", specifier = "~=2.2.3" }, @@ -1499,6 +1502,7 @@ dev = [ { name = "types-pymysql", specifier = "~=1.1.0" }, { name = "types-pyopenssl", specifier = ">=24.1.0" }, { name = "types-python-dateutil", specifier = "~=2.9.0" }, + { name = "types-python-http-client", specifier = ">=3.3.7.20240910" }, { name = "types-pywin32", specifier = "~=310.0.0" }, { name = "types-pyyaml", specifier = "~=6.0.12" }, { name = "types-regex", specifier = "~=2024.11.6" }, @@ -1543,7 +1547,7 @@ vdb = [ { name = "pymochow", specifier = "==1.3.1" }, { name = "pyobvector", specifier = "~=0.1.6" }, { name = "qdrant-client", specifier = "==1.9.0" }, - { name = "tablestore", specifier = "==6.1.0" }, + { name = "tablestore", specifier = "==6.2.0" }, { name = "tcvectordb", specifier = "~=1.6.4" }, { name = "tidb-vector", specifier = "==0.0.9" }, { name = "upstash-vector", specifier = "==0.6.0" }, @@ -1611,9 +1615,9 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "six" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c0/1f/924e3caae75f471eae4b26bd13b698f6af2c44279f67af317439c2f4c46a/ecdsa-0.19.1.tar.gz", hash = "sha256:478cba7b62555866fcb3bb3fe985e06decbdb68ef55713c4e5ab98c57d508e61", size = 201793, upload-time = "2025-03-13T11:52:43.25Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c0/1f/924e3caae75f471eae4b26bd13b698f6af2c44279f67af317439c2f4c46a/ecdsa-0.19.1.tar.gz", hash = "sha256:478cba7b62555866fcb3bb3fe985e06decbdb68ef55713c4e5ab98c57d508e61", size = 201793 } wheels = [ - { url = "https://files.pythonhosted.org/packages/cb/a3/460c57f094a4a165c84a1341c373b0a4f5ec6ac244b998d5021aade89b77/ecdsa-0.19.1-py2.py3-none-any.whl", hash = "sha256:30638e27cf77b7e15c4c4cc1973720149e1033827cfd00661ca5c8cc0cdb24c3", size = 150607, upload-time = "2025-03-13T11:52:41.757Z" }, + { url = "https://files.pythonhosted.org/packages/cb/a3/460c57f094a4a165c84a1341c373b0a4f5ec6ac244b998d5021aade89b77/ecdsa-0.19.1-py2.py3-none-any.whl", hash = "sha256:30638e27cf77b7e15c4c4cc1973720149e1033827cfd00661ca5c8cc0cdb24c3", size = 150607 }, ] [[package]] @@ -1650,15 +1654,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/91/db/a0335710caaa6d0aebdaa65ad4df789c15d89b7babd9a30277838a7d9aac/emoji-2.14.1-py3-none-any.whl", hash = "sha256:35a8a486c1460addb1499e3bf7929d3889b2e2841a57401903699fef595e942b", size = 590617 }, ] -[[package]] -name = "enum34" -version = "1.1.10" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/11/c4/2da1f4952ba476677a42f25cd32ab8aaf0e1c0d0e00b89822b835c7e654c/enum34-1.1.10.tar.gz", hash = "sha256:cce6a7477ed816bd2542d03d53db9f0db935dd013b70f336a95c73979289f248", size = 28187 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/63/f6/ccb1c83687756aeabbf3ca0f213508fcfb03883ff200d201b3a4c60cedcc/enum34-1.1.10-py3-none-any.whl", hash = "sha256:c3858660960c984d6ab0ebad691265180da2b43f07e061c0f8dca9ef3cffd328", size = 11224 }, -] - [[package]] name = "esdk-obs-python" version = "3.24.6.1" @@ -2556,6 +2551,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/30/47d0bf6072f7252e6521f3447ccfa40b421b6824517f82854703d0f5a98b/hyperframe-6.1.0-py3-none-any.whl", hash = "sha256:b03380493a519fce58ea5af42e4a42317bf9bd425596f7a0835ffce80f1a42e5", size = 13007 }, ] +[[package]] +name = "hypothesis" +version = "6.131.15" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "sortedcontainers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f1/6f/1e291f80627f3e043b19a86f9f6b172b910e3575577917d3122a6558410d/hypothesis-6.131.15.tar.gz", hash = "sha256:11849998ae5eecc8c586c6c98e47677fcc02d97475065f62768cfffbcc15ef7a", size = 436596 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b6/c7/78597bcec48e1585ea9029deb2bf2341516e90dd615a3db498413d68a4cc/hypothesis-6.131.15-py3-none-any.whl", hash = "sha256:e02e67e9f3cfd4cd4a67ccc03bf7431beccc1a084c5e90029799ddd36ce006d7", size = 501128 }, +] + [[package]] name = "idna" version = "3.10" @@ -4615,9 +4623,9 @@ wheels = [ name = "python-http-client" version = "3.3.7" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/56/fa/284e52a8c6dcbe25671f02d217bf2f85660db940088faf18ae7a05e97313/python_http_client-3.3.7.tar.gz", hash = "sha256:bf841ee45262747e00dec7ee9971dfb8c7d83083f5713596488d67739170cea0", size = 9377, upload-time = "2022-03-09T20:23:56.386Z" } +sdist = { url = "https://files.pythonhosted.org/packages/56/fa/284e52a8c6dcbe25671f02d217bf2f85660db940088faf18ae7a05e97313/python_http_client-3.3.7.tar.gz", hash = "sha256:bf841ee45262747e00dec7ee9971dfb8c7d83083f5713596488d67739170cea0", size = 9377 } wheels = [ - { url = "https://files.pythonhosted.org/packages/29/31/9b360138f4e4035ee9dac4fe1132b6437bd05751aaf1db2a2d83dc45db5f/python_http_client-3.3.7-py3-none-any.whl", hash = "sha256:ad371d2bbedc6ea15c26179c6222a78bc9308d272435ddf1d5c84f068f249a36", size = 8352, upload-time = "2022-03-09T20:23:54.862Z" }, + { url = "https://files.pythonhosted.org/packages/29/31/9b360138f4e4035ee9dac4fe1132b6437bd05751aaf1db2a2d83dc45db5f/python_http_client-3.3.7-py3-none-any.whl", hash = "sha256:ad371d2bbedc6ea15c26179c6222a78bc9308d272435ddf1d5c84f068f249a36", size = 8352 }, ] [[package]] @@ -5103,9 +5111,9 @@ dependencies = [ { name = "python-http-client" }, { name = "werkzeug" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/11/31/62e00433878dccf33edf07f8efa417b9030a2464eb3b04bbd797a11b4447/sendgrid-6.12.4.tar.gz", hash = "sha256:9e88b849daf0fa4bdf256c3b5da9f5a3272402c0c2fd6b1928c9de440db0a03d", size = 50271, upload-time = "2025-06-12T10:29:37.213Z" } +sdist = { url = "https://files.pythonhosted.org/packages/11/31/62e00433878dccf33edf07f8efa417b9030a2464eb3b04bbd797a11b4447/sendgrid-6.12.4.tar.gz", hash = "sha256:9e88b849daf0fa4bdf256c3b5da9f5a3272402c0c2fd6b1928c9de440db0a03d", size = 50271 } wheels = [ - { url = "https://files.pythonhosted.org/packages/c2/9c/45d068fd831a65e6ed1e2ab3233de58784842afdc62fdcdd0a01bbb6b39d/sendgrid-6.12.4-py3-none-any.whl", hash = "sha256:9a211b96241e63bd5b9ed9afcc8608f4bcac426e4a319b3920ab877c8426e92c", size = 102122, upload-time = "2025-06-12T10:29:35.457Z" }, + { url = "https://files.pythonhosted.org/packages/c2/9c/45d068fd831a65e6ed1e2ab3233de58784842afdc62fdcdd0a01bbb6b39d/sendgrid-6.12.4-py3-none-any.whl", hash = "sha256:9a211b96241e63bd5b9ed9afcc8608f4bcac426e4a319b3920ab877c8426e92c", size = 102122 }, ] [[package]] @@ -5241,6 +5249,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/37/c3/6eeb6034408dac0fa653d126c9204ade96b819c936e136c5e8a6897eee9c/socksio-1.0.0-py3-none-any.whl", hash = "sha256:95dc1f15f9b34e8d7b16f06d74b8ccf48f609af32ab33c608d08761c5dcbb1f3", size = 12763 }, ] +[[package]] +name = "sortedcontainers" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e8/c4/ba2f8066cceb6f23394729afe52f3bf7adec04bf9ed2c820b39e19299111/sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88", size = 30594 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575 }, +] + [[package]] name = "soupsieve" version = "2.7" @@ -5349,12 +5366,11 @@ wheels = [ [[package]] name = "tablestore" -version = "6.1.0" +version = "6.2.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, { name = "crc32c" }, - { name = "enum34" }, { name = "flatbuffers" }, { name = "future" }, { name = "numpy" }, @@ -5362,7 +5378,10 @@ dependencies = [ { name = "six" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0f/ed/5bdd906ec9d2dbae3909525dbb7602558c377e0cbcdddb6405d2d0d3f1af/tablestore-6.1.0.tar.gz", hash = "sha256:bfe6a3e0fe88a230729723c357f4a46b8869a06a4b936db20692ed587a721c1c", size = 135690 } +sdist = { url = "https://files.pythonhosted.org/packages/a1/58/48d65d181a69f7db19f7cdee01d252168fbfbad2d1bb25abed03e6df3b05/tablestore-6.2.0.tar.gz", hash = "sha256:0773e77c00542be1bfebbc3c7a85f72a881c63e4e7df7c5a9793a54144590e68", size = 85942 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9c/da/30451712a769bcf417add8e81163d478a4d668b0e8d489a9d667260d55df/tablestore-6.2.0-py3-none-any.whl", hash = "sha256:6af496d841ab1ff3f78b46abbd87b95a08d89605c51664d2b30933b1d1c5583a", size = 106297 }, +] [[package]] name = "tabulate" @@ -5865,6 +5884,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c5/3f/b0e8db149896005adc938a1e7f371d6d7e9eca4053a29b108978ed15e0c2/types_python_dateutil-2.9.0.20250516-py3-none-any.whl", hash = "sha256:2b2b3f57f9c6a61fba26a9c0ffb9ea5681c9b83e69cd897c6b5f668d9c0cab93", size = 14356 }, ] +[[package]] +name = "types-python-http-client" +version = "3.3.7.20240910" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e1/d7/bb2754c2d1b20c1890593ec89799c99e8875b04f474197c41354f41e9d31/types-python-http-client-3.3.7.20240910.tar.gz", hash = "sha256:8a6ebd30ad4b90a329ace69c240291a6176388624693bc971a5ecaa7e9b05074", size = 2804 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/64/95/8f492d37d99630e096acbb4071788483282a34a73ae89dd1a5727f4189cc/types_python_http_client-3.3.7.20240910-py3-none-any.whl", hash = "sha256:58941bd986fb8bb0f4f782ef376be145ece8023f391364fbcd22bd26b13a140e", size = 3917 }, +] + [[package]] name = "types-pytz" version = "2025.2.0.20250516" diff --git a/docker/.env.example b/docker/.env.example index 5a2a426338..a024566c8f 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -285,6 +285,7 @@ BROKER_USE_SSL=false # If you are using Redis Sentinel for high availability, configure the following settings. CELERY_USE_SENTINEL=false CELERY_SENTINEL_MASTER_NAME= +CELERY_SENTINEL_PASSWORD= CELERY_SENTINEL_SOCKET_TIMEOUT=0.1 # ------------------------------ @@ -798,6 +799,9 @@ HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 HTTP_REQUEST_NODE_SSL_VERIFY=True +# Respect X-* headers to redirect clients +RESPECT_XFORWARD_HEADERS_ENABLED=false + # SSRF Proxy server HTTP URL SSRF_PROXY_HTTP_URL=http://ssrf_proxy:3128 # SSRF Proxy server HTTPS URL diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index f370934a3a..a6a4ed959a 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -2,7 +2,7 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:1.4.3 + image: langgenius/dify-api:1.5.0 restart: always environment: # Use the shared environment variables. @@ -31,7 +31,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:1.4.3 + image: langgenius/dify-api:1.5.0 restart: always environment: # Use the shared environment variables. @@ -57,7 +57,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.4.3 + image: langgenius/dify-web:1.5.0 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 5f13060658..0d187176b0 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -79,6 +79,7 @@ x-shared-env: &shared-api-worker-env BROKER_USE_SSL: ${BROKER_USE_SSL:-false} CELERY_USE_SENTINEL: ${CELERY_USE_SENTINEL:-false} CELERY_SENTINEL_MASTER_NAME: ${CELERY_SENTINEL_MASTER_NAME:-} + CELERY_SENTINEL_PASSWORD: ${CELERY_SENTINEL_PASSWORD:-} CELERY_SENTINEL_SOCKET_TIMEOUT: ${CELERY_SENTINEL_SOCKET_TIMEOUT:-0.1} WEB_API_CORS_ALLOW_ORIGINS: ${WEB_API_CORS_ALLOW_ORIGINS:-*} CONSOLE_CORS_ALLOW_ORIGINS: ${CONSOLE_CORS_ALLOW_ORIGINS:-*} @@ -355,6 +356,7 @@ x-shared-env: &shared-api-worker-env HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True} + RESPECT_XFORWARD_HEADERS_ENABLED: ${RESPECT_XFORWARD_HEADERS_ENABLED:-false} SSRF_PROXY_HTTP_URL: ${SSRF_PROXY_HTTP_URL:-http://ssrf_proxy:3128} SSRF_PROXY_HTTPS_URL: ${SSRF_PROXY_HTTPS_URL:-http://ssrf_proxy:3128} LOOP_NODE_MAX_COUNT: ${LOOP_NODE_MAX_COUNT:-100} @@ -515,7 +517,7 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:1.4.3 + image: langgenius/dify-api:1.5.0 restart: always environment: # Use the shared environment variables. @@ -544,7 +546,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:1.4.3 + image: langgenius/dify-api:1.5.0 restart: always environment: # Use the shared environment variables. @@ -570,7 +572,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.4.3 + image: langgenius/dify-web:1.5.0 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/images/GitHub_README_if.png b/images/GitHub_README_if.png index 2a4e67264e..10c9d87b08 100644 Binary files a/images/GitHub_README_if.png and b/images/GitHub_README_if.png differ diff --git a/tests/unit_tests/events/test_provider_update_deadlock_prevention.py b/tests/unit_tests/events/test_provider_update_deadlock_prevention.py new file mode 100644 index 0000000000..47c175acd7 --- /dev/null +++ b/tests/unit_tests/events/test_provider_update_deadlock_prevention.py @@ -0,0 +1,248 @@ +import threading +from unittest.mock import Mock, patch + +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity +from core.entities.provider_entities import QuotaUnit +from events.event_handlers.update_provider_when_message_created import ( + handle, + get_update_stats, +) +from models.provider import ProviderType +from sqlalchemy.exc import OperationalError + + +class TestProviderUpdateDeadlockPrevention: + """Test suite for deadlock prevention in Provider updates.""" + + def setup_method(self): + """Setup test fixtures.""" + self.mock_message = Mock() + self.mock_message.answer_tokens = 100 + + self.mock_app_config = Mock() + self.mock_app_config.tenant_id = "test-tenant-123" + + self.mock_model_conf = Mock() + self.mock_model_conf.provider = "openai" + + self.mock_system_config = Mock() + self.mock_system_config.current_quota_type = QuotaUnit.TOKENS + + self.mock_provider_config = Mock() + self.mock_provider_config.using_provider_type = ProviderType.SYSTEM + self.mock_provider_config.system_configuration = self.mock_system_config + + self.mock_provider_bundle = Mock() + self.mock_provider_bundle.configuration = self.mock_provider_config + + self.mock_model_conf.provider_model_bundle = self.mock_provider_bundle + + self.mock_generate_entity = Mock(spec=ChatAppGenerateEntity) + self.mock_generate_entity.app_config = self.mock_app_config + self.mock_generate_entity.model_conf = self.mock_model_conf + + @patch("events.event_handlers.update_provider_when_message_created.db") + def test_consolidated_handler_basic_functionality(self, mock_db): + """Test that the consolidated handler performs both updates correctly.""" + # Setup mock query chain + mock_query = Mock() + mock_db.session.query.return_value = mock_query + mock_query.filter.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.update.return_value = 1 # 1 row affected + + # Call the handler + handle(self.mock_message, application_generate_entity=self.mock_generate_entity) + + # Verify db.session.query was called + assert mock_db.session.query.called + + # Verify commit was called + mock_db.session.commit.assert_called_once() + + # Verify no rollback was called + assert not mock_db.session.rollback.called + + @patch("events.event_handlers.update_provider_when_message_created.db") + def test_deadlock_retry_mechanism(self, mock_db): + """Test that deadlock errors trigger retry logic.""" + # Setup mock to raise deadlock error on first attempt, succeed on second + mock_query = Mock() + mock_db.session.query.return_value = mock_query + mock_query.filter.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.update.return_value = 1 + + # First call raises deadlock, second succeeds + mock_db.session.commit.side_effect = [ + OperationalError("deadlock detected", None, None), + None, # Success on retry + ] + + # Call the handler + handle(self.mock_message, application_generate_entity=self.mock_generate_entity) + + # Verify commit was called twice (original + retry) + assert mock_db.session.commit.call_count == 2 + + # Verify rollback was called once (after first failure) + mock_db.session.rollback.assert_called_once() + + @patch("events.event_handlers.update_provider_when_message_created.db") + @patch("events.event_handlers.update_provider_when_message_created.time.sleep") + def test_exponential_backoff_timing(self, mock_sleep, mock_db): + """Test that retry delays follow exponential backoff pattern.""" + # Setup mock to fail twice, succeed on third attempt + mock_query = Mock() + mock_db.session.query.return_value = mock_query + mock_query.filter.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.update.return_value = 1 + + mock_db.session.commit.side_effect = [ + OperationalError("deadlock detected", None, None), + OperationalError("deadlock detected", None, None), + None, # Success on third attempt + ] + + # Call the handler + handle(self.mock_message, application_generate_entity=self.mock_generate_entity) + + # Verify sleep was called twice with increasing delays + assert mock_sleep.call_count == 2 + + # First delay should be around 0.1s + jitter + first_delay = mock_sleep.call_args_list[0][0][0] + assert 0.1 <= first_delay <= 0.3 + + # Second delay should be around 0.2s + jitter + second_delay = mock_sleep.call_args_list[1][0][0] + assert 0.2 <= second_delay <= 0.4 + + def test_concurrent_handler_execution(self): + """Test that multiple handlers can run concurrently without deadlock.""" + results = [] + errors = [] + + def run_handler(): + try: + with patch( + "events.event_handlers.update_provider_when_message_created.db" + ) as mock_db: + mock_query = Mock() + mock_db.session.query.return_value = mock_query + mock_query.filter.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.update.return_value = 1 + + handle( + self.mock_message, + application_generate_entity=self.mock_generate_entity, + ) + results.append("success") + except Exception as e: + errors.append(str(e)) + + # Run multiple handlers concurrently + threads = [] + for _ in range(5): + thread = threading.Thread(target=run_handler) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join(timeout=5) + + # Verify all handlers completed successfully + assert len(results) == 5 + assert len(errors) == 0 + + def test_performance_stats_tracking(self): + """Test that performance statistics are tracked correctly.""" + # Reset stats + stats = get_update_stats() + initial_total = stats["total_updates"] + + with patch( + "events.event_handlers.update_provider_when_message_created.db" + ) as mock_db: + mock_query = Mock() + mock_db.session.query.return_value = mock_query + mock_query.filter.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.update.return_value = 1 + + # Call handler + handle( + self.mock_message, application_generate_entity=self.mock_generate_entity + ) + + # Check that stats were updated + updated_stats = get_update_stats() + assert updated_stats["total_updates"] == initial_total + 1 + assert updated_stats["successful_updates"] >= initial_total + 1 + + def test_non_chat_entity_ignored(self): + """Test that non-chat entities are ignored by the handler.""" + # Create a non-chat entity + mock_non_chat_entity = Mock() + mock_non_chat_entity.__class__.__name__ = "NonChatEntity" + + with patch( + "events.event_handlers.update_provider_when_message_created.db" + ) as mock_db: + # Call handler with non-chat entity + handle(self.mock_message, application_generate_entity=mock_non_chat_entity) + + # Verify no database operations were performed + assert not mock_db.session.query.called + assert not mock_db.session.commit.called + + @patch("events.event_handlers.update_provider_when_message_created.db") + def test_quota_calculation_tokens(self, mock_db): + """Test quota calculation for token-based quotas.""" + # Setup token-based quota + self.mock_system_config.current_quota_type = QuotaUnit.TOKENS + self.mock_message.answer_tokens = 150 + + mock_query = Mock() + mock_db.session.query.return_value = mock_query + mock_query.filter.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.update.return_value = 1 + + # Call handler + handle(self.mock_message, application_generate_entity=self.mock_generate_entity) + + # Verify update was called with token count + update_calls = mock_query.update.call_args_list + + # Should have at least one call with quota_used update + quota_update_found = False + for call in update_calls: + values = call[0][0] # First argument to update() + if "quota_used" in values: + quota_update_found = True + break + + assert quota_update_found + + @patch("events.event_handlers.update_provider_when_message_created.db") + def test_quota_calculation_times(self, mock_db): + """Test quota calculation for times-based quotas.""" + # Setup times-based quota + self.mock_system_config.current_quota_type = QuotaUnit.TIMES + + mock_query = Mock() + mock_db.session.query.return_value = mock_query + mock_query.filter.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.update.return_value = 1 + + # Call handler + handle(self.mock_message, application_generate_entity=self.mock_generate_entity) + + # Verify update was called + assert mock_query.update.called + assert mock_db.session.commit.called diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index 7b6e66f7e7..c6d0e776dd 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -39,16 +39,19 @@ import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigge export type IAppInfoProps = { expand: boolean + onlyShowDetail?: boolean + openState?: boolean + onDetailExpand?: (expand: boolean) => void } -const AppInfo = ({ expand }: IAppInfoProps) => { +const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailExpand }: IAppInfoProps) => { const { t } = useTranslation() const { notify } = useContext(ToastContext) const { replace } = useRouter() const { onPlanInfoChanged } = useProviderContext() const appDetail = useAppStore(state => state.appDetail) const setAppDetail = useAppStore(state => state.setAppDetail) - const [open, setOpen] = useState(false) + const [open, setOpen] = useState(openState) const [showEditModal, setShowEditModal] = useState(false) const [showDuplicateModal, setShowDuplicateModal] = useState(false) const [showConfirmDelete, setShowConfirmDelete] = useState(false) @@ -193,43 +196,48 @@ const AppInfo = ({ expand }: IAppInfoProps) => { return (
- + ) + } +
+ + )} setOpen(false)} + show={onlyShowDetail ? openState : open} + onClose={() => { + setOpen(false) + onDetailExpand?.(false) + }} className='absolute bottom-2 left-2 top-2 flex w-[420px] flex-col rounded-2xl !p-0' >
@@ -248,7 +256,7 @@ const AppInfo = ({ expand }: IAppInfoProps) => {
{/* description */} {appDetail.description && ( -
{appDetail.description}
+
{appDetail.description}
)} {/* operations */}
@@ -258,6 +266,7 @@ const AppInfo = ({ expand }: IAppInfoProps) => { className='gap-[1px]' onClick={() => { setOpen(false) + onDetailExpand?.(false) setShowEditModal(true) }} > @@ -270,6 +279,7 @@ const AppInfo = ({ expand }: IAppInfoProps) => { className='gap-[1px]' onClick={() => { setOpen(false) + onDetailExpand?.(false) setShowDuplicateModal(true) }}> @@ -308,6 +318,7 @@ const AppInfo = ({ expand }: IAppInfoProps) => { &&
{ setOpen(false) + onDetailExpand?.(false) setShowImportDSLModal(true) }}> @@ -319,6 +330,7 @@ const AppInfo = ({ expand }: IAppInfoProps) => { &&
{ setOpen(false) + onDetailExpand?.(false) setShowSwitchModal(true) }}> @@ -345,6 +357,7 @@ const AppInfo = ({ expand }: IAppInfoProps) => { className='gap-0.5' onClick={() => { setOpen(false) + onDetailExpand?.(false) setShowConfirmDelete(true) }} > diff --git a/web/app/components/app-sidebar/app-sidebar-dropdown.tsx b/web/app/components/app-sidebar/app-sidebar-dropdown.tsx new file mode 100644 index 0000000000..b1da43ae14 --- /dev/null +++ b/web/app/components/app-sidebar/app-sidebar-dropdown.tsx @@ -0,0 +1,125 @@ +import React, { useCallback, useRef, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { useAppContext } from '@/context/app-context' +import { + RiEqualizer2Line, + RiMenuLine, +} from '@remixicon/react' +import { + PortalToFollowElem, + PortalToFollowElemContent, + PortalToFollowElemTrigger, +} from '@/app/components/base/portal-to-follow-elem' +import AppIcon from '../base/app-icon' +import Divider from '../base/divider' +import AppInfo from './app-info' +import NavLink from './navLink' +import { useStore as useAppStore } from '@/app/components/app/store' +import type { NavIcon } from './navLink' +import cn from '@/utils/classnames' + +type Props = { + navigation: Array<{ + name: string + href: string + icon: NavIcon + selectedIcon: NavIcon + }> +} + +const AppSidebarDropdown = ({ navigation }: Props) => { + const { t } = useTranslation() + const { isCurrentWorkspaceEditor } = useAppContext() + const appDetail = useAppStore(state => state.appDetail) + const [detailExpand, setDetailExpand] = useState(false) + + const [open, doSetOpen] = useState(false) + const openRef = useRef(open) + const setOpen = useCallback((v: boolean) => { + doSetOpen(v) + openRef.current = v + }, [doSetOpen]) + const handleTrigger = useCallback(() => { + setOpen(!openRef.current) + }, [setOpen]) + + if (!appDetail) + return null + + return ( + <> +
+ + +
+ + +
+
+ +
+
+
{ + setDetailExpand(true) + setOpen(false) + }} + > +
+ +
+
+ +
+
+
+
+
+
{appDetail.name}
+
+
{appDetail.mode === 'advanced-chat' ? t('app.types.advanced') : appDetail.mode === 'agent-chat' ? t('app.types.agent') : appDetail.mode === 'chat' ? t('app.types.chatbot') : appDetail.mode === 'completion' ? t('app.types.completion') : t('app.types.workflow')}
+
+
+
+
+ +
+ +
+
+
+
+
+ +
+ + ) +} + +export default AppSidebarDropdown diff --git a/web/app/components/app-sidebar/index.tsx b/web/app/components/app-sidebar/index.tsx index f58985ed96..b6bfc0e9ac 100644 --- a/web/app/components/app-sidebar/index.tsx +++ b/web/app/components/app-sidebar/index.tsx @@ -1,4 +1,5 @@ -import React, { useEffect } from 'react' +import React, { useEffect, useState } from 'react' +import { usePathname } from 'next/navigation' import { useShallow } from 'zustand/react/shallow' import { RiLayoutLeft2Line, RiLayoutRight2Line } from '@remixicon/react' import NavLink from './navLink' @@ -6,8 +7,10 @@ import type { NavIcon } from './navLink' import AppBasic from './basic' import AppInfo from './app-info' import DatasetInfo from './dataset-info' +import AppSidebarDropdown from './app-sidebar-dropdown' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import { useStore as useAppStore } from '@/app/components/app/store' +import { useEventEmitterContextContext } from '@/context/event-emitter' import cn from '@/utils/classnames' export type IAppDetailNavProps = { @@ -39,6 +42,18 @@ const AppDetailNav = ({ title, desc, isExternal, icon, icon_background, navigati setAppSiderbarExpand(state === 'expand' ? 'collapse' : 'expand') } + // // Check if the current path is a workflow canvas & fullscreen + const pathname = usePathname() + const inWorkflowCanvas = pathname.endsWith('/workflow') + const workflowCanvasMaximize = localStorage.getItem('workflow-canvas-maximize') === 'true' + const [hideHeader, setHideHeader] = useState(workflowCanvasMaximize) + const { eventEmitter } = useEventEmitterContextContext() + + eventEmitter?.useSubscription((v: any) => { + if (v?.type === 'workflow-canvas-maximize') + setHideHeader(v.payload) + }) + useEffect(() => { if (appSidebarExpand) { localStorage.setItem('app-detail-collapse-or-expand', appSidebarExpand) @@ -46,6 +61,14 @@ const AppDetailNav = ({ title, desc, isExternal, icon, icon_background, navigati } }, [appSidebarExpand, setAppSiderbarExpand]) + if (inWorkflowCanvas && hideHeader) { + return ( +
+ +
+ ) +} + return (
= ({ }} variableBlock={{ show: true, - variables: modelConfig.configs.prompt_variables.filter(item => item.type !== 'api').map(item => ({ + variables: modelConfig.configs.prompt_variables.filter(item => item.type !== 'api' && item.key && item.key.trim() && item.name && item.name.trim()).map(item => ({ name: item.name, value: item.key, })), diff --git a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx index 2a9a15296e..eb0f524386 100644 --- a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx @@ -97,20 +97,31 @@ const Prompt: FC = ({ }, }) } - const promptVariablesObj = (() => { - const obj: Record = {} - promptVariables.forEach((item) => { - obj[item.key] = true - }) - return obj - })() const [newPromptVariables, setNewPromptVariables] = React.useState(promptVariables) const [newTemplates, setNewTemplates] = React.useState('') const [isShowConfirmAddVar, { setTrue: showConfirmAddVar, setFalse: hideConfirmAddVar }] = useBoolean(false) const handleChange = (newTemplates: string, keys: string[]) => { - const newPromptVariables = keys.filter(key => !(key in promptVariablesObj) && !externalDataToolsConfig.find(item => item.variable === key)).map(key => getNewVar(key, '')) + // Filter out keys that are not properly defined (either not exist or exist but without valid name) + const newPromptVariables = keys.filter((key) => { + // Check if key exists in external data tools + if (externalDataToolsConfig.find((item: ExternalDataTool) => item.variable === key)) + return false + + // Check if key exists in prompt variables + const existingVar = promptVariables.find((item: PromptVariable) => item.key === key) + if (!existingVar) { + // Variable doesn't exist at all + return true + } + + // Variable exists but check if it has valid name and key + return !existingVar.name || !existingVar.name.trim() || !existingVar.key || !existingVar.key.trim() + + return false + }).map(key => getNewVar(key, '')) + if (newPromptVariables.length > 0) { setNewPromptVariables(newPromptVariables) setNewTemplates(newTemplates) @@ -210,14 +221,14 @@ const Prompt: FC = ({ }} variableBlock={{ show: true, - variables: modelConfig.configs.prompt_variables.filter(item => item.type !== 'api').map(item => ({ + variables: modelConfig.configs.prompt_variables.filter((item: PromptVariable) => item.type !== 'api' && item.key && item.key.trim() && item.name && item.name.trim()).map((item: PromptVariable) => ({ name: item.name, value: item.key, })), }} externalToolBlock={{ show: true, - externalTools: modelConfig.configs.prompt_variables.filter(item => item.type === 'api').map(item => ({ + externalTools: modelConfig.configs.prompt_variables.filter((item: PromptVariable) => item.type === 'api').map((item: PromptVariable) => ({ name: item.name, variableName: item.key, icon: item.icon, diff --git a/web/app/components/app/configuration/config/agent/prompt-editor.tsx b/web/app/components/app/configuration/config/agent/prompt-editor.tsx index 7f7f140eda..579b7c4d64 100644 --- a/web/app/components/app/configuration/config/agent/prompt-editor.tsx +++ b/web/app/components/app/configuration/config/agent/prompt-editor.tsx @@ -107,7 +107,7 @@ const Editor: FC = ({ }} variableBlock={{ show: true, - variables: modelConfig.configs.prompt_variables.map(item => ({ + variables: modelConfig.configs.prompt_variables.filter(item => item.key && item.key.trim() && item.name && item.name.trim()).map(item => ({ name: item.name, value: item.key, })), diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index 208fddecd1..47f8c09e39 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -41,6 +41,7 @@ import { buildChatItemTree, getThreadMessages } from '@/app/components/base/chat import { getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils' import cn from '@/utils/classnames' import { noop } from 'lodash-es' +import PromptLogModal from '../../base/prompt-log-modal' dayjs.extend(utc) dayjs.extend(timezone) @@ -190,11 +191,14 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { const { userProfile: { timezone } } = useAppContext() const { formatTime } = useTimestamp() const { onClose, appDetail } = useContext(DrawerContext) - const { currentLogItem, setCurrentLogItem, showMessageLogModal, setShowMessageLogModal, currentLogModalActiveTab } = useAppStore(useShallow(state => ({ + const { notify } = useContext(ToastContext) + const { currentLogItem, setCurrentLogItem, showMessageLogModal, setShowMessageLogModal, showPromptLogModal, setShowPromptLogModal, currentLogModalActiveTab } = useAppStore(useShallow(state => ({ currentLogItem: state.currentLogItem, setCurrentLogItem: state.setCurrentLogItem, showMessageLogModal: state.showMessageLogModal, setShowMessageLogModal: state.setShowMessageLogModal, + showPromptLogModal: state.showPromptLogModal, + setShowPromptLogModal: state.setShowPromptLogModal, currentLogModalActiveTab: state.currentLogModalActiveTab, }))) const { t } = useTranslation() @@ -309,18 +313,34 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { return item })) }, [allChatItems]) - const handleAnnotationRemoved = useCallback((index: number) => { - setAllChatItems(allChatItems.map((item, i) => { - if (i === index) { - return { - ...item, - content: item.content, - annotation: undefined, - } + const handleAnnotationRemoved = useCallback(async (index: number): Promise => { + const annotation = allChatItems[index]?.annotation + + try { + if (annotation?.id) { + const { delAnnotation } = await import('@/service/annotation') + await delAnnotation(appDetail?.id || '', annotation.id) } - return item - })) - }, [allChatItems]) + + setAllChatItems(allChatItems.map((item, i) => { + if (i === index) { + return { + ...item, + content: item.content, + annotation: undefined, + } + } + return item + })) + + notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) + return true + } + catch { + notify({ type: 'error', message: t('common.actionMsg.modifiedUnsuccessfully') }) + return false + } + }, [allChatItems, appDetail?.id, t]) const fetchInitiated = useRef(false) @@ -516,6 +536,16 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { defaultTab={currentLogModalActiveTab} /> )} + {!isChatMode && showPromptLogModal && ( + { + setCurrentLogItem() + setShowPromptLogModal(false) + }} + /> + )}
) } diff --git a/web/app/components/app/text-generate/item/index.tsx b/web/app/components/app/text-generate/item/index.tsx index aa3ffa33c4..92d86351e0 100644 --- a/web/app/components/app/text-generate/item/index.tsx +++ b/web/app/components/app/text-generate/item/index.tsx @@ -171,7 +171,7 @@ const GenerationItem: FC = ({ appId: params.appId as string, messageId: messageId!, }) - const logItem = { + const logItem = Array.isArray(data.message) ? { ...data, log: [ ...data.message, @@ -185,6 +185,11 @@ const GenerationItem: FC = ({ ] : []), ], + } : { + ...data, + log: [typeof data.message === 'string' ? { + text: data.message, + } : data.message], } setCurrentLogItem(logItem) setShowPromptLogModal(true) diff --git a/web/app/components/base/app-icon-picker/index.tsx b/web/app/components/base/app-icon-picker/index.tsx index 975f8aeb6c..8e66cd38cf 100644 --- a/web/app/components/base/app-icon-picker/index.tsx +++ b/web/app/components/base/app-icon-picker/index.tsx @@ -112,7 +112,7 @@ const AppIconPicker: FC = ({ isShow closable={false} wrapperClassName={className} - className={cn(s.container, '!w-[362px] !p-0')} + className={cn(s.container, '!h-[462px] !w-[362px] !p-0')} > {!DISABLE_UPLOAD_IMAGE_AS_ICON &&
@@ -131,8 +131,8 @@ const AppIconPicker: FC = ({
} - - + {activeTab === 'emoji' && } + {activeTab === 'image' && }
diff --git a/web/app/components/base/emoji-picker/Inner.tsx b/web/app/components/base/emoji-picker/Inner.tsx index 34ce3f7da0..8d05967f33 100644 --- a/web/app/components/base/emoji-picker/Inner.tsx +++ b/web/app/components/base/emoji-picker/Inner.tsx @@ -5,6 +5,8 @@ import data from '@emoji-mart/data' import type { EmojiMartData } from '@emoji-mart/data' import { init } from 'emoji-mart' import { + ChevronDownIcon, + ChevronUpIcon, MagnifyingGlassIcon, } from '@heroicons/react/24/outline' import Input from '@/app/components/base/input' @@ -60,16 +62,20 @@ const EmojiPickerInner: FC = ({ const { categories } = data as EmojiMartData const [selectedEmoji, setSelectedEmoji] = useState('') const [selectedBackground, setSelectedBackground] = useState(backgroundColors[0]) + const [showStyleColors, setShowStyleColors] = useState(false) const [searchedEmojis, setSearchedEmojis] = useState([]) const [isSearching, setIsSearching] = useState(false) React.useEffect(() => { - if (selectedEmoji && selectedBackground) - onSelect?.(selectedEmoji, selectedBackground) + if (selectedEmoji) { + setShowStyleColors(true) + if (selectedBackground) + onSelect?.(selectedEmoji, selectedBackground) + } }, [onSelect, selectedEmoji, selectedBackground]) - return
+ return
@@ -95,7 +101,7 @@ const EmojiPickerInner: FC = ({
-
+
{isSearching && <>

Search

@@ -141,33 +147,34 @@ const EmojiPickerInner: FC = ({
{/* Color Select */} -
+

Choose Style

-
- {backgroundColors.map((color) => { - return
{ - setSelectedBackground(color) - }} - > -
- {selectedEmoji !== '' && } -
-
- })} -
+ {showStyleColors ? setShowStyleColors(!showStyleColors)} /> : setShowStyleColors(!showStyleColors)} />}
+ {showStyleColors &&
+ {backgroundColors.map((color) => { + return
{ + setSelectedBackground(color) + }} + > +
+ {selectedEmoji !== '' && } +
+
+ })} +
}
} export default EmojiPickerInner diff --git a/web/app/components/base/file-uploader/file-uploader-in-attachment/index.tsx b/web/app/components/base/file-uploader/file-uploader-in-attachment/index.tsx index ab4e2aaa42..02bb3ad673 100644 --- a/web/app/components/base/file-uploader/file-uploader-in-attachment/index.tsx +++ b/web/app/components/base/file-uploader/file-uploader-in-attachment/index.tsx @@ -26,9 +26,11 @@ type Option = { icon: React.JSX.Element } type FileUploaderInAttachmentProps = { + isDisabled?: boolean fileConfig: FileUpload } const FileUploaderInAttachment = ({ + isDisabled, fileConfig, }: FileUploaderInAttachmentProps) => { const { t } = useTranslation() @@ -89,16 +91,18 @@ const FileUploaderInAttachment = ({ return (
-
- {options.map(renderOption)} -
+ {!isDisabled && ( +
+ {options.map(renderOption)} +
+ )}
{ files.map(file => ( handleRemoveFile(file.id)} onReUpload={() => handleReUploadFile(file.id)} @@ -114,18 +118,20 @@ type FileUploaderInAttachmentWrapperProps = { value?: FileEntity[] onChange: (files: FileEntity[]) => void fileConfig: FileUpload + isDisabled?: boolean } const FileUploaderInAttachmentWrapper = ({ value, onChange, fileConfig, + isDisabled, }: FileUploaderInAttachmentWrapperProps) => { return ( - + ) } diff --git a/web/app/components/base/file-uploader/utils.ts b/web/app/components/base/file-uploader/utils.ts index 9b5a449481..e870f9edab 100644 --- a/web/app/components/base/file-uploader/utils.ts +++ b/web/app/components/base/file-uploader/utils.ts @@ -154,7 +154,7 @@ export const getProcessedFilesFromResponse = (files: FileResponse[]) => { transferMethod: fileItem.transfer_method, supportFileType: fileItem.type, uploadedId: fileItem.upload_file_id || fileItem.related_id, - url: fileItem.url, + url: fileItem.url || fileItem.remote_url, } }) } diff --git a/web/app/components/base/icons/assets/public/llm/openai-teal.svg b/web/app/components/base/icons/assets/public/llm/openai-teal.svg new file mode 100644 index 0000000000..359cb532b6 --- /dev/null +++ b/web/app/components/base/icons/assets/public/llm/openai-teal.svg @@ -0,0 +1,4 @@ + + + + diff --git a/web/app/components/base/icons/assets/public/llm/openai-yellow.svg b/web/app/components/base/icons/assets/public/llm/openai-yellow.svg new file mode 100644 index 0000000000..015eb74adc --- /dev/null +++ b/web/app/components/base/icons/assets/public/llm/openai-yellow.svg @@ -0,0 +1,4 @@ + + + + diff --git a/web/app/components/base/input-number/index.spec.tsx b/web/app/components/base/input-number/index.spec.tsx index 8dfd1184b0..891cbd21e3 100644 --- a/web/app/components/base/input-number/index.spec.tsx +++ b/web/app/components/base/input-number/index.spec.tsx @@ -18,7 +18,7 @@ describe('InputNumber Component', () => { it('renders input with default values', () => { render() - const input = screen.getByRole('textbox') + const input = screen.getByRole('spinbutton') expect(input).toBeInTheDocument() }) @@ -56,7 +56,7 @@ describe('InputNumber Component', () => { it('handles direct input changes', () => { render() - const input = screen.getByRole('textbox') + const input = screen.getByRole('spinbutton') fireEvent.change(input, { target: { value: '42' } }) expect(defaultProps.onChange).toHaveBeenCalledWith(42) @@ -64,7 +64,7 @@ describe('InputNumber Component', () => { it('handles empty input', () => { render() - const input = screen.getByRole('textbox') + const input = screen.getByRole('spinbutton') fireEvent.change(input, { target: { value: '' } }) expect(defaultProps.onChange).toHaveBeenCalledWith(undefined) @@ -72,7 +72,7 @@ describe('InputNumber Component', () => { it('handles invalid input', () => { render() - const input = screen.getByRole('textbox') + const input = screen.getByRole('spinbutton') fireEvent.change(input, { target: { value: 'abc' } }) expect(defaultProps.onChange).not.toHaveBeenCalled() @@ -86,7 +86,7 @@ describe('InputNumber Component', () => { it('disables controls when disabled prop is true', () => { render() - const input = screen.getByRole('textbox') + const input = screen.getByRole('spinbutton') const incrementBtn = screen.getByRole('button', { name: /increment/i }) const decrementBtn = screen.getByRole('button', { name: /decrement/i }) diff --git a/web/app/components/base/input-number/index.tsx b/web/app/components/base/input-number/index.tsx index 98efc94462..5fd45944db 100644 --- a/web/app/components/base/input-number/index.tsx +++ b/web/app/components/base/input-number/index.tsx @@ -55,8 +55,8 @@ export const InputNumber: FC = (props) => { return
= (props) => { size={size} />
- + const CustomButton = useMemo(() => ( + <> + + + + ), [viewNewlyAddedChunk, t]) - const isQAModel = useMemo(() => { - return docForm === ChunkingMode.qa - }, [docForm]) - - const handleCancel = (actionType: 'esc' | 'add' = 'esc') => { + const handleCancel = useCallback((actionType: 'esc' | 'add' = 'esc') => { if (actionType === 'esc' || !addAnother) onCancel() - } + }, [onCancel, addAnother]) const { mutateAsync: addSegment } = useAddSegment() - const handleSave = async () => { + const handleSave = useCallback(async () => { const params: SegmentUpdater = { content: '' } - if (isQAModel) { + if (docForm === ChunkingMode.qa) { if (!question.trim()) { return notify({ type: 'error', @@ -129,21 +128,27 @@ const NewSegmentModal: FC = ({ setLoading(false) }, }) - } + }, [docForm, keywords, addSegment, datasetId, documentId, question, answer, notify, t, appSidebarExpand, CustomButton, handleCancel, onSave]) const wordCountText = useMemo(() => { - const count = isQAModel ? (question.length + answer.length) : question.length + const count = docForm === ChunkingMode.qa ? (question.length + answer.length) : question.length return `${formatNumber(count)} ${t('datasetDocuments.segment.characters', { count })}` - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [question.length, answer.length, isQAModel]) + }, [question.length, answer.length, docForm, t]) + + const isECOIndexing = indexingTechnique === IndexingType.ECONOMICAL return (
-
+
-
{ - t('datasetDocuments.segment.addChunk') - }
+
+ {t('datasetDocuments.segment.addChunk')} +
@@ -171,8 +176,8 @@ const NewSegmentModal: FC = ({
-
-
+
+
= ({ isEditMode={true} />
- {mode === 'custom' && { } } + if (name.length > 255) { + return { + errorMsg: t(`${i18nPrefix}.tooLong`, { max: 255 }), + } + } + return { errorMsg: '', } diff --git a/web/app/components/header/account-setting/model-provider-page/declarations.ts b/web/app/components/header/account-setting/model-provider-page/declarations.ts index 12dd9b3b5b..0ee7b26114 100644 --- a/web/app/components/header/account-setting/model-provider-page/declarations.ts +++ b/web/app/components/header/account-setting/model-provider-page/declarations.ts @@ -19,12 +19,14 @@ export enum FormTypeEnum { toolSelector = 'tool-selector', multiToolSelector = 'array[tools]', appSelector = 'app-selector', + dynamicSelect = 'dynamic-select', } export type FormOption = { label: TypeWithI18N value: string show_on: FormShowOnObject[] + icon?: string } export enum ModelTypeEnum { diff --git a/web/app/components/header/header-wrapper.tsx b/web/app/components/header/header-wrapper.tsx index dd0ec77b82..cd4d229b05 100644 --- a/web/app/components/header/header-wrapper.tsx +++ b/web/app/components/header/header-wrapper.tsx @@ -1,6 +1,8 @@ 'use client' +import React, { useState } from 'react' import { usePathname } from 'next/navigation' import s from './index.module.css' +import { useEventEmitterContextContext } from '@/context/event-emitter' import classNames from '@/utils/classnames' type HeaderWrapperProps = { @@ -12,10 +14,23 @@ const HeaderWrapper = ({ }: HeaderWrapperProps) => { const pathname = usePathname() const isBordered = ['/apps', '/datasets', '/datasets/create', '/tools'].includes(pathname) + // // Check if the current path is a workflow canvas & fullscreen + const inWorkflowCanvas = pathname.endsWith('/workflow') + const workflowCanvasMaximize = localStorage.getItem('workflow-canvas-maximize') === 'true' + const [hideHeader, setHideHeader] = useState(workflowCanvasMaximize) + const { eventEmitter } = useEventEmitterContextContext() + + eventEmitter?.useSubscription((v: any) => { + if (v?.type === 'workflow-canvas-maximize') + setHideHeader(v.payload) + }) + + if (hideHeader && inWorkflowCanvas) + return null return (
void onSelect: (app: App) => void + apps: App[] + isLoading: boolean + hasMore: boolean + onLoadMore: () => void + searchText: string + onSearchChange: (text: string) => void } const AppPicker: FC = ({ scope, - appList, disabled, trigger, placement = 'right-start', @@ -37,19 +41,81 @@ const AppPicker: FC = ({ isShow, onShowChange, onSelect, + apps, + isLoading, + hasMore, + onLoadMore, + searchText, + onSearchChange, }) => { - const [searchText, setSearchText] = useState('') - const filteredAppList = useMemo(() => { - return (appList || []) - .filter(app => app.name.toLowerCase().includes(searchText.toLowerCase())) - .filter(app => (app.mode !== 'advanced-chat' && app.mode !== 'workflow') || !!app.workflow) - .filter(app => scope === 'all' - || (scope === 'completion' && app.mode === 'completion') - || (scope === 'workflow' && app.mode === 'workflow') - || (scope === 'chat' && app.mode === 'advanced-chat') - || (scope === 'chat' && app.mode === 'agent-chat') - || (scope === 'chat' && app.mode === 'chat')) - }, [appList, scope, searchText]) + const { t } = useTranslation() + const observerTarget = useRef(null) + const observerRef = useRef(null) + const loadingRef = useRef(false) + + const handleIntersection = useCallback((entries: IntersectionObserverEntry[]) => { + const target = entries[0] + if (!target.isIntersecting || loadingRef.current || !hasMore || isLoading) return + + loadingRef.current = true + onLoadMore() + // Reset loading state + setTimeout(() => { + loadingRef.current = false + }, 500) + }, [hasMore, isLoading, onLoadMore]) + + useEffect(() => { + if (!isShow) { + if (observerRef.current) { + observerRef.current.disconnect() + observerRef.current = null + } + return + } + + let mutationObserver: MutationObserver | null = null + + const setupIntersectionObserver = () => { + if (!observerTarget.current) return + + // Create new observer + observerRef.current = new IntersectionObserver(handleIntersection, { + root: null, + rootMargin: '100px', + threshold: 0.1, + }) + + observerRef.current.observe(observerTarget.current) + } + + // Set up MutationObserver to watch DOM changes + mutationObserver = new MutationObserver((mutations) => { + if (observerTarget.current) { + setupIntersectionObserver() + mutationObserver?.disconnect() + } + }) + + // Watch body changes since Portal adds content to body + mutationObserver.observe(document.body, { + childList: true, + subtree: true, + }) + + // If element exists, set up IntersectionObserver directly + if (observerTarget.current) + setupIntersectionObserver() + + return () => { + if (observerRef.current) { + observerRef.current.disconnect() + observerRef.current = null + } + mutationObserver?.disconnect() + } + }, [isShow, handleIntersection]) + const getAppType = (app: App) => { switch (app.mode) { case 'advanced-chat': @@ -84,18 +150,18 @@ const AppPicker: FC = ({ -
+
setSearchText(e.target.value)} - onClear={() => setSearchText('')} + onChange={e => onSearchChange(e.target.value)} + onClear={() => onSearchChange('')} />
-
- {filteredAppList.map(app => ( +
+ {apps.map(app => (
= ({
{getAppType(app)}
))} +
+ {isLoading && ( +
+
{t('common.loading')}
+
+ )} +
diff --git a/web/app/components/plugins/plugin-detail-panel/app-selector/index.tsx b/web/app/components/plugins/plugin-detail-panel/app-selector/index.tsx index 39ca957953..9c16c66a70 100644 --- a/web/app/components/plugins/plugin-detail-panel/app-selector/index.tsx +++ b/web/app/components/plugins/plugin-detail-panel/app-selector/index.tsx @@ -1,6 +1,6 @@ 'use client' import type { FC } from 'react' -import React, { useMemo, useState } from 'react' +import React, { useCallback, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import { PortalToFollowElem, @@ -10,12 +10,36 @@ import { import AppTrigger from '@/app/components/plugins/plugin-detail-panel/app-selector/app-trigger' import AppPicker from '@/app/components/plugins/plugin-detail-panel/app-selector/app-picker' import AppInputsPanel from '@/app/components/plugins/plugin-detail-panel/app-selector/app-inputs-panel' -import { useAppFullList } from '@/service/use-apps' import type { App } from '@/types/app' import type { OffsetOptions, Placement, } from '@floating-ui/react' +import useSWRInfinite from 'swr/infinite' +import { fetchAppList } from '@/service/apps' +import type { AppListResponse } from '@/models/app' + +const PAGE_SIZE = 20 + +const getKey = ( + pageIndex: number, + previousPageData: AppListResponse, + searchText: string, +) => { + if (pageIndex === 0 || (previousPageData && previousPageData.has_more)) { + const params: any = { + url: 'apps', + params: { + page: pageIndex + 1, + limit: PAGE_SIZE, + name: searchText, + }, + } + + return params + } + return null +} type Props = { value?: { @@ -34,6 +58,7 @@ type Props = { }) => void supportAddCustomTool?: boolean } + const AppSelector: FC = ({ value, scope, @@ -44,18 +69,47 @@ const AppSelector: FC = ({ }) => { const { t } = useTranslation() const [isShow, onShowChange] = useState(false) + const [searchText, setSearchText] = useState('') + const [isLoadingMore, setIsLoadingMore] = useState(false) + + const { data, isLoading, setSize } = useSWRInfinite( + (pageIndex: number, previousPageData: AppListResponse) => getKey(pageIndex, previousPageData, searchText), + fetchAppList, + { + revalidateFirstPage: true, + shouldRetryOnError: false, + dedupingInterval: 500, + errorRetryCount: 3, + }, + ) + + const displayedApps = useMemo(() => { + if (!data) return [] + return data.flatMap(({ data: apps }) => apps) + }, [data]) + + const hasMore = data?.at(-1)?.has_more ?? true + + const handleLoadMore = useCallback(async () => { + if (isLoadingMore || !hasMore) return + + setIsLoadingMore(true) + try { + await setSize((size: number) => size + 1) + } + finally { + // Add a small delay to ensure state updates are complete + setTimeout(() => { + setIsLoadingMore(false) + }, 300) + } + }, [isLoadingMore, hasMore, setSize]) + const handleTriggerClick = () => { if (disabled) return onShowChange(true) } - const { data: appList } = useAppFullList() - const currentAppInfo = useMemo(() => { - if (!appList?.data || !value) - return undefined - return appList.data.find(app => app.id === value.app_id) - }, [appList?.data, value]) - const [isShowChooseApp, setIsShowChooseApp] = useState(false) const handleSelectApp = (app: App) => { const clearValue = app.id !== value?.app_id @@ -67,6 +121,7 @@ const AppSelector: FC = ({ onSelect(appValue) setIsShowChooseApp(false) } + const handleFormChange = (inputs: Record) => { const newFiles = inputs['#image#'] delete inputs['#image#'] @@ -88,6 +143,12 @@ const AppSelector: FC = ({ } }, [value]) + const currentAppInfo = useMemo(() => { + if (!displayedApps || !value) + return undefined + return displayedApps.find(app => app.id === value.app_id) + }, [displayedApps, value]) + return ( <> = ({ isShow={isShowChooseApp} onShowChange={setIsShowChooseApp} disabled={false} - appList={appList?.data || []} onSelect={handleSelectApp} scope={scope || 'all'} + apps={displayedApps} + isLoading={isLoading || isLoadingMore} + hasMore={hasMore} + onLoadMore={handleLoadMore} + searchText={searchText} + onSearchChange={setSearchText} />
{/* app inputs config panel */} @@ -140,4 +206,5 @@ const AppSelector: FC = ({ ) } + export default React.memo(AppSelector) diff --git a/web/app/components/plugins/plugin-detail-panel/multiple-tool-selector/index.tsx b/web/app/components/plugins/plugin-detail-panel/multiple-tool-selector/index.tsx index 7f5f22896a..fef79644cd 100644 --- a/web/app/components/plugins/plugin-detail-panel/multiple-tool-selector/index.tsx +++ b/web/app/components/plugins/plugin-detail-panel/multiple-tool-selector/index.tsx @@ -117,6 +117,7 @@ const MultipleToolSelector = ({ )} {!disabled && ( { + setCollapse(false) setOpen(!open) setPanelShowState(true) }}> @@ -126,23 +127,6 @@ const MultipleToolSelector = ({
{!collapse && ( <> -
- } - panelShowState={panelShowState} - onPanelShowStateChange={setPanelShowState} - isEdit={false} - /> {value.length === 0 && (
{t('plugin.detailPanel.toolSelector.empty')}
)} @@ -164,6 +148,23 @@ const MultipleToolSelector = ({ ))} )} +
+ } + panelShowState={panelShowState} + onPanelShowStateChange={setPanelShowState} + isEdit={false} + /> ) } diff --git a/web/app/components/plugins/plugin-detail-panel/tool-selector/index.tsx b/web/app/components/plugins/plugin-detail-panel/tool-selector/index.tsx index ca802414f3..350fe50933 100644 --- a/web/app/components/plugins/plugin-detail-panel/tool-selector/index.tsx +++ b/web/app/components/plugins/plugin-detail-panel/tool-selector/index.tsx @@ -275,7 +275,7 @@ const ToolSelector: FC = ({ /> )} - +
{!isShowSettingAuth && ( <> diff --git a/web/app/components/workflow-app/hooks/use-fetch-workflow-inspect-vars.ts b/web/app/components/workflow-app/hooks/use-fetch-workflow-inspect-vars.ts new file mode 100644 index 0000000000..9d3ff84929 --- /dev/null +++ b/web/app/components/workflow-app/hooks/use-fetch-workflow-inspect-vars.ts @@ -0,0 +1,68 @@ +import type { NodeWithVar, VarInInspect } from '@/types/workflow' +import { useWorkflowStore } from '../../workflow/store' +import { useStoreApi } from 'reactflow' +import type { Node } from '@/app/components/workflow/types' +import { fetchAllInspectVars } from '@/service/workflow' +import { useInvalidateConversationVarValues, useInvalidateSysVarValues } from '@/service/use-workflow' +import { useNodesInteractionsWithoutSync } from '../../workflow/hooks/use-nodes-interactions-without-sync' +const useSetWorkflowVarsWithValue = () => { + const workflowStore = useWorkflowStore() + const { setNodesWithInspectVars, appId } = workflowStore.getState() + const store = useStoreApi() + const invalidateConversationVarValues = useInvalidateConversationVarValues(appId) + const invalidateSysVarValues = useInvalidateSysVarValues(appId) + const { handleCancelAllNodeSuccessStatus } = useNodesInteractionsWithoutSync() + + const setInspectVarsToStore = (inspectVars: VarInInspect[]) => { + const { getNodes } = store.getState() + const nodeArr = getNodes() + const nodesKeyValue: Record = {} + nodeArr.forEach((node) => { + nodesKeyValue[node.id] = node + }) + + const withValueNodeIds: Record = {} + inspectVars.forEach((varItem) => { + const nodeId = varItem.selector[0] + + const node = nodesKeyValue[nodeId] + if (!node) + return + withValueNodeIds[nodeId] = true + }) + const withValueNodes = Object.keys(withValueNodeIds).map((nodeId) => { + return nodesKeyValue[nodeId] + }) + + const res: NodeWithVar[] = withValueNodes.map((node) => { + const nodeId = node.id + const varsUnderTheNode = inspectVars.filter((varItem) => { + return varItem.selector[0] === nodeId + }) + const nodeWithVar = { + nodeId, + nodePayload: node.data, + nodeType: node.data.type, + title: node.data.title, + vars: varsUnderTheNode, + isSingRunRunning: false, + isValueFetched: false, + } + return nodeWithVar + }) + setNodesWithInspectVars(res) + } + + const fetchInspectVars = async () => { + invalidateConversationVarValues() + invalidateSysVarValues() + const data = await fetchAllInspectVars(appId) + setInspectVarsToStore(data) + handleCancelAllNodeSuccessStatus() // to make sure clear node output show the unset status + } + return { + fetchInspectVars, + } +} + +export default useSetWorkflowVarsWithValue diff --git a/web/app/components/workflow-app/hooks/use-workflow-init.ts b/web/app/components/workflow-app/hooks/use-workflow-init.ts index e1c4c25a4e..6d16dc5c44 100644 --- a/web/app/components/workflow-app/hooks/use-workflow-init.ts +++ b/web/app/components/workflow-app/hooks/use-workflow-init.ts @@ -17,7 +17,6 @@ import { } from '@/service/workflow' import type { FetchWorkflowDraftResponse } from '@/types/workflow' import { useWorkflowConfig } from '@/service/use-workflow' - export const useWorkflowInit = () => { const workflowStore = useWorkflowStore() const { diff --git a/web/app/components/workflow-app/hooks/use-workflow-run.ts b/web/app/components/workflow-app/hooks/use-workflow-run.ts index 1e484d0760..99b88238f1 100644 --- a/web/app/components/workflow-app/hooks/use-workflow-run.ts +++ b/web/app/components/workflow-app/hooks/use-workflow-run.ts @@ -19,6 +19,8 @@ import { AudioPlayerManager } from '@/app/components/base/audio-btn/audio.player import type { VersionHistory } from '@/types/workflow' import { noop } from 'lodash-es' import { useNodesSyncDraft } from './use-nodes-sync-draft' +import { useInvalidAllLastRun } from '@/service/use-workflow' +import useSetWorkflowVarsWithValue from './use-fetch-workflow-inspect-vars' export const useWorkflowRun = () => { const store = useStoreApi() @@ -28,6 +30,9 @@ export const useWorkflowRun = () => { const { doSyncWorkflowDraft } = useNodesSyncDraft() const { handleUpdateWorkflowCanvas } = useWorkflowUpdate() const pathname = usePathname() + const appId = useAppStore.getState().appDetail?.id + const invalidAllLastRun = useInvalidAllLastRun(appId as string) + const { fetchInspectVars } = useSetWorkflowVarsWithValue() const { handleWorkflowStarted, @@ -140,11 +145,13 @@ export const useWorkflowRun = () => { clientHeight, } = workflowContainer! + const isInWorkflowDebug = appDetail?.mode === 'workflow' + let url = '' if (appDetail?.mode === 'advanced-chat') url = `/apps/${appDetail.id}/advanced-chat/workflows/draft/run` - if (appDetail?.mode === 'workflow') + if (isInWorkflowDebug) url = `/apps/${appDetail.id}/workflows/draft/run` const { @@ -189,6 +196,10 @@ export const useWorkflowRun = () => { if (onWorkflowFinished) onWorkflowFinished(params) + if (isInWorkflowDebug) { + fetchInspectVars() + invalidAllLastRun() + } }, onError: (params) => { handleWorkflowFailed() @@ -292,26 +303,7 @@ export const useWorkflowRun = () => { ...restCallback, }, ) - }, [ - store, - workflowStore, - doSyncWorkflowDraft, - handleWorkflowStarted, - handleWorkflowFinished, - handleWorkflowFailed, - handleWorkflowNodeStarted, - handleWorkflowNodeFinished, - handleWorkflowNodeIterationStarted, - handleWorkflowNodeIterationNext, - handleWorkflowNodeIterationFinished, - handleWorkflowNodeLoopStarted, - handleWorkflowNodeLoopNext, - handleWorkflowNodeLoopFinished, - handleWorkflowNodeRetry, - handleWorkflowTextChunk, - handleWorkflowTextReplace, - handleWorkflowAgentLog, - pathname], + }, [store, doSyncWorkflowDraft, workflowStore, pathname, handleWorkflowStarted, handleWorkflowFinished, fetchInspectVars, invalidAllLastRun, handleWorkflowFailed, handleWorkflowNodeStarted, handleWorkflowNodeFinished, handleWorkflowNodeIterationStarted, handleWorkflowNodeIterationNext, handleWorkflowNodeIterationFinished, handleWorkflowNodeLoopStarted, handleWorkflowNodeLoopNext, handleWorkflowNodeLoopFinished, handleWorkflowNodeRetry, handleWorkflowAgentLog, handleWorkflowTextChunk, handleWorkflowTextReplace], ) const handleStopRun = useCallback((taskId: string) => { diff --git a/web/app/components/workflow-app/hooks/use-workflow-template.ts b/web/app/components/workflow-app/hooks/use-workflow-template.ts index 9f47b981dc..2bab08f205 100644 --- a/web/app/components/workflow-app/hooks/use-workflow-template.ts +++ b/web/app/components/workflow-app/hooks/use-workflow-template.ts @@ -22,7 +22,7 @@ export const useWorkflowTemplate = () => { ...nodesInitialData.llm, memory: { window: { enabled: false, size: 10 }, - query_prompt_template: '{{#sys.query#}}', + query_prompt_template: '{{#sys.query#}}\n\n{{#sys.files#}}', }, selected: true, }, diff --git a/web/app/components/workflow/block-selector/types.ts b/web/app/components/workflow/block-selector/types.ts index 0abf7b9031..f1bdbbfbd9 100644 --- a/web/app/components/workflow/block-selector/types.ts +++ b/web/app/components/workflow/block-selector/types.ts @@ -36,7 +36,7 @@ export type ToolValue = { provider_name: string tool_name: string tool_label: string - tool_description: string + tool_description?: string settings?: Record parameters?: Record enabled?: boolean diff --git a/web/app/components/workflow/constants.ts b/web/app/components/workflow/constants.ts index b7432f1203..304295cfbf 100644 --- a/web/app/components/workflow/constants.ts +++ b/web/app/components/workflow/constants.ts @@ -31,6 +31,7 @@ type NodesExtraData = { getAvailablePrevNodes: (isChatMode: boolean) => BlockEnum[] getAvailableNextNodes: (isChatMode: boolean) => BlockEnum[] checkValid: any + defaultRunInputData?: Record } export const NODES_EXTRA_DATA: Record = { [BlockEnum.Start]: { @@ -68,6 +69,7 @@ export const NODES_EXTRA_DATA: Record = { getAvailablePrevNodes: LLMDefault.getAvailablePrevNodes, getAvailableNextNodes: LLMDefault.getAvailableNextNodes, checkValid: LLMDefault.checkValid, + defaultRunInputData: LLMDefault.defaultRunInputData, }, [BlockEnum.KnowledgeRetrieval]: { author: 'Dify', diff --git a/web/app/components/workflow/header/header-in-normal.tsx b/web/app/components/workflow/header/header-in-normal.tsx index ec016b1b65..5768e6bc06 100644 --- a/web/app/components/workflow/header/header-in-normal.tsx +++ b/web/app/components/workflow/header/header-in-normal.tsx @@ -33,6 +33,8 @@ const HeaderInNormal = ({ const setShowWorkflowVersionHistoryPanel = useStore(s => s.setShowWorkflowVersionHistoryPanel) const setShowEnvPanel = useStore(s => s.setShowEnvPanel) const setShowDebugAndPreviewPanel = useStore(s => s.setShowDebugAndPreviewPanel) + const setShowVariableInspectPanel = useStore(s => s.setShowVariableInspectPanel) + const setShowChatVariablePanel = useStore(s => s.setShowChatVariablePanel) const nodes = useNodes() const selectedNode = nodes.find(node => node.data.selected) const { handleBackupDraft } = useWorkflowRun() @@ -46,8 +48,10 @@ const HeaderInNormal = ({ setShowWorkflowVersionHistoryPanel(true) setShowEnvPanel(false) setShowDebugAndPreviewPanel(false) + setShowVariableInspectPanel(false) + setShowChatVariablePanel(false) }, [handleBackupDraft, workflowStore, handleNodeSelect, selectedNode, - setShowWorkflowVersionHistoryPanel, setShowEnvPanel, setShowDebugAndPreviewPanel]) + setShowWorkflowVersionHistoryPanel, setShowEnvPanel, setShowDebugAndPreviewPanel, setShowVariableInspectPanel]) return ( <> diff --git a/web/app/components/workflow/header/header-in-restoring.tsx b/web/app/components/workflow/header/header-in-restoring.tsx index 4d1954587d..afa4e62099 100644 --- a/web/app/components/workflow/header/header-in-restoring.tsx +++ b/web/app/components/workflow/header/header-in-restoring.tsx @@ -17,6 +17,8 @@ import { import Toast from '../../base/toast' import RestoringTitle from './restoring-title' import Button from '@/app/components/base/button' +import { useStore as useAppStore } from '@/app/components/app/store' +import { useInvalidAllLastRun } from '@/service/use-workflow' export type HeaderInRestoringProps = { onRestoreSettled?: () => void @@ -26,6 +28,12 @@ const HeaderInRestoring = ({ }: HeaderInRestoringProps) => { const { t } = useTranslation() const workflowStore = useWorkflowStore() + const appDetail = useAppStore.getState().appDetail + + const invalidAllLastRun = useInvalidAllLastRun(appDetail!.id) + const { + deleteAllInspectVars, + } = workflowStore.getState() const currentVersion = useStore(s => s.currentVersion) const setShowWorkflowVersionHistoryPanel = useStore(s => s.setShowWorkflowVersionHistoryPanel) @@ -61,7 +69,9 @@ const HeaderInRestoring = ({ onRestoreSettled?.() }, }) - }, [handleSyncWorkflowDraft, workflowStore, setShowWorkflowVersionHistoryPanel, onRestoreSettled, t]) + deleteAllInspectVars() + invalidAllLastRun() + }, [setShowWorkflowVersionHistoryPanel, workflowStore, handleSyncWorkflowDraft, deleteAllInspectVars, invalidAllLastRun, t, onRestoreSettled]) return ( <> diff --git a/web/app/components/workflow/header/index.tsx b/web/app/components/workflow/header/index.tsx index e5391afb09..7713753478 100644 --- a/web/app/components/workflow/header/index.tsx +++ b/web/app/components/workflow/header/index.tsx @@ -1,3 +1,4 @@ +import { usePathname } from 'next/navigation' import { useWorkflowMode, } from '../hooks' @@ -6,7 +7,7 @@ import HeaderInNormal from './header-in-normal' import HeaderInHistory from './header-in-view-history' import type { HeaderInRestoringProps } from './header-in-restoring' import HeaderInRestoring from './header-in-restoring' - +import { useStore } from '../store' export type HeaderProps = { normal?: HeaderInNormalProps restoring?: HeaderInRestoringProps @@ -15,16 +16,20 @@ const Header = ({ normal: normalProps, restoring: restoringProps, }: HeaderProps) => { + const pathname = usePathname() + const inWorkflowCanvas = pathname.endsWith('/workflow') const { normal, restoring, viewHistory, } = useWorkflowMode() + const maximizeCanvas = useStore(s => s.maximizeCanvas) return (
+ {inWorkflowCanvas && maximizeCanvas &&
} { normal && ( { const { t } = useTranslation() @@ -27,6 +29,16 @@ const RunMode = memo(() => { const workflowRunningData = useStore(s => s.workflowRunningData) const isRunning = workflowRunningData?.result.status === WorkflowRunningStatus.Running + const handleStop = () => { + handleStopRun(workflowRunningData?.task_id || '') + } + + const { eventEmitter } = useEventEmitterContextContext() + eventEmitter?.useSubscription((v: any) => { + if (v.type === EVENT_WORKFLOW_STOP) + handleStop() + }) + return ( <>
{ isRunning && (
handleStopRun(workflowRunningData?.task_id || '')} + onClick={handleStop} >
diff --git a/web/app/components/workflow/hooks/use-inspect-vars-crud.ts b/web/app/components/workflow/hooks/use-inspect-vars-crud.ts new file mode 100644 index 0000000000..59cc98a17b --- /dev/null +++ b/web/app/components/workflow/hooks/use-inspect-vars-crud.ts @@ -0,0 +1,241 @@ +import { fetchNodeInspectVars } from '@/service/workflow' +import { useStore, useWorkflowStore } from '../store' +import type { ValueSelector } from '../types' +import type { VarInInspect } from '@/types/workflow' +import { VarInInspectType } from '@/types/workflow' +import { + useConversationVarValues, + useDeleteAllInspectorVars, + useDeleteInspectVar, + useDeleteNodeInspectorVars, + useEditInspectorVar, + useInvalidateConversationVarValues, + useInvalidateSysVarValues, + useLastRun, + useResetConversationVar, + useResetToLastRunValue, + useSysVarValues, +} from '@/service/use-workflow' +import { useCallback, useEffect, useState } from 'react' +import { isConversationVar, isENV, isSystemVar } from '../nodes/_base/components/variable/utils' +import produce from 'immer' +import type { Node } from '@/app/components/workflow/types' +import { useNodesInteractionsWithoutSync } from './use-nodes-interactions-without-sync' +import { useEdgesInteractionsWithoutSync } from './use-edges-interactions-without-sync' + +const useInspectVarsCrud = () => { + const workflowStore = useWorkflowStore() + const nodesWithInspectVars = useStore(s => s.nodesWithInspectVars) + const { + appId, + setNodeInspectVars, + setInspectVarValue, + renameInspectVarName: renameInspectVarNameInStore, + deleteAllInspectVars: deleteAllInspectVarsInStore, + deleteNodeInspectVars: deleteNodeInspectVarsInStore, + deleteInspectVar: deleteInspectVarInStore, + setNodesWithInspectVars, + resetToLastRunVar: resetToLastRunVarInStore, + } = workflowStore.getState() + + const { data: conversationVars } = useConversationVarValues(appId) + const invalidateConversationVarValues = useInvalidateConversationVarValues(appId) + const { mutateAsync: doResetConversationVar } = useResetConversationVar(appId) + const { mutateAsync: doResetToLastRunValue } = useResetToLastRunValue(appId) + const { data: systemVars } = useSysVarValues(appId) + const invalidateSysVarValues = useInvalidateSysVarValues(appId) + + const { mutateAsync: doDeleteAllInspectorVars } = useDeleteAllInspectorVars(appId) + const { mutate: doDeleteNodeInspectorVars } = useDeleteNodeInspectorVars(appId) + const { mutate: doDeleteInspectVar } = useDeleteInspectVar(appId) + + const { mutateAsync: doEditInspectorVar } = useEditInspectorVar(appId) + const { handleCancelNodeSuccessStatus } = useNodesInteractionsWithoutSync() + const { handleEdgeCancelRunningStatus } = useEdgesInteractionsWithoutSync() + const getNodeInspectVars = useCallback((nodeId: string) => { + const node = nodesWithInspectVars.find(node => node.nodeId === nodeId) + return node + }, [nodesWithInspectVars]) + + const getVarId = useCallback((nodeId: string, varName: string) => { + const node = getNodeInspectVars(nodeId) + if (!node) + return undefined + const varId = node.vars.find((varItem) => { + return varItem.selector[1] === varName + })?.id + return varId + }, [getNodeInspectVars]) + + const getInspectVar = useCallback((nodeId: string, name: string): VarInInspect | undefined => { + const node = getNodeInspectVars(nodeId) + if (!node) + return undefined + + const variable = node.vars.find((varItem) => { + return varItem.name === name + }) + return variable + }, [getNodeInspectVars]) + + const hasSetInspectVar = useCallback((nodeId: string, name: string, sysVars: VarInInspect[], conversationVars: VarInInspect[]) => { + const isEnv = isENV([nodeId]) + if (isEnv) // always have value + return true + const isSys = isSystemVar([nodeId]) + if (isSys) + return sysVars.some(varItem => varItem.selector?.[1]?.replace('sys.', '') === name) + const isChatVar = isConversationVar([nodeId]) + if (isChatVar) + return conversationVars.some(varItem => varItem.selector?.[1] === name) + return getInspectVar(nodeId, name) !== undefined + }, [getInspectVar]) + + const hasNodeInspectVars = useCallback((nodeId: string) => { + return !!getNodeInspectVars(nodeId) + }, [getNodeInspectVars]) + + const fetchInspectVarValue = async (selector: ValueSelector) => { + const nodeId = selector[0] + const isSystemVar = nodeId === 'sys' + const isConversationVar = nodeId === 'conversation' + if (isSystemVar) { + invalidateSysVarValues() + return + } + if (isConversationVar) { + invalidateConversationVarValues() + return + } + const vars = await fetchNodeInspectVars(appId, nodeId) + setNodeInspectVars(nodeId, vars) + } + + // after last run would call this + const appendNodeInspectVars = (nodeId: string, payload: VarInInspect[], allNodes: Node[]) => { + const nodes = produce(nodesWithInspectVars, (draft) => { + const nodeInfo = allNodes.find(node => node.id === nodeId) + if (nodeInfo) { + const index = draft.findIndex(node => node.nodeId === nodeId) + if (index === -1) { + draft.push({ + nodeId, + nodeType: nodeInfo.data.type, + title: nodeInfo.data.title, + vars: payload, + }) + } + else { + draft[index].vars = payload + } + } + }) + setNodesWithInspectVars(nodes) + handleCancelNodeSuccessStatus(nodeId) + } + + const hasNodeInspectVar = (nodeId: string, varId: string) => { + const targetNode = nodesWithInspectVars.find(item => item.nodeId === nodeId) + if(!targetNode || !targetNode.vars) + return false + return targetNode.vars.some(item => item.id === varId) + } + + const deleteInspectVar = async (nodeId: string, varId: string) => { + if(hasNodeInspectVar(nodeId, varId)) { + await doDeleteInspectVar(varId) + deleteInspectVarInStore(nodeId, varId) + } + } + + const resetConversationVar = async (varId: string) => { + await doResetConversationVar(varId) + invalidateConversationVarValues() + } + + const deleteNodeInspectorVars = async (nodeId: string) => { + if (hasNodeInspectVars(nodeId)) { + await doDeleteNodeInspectorVars(nodeId) + deleteNodeInspectVarsInStore(nodeId) + } + } + + const deleteAllInspectorVars = async () => { + await doDeleteAllInspectorVars() + await invalidateConversationVarValues() + await invalidateSysVarValues() + deleteAllInspectVarsInStore() + handleEdgeCancelRunningStatus() + } + + const editInspectVarValue = useCallback(async (nodeId: string, varId: string, value: any) => { + await doEditInspectorVar({ + varId, + value, + }) + setInspectVarValue(nodeId, varId, value) + if (nodeId === VarInInspectType.conversation) + invalidateConversationVarValues() + if (nodeId === VarInInspectType.system) + invalidateSysVarValues() + }, [doEditInspectorVar, invalidateConversationVarValues, invalidateSysVarValues, setInspectVarValue]) + + const [currNodeId, setCurrNodeId] = useState(null) + const [currEditVarId, setCurrEditVarId] = useState(null) + const { data } = useLastRun(appId, currNodeId || '', !!currNodeId) + useEffect(() => { + if (data && currNodeId && currEditVarId) { + const inspectVar = getNodeInspectVars(currNodeId)?.vars?.find(item => item.id === currEditVarId) + resetToLastRunVarInStore(currNodeId, currEditVarId, data.outputs?.[inspectVar?.selector?.[1] || '']) + } + }, [data, currNodeId, currEditVarId, getNodeInspectVars, editInspectVarValue, resetToLastRunVarInStore]) + + const renameInspectVarName = async (nodeId: string, oldName: string, newName: string) => { + const varId = getVarId(nodeId, oldName) + if (!varId) + return + + const newSelector = [nodeId, newName] + await doEditInspectorVar({ + varId, + name: newName, + }) + renameInspectVarNameInStore(nodeId, varId, newSelector) + } + + const isInspectVarEdited = useCallback((nodeId: string, name: string) => { + const inspectVar = getInspectVar(nodeId, name) + if (!inspectVar) + return false + + return inspectVar.edited + }, [getInspectVar]) + + const resetToLastRunVar = async (nodeId: string, varId: string) => { + await doResetToLastRunValue(varId) + setCurrNodeId(nodeId) + setCurrEditVarId(varId) + } + + return { + conversationVars: conversationVars || [], + systemVars: systemVars || [], + nodesWithInspectVars, + hasNodeInspectVars, + hasSetInspectVar, + fetchInspectVarValue, + editInspectVarValue, + renameInspectVarName, + appendNodeInspectVars, + deleteInspectVar, + deleteNodeInspectorVars, + deleteAllInspectorVars, + isInspectVarEdited, + resetToLastRunVar, + invalidateSysVarValues, + resetConversationVar, + invalidateConversationVarValues, + } +} + +export default useInspectVarsCrud diff --git a/web/app/components/workflow/hooks/use-nodes-interactions-without-sync.ts b/web/app/components/workflow/hooks/use-nodes-interactions-without-sync.ts index 7fbf0ce868..e01609cdb6 100644 --- a/web/app/components/workflow/hooks/use-nodes-interactions-without-sync.ts +++ b/web/app/components/workflow/hooks/use-nodes-interactions-without-sync.ts @@ -1,6 +1,7 @@ import { useCallback } from 'react' import produce from 'immer' import { useStoreApi } from 'reactflow' +import { NodeRunningStatus } from '../types' export const useNodesInteractionsWithoutSync = () => { const store = useStoreApi() @@ -21,7 +22,41 @@ export const useNodesInteractionsWithoutSync = () => { setNodes(newNodes) }, [store]) + const handleCancelAllNodeSuccessStatus = useCallback(() => { + const { + getNodes, + setNodes, + } = store.getState() + + const nodes = getNodes() + const newNodes = produce(nodes, (draft) => { + draft.forEach((node) => { + if(node.data._runningStatus === NodeRunningStatus.Succeeded) + node.data._runningStatus = undefined + }) + }) + setNodes(newNodes) + }, [store]) + + const handleCancelNodeSuccessStatus = useCallback((nodeId: string) => { + const { + getNodes, + setNodes, + } = store.getState() + + const newNodes = produce(getNodes(), (draft) => { + const node = draft.find(n => n.id === nodeId) + if (node && node.data._runningStatus === NodeRunningStatus.Succeeded) { + node.data._runningStatus = undefined + node.data._waitingRun = false + } + }) + setNodes(newNodes) + }, [store]) + return { handleNodeCancelRunningStatus, + handleCancelAllNodeSuccessStatus, + handleCancelNodeSuccessStatus, } } diff --git a/web/app/components/workflow/hooks/use-nodes-interactions.ts b/web/app/components/workflow/hooks/use-nodes-interactions.ts index 94b10c9929..b598951adb 100644 --- a/web/app/components/workflow/hooks/use-nodes-interactions.ts +++ b/web/app/components/workflow/hooks/use-nodes-interactions.ts @@ -60,6 +60,7 @@ import { useWorkflowReadOnly, } from './use-workflow' import { WorkflowHistoryEvent, useWorkflowHistory } from './use-workflow-history' +import useInspectVarsCrud from './use-inspect-vars-crud' export const useNodesInteractions = () => { const { t } = useTranslation() @@ -288,7 +289,9 @@ export const useNodesInteractions = () => { setEdges(newEdges) }, [store, workflowStore, getNodesReadOnly]) - const handleNodeSelect = useCallback((nodeId: string, cancelSelection?: boolean) => { + const handleNodeSelect = useCallback((nodeId: string, cancelSelection?: boolean, initShowLastRunTab?: boolean) => { + if(initShowLastRunTab) + workflowStore.setState({ initShowLastRunTab: true }) const { getNodes, setNodes, @@ -530,6 +533,8 @@ export const useNodesInteractions = () => { setEnteringNodePayload(undefined) }, [store, handleNodeConnect, getNodesReadOnly, workflowStore, reactflow]) + const { deleteNodeInspectorVars } = useInspectVarsCrud() + const handleNodeDelete = useCallback((nodeId: string) => { if (getNodesReadOnly()) return @@ -551,6 +556,7 @@ export const useNodesInteractions = () => { if (currentNode.data.type === BlockEnum.Start) return + deleteNodeInspectorVars(nodeId) if (currentNode.data.type === BlockEnum.Iteration) { const iterationChildren = nodes.filter(node => node.parentId === currentNode.id) @@ -655,7 +661,7 @@ export const useNodesInteractions = () => { else saveStateToHistory(WorkflowHistoryEvent.NodeDelete) - }, [getNodesReadOnly, store, handleSyncWorkflowDraft, saveStateToHistory, workflowStore, t]) + }, [getNodesReadOnly, store, deleteNodeInspectorVars, handleSyncWorkflowDraft, saveStateToHistory, workflowStore, t]) const handleNodeAdd = useCallback(( { diff --git a/web/app/components/workflow/hooks/use-shortcuts.ts b/web/app/components/workflow/hooks/use-shortcuts.ts index 8b1003e89c..118ec94058 100644 --- a/web/app/components/workflow/hooks/use-shortcuts.ts +++ b/web/app/components/workflow/hooks/use-shortcuts.ts @@ -11,6 +11,7 @@ import { useEdgesInteractions, useNodesInteractions, useNodesSyncDraft, + useWorkflowCanvasMaximize, useWorkflowMoveMode, useWorkflowOrganize, useWorkflowStartRun, @@ -35,6 +36,7 @@ export const useShortcuts = (): void => { handleModePointer, } = useWorkflowMoveMode() const { handleLayout } = useWorkflowOrganize() + const { handleToggleMaximizeCanvas } = useWorkflowCanvasMaximize() const { zoomTo, @@ -145,6 +147,16 @@ export const useShortcuts = (): void => { } }, { exactMatch: true, useCapture: true }) + useKeyPress('f', (e) => { + if (shouldHandleShortcut(e)) { + e.preventDefault() + handleToggleMaximizeCanvas() + } + }, { + exactMatch: true, + useCapture: true, + }) + useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.1`, (e) => { if (shouldHandleShortcut(e)) { e.preventDefault() diff --git a/web/app/components/workflow/hooks/use-workflow-interactions.ts b/web/app/components/workflow/hooks/use-workflow-interactions.ts index 636d3b94f9..d8653a5942 100644 --- a/web/app/components/workflow/hooks/use-workflow-interactions.ts +++ b/web/app/components/workflow/hooks/use-workflow-interactions.ts @@ -401,3 +401,29 @@ export const useDSL = () => { handleExportDSL, } } + +export const useWorkflowCanvasMaximize = () => { + const { eventEmitter } = useEventEmitterContextContext() + + const maximizeCanvas = useStore(s => s.maximizeCanvas) + const setMaximizeCanvas = useStore(s => s.setMaximizeCanvas) + const { + getNodesReadOnly, + } = useNodesReadOnly() + + const handleToggleMaximizeCanvas = useCallback(() => { + if (getNodesReadOnly()) + return + + setMaximizeCanvas(!maximizeCanvas) + localStorage.setItem('workflow-canvas-maximize', String(!maximizeCanvas)) + eventEmitter?.emit({ + type: 'workflow-canvas-maximize', + payload: !maximizeCanvas, + } as any) + }, [eventEmitter, getNodesReadOnly, maximizeCanvas, setMaximizeCanvas]) + + return { + handleToggleMaximizeCanvas, + } +} diff --git a/web/app/components/workflow/hooks/use-workflow.ts b/web/app/components/workflow/hooks/use-workflow.ts index 99dce4dc15..1b98178152 100644 --- a/web/app/components/workflow/hooks/use-workflow.ts +++ b/web/app/components/workflow/hooks/use-workflow.ts @@ -59,10 +59,6 @@ export const useWorkflow = () => { const store = useStoreApi() const workflowStore = useWorkflowStore() const nodesExtraData = useNodesExtraData() - const setPanelWidth = useCallback((width: number) => { - localStorage.setItem('workflow-node-panel-width', `${width}`) - workflowStore.setState({ panelWidth: width }) - }, [workflowStore]) const getTreeLeafNodes = useCallback((nodeId: string) => { const { @@ -399,7 +395,6 @@ export const useWorkflow = () => { }, [store]) return { - setPanelWidth, getTreeLeafNodes, getBeforeNodesInSameBranch, getBeforeNodesInSameBranchIncludeParent, @@ -497,6 +492,8 @@ export const useToolIcon = (data: Node['data']) => { const customTools = useStore(s => s.customTools) const workflowTools = useStore(s => s.workflowTools) const toolIcon = useMemo(() => { + if(!data) + return '' if (data.type === BlockEnum.Tool) { let targetTools = buildInTools if (data.provider_type === CollectionType.builtIn) diff --git a/web/app/components/workflow/index.tsx b/web/app/components/workflow/index.tsx index 549117faf7..429d07853d 100644 --- a/web/app/components/workflow/index.tsx +++ b/web/app/components/workflow/index.tsx @@ -5,6 +5,7 @@ import { memo, useCallback, useEffect, + useMemo, useRef, } from 'react' import { setAutoFreeze } from 'immer' @@ -56,6 +57,7 @@ import { CUSTOM_LOOP_START_NODE } from './nodes/loop-start/constants' import CustomSimpleNode from './simple-node' import { CUSTOM_SIMPLE_NODE } from './simple-node/constants' import Operator from './operator' +import Control from './operator/control' import CustomEdge from './custom-edge' import CustomConnectionLine from './custom-connection-line' import HelpLine from './help-line' @@ -80,6 +82,7 @@ import Confirm from '@/app/components/base/confirm' import DatasetsDetailProvider from './datasets-detail-store/provider' import { HooksStoreContextProvider } from './hooks-store' import type { Shape as HooksStoreShape } from './hooks-store' +import useSetWorkflowVarsWithValue from '../workflow-app/hooks/use-fetch-workflow-inspect-vars' const nodeTypes = { [CUSTOM_NODE]: CustomNode, @@ -114,6 +117,32 @@ export const Workflow: FC = memo(({ const controlMode = useStore(s => s.controlMode) const nodeAnimation = useStore(s => s.nodeAnimation) const showConfirm = useStore(s => s.showConfirm) + const workflowCanvasHeight = useStore(s => s.workflowCanvasHeight) + const bottomPanelHeight = useStore(s => s.bottomPanelHeight) + const setWorkflowCanvasWidth = useStore(s => s.setWorkflowCanvasWidth) + const setWorkflowCanvasHeight = useStore(s => s.setWorkflowCanvasHeight) + const controlHeight = useMemo(() => { + if (!workflowCanvasHeight) + return '100%' + return workflowCanvasHeight - bottomPanelHeight + }, [workflowCanvasHeight, bottomPanelHeight]) + + // update workflow Canvas width and height + useEffect(() => { + if (workflowContainerRef.current) { + const resizeContainerObserver = new ResizeObserver((entries) => { + for (const entry of entries) { + const { inlineSize, blockSize } = entry.borderBoxSize[0] + setWorkflowCanvasWidth(inlineSize) + setWorkflowCanvasHeight(blockSize) + } + }) + resizeContainerObserver.observe(workflowContainerRef.current) + return () => { + resizeContainerObserver.disconnect() + } + } + }, [setWorkflowCanvasHeight, setWorkflowCanvasWidth]) const { setShowConfirm, @@ -245,6 +274,11 @@ export const Workflow: FC = memo(({ }) useShortcuts() + const { fetchInspectVars } = useSetWorkflowVarsWithValue() + useEffect(() => { + fetchInspectVars() + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []) const store = useStoreApi() if (process.env.NODE_ENV === 'development') { @@ -267,6 +301,12 @@ export const Workflow: FC = memo(({ > +
+ +
diff --git a/web/app/components/workflow/nodes/_base/components/before-run-form/form-item.tsx b/web/app/components/workflow/nodes/_base/components/before-run-form/form-item.tsx index 9f415adda9..269f5e0a96 100644 --- a/web/app/components/workflow/nodes/_base/components/before-run-form/form-item.tsx +++ b/web/app/components/workflow/nodes/_base/components/before-run-form/form-item.tsx @@ -95,6 +95,7 @@ const FormItem: FC = ({ const isArrayLikeType = [InputVarType.contexts, InputVarType.iterator].includes(type) const isContext = type === InputVarType.contexts const isIterator = type === InputVarType.iterator + const isIteratorItemFile = isIterator && payload.isFileItem const singleFileValue = useMemo(() => { if (payload.variable === '#files#') return value?.[0] || [] @@ -202,12 +203,12 @@ const FormItem: FC = ({ }} /> )} - {(type === InputVarType.multiFiles) && ( + {(type === InputVarType.multiFiles || isIteratorItemFile) && ( onChange(files)} fileConfig={{ - allowed_file_types: inStepRun + allowed_file_types: (inStepRun || isIteratorItemFile) ? [ SupportUploadFileTypes.image, SupportUploadFileTypes.document, @@ -215,7 +216,7 @@ const FormItem: FC = ({ SupportUploadFileTypes.video, ] : payload.allowed_file_types, - allowed_file_extensions: inStepRun + allowed_file_extensions: (inStepRun || isIteratorItemFile) ? [ ...FILE_EXTS[SupportUploadFileTypes.image], ...FILE_EXTS[SupportUploadFileTypes.document], @@ -223,8 +224,8 @@ const FormItem: FC = ({ ...FILE_EXTS[SupportUploadFileTypes.video], ] : payload.allowed_file_extensions, - allowed_file_upload_methods: inStepRun ? [TransferMethod.local_file, TransferMethod.remote_url] : payload.allowed_file_upload_methods, - number_limits: inStepRun ? 5 : payload.max_length, + allowed_file_upload_methods: (inStepRun || isIteratorItemFile) ? [TransferMethod.local_file, TransferMethod.remote_url] : payload.allowed_file_upload_methods, + number_limits: (inStepRun || isIteratorItemFile) ? 5 : payload.max_length, fileUploadConfig: fileSettings?.fileUploadConfig, }} /> @@ -272,7 +273,7 @@ const FormItem: FC = ({ } { - isIterator && ( + (isIterator && !isIteratorItemFile) && (
{(value || []).map((item: any, index: number) => ( = ({ } }, [valuesRef, onChange, mapKeysWithSameValueSelector]) const isArrayLikeType = [InputVarType.contexts, InputVarType.iterator].includes(inputs[0]?.type) + const isIteratorItemFile = inputs[0]?.type === InputVarType.iterator && inputs[0]?.isFileItem + const isContext = inputs[0]?.type === InputVarType.contexts const handleAddContext = useCallback(() => { const newValues = produce(values, (draft: any) => { const key = inputs[0].variable + if (!draft[key]) + draft[key] = [] draft[key].push(isContext ? RETRIEVAL_OUTPUT_STRUCT : '') }) onChange(newValues) @@ -75,7 +79,7 @@ const Form: FC = ({ {label && (
{label}
- {isArrayLikeType && ( + {isArrayLikeType && !isIteratorItemFile && ( )}
diff --git a/web/app/components/workflow/nodes/_base/components/before-run-form/index.tsx b/web/app/components/workflow/nodes/_base/components/before-run-form/index.tsx index ad8d0b9c61..11bd5156ef 100644 --- a/web/app/components/workflow/nodes/_base/components/before-run-form/index.tsx +++ b/web/app/components/workflow/nodes/_base/components/before-run-form/index.tsx @@ -1,30 +1,23 @@ 'use client' import type { FC } from 'react' -import React, { useCallback } from 'react' +import React, { useEffect, useRef } from 'react' import { useTranslation } from 'react-i18next' -import { - RiCloseLine, - RiLoader2Line, -} from '@remixicon/react' import type { Props as FormProps } from './form' import Form from './form' import cn from '@/utils/classnames' import Button from '@/app/components/base/button' -import { StopCircle } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices' import Split from '@/app/components/workflow/nodes/_base/components/split' -import { InputVarType, NodeRunningStatus } from '@/app/components/workflow/types' -import ResultPanel from '@/app/components/workflow/run/result-panel' +import { InputVarType } from '@/app/components/workflow/types' import Toast from '@/app/components/base/toast' import { TransferMethod } from '@/types/app' import { getProcessedFiles } from '@/app/components/base/file-uploader/utils' -import type { BlockEnum } from '@/app/components/workflow/types' +import type { BlockEnum, NodeRunningStatus } from '@/app/components/workflow/types' import type { Emoji } from '@/app/components/tools/types' import type { SpecialResultPanelProps } from '@/app/components/workflow/run/special-result-panel' -import SpecialResultPanel from '@/app/components/workflow/run/special-result-panel' - +import PanelWrap from './panel-wrap' const i18nPrefix = 'workflow.singleRun' -type BeforeRunFormProps = { +export type BeforeRunFormProps = { nodeName: string nodeType?: BlockEnum toolIcon?: string | Emoji @@ -32,12 +25,15 @@ type BeforeRunFormProps = { onRun: (submitData: Record) => void onStop: () => void runningStatus: NodeRunningStatus - result?: React.JSX.Element forms: FormProps[] showSpecialResultPanel?: boolean + existVarValuesInForms: Record[] + filteredExistVarForms: FormProps[] } & Partial function formatValue(value: string | any, type: InputVarType) { + if(value === undefined || value === null) + return value if (type === InputVarType.number) return Number.parseFloat(value) if (type === InputVarType.json) @@ -53,6 +49,8 @@ function formatValue(value: string | any, type: InputVarType) { if (type === InputVarType.singleFile) { if (Array.isArray(value)) return getProcessedFiles(value) + if (!value) + return undefined return getProcessedFiles([value])[0] } @@ -60,22 +58,17 @@ function formatValue(value: string | any, type: InputVarType) { } const BeforeRunForm: FC = ({ nodeName, - nodeType, - toolIcon, onHide, onRun, - onStop, - runningStatus, - result, forms, - showSpecialResultPanel, - ...restResultPanelParams + filteredExistVarForms, + existVarValuesInForms, }) => { const { t } = useTranslation() - const isFinished = runningStatus === NodeRunningStatus.Succeeded || runningStatus === NodeRunningStatus.Failed || runningStatus === NodeRunningStatus.Exception - const isRunning = runningStatus === NodeRunningStatus.Running const isFileLoaded = (() => { + if (!forms || forms.length === 0) + return true // system files const filesForm = forms.find(item => !!item.values['#files#']) if (!filesForm) @@ -87,12 +80,14 @@ const BeforeRunForm: FC = ({ return true })() - const handleRun = useCallback(() => { + const handleRun = () => { let errMsg = '' - forms.forEach((form) => { + forms.forEach((form, i) => { + const existVarValuesInForm = existVarValuesInForms[i] + form.inputs.forEach((input) => { const value = form.values[input.variable] as any - if (!errMsg && input.required && (value === '' || value === undefined || value === null || (input.type === InputVarType.files && value.length === 0))) + if (!errMsg && input.required && !(input.variable in existVarValuesInForm) && (value === '' || value === undefined || value === null || (input.type === InputVarType.files && value.length === 0))) errMsg = t('workflow.errorMsg.fieldRequired', { field: typeof input.label === 'object' ? input.label.variable : input.label }) if (!errMsg && (input.type === InputVarType.singleFile || input.type === InputVarType.multiFiles) && value) { @@ -137,69 +132,45 @@ const BeforeRunForm: FC = ({ } onRun(submitData) - }, [forms, onRun, t]) + } + const hasRun = useRef(false) + useEffect(() => { + // React 18 run twice in dev mode + if(hasRun.current) + return + hasRun.current = true + if(filteredExistVarForms.length === 0) + onRun({}) + }, [filteredExistVarForms, onRun]) + + if(filteredExistVarForms.length === 0) + return null + return ( -
-
-
-
- {t(`${i18nPrefix}.testRun`)} {nodeName} -
-
{ - onHide() - }}> - -
+ +
+
+ {filteredExistVarForms.map((form, index) => ( +
+
+ {index < forms.length - 1 && } +
+ ))} +
+
+
- { - showSpecialResultPanel && ( -
- -
- ) - } - { - !showSpecialResultPanel && ( -
-
- {forms.map((form, index) => ( -
- - {index < forms.length - 1 && } -
- ))} -
-
- {isRunning && ( -
- -
- )} - -
- {isRunning && ( - - )} - {isFinished && ( - <> - {result} - - )} -
- ) - }
-
+ ) } export default React.memo(BeforeRunForm) diff --git a/web/app/components/workflow/nodes/_base/components/before-run-form/panel-wrap.tsx b/web/app/components/workflow/nodes/_base/components/before-run-form/panel-wrap.tsx new file mode 100644 index 0000000000..7312adf6c6 --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/before-run-form/panel-wrap.tsx @@ -0,0 +1,41 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import { useTranslation } from 'react-i18next' +import { + RiCloseLine, +} from '@remixicon/react' + +const i18nPrefix = 'workflow.singleRun' + +export type Props = { + nodeName: string + onHide: () => void + children: React.ReactNode +} + +const PanelWrap: FC = ({ + nodeName, + onHide, + children, +}) => { + const { t } = useTranslation() + return ( +
+
+
+
+ {t(`${i18nPrefix}.testRun`)} {nodeName} +
+
{ + onHide() + }}> + +
+
+ {children} +
+
+ ) +} +export default React.memo(PanelWrap) diff --git a/web/app/components/workflow/nodes/_base/components/memory-config.tsx b/web/app/components/workflow/nodes/_base/components/memory-config.tsx index 446fcfa8ae..168ad656b6 100644 --- a/web/app/components/workflow/nodes/_base/components/memory-config.tsx +++ b/web/app/components/workflow/nodes/_base/components/memory-config.tsx @@ -53,7 +53,7 @@ type Props = { const MEMORY_DEFAULT: Memory = { window: { enabled: false, size: WINDOW_SIZE_DEFAULT }, - query_prompt_template: '{{#sys.query#}}', + query_prompt_template: '{{#sys.query#}}\n\n{{#sys.files#}}', } const MemoryConfig: FC = ({ diff --git a/web/app/components/workflow/nodes/_base/components/node-control.tsx b/web/app/components/workflow/nodes/_base/components/node-control.tsx index a85c41741b..5b92b7b6b4 100644 --- a/web/app/components/workflow/nodes/_base/components/node-control.tsx +++ b/web/app/components/workflow/nodes/_base/components/node-control.tsx @@ -13,7 +13,7 @@ import { useNodesInteractions, useNodesSyncDraft, } from '../../../hooks' -import type { Node } from '../../../types' +import { type Node, NodeRunningStatus } from '../../../types' import { canRunBySingle } from '../../../utils' import PanelOperator from './panel-operator' import { @@ -31,11 +31,12 @@ const NodeControl: FC = ({ const { handleNodeDataUpdate } = useNodeDataUpdate() const { handleNodeSelect } = useNodesInteractions() const { handleSyncWorkflowDraft } = useNodesSyncDraft() - + const isSingleRunning = data._singleRunningStatus === NodeRunningStatus.Running const handleOpenChange = useCallback((newOpen: boolean) => { setOpen(newOpen) }, []) + const isChildNode = !!(data.isInIteration || data.isInLoop) return (
= ({ onClick={e => e.stopPropagation()} > { - canRunBySingle(data.type) && ( + canRunBySingle(data.type, isChildNode) && (
{ + const nextData: Record = { + _isSingleRun: !isSingleRunning, + } + if(isSingleRunning) + nextData._singleRunningStatus = undefined + handleNodeDataUpdate({ id, - data: { - _isSingleRun: !data._isSingleRun, - }, + data: nextData, }) handleNodeSelect(id) - if (!data._isSingleRun) - handleSyncWorkflowDraft(true) }} > { - data._isSingleRun + isSingleRunning ? : ( { - (showChangeBlock || canRunBySingle(data.type)) && ( + (showChangeBlock || canRunBySingle(data.type, isChildNode)) && ( <>
{ - canRunBySingle(data.type) && ( + canRunBySingle(data.type, isChildNode) && (
void + onOpenChange?: (open: boolean) => void + isLoading?: boolean } const DEFAULT_SCHEMA = {} as CredentialFormSchema @@ -22,6 +24,8 @@ const ConstantField: FC = ({ readonly, value, onChange, + onOpenChange, + isLoading, }) => { const language = useLanguage() const placeholder = (schema as CredentialFormSchemaSelect).placeholder @@ -36,7 +40,7 @@ const ConstantField: FC = ({ return ( <> - {schema.type === FormTypeEnum.select && ( + {(schema.type === FormTypeEnum.select || schema.type === FormTypeEnum.dynamicSelect) && ( = ({ items={(schema as CredentialFormSchemaSelect).options.map(option => ({ value: option.value, name: option.label[language] || option.label.en_US }))} onSelect={item => handleSelectChange(item.value)} placeholder={placeholder?.[language] || placeholder?.en_US} + onOpenChange={onOpenChange} + isLoading={isLoading} /> )} {schema.type === FormTypeEnum.textNumber && ( diff --git a/web/app/components/workflow/nodes/_base/components/variable/utils.ts b/web/app/components/workflow/nodes/_base/components/variable/utils.ts index 428c204dd3..a69f9a51a7 100644 --- a/web/app/components/workflow/nodes/_base/components/variable/utils.ts +++ b/web/app/components/workflow/nodes/_base/components/variable/utils.ts @@ -1090,13 +1090,13 @@ export const getNodeUsedVarPassToServerKey = (node: Node, valueSelector: ValueSe break } case BlockEnum.Code: { - const targetVar = (data as CodeNodeType).variables?.find(v => v.value_selector.join('.') === valueSelector.join('.')) + const targetVar = (data as CodeNodeType).variables?.find(v => Array.isArray(v.value_selector) && v.value_selector && v.value_selector.join('.') === valueSelector.join('.')) if (targetVar) res = targetVar.variable break } case BlockEnum.TemplateTransform: { - const targetVar = (data as TemplateTransformNodeType).variables?.find(v => v.value_selector.join('.') === valueSelector.join('.')) + const targetVar = (data as TemplateTransformNodeType).variables?.find(v => Array.isArray(v.value_selector) && v.value_selector && v.value_selector.join('.') === valueSelector.join('.')) if (targetVar) res = targetVar.variable break diff --git a/web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx b/web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx index e9825cd44a..2c89e722cd 100644 --- a/web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx +++ b/web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx @@ -6,6 +6,7 @@ import { RiArrowDownSLine, RiCloseLine, RiErrorWarningFill, + RiLoader4Line, RiMoreLine, } from '@remixicon/react' import produce from 'immer' @@ -16,8 +17,9 @@ import VarReferencePopup from './var-reference-popup' import { getNodeInfoById, isConversationVar, isENV, isSystemVar, varTypeToStructType } from './utils' import ConstantField from './constant-field' import cn from '@/utils/classnames' -import type { Node, NodeOutPutVar, ValueSelector, Var } from '@/app/components/workflow/types' -import type { CredentialFormSchema } from '@/app/components/header/account-setting/model-provider-page/declarations' +import type { Node, NodeOutPutVar, ToolWithProvider, ValueSelector, Var } from '@/app/components/workflow/types' +import type { CredentialFormSchemaSelect } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { type CredentialFormSchema, type FormOption, FormTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { BlockEnum } from '@/app/components/workflow/types' import { VarBlockIcon } from '@/app/components/workflow/block-icon' import { Line3 } from '@/app/components/base/icons/src/public/common' @@ -40,6 +42,8 @@ import Tooltip from '@/app/components/base/tooltip' import { isExceptionVariable } from '@/app/components/workflow/utils' import VarFullPathPanel from './var-full-path-panel' import { noop } from 'lodash-es' +import { useFetchDynamicOptions } from '@/service/use-plugins' +import type { Tool } from '@/app/components/tools/types' const TRIGGER_DEFAULT_WIDTH = 227 @@ -68,6 +72,8 @@ type Props = { minWidth?: number popupFor?: 'assigned' | 'toAssigned' zIndex?: number + currentTool?: Tool + currentProvider?: ToolWithProvider } const DEFAULT_VALUE_SELECTOR: Props['value'] = [] @@ -97,6 +103,8 @@ const VarReferencePicker: FC = ({ minWidth, popupFor, zIndex, + currentTool, + currentProvider, }) => { const { t } = useTranslation() const store = useStoreApi() @@ -316,6 +324,42 @@ const VarReferencePicker: FC = ({ return null }, [isValidVar, isShowAPart, hasValue, t, outputVarNode?.title, outputVarNode?.type, value, type]) + + const [dynamicOptions, setDynamicOptions] = useState(null) + const [isLoading, setIsLoading] = useState(false) + const { mutateAsync: fetchDynamicOptions } = useFetchDynamicOptions( + currentProvider?.plugin_id || '', currentProvider?.name || '', currentTool?.name || '', (schema as CredentialFormSchemaSelect)?.variable || '', + 'tool', + ) + const handleFetchDynamicOptions = async () => { + if (schema?.type !== FormTypeEnum.dynamicSelect || !currentTool || !currentProvider) + return + setIsLoading(true) + try { + const data = await fetchDynamicOptions() + setDynamicOptions(data?.options || []) + } + finally { + setIsLoading(false) + } + } + useEffect(() => { + handleFetchDynamicOptions() + }, [currentTool, currentProvider, schema]) + + const schemaWithDynamicSelect = useMemo(() => { + if (schema?.type !== FormTypeEnum.dynamicSelect) + return schema + // rewrite schema.options with dynamicOptions + if (dynamicOptions) { + return { + ...schema, + options: dynamicOptions, + } + } + return schema + }, [dynamicOptions]) + return (
= ({ void)} - schema={schema as CredentialFormSchema} + schema={schemaWithDynamicSelect as CredentialFormSchema} readonly={readonly} + isLoading={isLoading} /> ) : ( @@ -412,6 +457,7 @@ const VarReferencePicker: FC = ({ )}
{!hasValue && } + {isLoading && } {isEnv && } {isChatVar && }
= ({ {!isValidVar && } ) - :
{placeholder ?? t('workflow.common.setVarValuePlaceholder')}
} + :
+ {isLoading ? ( +
+ + {placeholder ?? t('workflow.common.setVarValuePlaceholder')} +
+ ) : ( + placeholder ?? t('workflow.common.setVarValuePlaceholder') + )} +
}
diff --git a/web/app/components/workflow/nodes/_base/components/workflow-panel/index.tsx b/web/app/components/workflow/nodes/_base/components/workflow-panel/index.tsx new file mode 100644 index 0000000000..a47bb226b2 --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/workflow-panel/index.tsx @@ -0,0 +1,429 @@ +import type { + FC, + ReactNode, +} from 'react' +import { + cloneElement, + memo, + useCallback, + useEffect, + useMemo, + useRef, + useState, +} from 'react' +import { + RiCloseLine, + RiPlayLargeLine, +} from '@remixicon/react' +import { useShallow } from 'zustand/react/shallow' +import { useTranslation } from 'react-i18next' +import NextStep from '../next-step' +import PanelOperator from '../panel-operator' +import NodePosition from '@/app/components/workflow/nodes/_base/components/node-position' +import HelpLink from '../help-link' +import { + DescriptionInput, + TitleInput, +} from '../title-description-input' +import ErrorHandleOnPanel from '../error-handle/error-handle-on-panel' +import RetryOnPanel from '../retry/retry-on-panel' +import { useResizePanel } from '../../hooks/use-resize-panel' +import cn from '@/utils/classnames' +import BlockIcon from '@/app/components/workflow/block-icon' +import Split from '@/app/components/workflow/nodes/_base/components/split' +import { + WorkflowHistoryEvent, + useAvailableBlocks, + useNodeDataUpdate, + useNodesInteractions, + useNodesReadOnly, + useToolIcon, + useWorkflowHistory, +} from '@/app/components/workflow/hooks' +import { + canRunBySingle, + hasErrorHandleNode, + hasRetryNode, +} from '@/app/components/workflow/utils' +import Tooltip from '@/app/components/base/tooltip' +import { BlockEnum, type Node, NodeRunningStatus } from '@/app/components/workflow/types' +import { useStore as useAppStore } from '@/app/components/app/store' +import { useStore } from '@/app/components/workflow/store' +import Tab, { TabType } from './tab' +import LastRun from './last-run' +import useLastRun from './last-run/use-last-run' +import BeforeRunForm from '../before-run-form' +import { debounce } from 'lodash-es' +import { NODES_EXTRA_DATA } from '@/app/components/workflow/constants' +import { useLogs } from '@/app/components/workflow/run/hooks' +import PanelWrap from '../before-run-form/panel-wrap' +import SpecialResultPanel from '@/app/components/workflow/run/special-result-panel' +import { Stop } from '@/app/components/base/icons/src/vender/line/mediaAndDevices' + +type BasePanelProps = { + children: ReactNode +} & Node + +const BasePanel: FC = ({ + id, + data, + children, + position, + width, + height, +}) => { + const { t } = useTranslation() + const { showMessageLogModal } = useAppStore(useShallow(state => ({ + showMessageLogModal: state.showMessageLogModal, + }))) + const isSingleRunning = data._singleRunningStatus === NodeRunningStatus.Running + + const showSingleRunPanel = useStore(s => s.showSingleRunPanel) + const workflowCanvasWidth = useStore(s => s.workflowCanvasWidth) + const nodePanelWidth = useStore(s => s.nodePanelWidth) + const otherPanelWidth = useStore(s => s.otherPanelWidth) + const setNodePanelWidth = useStore(s => s.setNodePanelWidth) + + const maxNodePanelWidth = useMemo(() => { + if (!workflowCanvasWidth) + return 720 + if (!otherPanelWidth) + return workflowCanvasWidth - 400 + + return workflowCanvasWidth - otherPanelWidth - 400 + }, [workflowCanvasWidth, otherPanelWidth]) + + const updateNodePanelWidth = useCallback((width: number) => { + // Ensure the width is within the min and max range + const newValue = Math.min(Math.max(width, 400), maxNodePanelWidth) + localStorage.setItem('workflow-node-panel-width', `${newValue}`) + setNodePanelWidth(newValue) + }, [maxNodePanelWidth, setNodePanelWidth]) + + const handleResize = useCallback((width: number) => { + updateNodePanelWidth(width) + }, [updateNodePanelWidth]) + + const { + triggerRef, + containerRef, + } = useResizePanel({ + direction: 'horizontal', + triggerDirection: 'left', + minWidth: 400, + maxWidth: maxNodePanelWidth, + onResize: debounce(handleResize), + }) + + const debounceUpdate = debounce(updateNodePanelWidth) + useEffect(() => { + if (!workflowCanvasWidth) + return + if (workflowCanvasWidth - 400 <= nodePanelWidth + otherPanelWidth) + debounceUpdate(workflowCanvasWidth - 400 - otherPanelWidth) + }, [nodePanelWidth, otherPanelWidth, workflowCanvasWidth, updateNodePanelWidth]) + + const { handleNodeSelect } = useNodesInteractions() + const { nodesReadOnly } = useNodesReadOnly() + const { availableNextBlocks } = useAvailableBlocks(data.type, data.isInIteration, data.isInLoop) + const toolIcon = useToolIcon(data) + + const { saveStateToHistory } = useWorkflowHistory() + + const { + handleNodeDataUpdate, + handleNodeDataUpdateWithSyncDraft, + } = useNodeDataUpdate() + + const handleTitleBlur = useCallback((title: string) => { + handleNodeDataUpdateWithSyncDraft({ id, data: { title } }) + saveStateToHistory(WorkflowHistoryEvent.NodeTitleChange) + }, [handleNodeDataUpdateWithSyncDraft, id, saveStateToHistory]) + const handleDescriptionChange = useCallback((desc: string) => { + handleNodeDataUpdateWithSyncDraft({ id, data: { desc } }) + saveStateToHistory(WorkflowHistoryEvent.NodeDescriptionChange) + }, [handleNodeDataUpdateWithSyncDraft, id, saveStateToHistory]) + + const isChildNode = !!(data.isInIteration || data.isInLoop) + const isSupportSingleRun = canRunBySingle(data.type, isChildNode) + const appDetail = useAppStore(state => state.appDetail) + + const hasClickRunning = useRef(false) + const [isPaused, setIsPaused] = useState(false) + + useEffect(() => { + if(data._singleRunningStatus === NodeRunningStatus.Running) { + hasClickRunning.current = true + setIsPaused(false) + } + else if(data._isSingleRun && data._singleRunningStatus === undefined && hasClickRunning) { + setIsPaused(true) + hasClickRunning.current = false + } + }, [data]) + + const updateNodeRunningStatus = useCallback((status: NodeRunningStatus) => { + handleNodeDataUpdate({ + id, + data: { + ...data, + _singleRunningStatus: status, + }, + }) + }, [handleNodeDataUpdate, id, data]) + + useEffect(() => { + // console.log(`id changed: ${id}, hasClickRunning: ${hasClickRunning.current}`) + hasClickRunning.current = false + }, [id]) + + const { + isShowSingleRun, + hideSingleRun, + runningStatus, + handleStop, + runInputData, + runInputDataRef, + runResult, + getInputVars, + toVarInputs, + tabType, + isRunAfterSingleRun, + setTabType, + singleRunParams, + nodeInfo, + setRunInputData, + handleSingleRun, + handleRunWithParams, + getExistVarValuesInForms, + getFilteredExistVarForms, + } = useLastRun({ + id, + data, + defaultRunInputData: NODES_EXTRA_DATA[data.type]?.defaultRunInputData || {}, + isPaused, + }) + + useEffect(() => { + setIsPaused(false) + }, [tabType]) + + const logParams = useLogs() + const passedLogParams = (() => { + if ([BlockEnum.Tool, BlockEnum.Agent, BlockEnum.Iteration, BlockEnum.Loop].includes(data.type)) + return logParams + + return {} + })() + + if(logParams.showSpecialResultPanel) { + return ( +
+
+ +
+ +
+
+
+
+ ) + } + + if (isShowSingleRun) { + return ( +
+
+ +
+
+ ) + } + + return ( +
+
+
+
+
+
+
+ + +
+ { + isSupportSingleRun && !nodesReadOnly && ( + +
{ + if(isSingleRunning) { + handleNodeDataUpdate({ + id, + data: { + _isSingleRun: false, + _singleRunningStatus: undefined, + }, + }) + } + else { + handleSingleRun() + } + }} + > + { + isSingleRunning ? + : + } +
+
+ ) + } + + + +
+
handleNodeSelect(id, true)} + > + +
+
+
+
+ +
+
+ +
+ +
+ + {tabType === TabType.settings && ( + <> +
+ {cloneElement(children as any, { + id, + data, + panelProps: { + getInputVars, + toVarInputs, + runInputData, + setRunInputData, + runResult, + runInputDataRef, + }, + })} +
+ + { + hasRetryNode(data.type) && ( + + ) + } + { + hasErrorHandleNode(data.type) && ( + + ) + } + { + !!availableNextBlocks.length && ( +
+
+ {t('workflow.panel.nextStep').toLocaleUpperCase()} +
+
+ {t('workflow.panel.addNextStep')} +
+ +
+ ) + } + + )} + + {tabType === TabType.lastRun && ( + + )} +
+
+ ) +} + +export default memo(BasePanel) diff --git a/web/app/components/workflow/nodes/_base/components/workflow-panel/last-run/index.tsx b/web/app/components/workflow/nodes/_base/components/workflow-panel/last-run/index.tsx new file mode 100644 index 0000000000..a029987818 --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/workflow-panel/last-run/index.tsx @@ -0,0 +1,126 @@ +'use client' +import type { ResultPanelProps } from '@/app/components/workflow/run/result-panel' +import ResultPanel from '@/app/components/workflow/run/result-panel' +import { NodeRunningStatus } from '@/app/components/workflow/types' +import type { FC } from 'react' +import React, { useCallback, useEffect, useMemo, useState } from 'react' +import NoData from './no-data' +import { useLastRun } from '@/service/use-workflow' +import { RiLoader2Line } from '@remixicon/react' +import type { NodeTracing } from '@/types/workflow' + +type Props = { + appId: string + nodeId: string + canSingleRun: boolean + isRunAfterSingleRun: boolean + updateNodeRunningStatus: (status: NodeRunningStatus) => void + nodeInfo?: NodeTracing + runningStatus?: NodeRunningStatus + onSingleRunClicked: () => void + singleRunResult?: NodeTracing + isPaused?: boolean +} & Partial + +const LastRun: FC = ({ + appId, + nodeId, + canSingleRun, + isRunAfterSingleRun, + updateNodeRunningStatus, + nodeInfo, + runningStatus: oneStepRunRunningStatus, + onSingleRunClicked, + singleRunResult, + isPaused, + ...otherResultPanelProps +}) => { + const isOneStepRunSucceed = oneStepRunRunningStatus === NodeRunningStatus.Succeeded + const isOneStepRunFailed = oneStepRunRunningStatus === NodeRunningStatus.Failed + // hide page and return to page would lost the oneStepRunRunningStatus + const [hidePageOneStepFinishedStatus, setHidePageOneStepFinishedStatus] = React.useState(null) + const [pageHasHide, setPageHasHide] = useState(false) + const [pageShowed, setPageShowed] = useState(false) + + const hidePageOneStepRunFinished = [NodeRunningStatus.Succeeded, NodeRunningStatus.Failed].includes(hidePageOneStepFinishedStatus!) + const canRunLastRun = !isRunAfterSingleRun || isOneStepRunSucceed || isOneStepRunFailed || (pageHasHide && hidePageOneStepRunFinished) + const { data: lastRunResult, isFetching, error } = useLastRun(appId, nodeId, canRunLastRun) + const isRunning = useMemo(() => { + if(isPaused) + return false + + if(!isRunAfterSingleRun) + return isFetching + return [NodeRunningStatus.Running, NodeRunningStatus.NotStart].includes(oneStepRunRunningStatus!) + }, [isFetching, isPaused, isRunAfterSingleRun, oneStepRunRunningStatus]) + + const noLastRun = (error as any)?.status === 404 + const runResult = (canRunLastRun ? lastRunResult : singleRunResult) || lastRunResult || {} + + const resetHidePageStatus = useCallback(() => { + setPageHasHide(false) + setPageShowed(false) + setHidePageOneStepFinishedStatus(null) + }, []) + useEffect(() => { + if (pageShowed && hidePageOneStepFinishedStatus && (!oneStepRunRunningStatus || oneStepRunRunningStatus === NodeRunningStatus.NotStart)) { + updateNodeRunningStatus(hidePageOneStepFinishedStatus) + resetHidePageStatus() + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [isOneStepRunSucceed, isOneStepRunFailed, oneStepRunRunningStatus]) + + useEffect(() => { + if([NodeRunningStatus.Succeeded, NodeRunningStatus.Failed].includes(oneStepRunRunningStatus!)) + setHidePageOneStepFinishedStatus(oneStepRunRunningStatus!) + }, [oneStepRunRunningStatus]) + + useEffect(() => { + resetHidePageStatus() + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [nodeId]) + + const handlePageVisibilityChange = useCallback(() => { + if (document.visibilityState === 'hidden') + setPageHasHide(true) + else + setPageShowed(true) + }, []) + useEffect(() => { + document.addEventListener('visibilitychange', handlePageVisibilityChange) + + return () => { + document.removeEventListener('visibilitychange', handlePageVisibilityChange) + } + }, [handlePageVisibilityChange]) + + if (isFetching && !isRunAfterSingleRun) { + return ( +
+ +
) + } + + if (isRunning) + return + + if (!isPaused && (noLastRun || !runResult)) { + return ( + + ) + } + return ( +
+ +
+ ) +} +export default React.memo(LastRun) diff --git a/web/app/components/workflow/nodes/_base/components/workflow-panel/last-run/no-data.tsx b/web/app/components/workflow/nodes/_base/components/workflow-panel/last-run/no-data.tsx new file mode 100644 index 0000000000..ad0058efae --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/workflow-panel/last-run/no-data.tsx @@ -0,0 +1,36 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import { ClockPlay } from '@/app/components/base/icons/src/vender/line/time' +import Button from '@/app/components/base/button' +import { RiPlayLine } from '@remixicon/react' +import { useTranslation } from 'react-i18next' + +type Props = { + canSingleRun: boolean + onSingleRun: () => void +} + +const NoData: FC = ({ + canSingleRun, + onSingleRun, +}) => { + const { t } = useTranslation() + return ( +
+ +
{t('workflow.debug.noData.description')}
+ {canSingleRun && ( + + )} +
+ ) +} +export default React.memo(NoData) diff --git a/web/app/components/workflow/nodes/_base/components/workflow-panel/last-run/use-last-run.ts b/web/app/components/workflow/nodes/_base/components/workflow-panel/last-run/use-last-run.ts new file mode 100644 index 0000000000..014707cdfb --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/workflow-panel/last-run/use-last-run.ts @@ -0,0 +1,330 @@ +import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run' +import type { Params as OneStepRunParams } from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run' +import { useCallback, useEffect, useState } from 'react' +import { TabType } from '../tab' +import type { Props as FormProps } from '@/app/components/workflow/nodes/_base/components/before-run-form/form' +import useStartSingleRunFormParams from '@/app/components/workflow/nodes/start/use-single-run-form-params' +import useLLMSingleRunFormParams from '@/app/components/workflow/nodes/llm/use-single-run-form-params' +import useKnowledgeRetrievalSingleRunFormParams from '@/app/components/workflow/nodes/knowledge-retrieval/use-single-run-form-params' +import useCodeSingleRunFormParams from '@/app/components/workflow/nodes/code/use-single-run-form-params' +import useTemplateTransformSingleRunFormParams from '@/app/components/workflow/nodes/template-transform/use-single-run-form-params' +import useQuestionClassifierSingleRunFormParams from '@/app/components/workflow/nodes/question-classifier/use-single-run-form-params' +import useParameterExtractorSingleRunFormParams from '@/app/components/workflow/nodes/parameter-extractor/use-single-run-form-params' +import useHttpRequestSingleRunFormParams from '@/app/components/workflow/nodes/http/use-single-run-form-params' +import useToolSingleRunFormParams from '@/app/components/workflow/nodes/tool/use-single-run-form-params' +import useIterationSingleRunFormParams from '@/app/components/workflow/nodes/iteration/use-single-run-form-params' +import useAgentSingleRunFormParams from '@/app/components/workflow/nodes/agent/use-single-run-form-params' +import useDocExtractorSingleRunFormParams from '@/app/components/workflow/nodes/document-extractor/use-single-run-form-params' +import useLoopSingleRunFormParams from '@/app/components/workflow/nodes/loop/use-single-run-form-params' +import useIfElseSingleRunFormParams from '@/app/components/workflow/nodes/if-else/use-single-run-form-params' +import useVariableAggregatorSingleRunFormParams from '@/app/components/workflow/nodes/variable-assigner/use-single-run-form-params' +import useVariableAssignerSingleRunFormParams from '@/app/components/workflow/nodes/assigner/use-single-run-form-params' + +import useToolGetDataForCheckMore from '@/app/components/workflow/nodes/tool/use-get-data-for-check-more' +import { VALUE_SELECTOR_DELIMITER as DELIMITER } from '@/config' + +// import +import type { CommonNodeType, ValueSelector } from '@/app/components/workflow/types' +import { BlockEnum } from '@/app/components/workflow/types' +import { + useNodesSyncDraft, +} from '@/app/components/workflow/hooks' +import useInspectVarsCrud from '@/app/components/workflow/hooks/use-inspect-vars-crud' +import { useInvalidLastRun } from '@/service/use-workflow' +import { useStore, useWorkflowStore } from '@/app/components/workflow/store' + +const singleRunFormParamsHooks: Record = { + [BlockEnum.LLM]: useLLMSingleRunFormParams, + [BlockEnum.KnowledgeRetrieval]: useKnowledgeRetrievalSingleRunFormParams, + [BlockEnum.Code]: useCodeSingleRunFormParams, + [BlockEnum.TemplateTransform]: useTemplateTransformSingleRunFormParams, + [BlockEnum.QuestionClassifier]: useQuestionClassifierSingleRunFormParams, + [BlockEnum.HttpRequest]: useHttpRequestSingleRunFormParams, + [BlockEnum.Tool]: useToolSingleRunFormParams, + [BlockEnum.ParameterExtractor]: useParameterExtractorSingleRunFormParams, + [BlockEnum.Iteration]: useIterationSingleRunFormParams, + [BlockEnum.Agent]: useAgentSingleRunFormParams, + [BlockEnum.DocExtractor]: useDocExtractorSingleRunFormParams, + [BlockEnum.Loop]: useLoopSingleRunFormParams, + [BlockEnum.Start]: useStartSingleRunFormParams, + [BlockEnum.IfElse]: useIfElseSingleRunFormParams, + [BlockEnum.VariableAggregator]: useVariableAggregatorSingleRunFormParams, + [BlockEnum.Assigner]: useVariableAssignerSingleRunFormParams, + [BlockEnum.VariableAssigner]: undefined, + [BlockEnum.End]: undefined, + [BlockEnum.Answer]: undefined, + [BlockEnum.ListFilter]: undefined, + [BlockEnum.IterationStart]: undefined, + [BlockEnum.LoopStart]: undefined, + [BlockEnum.LoopEnd]: undefined, +} + +const useSingleRunFormParamsHooks = (nodeType: BlockEnum) => { + return (params: any) => { + return singleRunFormParamsHooks[nodeType]?.(params) || {} + } +} + +const getDataForCheckMoreHooks: Record = { + [BlockEnum.Tool]: useToolGetDataForCheckMore, + [BlockEnum.LLM]: undefined, + [BlockEnum.KnowledgeRetrieval]: undefined, + [BlockEnum.Code]: undefined, + [BlockEnum.TemplateTransform]: undefined, + [BlockEnum.QuestionClassifier]: undefined, + [BlockEnum.HttpRequest]: undefined, + [BlockEnum.ParameterExtractor]: undefined, + [BlockEnum.Iteration]: undefined, + [BlockEnum.Agent]: undefined, + [BlockEnum.DocExtractor]: undefined, + [BlockEnum.Loop]: undefined, + [BlockEnum.Start]: undefined, + [BlockEnum.IfElse]: undefined, + [BlockEnum.VariableAggregator]: undefined, + [BlockEnum.End]: undefined, + [BlockEnum.Answer]: undefined, + [BlockEnum.VariableAssigner]: undefined, + [BlockEnum.ListFilter]: undefined, + [BlockEnum.IterationStart]: undefined, + [BlockEnum.Assigner]: undefined, + [BlockEnum.LoopStart]: undefined, + [BlockEnum.LoopEnd]: undefined, +} + +const useGetDataForCheckMoreHooks = (nodeType: BlockEnum) => { + return (id: string, payload: CommonNodeType) => { + return getDataForCheckMoreHooks[nodeType]?.({ id, payload }) || { + getData: () => { + return {} + }, + } + } +} + +type Params = Omit, 'isRunAfterSingleRun'> +const useLastRun = ({ + ...oneStepRunParams +}: Params) => { + const { conversationVars, systemVars, hasSetInspectVar } = useInspectVarsCrud() + const blockType = oneStepRunParams.data.type + const isStartNode = blockType === BlockEnum.Start + const isIterationNode = blockType === BlockEnum.Iteration + const isLoopNode = blockType === BlockEnum.Loop + const isAggregatorNode = blockType === BlockEnum.VariableAggregator + const { handleSyncWorkflowDraft } = useNodesSyncDraft() + const { + getData: getDataForCheckMore, + } = useGetDataForCheckMoreHooks(blockType)(oneStepRunParams.id, oneStepRunParams.data) + const [isRunAfterSingleRun, setIsRunAfterSingleRun] = useState(false) + + const { + id, + data, + } = oneStepRunParams + const oneStepRunRes = useOneStepRun({ + ...oneStepRunParams, + iteratorInputKey: blockType === BlockEnum.Iteration ? `${id}.input_selector` : '', + moreDataForCheckValid: getDataForCheckMore(), + isRunAfterSingleRun, + }) + + const { + appId, + hideSingleRun, + handleRun: doCallRunApi, + getInputVars, + toVarInputs, + varSelectorsToVarInputs, + runInputData, + runInputDataRef, + setRunInputData, + showSingleRun, + runResult, + iterationRunResult, + loopRunResult, + setNodeRunning, + checkValid, + } = oneStepRunRes + + const { + nodeInfo, + ...singleRunParams + } = useSingleRunFormParamsHooks(blockType)({ + id, + payload: data, + runInputData, + runInputDataRef, + getInputVars, + setRunInputData, + toVarInputs, + varSelectorsToVarInputs, + runResult, + iterationRunResult, + loopRunResult, + }) + + const toSubmitData = useCallback((data: Record) => { + if(!isIterationNode && !isLoopNode) + return data + + const allVarObject = singleRunParams?.allVarObject || {} + const formattedData: Record = {} + Object.keys(allVarObject).forEach((key) => { + const [varSectorStr, nodeId] = key.split(DELIMITER) + formattedData[`${nodeId}.${allVarObject[key].inSingleRunPassedKey}`] = data[varSectorStr] + }) + if(isIterationNode) { + const iteratorInputKey = `${id}.input_selector` + formattedData[iteratorInputKey] = data[iteratorInputKey] + } + return formattedData + }, [isIterationNode, isLoopNode, singleRunParams?.allVarObject, id]) + + const callRunApi = (data: Record, cb?: () => void) => { + handleSyncWorkflowDraft(true, true, { + onSuccess() { + doCallRunApi(toSubmitData(data)) + cb?.() + }, + }) + } + const workflowStore = useWorkflowStore() + const { setInitShowLastRunTab } = workflowStore.getState() + const initShowLastRunTab = useStore(s => s.initShowLastRunTab) + const [tabType, setTabType] = useState(initShowLastRunTab ? TabType.lastRun : TabType.settings) + useEffect(() => { + if(initShowLastRunTab) + setTabType(TabType.lastRun) + + setInitShowLastRunTab(false) + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [initShowLastRunTab]) + const invalidLastRun = useInvalidLastRun(appId!, id) + + const handleRunWithParams = async (data: Record) => { + const { isValid } = checkValid() + if(!isValid) + return + setNodeRunning() + setIsRunAfterSingleRun(true) + setTabType(TabType.lastRun) + callRunApi(data, () => { + invalidLastRun() + }) + hideSingleRun() + } + + const handleTabClicked = useCallback((type: TabType) => { + setIsRunAfterSingleRun(false) + setTabType(type) + }, []) + + const getExistVarValuesInForms = (forms: FormProps[]) => { + if (!forms || forms.length === 0) + return [] + + const valuesArr = forms.map((form) => { + const values: Record = {} + form.inputs.forEach(({ variable, getVarValueFromDependent }) => { + const isGetValueFromDependent = getVarValueFromDependent || !variable.includes('.') + if(isGetValueFromDependent && !singleRunParams?.getDependentVar) + return + + const selector = isGetValueFromDependent ? (singleRunParams?.getDependentVar(variable) || []) : variable.slice(1, -1).split('.') + if(!selector || selector.length === 0) + return + const [nodeId, varName] = selector.slice(0, 2) + if(!isStartNode && nodeId === id) { // inner vars like loop vars + values[variable] = true + return + } + const inspectVarValue = hasSetInspectVar(nodeId, varName, systemVars, conversationVars) // also detect system var , env and conversation var + if (inspectVarValue) + values[variable] = true + }) + return values + }) + return valuesArr + } + + const isAllVarsHasValue = (vars?: ValueSelector[]) => { + if(!vars || vars.length === 0) + return true + return vars.every((varItem) => { + const [nodeId, varName] = varItem.slice(0, 2) + const inspectVarValue = hasSetInspectVar(nodeId, varName, systemVars, conversationVars) // also detect system var , env and conversation var + return inspectVarValue + }) + } + + const isSomeVarsHasValue = (vars?: ValueSelector[]) => { + if(!vars || vars.length === 0) + return true + return vars.some((varItem) => { + const [nodeId, varName] = varItem.slice(0, 2) + const inspectVarValue = hasSetInspectVar(nodeId, varName, systemVars, conversationVars) // also detect system var , env and conversation var + return inspectVarValue + }) + } + const getFilteredExistVarForms = (forms: FormProps[]) => { + if (!forms || forms.length === 0) + return [] + + const existVarValuesInForms = getExistVarValuesInForms(forms) + + const res = forms.map((form, i) => { + const existVarValuesInForm = existVarValuesInForms[i] + const newForm = { ...form } + const inputs = form.inputs.filter((input) => { + return !(input.variable in existVarValuesInForm) + }) + newForm.inputs = inputs + return newForm + }).filter(form => form.inputs.length > 0) + return res + } + + const checkAggregatorVarsSet = (vars: ValueSelector[][]) => { + if(!vars || vars.length === 0) + return true + // in each group, at last one set is ok + return vars.every((varItem) => { + return isSomeVarsHasValue(varItem) + }) + } + + const handleSingleRun = () => { + const { isValid } = checkValid() + if(!isValid) + return + const vars = singleRunParams?.getDependentVars?.() + // no need to input params + if (isAggregatorNode ? checkAggregatorVarsSet(vars) : isAllVarsHasValue(vars)) { + callRunApi({}, async () => { + setIsRunAfterSingleRun(true) + setNodeRunning() + invalidLastRun() + setTabType(TabType.lastRun) + }) + } + else { + showSingleRun() + } + } + + return { + ...oneStepRunRes, + tabType, + isRunAfterSingleRun, + setTabType: handleTabClicked, + singleRunParams, + nodeInfo, + setRunInputData, + handleSingleRun, + handleRunWithParams, + getExistVarValuesInForms, + getFilteredExistVarForms, + } +} + +export default useLastRun diff --git a/web/app/components/workflow/nodes/_base/components/workflow-panel/tab.tsx b/web/app/components/workflow/nodes/_base/components/workflow-panel/tab.tsx new file mode 100644 index 0000000000..09d7ed266d --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/workflow-panel/tab.tsx @@ -0,0 +1,34 @@ +'use client' +import TabHeader from '@/app/components/base/tab-header' +import type { FC } from 'react' +import React from 'react' +import { useTranslation } from 'react-i18next' + +export enum TabType { + settings = 'settings', + lastRun = 'lastRun', +} + +type Props = { + value: TabType, + onChange: (value: TabType) => void +} + +const Tab: FC = ({ + value, + onChange, +}) => { + const { t } = useTranslation() + return ( + + ) +} +export default React.memo(Tab) diff --git a/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts b/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts index f23af5812c..769804a20b 100644 --- a/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts +++ b/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts @@ -13,7 +13,7 @@ import type { CommonNodeType, InputVar, ValueSelector, Var, Variable } from '@/a import { BlockEnum, InputVarType, NodeRunningStatus, VarType } from '@/app/components/workflow/types' import { useStore as useAppStore } from '@/app/components/app/store' import { useStore, useWorkflowStore } from '@/app/components/workflow/store' -import { getIterationSingleNodeRunUrl, getLoopSingleNodeRunUrl, singleNodeRun } from '@/service/workflow' +import { fetchNodeInspectVars, getIterationSingleNodeRunUrl, getLoopSingleNodeRunUrl, singleNodeRun } from '@/service/workflow' import Toast from '@/app/components/base/toast' import LLMDefault from '@/app/components/workflow/nodes/llm/default' import KnowledgeRetrievalDefault from '@/app/components/workflow/nodes/knowledge-retrieval/default' @@ -32,7 +32,7 @@ import LoopDefault from '@/app/components/workflow/nodes/loop/default' import { ssePost } from '@/service/base' import { noop } from 'lodash-es' import { getInputVars as doGetInputVars } from '@/app/components/base/prompt-editor/constants' -import type { NodeTracing } from '@/types/workflow' +import type { NodeRunResult, NodeTracing } from '@/types/workflow' const { checkValid: checkLLMValid } = LLMDefault const { checkValid: checkKnowledgeRetrievalValid } = KnowledgeRetrievalDefault const { checkValid: checkIfElseValid } = IfElseDefault @@ -47,7 +47,11 @@ const { checkValid: checkParameterExtractorValid } = ParameterExtractorDefault const { checkValid: checkIterationValid } = IterationDefault const { checkValid: checkDocumentExtractorValid } = DocumentExtractorDefault const { checkValid: checkLoopValid } = LoopDefault - +import { + useStoreApi, +} from 'reactflow' +import { useInvalidLastRun } from '@/service/use-workflow' +import useInspectVarsCrud from '../../../hooks/use-inspect-vars-crud' // eslint-disable-next-line ts/no-unsafe-function-type const checkValidFns: Record = { [BlockEnum.LLM]: checkLLMValid, @@ -66,13 +70,15 @@ const checkValidFns: Record = { [BlockEnum.Loop]: checkLoopValid, } as any -type Params = { +export type Params = { id: string data: CommonNodeType defaultRunInputData: Record moreDataForCheckValid?: any iteratorInputKey?: string loopInputKey?: string + isRunAfterSingleRun: boolean + isPaused: boolean } const varTypeToInputVarType = (type: VarType, { @@ -105,6 +111,8 @@ const useOneStepRun = ({ moreDataForCheckValid, iteratorInputKey, loopInputKey, + isRunAfterSingleRun, + isPaused, }: Params) => { const { t } = useTranslation() const { getBeforeNodesInSameBranch, getBeforeNodesInSameBranchIncludeParent } = useWorkflow() as any @@ -112,6 +120,7 @@ const useOneStepRun = ({ const isChatMode = useIsChatMode() const isIteration = data.type === BlockEnum.Iteration const isLoop = data.type === BlockEnum.Loop + const isStartNode = data.type === BlockEnum.Start const availableNodes = getBeforeNodesInSameBranch(id) const availableNodesIncludeParent = getBeforeNodesInSameBranchIncludeParent(id) @@ -143,6 +152,7 @@ const useOneStepRun = ({ } const checkValid = checkValidFns[data.type] + const appId = useAppStore.getState().appDetail?.id const [runInputData, setRunInputData] = useState>(defaultRunInputData || {}) const runInputDataRef = useRef(runInputData) @@ -150,11 +160,82 @@ const useOneStepRun = ({ runInputDataRef.current = data setRunInputData(data) }, []) - const iterationTimes = iteratorInputKey ? runInputData[iteratorInputKey].length : 0 - const loopTimes = loopInputKey ? runInputData[loopInputKey].length : 0 - const [runResult, setRunResult] = useState(null) + const iterationTimes = iteratorInputKey ? runInputData[iteratorInputKey]?.length : 0 + const loopTimes = loopInputKey ? runInputData[loopInputKey]?.length : 0 + + const store = useStoreApi() + const workflowStore = useWorkflowStore() + const { + setShowSingleRunPanel, + } = workflowStore.getState() + const invalidLastRun = useInvalidLastRun(appId!, id) + const [runResult, doSetRunResult] = useState(null) + const { + appendNodeInspectVars, + invalidateSysVarValues, + invalidateConversationVarValues, + } = useInspectVarsCrud() + const runningStatus = data._singleRunningStatus || NodeRunningStatus.NotStart + const isPausedRef = useRef(isPaused) + useEffect(() => { + isPausedRef.current = isPaused + }, [isPaused]) + + const setRunResult = useCallback(async (data: NodeRunResult | null) => { + const isPaused = isPausedRef.current + + // The backend don't support pause the single run, so the frontend handle the pause state. + if(isPaused) + return + + const canRunLastRun = !isRunAfterSingleRun || runningStatus === NodeRunningStatus.Succeeded + if(!canRunLastRun) { + doSetRunResult(data) + return + } + + // run fail may also update the inspect vars when the node set the error default output. + const vars = await fetchNodeInspectVars(appId!, id) + const { getNodes } = store.getState() + const nodes = getNodes() + appendNodeInspectVars(id, vars, nodes) + if(data?.status === NodeRunningStatus.Succeeded) { + invalidLastRun() + if(isStartNode) + invalidateSysVarValues() + invalidateConversationVarValues() // loop, iteration, variable assigner node can update the conversation variables, but to simple the logic(some nodes may also can update in the future), all nodes refresh. + } + }, [isRunAfterSingleRun, runningStatus, appId, id, store, appendNodeInspectVars, invalidLastRun, isStartNode, invalidateSysVarValues, invalidateConversationVarValues]) const { handleNodeDataUpdate }: { handleNodeDataUpdate: (data: any) => void } = useNodeDataUpdate() + const setNodeRunning = () => { + handleNodeDataUpdate({ + id, + data: { + ...data, + _singleRunningStatus: NodeRunningStatus.Running, + }, + }) + } + const checkValidWrap = () => { + if(!checkValid) + return { isValid: true, errorMessage: '' } + const res = checkValid(data, t, moreDataForCheckValid) + if(!res.isValid) { + handleNodeDataUpdate({ + id, + data: { + ...data, + _isSingleRun: false, + }, + }) + Toast.notify({ + type: 'error', + message: res.errorMessage, + }) + } + return res + } const [canShowSingleRun, setCanShowSingleRun] = useState(false) const isShowSingleRun = data._isSingleRun && canShowSingleRun const [iterationRunResult, setIterationRunResult] = useState([]) @@ -167,29 +248,15 @@ const useOneStepRun = ({ } if (data._isSingleRun) { - const { isValid, errorMessage } = checkValid(data, t, moreDataForCheckValid) + const { isValid } = checkValidWrap() setCanShowSingleRun(isValid) - if (!isValid) { - handleNodeDataUpdate({ - id, - data: { - ...data, - _isSingleRun: false, - }, - }) - Toast.notify({ - type: 'error', - message: errorMessage, - }) - } } // eslint-disable-next-line react-hooks/exhaustive-deps }, [data._isSingleRun]) - const workflowStore = useWorkflowStore() useEffect(() => { - workflowStore.getState().setShowSingleRunPanel(!!isShowSingleRun) - }, [isShowSingleRun, workflowStore]) + setShowSingleRunPanel(!!isShowSingleRun) + }, [isShowSingleRun, setShowSingleRunPanel]) const hideSingleRun = () => { handleNodeDataUpdate({ @@ -209,7 +276,6 @@ const useOneStepRun = ({ }, }) } - const runningStatus = data._singleRunningStatus || NodeRunningStatus.NotStart const isCompleted = runningStatus === NodeRunningStatus.Succeeded || runningStatus === NodeRunningStatus.Failed const handleRun = async (submitData: Record) => { @@ -217,13 +283,29 @@ const useOneStepRun = ({ id, data: { ...data, + _isSingleRun: false, _singleRunningStatus: NodeRunningStatus.Running, }, }) let res: any + let hasError = false try { if (!isIteration && !isLoop) { - res = await singleNodeRun(appId!, id, { inputs: submitData }) as any + const isStartNode = data.type === BlockEnum.Start + const postData: Record = {} + if(isStartNode) { + const { '#sys.query#': query, '#sys.files#': files, ...inputs } = submitData + if(isChatMode) + postData.conversation_id = '' + + postData.inputs = inputs + postData.query = query + postData.files = files || [] + } + else { + postData.inputs = submitData + } + res = await singleNodeRun(appId!, id, postData) as any } else if (isIteration) { setIterationRunResult([]) @@ -235,10 +317,13 @@ const useOneStepRun = ({ { onWorkflowStarted: noop, onWorkflowFinished: (params) => { + if(isPausedRef.current) + return handleNodeDataUpdate({ id, data: { ...data, + _isSingleRun: false, _singleRunningStatus: NodeRunningStatus.Succeeded, }, }) @@ -311,10 +396,13 @@ const useOneStepRun = ({ setIterationRunResult(newIterationRunResult) }, onError: () => { + if(isPausedRef.current) + return handleNodeDataUpdate({ id, data: { ...data, + _isSingleRun: false, _singleRunningStatus: NodeRunningStatus.Failed, }, }) @@ -332,10 +420,13 @@ const useOneStepRun = ({ { onWorkflowStarted: noop, onWorkflowFinished: (params) => { + if(isPausedRef.current) + return handleNodeDataUpdate({ id, data: { ...data, + _isSingleRun: false, _singleRunningStatus: NodeRunningStatus.Succeeded, }, }) @@ -409,10 +500,13 @@ const useOneStepRun = ({ setLoopRunResult(newLoopRunResult) }, onError: () => { + if(isPausedRef.current) + return handleNodeDataUpdate({ id, data: { ...data, + _isSingleRun: false, _singleRunningStatus: NodeRunningStatus.Failed, }, }) @@ -425,11 +519,16 @@ const useOneStepRun = ({ } catch (e: any) { console.error(e) + hasError = true + invalidLastRun() if (!isIteration && !isLoop) { + if(isPausedRef.current) + return handleNodeDataUpdate({ id, data: { ...data, + _isSingleRun: false, _singleRunningStatus: NodeRunningStatus.Failed, }, }) @@ -437,7 +536,7 @@ const useOneStepRun = ({ } } finally { - if (!isIteration && !isLoop) { + if (!isPausedRef.current && !isIteration && !isLoop && res) { setRunResult({ ...res, total_tokens: res.execution_metadata?.total_tokens || 0, @@ -445,11 +544,17 @@ const useOneStepRun = ({ }) } } - if (!isIteration && !isLoop) { + if(isPausedRef.current) + return + + if (!isIteration && !isLoop && !hasError) { + if(isPausedRef.current) + return handleNodeDataUpdate({ id, data: { ...data, + _isSingleRun: false, _singleRunningStatus: NodeRunningStatus.Succeeded, }, }) @@ -521,11 +626,19 @@ const useOneStepRun = ({ return varInputs } + const varSelectorsToVarInputs = (valueSelectors: ValueSelector[] | string[]): InputVar[] => { + return valueSelectors.filter(item => !!item).map((item) => { + return getInputVars([`{{#${typeof item === 'string' ? item : item.join('.')}#}}`])[0] + }) + } + return { + appId, isShowSingleRun, hideSingleRun, showSingleRun, toVarInputs, + varSelectorsToVarInputs, getInputVars, runningStatus, isCompleted, @@ -537,6 +650,8 @@ const useOneStepRun = ({ runResult, iterationRunResult, loopRunResult, + setNodeRunning, + checkValid: checkValidWrap, } } diff --git a/web/app/components/workflow/nodes/_base/hooks/use-output-var-list.ts b/web/app/components/workflow/nodes/_base/hooks/use-output-var-list.ts index 839cd14026..515f2c365b 100644 --- a/web/app/components/workflow/nodes/_base/hooks/use-output-var-list.ts +++ b/web/app/components/workflow/nodes/_base/hooks/use-output-var-list.ts @@ -1,6 +1,6 @@ -import { useCallback, useState } from 'react' +import { useCallback, useRef, useState } from 'react' import produce from 'immer' -import { useBoolean } from 'ahooks' +import { useBoolean, useDebounceFn } from 'ahooks' import type { CodeNodeType, OutputVar, @@ -17,6 +17,7 @@ import { } from '@/app/components/workflow/hooks' import { ErrorHandleTypeEnum } from '@/app/components/workflow/nodes/_base/components/error-handle/types' import { getDefaultValue } from '@/app/components/workflow/nodes/_base/components/error-handle/utils' +import useInspectVarsCrud from '../../../hooks/use-inspect-vars-crud' type Params = { id: string @@ -34,8 +35,27 @@ function useOutputVarList({ outputKeyOrders = [], onOutputKeyOrdersChange, }: Params) { + const { + renameInspectVarName, + deleteInspectVar, + nodesWithInspectVars, + } = useInspectVarsCrud() + const { handleOutVarRenameChange, isVarUsedInNodes, removeUsedVarInNodes } = useWorkflow() + // record the first old name value + const oldNameRecord = useRef>({}) + + const { + run: renameInspectNameWithDebounce, + } = useDebounceFn( + (id: string, newName: string) => { + const oldName = oldNameRecord.current[id] + renameInspectVarName(id, oldName, newName) + delete oldNameRecord.current[id] + }, + { wait: 500 }, + ) const handleVarsChange = useCallback((newVars: OutputVar, changedIndex?: number, newKey?: string) => { const newInputs = produce(inputs, (draft: any) => { draft[varKey] = newVars @@ -52,9 +72,20 @@ function useOutputVarList({ onOutputKeyOrdersChange(newOutputKeyOrders) } - if (newKey) + if (newKey) { handleOutVarRenameChange(id, [id, outputKeyOrders[changedIndex!]], [id, newKey]) - }, [inputs, setInputs, handleOutVarRenameChange, id, outputKeyOrders, varKey, onOutputKeyOrdersChange]) + if(!(id in oldNameRecord.current)) + oldNameRecord.current[id] = outputKeyOrders[changedIndex!] + renameInspectNameWithDebounce(id, newKey) + } + else if (changedIndex === undefined) { + const varId = nodesWithInspectVars.find(node => node.nodeId === id)?.vars.find((varItem) => { + return varItem.name === Object.keys(newVars)[0] + })?.id + if(varId) + deleteInspectVar(id, varId) + } + }, [inputs, setInputs, varKey, outputKeyOrders, onOutputKeyOrdersChange, handleOutVarRenameChange, id, renameInspectNameWithDebounce, nodesWithInspectVars, deleteInspectVar]) const generateNewKey = useCallback(() => { let keyIndex = Object.keys((inputs as any)[varKey]).length + 1 @@ -86,9 +117,14 @@ function useOutputVarList({ }] = useBoolean(false) const [removedVar, setRemovedVar] = useState([]) const removeVarInNode = useCallback(() => { + const varId = nodesWithInspectVars.find(node => node.nodeId === id)?.vars.find((varItem) => { + return varItem.name === removedVar[1] + })?.id + if(varId) + deleteInspectVar(id, varId) removeUsedVarInNodes(removedVar) hideRemoveVarConfirm() - }, [hideRemoveVarConfirm, removeUsedVarInNodes, removedVar]) + }, [deleteInspectVar, hideRemoveVarConfirm, id, nodesWithInspectVars, removeUsedVarInNodes, removedVar]) const handleRemoveVariable = useCallback((index: number) => { const key = outputKeyOrders[index] @@ -106,7 +142,12 @@ function useOutputVarList({ }) setInputs(newInputs) onOutputKeyOrdersChange(outputKeyOrders.filter((_, i) => i !== index)) - }, [outputKeyOrders, isVarUsedInNodes, id, inputs, setInputs, onOutputKeyOrdersChange, showRemoveVarConfirm, varKey]) + const varId = nodesWithInspectVars.find(node => node.nodeId === id)?.vars.find((varItem) => { + return varItem.name === key + })?.id + if(varId) + deleteInspectVar(id, varId) + }, [outputKeyOrders, isVarUsedInNodes, id, inputs, setInputs, onOutputKeyOrdersChange, nodesWithInspectVars, deleteInspectVar, showRemoveVarConfirm, varKey]) return { handleVarsChange, diff --git a/web/app/components/workflow/nodes/_base/node.tsx b/web/app/components/workflow/nodes/_base/node.tsx index 527b2f094d..27d6adc62b 100644 --- a/web/app/components/workflow/nodes/_base/node.tsx +++ b/web/app/components/workflow/nodes/_base/node.tsx @@ -44,6 +44,7 @@ import AddVariablePopupWithPosition from './components/add-variable-popup-with-p import cn from '@/utils/classnames' import BlockIcon from '@/app/components/workflow/block-icon' import Tooltip from '@/app/components/base/tooltip' +import useInspectVarsCrud from '../../hooks/use-inspect-vars-crud' type BaseNodeProps = { children: ReactElement @@ -89,6 +90,9 @@ const BaseNode: FC = ({ } }, [data.isInLoop, data.selected, id, handleNodeLoopChildSizeChange]) + const { hasNodeInspectVars } = useInspectVarsCrud() + const isLoading = data._runningStatus === NodeRunningStatus.Running || data._singleRunningStatus === NodeRunningStatus.Running + const hasVarValue = hasNodeInspectVars(id) const showSelectedBorder = data.selected || data._isBundled || data._isEntering const { showRunningBorder, @@ -98,11 +102,11 @@ const BaseNode: FC = ({ } = useMemo(() => { return { showRunningBorder: data._runningStatus === NodeRunningStatus.Running && !showSelectedBorder, - showSuccessBorder: data._runningStatus === NodeRunningStatus.Succeeded && !showSelectedBorder, + showSuccessBorder: (data._runningStatus === NodeRunningStatus.Succeeded || hasVarValue) && !showSelectedBorder, showFailedBorder: data._runningStatus === NodeRunningStatus.Failed && !showSelectedBorder, showExceptionBorder: data._runningStatus === NodeRunningStatus.Exception && !showSelectedBorder, } - }, [data._runningStatus, showSelectedBorder]) + }, [data._runningStatus, hasVarValue, showSelectedBorder]) const LoopIndex = useMemo(() => { let text = '' @@ -260,12 +264,12 @@ const BaseNode: FC = ({ data.type === BlockEnum.Loop && data._loopIndex && LoopIndex } { - (data._runningStatus === NodeRunningStatus.Running || data._singleRunningStatus === NodeRunningStatus.Running) && ( + isLoading && ( ) } { - data._runningStatus === NodeRunningStatus.Succeeded && ( + (!isLoading && (data._runningStatus === NodeRunningStatus.Succeeded || hasVarValue)) && ( ) } diff --git a/web/app/components/workflow/nodes/_base/panel.tsx b/web/app/components/workflow/nodes/_base/panel.tsx deleted file mode 100644 index 49c61b3416..0000000000 --- a/web/app/components/workflow/nodes/_base/panel.tsx +++ /dev/null @@ -1,214 +0,0 @@ -import type { - FC, - ReactNode, -} from 'react' -import { - cloneElement, - memo, - useCallback, -} from 'react' -import { - RiCloseLine, - RiPlayLargeLine, -} from '@remixicon/react' -import { useShallow } from 'zustand/react/shallow' -import { useTranslation } from 'react-i18next' -import NextStep from './components/next-step' -import PanelOperator from './components/panel-operator' -import HelpLink from './components/help-link' -import NodePosition from './components/node-position' -import { - DescriptionInput, - TitleInput, -} from './components/title-description-input' -import ErrorHandleOnPanel from './components/error-handle/error-handle-on-panel' -import RetryOnPanel from './components/retry/retry-on-panel' -import { useResizePanel } from './hooks/use-resize-panel' -import cn from '@/utils/classnames' -import BlockIcon from '@/app/components/workflow/block-icon' -import Split from '@/app/components/workflow/nodes/_base/components/split' -import { - WorkflowHistoryEvent, - useAvailableBlocks, - useNodeDataUpdate, - useNodesInteractions, - useNodesReadOnly, - useNodesSyncDraft, - useToolIcon, - useWorkflow, - useWorkflowHistory, -} from '@/app/components/workflow/hooks' -import { - canRunBySingle, - hasErrorHandleNode, - hasRetryNode, -} from '@/app/components/workflow/utils' -import Tooltip from '@/app/components/base/tooltip' -import type { Node } from '@/app/components/workflow/types' -import { useStore as useAppStore } from '@/app/components/app/store' -import { useStore } from '@/app/components/workflow/store' - -type BasePanelProps = { - children: ReactNode -} & Node - -const BasePanel: FC = ({ - id, - data, - children, - position, - width, - height, -}) => { - const { t } = useTranslation() - const { showMessageLogModal } = useAppStore(useShallow(state => ({ - showMessageLogModal: state.showMessageLogModal, - }))) - const showSingleRunPanel = useStore(s => s.showSingleRunPanel) - const panelWidth = localStorage.getItem('workflow-node-panel-width') ? Number.parseFloat(localStorage.getItem('workflow-node-panel-width')!) : 420 - const { - setPanelWidth, - } = useWorkflow() - const { handleNodeSelect } = useNodesInteractions() - const { handleSyncWorkflowDraft } = useNodesSyncDraft() - const { nodesReadOnly } = useNodesReadOnly() - const { availableNextBlocks } = useAvailableBlocks(data.type, data.isInIteration, data.isInLoop) - const toolIcon = useToolIcon(data) - - const handleResize = useCallback((width: number) => { - setPanelWidth(width) - }, [setPanelWidth]) - - const { - triggerRef, - containerRef, - } = useResizePanel({ - direction: 'horizontal', - triggerDirection: 'left', - minWidth: 420, - maxWidth: 720, - onResize: handleResize, - }) - - const { saveStateToHistory } = useWorkflowHistory() - - const { - handleNodeDataUpdate, - handleNodeDataUpdateWithSyncDraft, - } = useNodeDataUpdate() - - const handleTitleBlur = useCallback((title: string) => { - handleNodeDataUpdateWithSyncDraft({ id, data: { title } }) - saveStateToHistory(WorkflowHistoryEvent.NodeTitleChange) - }, [handleNodeDataUpdateWithSyncDraft, id, saveStateToHistory]) - const handleDescriptionChange = useCallback((desc: string) => { - handleNodeDataUpdateWithSyncDraft({ id, data: { desc } }) - saveStateToHistory(WorkflowHistoryEvent.NodeDescriptionChange) - }, [handleNodeDataUpdateWithSyncDraft, id, saveStateToHistory]) - - return ( -
-
-
-
-
-
-
- - -
- { - canRunBySingle(data.type) && !nodesReadOnly && ( - -
{ - handleNodeDataUpdate({ id, data: { _isSingleRun: true } }) - handleSyncWorkflowDraft(true) - }} - > - -
-
- ) - } - - - -
-
handleNodeSelect(id, true)} - > - -
-
-
-
- -
-
-
- {cloneElement(children as any, { id, data })} -
- - { - hasRetryNode(data.type) && ( - - ) - } - { - hasErrorHandleNode(data.type) && ( - - ) - } - { - !!availableNextBlocks.length && ( -
-
- {t('workflow.panel.nextStep').toLocaleUpperCase()} -
-
- {t('workflow.panel.addNextStep')} -
- -
- ) - } -
-
- ) -} - -export default memo(BasePanel) diff --git a/web/app/components/workflow/nodes/agent/panel.tsx b/web/app/components/workflow/nodes/agent/panel.tsx index f92e92dbcb..391383031f 100644 --- a/web/app/components/workflow/nodes/agent/panel.tsx +++ b/web/app/components/workflow/nodes/agent/panel.tsx @@ -1,5 +1,5 @@ import type { FC } from 'react' -import { memo, useMemo } from 'react' +import { memo } from 'react' import type { NodePanelProps } from '../../types' import { AgentFeature, type AgentNodeType } from './types' import Field from '../_base/components/field' @@ -9,16 +9,10 @@ import { useTranslation } from 'react-i18next' import OutputVars, { VarItem } from '../_base/components/output-vars' import type { StrategyParamItem } from '@/app/components/plugins/types' import type { CredentialFormSchema } from '@/app/components/header/account-setting/model-provider-page/declarations' -import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form' -import ResultPanel from '@/app/components/workflow/run/result-panel' -import formatTracing from '@/app/components/workflow/run/utils/format-log' -import { useLogs } from '@/app/components/workflow/run/hooks' -import type { Props as FormProps } from '@/app/components/workflow/nodes/_base/components/before-run-form/form' import { toType } from '@/app/components/tools/utils/to-form-schema' import { useStore } from '../../store' import Split from '../_base/components/split' import MemoryConfig from '../_base/components/memory-config' - const i18nPrefix = 'workflow.nodes.agent' export function strategyParamToCredientialForm(param: StrategyParamItem): CredentialFormSchema { @@ -42,41 +36,10 @@ const AgentPanel: FC> = (props) => { availableNodesWithParent, availableVars, readOnly, - isShowSingleRun, - hideSingleRun, - runningStatus, - handleRun, - handleStop, - runResult, - runInputData, - setRunInputData, - varInputs, outputSchema, handleMemoryChange, } = useConfig(props.id, props.data) const { t } = useTranslation() - const nodeInfo = useMemo(() => { - if (!runResult) - return - return formatTracing([runResult], t)[0] - }, [runResult, t]) - const logsParams = useLogs() - const singleRunForms = (() => { - const forms: FormProps[] = [] - - if (varInputs.length > 0) { - forms.push( - { - label: t(`${i18nPrefix}.singleRun.variable`)!, - inputs: varInputs, - values: runInputData, - onChange: setRunInputData, - }, - ) - } - - return forms - })() const resetEditor = useStore(s => s.setControlPromptEditorRerenderKey) @@ -154,21 +117,6 @@ const AgentPanel: FC> = (props) => { ))}
- { - isShowSingleRun && ( - } - /> - ) - }
} diff --git a/web/app/components/workflow/nodes/agent/use-config.ts b/web/app/components/workflow/nodes/agent/use-config.ts index 8196caa3f5..c3e07e4e60 100644 --- a/web/app/components/workflow/nodes/agent/use-config.ts +++ b/web/app/components/workflow/nodes/agent/use-config.ts @@ -1,7 +1,6 @@ import { useStrategyProviderDetail } from '@/service/use-strategy' import useNodeCrud from '../_base/hooks/use-node-crud' import useVarList from '../_base/hooks/use-var-list' -import useOneStepRun from '../_base/hooks/use-one-step-run' import type { AgentNodeType } from './types' import { useIsChatMode, @@ -131,35 +130,6 @@ const useConfig = (id: string, payload: AgentNodeType) => { }) // single run - const { - isShowSingleRun, - showSingleRun, - hideSingleRun, - toVarInputs, - runningStatus, - handleRun, - handleStop, - runInputData, - setRunInputData, - runResult, - getInputVars, - } = useOneStepRun({ - id, - data: inputs, - defaultRunInputData: {}, - }) - const allVarStrArr = (() => { - const arr = currentStrategy?.parameters.filter(item => item.type === 'string').map((item) => { - return formData[item.name] - }) || [] - - return arr - })() - const varInputs = (() => { - const vars = getInputVars(allVarStrArr) - - return vars - })() const outputSchema = useMemo(() => { const res: any[] = [] @@ -199,18 +169,6 @@ const useConfig = (id: string, payload: AgentNodeType) => { pluginDetail: pluginDetail.data?.plugins.at(0), availableVars, availableNodesWithParent, - - isShowSingleRun, - showSingleRun, - hideSingleRun, - toVarInputs, - runningStatus, - handleRun, - handleStop, - runInputData, - setRunInputData, - runResult, - varInputs, outputSchema, handleMemoryChange, isChatMode, diff --git a/web/app/components/workflow/nodes/agent/use-single-run-form-params.ts b/web/app/components/workflow/nodes/agent/use-single-run-form-params.ts new file mode 100644 index 0000000000..5ddc24b2f2 --- /dev/null +++ b/web/app/components/workflow/nodes/agent/use-single-run-form-params.ts @@ -0,0 +1,90 @@ +import type { MutableRefObject } from 'react' +import type { InputVar, Variable } from '@/app/components/workflow/types' +import { useMemo } from 'react' +import useNodeCrud from '../_base/hooks/use-node-crud' +import type { AgentNodeType } from './types' +import { useTranslation } from 'react-i18next' +import type { Props as FormProps } from '@/app/components/workflow/nodes/_base/components/before-run-form/form' +import { useStrategyInfo } from './use-config' +import type { NodeTracing } from '@/types/workflow' +import formatTracing from '@/app/components/workflow/run/utils/format-log' + +type Params = { + id: string, + payload: AgentNodeType, + runInputData: Record + runInputDataRef: MutableRefObject> + getInputVars: (textList: string[]) => InputVar[] + setRunInputData: (data: Record) => void + toVarInputs: (variables: Variable[]) => InputVar[] + runResult: NodeTracing +} +const useSingleRunFormParams = ({ + id, + payload, + runInputData, + getInputVars, + setRunInputData, + runResult, +}: Params) => { + const { t } = useTranslation() + const { inputs } = useNodeCrud(id, payload) + + const formData = useMemo(() => { + return Object.fromEntries( + Object.entries(inputs.agent_parameters || {}).map(([key, value]) => { + return [key, value.value] + }), + ) + }, [inputs.agent_parameters]) + + const { + strategy: currentStrategy, + } = useStrategyInfo( + inputs.agent_strategy_provider_name, + inputs.agent_strategy_name, + ) + + const allVarStrArr = (() => { + const arr = currentStrategy?.parameters.filter(item => item.type === 'string').map((item) => { + return formData[item.name] + }) || [] + return arr + })() + + const varInputs = getInputVars?.(allVarStrArr) + + const forms = useMemo(() => { + const forms: FormProps[] = [] + + if (varInputs!.length > 0) { + forms.push( + { + label: t('workflow.nodes.llm.singleRun.variable')!, + inputs: varInputs!, + values: runInputData, + onChange: setRunInputData, + }, + ) + } + return forms + }, [runInputData, setRunInputData, t, varInputs]) + + const nodeInfo = useMemo(() => { + if (!runResult) + return + return formatTracing([runResult], t)[0] + }, [runResult, t]) + + const getDependentVars = () => { + return varInputs.map(item => item.variable.slice(1, -1).split('.')) + } + + return { + forms, + nodeInfo, + getDependentVars, + } +} + +export default useSingleRunFormParams diff --git a/web/app/components/workflow/nodes/assigner/components/var-list/index.tsx b/web/app/components/workflow/nodes/assigner/components/var-list/index.tsx index f34a1435ad..b19d5903a6 100644 --- a/web/app/components/workflow/nodes/assigner/components/var-list/index.tsx +++ b/web/app/components/workflow/nodes/assigner/components/var-list/index.tsx @@ -52,6 +52,7 @@ const VarList: FC = ({ const newList = produce(list, (draft) => { draft[index].variable_selector = value as ValueSelector draft[index].operation = WriteMode.overwrite + draft[index].input_type = AssignerNodeInputType.variable draft[index].value = undefined }) onChange(newList, value as ValueSelector) diff --git a/web/app/components/workflow/nodes/assigner/types.ts b/web/app/components/workflow/nodes/assigner/types.ts index 85d2b2850f..22f37bb7cd 100644 --- a/web/app/components/workflow/nodes/assigner/types.ts +++ b/web/app/components/workflow/nodes/assigner/types.ts @@ -30,3 +30,5 @@ export type AssignerNodeType = CommonNodeType & { version?: '1' | '2' items: AssignerNodeOperation[] } + +export const writeModeTypesNum = [WriteMode.increment, WriteMode.decrement, WriteMode.multiply, WriteMode.divide] diff --git a/web/app/components/workflow/nodes/assigner/use-config.ts b/web/app/components/workflow/nodes/assigner/use-config.ts index cbd5475483..c42dd67b37 100644 --- a/web/app/components/workflow/nodes/assigner/use-config.ts +++ b/web/app/components/workflow/nodes/assigner/use-config.ts @@ -5,6 +5,7 @@ import { VarType } from '../../types' import type { ValueSelector, Var } from '../../types' import { WriteMode } from './types' import type { AssignerNodeOperation, AssignerNodeType } from './types' +import { writeModeTypesNum } from './types' import { useGetAvailableVars } from './hooks' import { convertV1ToV2 } from './utils' import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud' @@ -71,7 +72,6 @@ const useConfig = (id: string, rawPayload: AssignerNodeType) => { const writeModeTypesArr = [WriteMode.overwrite, WriteMode.clear, WriteMode.append, WriteMode.extend, WriteMode.removeFirst, WriteMode.removeLast] const writeModeTypes = [WriteMode.overwrite, WriteMode.clear, WriteMode.set] - const writeModeTypesNum = [WriteMode.increment, WriteMode.decrement, WriteMode.multiply, WriteMode.divide] const getToAssignedVarType = useCallback((assignedVarType: VarType, write_mode: WriteMode) => { if (write_mode === WriteMode.overwrite || write_mode === WriteMode.increment || write_mode === WriteMode.decrement diff --git a/web/app/components/workflow/nodes/assigner/use-single-run-form-params.ts b/web/app/components/workflow/nodes/assigner/use-single-run-form-params.ts new file mode 100644 index 0000000000..7ff31d91c7 --- /dev/null +++ b/web/app/components/workflow/nodes/assigner/use-single-run-form-params.ts @@ -0,0 +1,55 @@ +import type { MutableRefObject } from 'react' +import type { InputVar, ValueSelector, Variable } from '@/app/components/workflow/types' +import { useMemo } from 'react' +import useNodeCrud from '../_base/hooks/use-node-crud' +import { type AssignerNodeType, WriteMode } from './types' +import { writeModeTypesNum } from './types' + +type Params = { + id: string, + payload: AssignerNodeType, + runInputData: Record + runInputDataRef: MutableRefObject> + getInputVars: (textList: string[]) => InputVar[] + setRunInputData: (data: Record) => void + toVarInputs: (variables: Variable[]) => InputVar[] + varSelectorsToVarInputs: (variables: ValueSelector[]) => InputVar[] +} +const useSingleRunFormParams = ({ + id, + payload, + runInputData, + setRunInputData, + varSelectorsToVarInputs, +}: Params) => { + const { inputs } = useNodeCrud(id, payload) + + const vars = (inputs.items ?? []).filter((item) => { + return item.operation !== WriteMode.clear && item.operation !== WriteMode.set + && item.operation !== WriteMode.removeFirst && item.operation !== WriteMode.removeLast + && !writeModeTypesNum.includes(item.operation) + }).map(item => item.value as ValueSelector) + + const forms = useMemo(() => { + const varInputs = varSelectorsToVarInputs(vars) + + return [ + { + inputs: varInputs, + values: runInputData, + onChange: setRunInputData, + }, + ] + }, [runInputData, setRunInputData, varSelectorsToVarInputs, vars]) + + const getDependentVars = () => { + return vars + } + + return { + forms, + getDependentVars, + } +} + +export default useSingleRunFormParams diff --git a/web/app/components/workflow/nodes/code/panel.tsx b/web/app/components/workflow/nodes/code/panel.tsx index a0b7535f89..05d6cd7957 100644 --- a/web/app/components/workflow/nodes/code/panel.tsx +++ b/web/app/components/workflow/nodes/code/panel.tsx @@ -14,8 +14,6 @@ import Split from '@/app/components/workflow/nodes/_base/components/split' import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor' import TypeSelector from '@/app/components/workflow/nodes/_base/components/selector' import type { NodePanelProps } from '@/app/components/workflow/types' -import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form' -import ResultPanel from '@/app/components/workflow/run/result-panel' const i18nPrefix = 'workflow.nodes.code' const codeLanguages = [ @@ -50,16 +48,6 @@ const Panel: FC> = ({ isShowRemoveVarConfirm, hideRemoveVarConfirm, onRemoveVarConfirm, - // single run - isShowSingleRun, - hideSingleRun, - runningStatus, - handleRun, - handleStop, - runResult, - varInputs, - inputVarValues, - setInputVarValues, } = useConfig(id, data) const handleGeneratedCode = (value: string) => { @@ -128,25 +116,6 @@ const Panel: FC> = ({ />
- { - isShowSingleRun && ( - } - /> - ) - } { }) syncOutputKeyOrders(defaultConfig.outputs) } - // eslint-disable-next-line react-hooks/exhaustive-deps + // eslint-disable-next-line react-hooks/exhaustive-deps }, [defaultConfig]) const handleCodeChange = useCallback((code: string) => { @@ -104,38 +103,6 @@ const useConfig = (id: string, payload: CodeNodeType) => { return [VarType.string, VarType.number, VarType.secret, VarType.object, VarType.array, VarType.arrayNumber, VarType.arrayString, VarType.arrayObject, VarType.file, VarType.arrayFile].includes(varPayload.type) }, []) - // single run - const { - isShowSingleRun, - hideSingleRun, - toVarInputs, - runningStatus, - isCompleted, - handleRun, - handleStop, - runInputData, - setRunInputData, - runResult, - } = useOneStepRun({ - id, - data: inputs, - defaultRunInputData: {}, - }) - - const varInputs = toVarInputs(inputs.variables) - - const inputVarValues = (() => { - const vars: Record = {} - Object.keys(runInputData) - .forEach((key) => { - vars[key] = runInputData[key] - }) - return vars - })() - - const setInputVarValues = useCallback((newPayload: Record) => { - setRunInputData(newPayload) - }, [setRunInputData]) const handleCodeAndVarsChange = useCallback((code: string, inputVariables: Variable[], outputVariables: OutputVar) => { const newInputs = produce(inputs, (draft) => { draft.code = code @@ -160,17 +127,6 @@ const useConfig = (id: string, payload: CodeNodeType) => { isShowRemoveVarConfirm, hideRemoveVarConfirm, onRemoveVarConfirm, - // single run - isShowSingleRun, - hideSingleRun, - runningStatus, - isCompleted, - handleRun, - handleStop, - varInputs, - inputVarValues, - setInputVarValues, - runResult, handleCodeAndVarsChange, } } diff --git a/web/app/components/workflow/nodes/code/use-single-run-form-params.ts b/web/app/components/workflow/nodes/code/use-single-run-form-params.ts new file mode 100644 index 0000000000..9714e55fff --- /dev/null +++ b/web/app/components/workflow/nodes/code/use-single-run-form-params.ts @@ -0,0 +1,65 @@ +import type { MutableRefObject } from 'react' +import type { InputVar, Variable } from '@/app/components/workflow/types' +import { useCallback, useMemo } from 'react' +import useNodeCrud from '../_base/hooks/use-node-crud' +import type { CodeNodeType } from './types' + +type Params = { + id: string, + payload: CodeNodeType, + runInputData: Record + runInputDataRef: MutableRefObject> + getInputVars: (textList: string[]) => InputVar[] + setRunInputData: (data: Record) => void + toVarInputs: (variables: Variable[]) => InputVar[] +} +const useSingleRunFormParams = ({ + id, + payload, + runInputData, + toVarInputs, + setRunInputData, +}: Params) => { + const { inputs } = useNodeCrud(id, payload) + + const varInputs = toVarInputs(inputs.variables) + const setInputVarValues = useCallback((newPayload: Record) => { + setRunInputData(newPayload) + }, [setRunInputData]) + const inputVarValues = (() => { + const vars: Record = {} + Object.keys(runInputData) + .forEach((key) => { + vars[key] = runInputData[key] + }) + return vars + })() + + const forms = useMemo(() => { + return [ + { + inputs: varInputs, + values: inputVarValues, + onChange: setInputVarValues, + }, + ] + }, [inputVarValues, setInputVarValues, varInputs]) + + const getDependentVars = () => { + return payload.variables.map(v => v.value_selector) + } + + const getDependentVar = (variable: string) => { + const varItem = payload.variables.find(v => v.variable === variable) + if (varItem) + return varItem.value_selector + } + + return { + forms, + getDependentVars, + getDependentVar, + } +} + +export default useSingleRunFormParams diff --git a/web/app/components/workflow/nodes/document-extractor/panel.tsx b/web/app/components/workflow/nodes/document-extractor/panel.tsx index 5ed1425778..a91608c717 100644 --- a/web/app/components/workflow/nodes/document-extractor/panel.tsx +++ b/web/app/components/workflow/nodes/document-extractor/panel.tsx @@ -11,11 +11,9 @@ import useConfig from './use-config' import type { DocExtractorNodeType } from './types' import { fetchSupportFileTypes } from '@/service/datasets' import Field from '@/app/components/workflow/nodes/_base/components/field' -import { BlockEnum, InputVarType, type NodePanelProps } from '@/app/components/workflow/types' +import { BlockEnum, type NodePanelProps } from '@/app/components/workflow/types' import I18n from '@/context/i18n' import { LanguagesSupported } from '@/i18n/language' -import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form' -import ResultPanel from '@/app/components/workflow/run/result-panel' const i18nPrefix = 'workflow.nodes.docExtractor' @@ -48,15 +46,6 @@ const Panel: FC> = ({ inputs, handleVarChanges, filterVar, - // single run - isShowSingleRun, - hideSingleRun, - runningStatus, - handleRun, - handleStop, - runResult, - files, - setFiles, } = useConfig(id, data) return ( @@ -93,30 +82,6 @@ const Panel: FC> = ({ />
- { - isShowSingleRun && ( - setFiles(keyValue.files), - }, - ]} - runningStatus={runningStatus} - onRun={handleRun} - onStop={handleStop} - result={} - /> - ) - }
) } diff --git a/web/app/components/workflow/nodes/document-extractor/use-config.ts b/web/app/components/workflow/nodes/document-extractor/use-config.ts index 8ceb153874..43f3e71fa2 100644 --- a/web/app/components/workflow/nodes/document-extractor/use-config.ts +++ b/web/app/components/workflow/nodes/document-extractor/use-config.ts @@ -1,12 +1,10 @@ import { useCallback, useMemo } from 'react' import produce from 'immer' import { useStoreApi } from 'reactflow' - import type { ValueSelector, Var } from '../../types' -import { InputVarType, VarType } from '../../types' +import { VarType } from '../../types' import type { DocExtractorNodeType } from './types' import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud' -import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run' import { useIsChatMode, useNodesReadOnly, @@ -58,53 +56,11 @@ const useConfig = (id: string, payload: DocExtractorNodeType) => { setInputs(newInputs) }, [getType, inputs, setInputs]) - // single run - const { - isShowSingleRun, - hideSingleRun, - runningStatus, - isCompleted, - handleRun, - handleStop, - runInputData, - setRunInputData, - runResult, - } = useOneStepRun({ - id, - data: inputs, - defaultRunInputData: { files: [] }, - }) - const varInputs = [{ - label: inputs.title, - variable: 'files', - type: InputVarType.multiFiles, - required: true, - }] - - const files = runInputData.files - const setFiles = useCallback((newFiles: []) => { - setRunInputData({ - ...runInputData, - files: newFiles, - }) - }, [runInputData, setRunInputData]) - return { readOnly, inputs, filterVar, handleVarChanges, - // single run - isShowSingleRun, - hideSingleRun, - runningStatus, - isCompleted, - handleRun, - handleStop, - varInputs, - files, - setFiles, - runResult, } } diff --git a/web/app/components/workflow/nodes/document-extractor/use-single-run-form-params.ts b/web/app/components/workflow/nodes/document-extractor/use-single-run-form-params.ts new file mode 100644 index 0000000000..3b249cd210 --- /dev/null +++ b/web/app/components/workflow/nodes/document-extractor/use-single-run-form-params.ts @@ -0,0 +1,64 @@ +import type { MutableRefObject } from 'react' +import type { InputVar, Variable } from '@/app/components/workflow/types' +import { useCallback, useMemo } from 'react' +import type { DocExtractorNodeType } from './types' +import { useTranslation } from 'react-i18next' +import { InputVarType } from '@/app/components/workflow/types' + +const i18nPrefix = 'workflow.nodes.docExtractor' + +type Params = { + id: string, + payload: DocExtractorNodeType, + runInputData: Record + runInputDataRef: MutableRefObject> + getInputVars: (textList: string[]) => InputVar[] + setRunInputData: (data: Record) => void + toVarInputs: (variables: Variable[]) => InputVar[] +} +const useSingleRunFormParams = ({ + payload, + runInputData, + setRunInputData, +}: Params) => { + const { t } = useTranslation() + const files = runInputData.files + const setFiles = useCallback((newFiles: []) => { + setRunInputData({ + ...runInputData, + files: newFiles, + }) + }, [runInputData, setRunInputData]) + + const forms = useMemo(() => { + return [ + { + inputs: [{ + label: t(`${i18nPrefix}.inputVar`)!, + variable: 'files', + type: InputVarType.multiFiles, + required: true, + }], + values: { files }, + onChange: (keyValue: Record) => setFiles(keyValue.files), + }, + ] + }, [files, setFiles, t]) + + const getDependentVars = () => { + return [payload.variable_selector] + } + + const getDependentVar = (variable: string) => { + if(variable === 'files') + return payload.variable_selector + } + + return { + forms, + getDependentVars, + getDependentVar, + } +} + +export default useSingleRunFormParams diff --git a/web/app/components/workflow/nodes/http/panel.tsx b/web/app/components/workflow/nodes/http/panel.tsx index 60f3de81c0..9a07c0ad61 100644 --- a/web/app/components/workflow/nodes/http/panel.tsx +++ b/web/app/components/workflow/nodes/http/panel.tsx @@ -16,8 +16,6 @@ import OutputVars, { VarItem } from '@/app/components/workflow/nodes/_base/compo import { Settings01 } from '@/app/components/base/icons/src/vender/line/general' import { FileArrow01 } from '@/app/components/base/icons/src/vender/line/files' import type { NodePanelProps } from '@/app/components/workflow/types' -import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form' -import ResultPanel from '@/app/components/workflow/run/result-panel' const i18nPrefix = 'workflow.nodes.http' @@ -45,16 +43,6 @@ const Panel: FC> = ({ hideAuthorization, setAuthorization, setTimeout, - // single run - isShowSingleRun, - hideSingleRun, - runningStatus, - handleRun, - handleStop, - varInputs, - inputVarValues, - setInputVarValues, - runResult, isShowCurlPanel, showCurlPanel, hideCurlPanel, @@ -180,24 +168,6 @@ const Panel: FC> = ({
- {isShowSingleRun && ( - } - /> - )} {(isShowCurlPanel && !readOnly) && ( { return [VarType.string, VarType.number, VarType.secret].includes(varPayload.type) }, []) - // single run - const { - isShowSingleRun, - hideSingleRun, - getInputVars, - runningStatus, - handleRun, - handleStop, - runInputData, - setRunInputData, - runResult, - } = useOneStepRun({ - id, - data: inputs, - defaultRunInputData: {}, - }) - - const fileVarInputs = useMemo(() => { - if (!Array.isArray(inputs.body.data)) - return '' - - const res = inputs.body.data - .filter(item => item.file?.length) - .map(item => item.file ? `{{#${item.file.join('.')}#}}` : '') - .join(' ') - return res - }, [inputs.body.data]) - - const varInputs = getInputVars([ - inputs.url, - inputs.headers, - inputs.params, - typeof inputs.body.data === 'string' ? inputs.body.data : inputs.body.data?.map(item => item.value).join(''), - fileVarInputs, - ]) - - const inputVarValues = (() => { - const vars: Record = {} - Object.keys(runInputData) - .forEach((key) => { - vars[key] = runInputData[key] - }) - return vars - })() - - const setInputVarValues = useCallback((newPayload: Record) => { - setRunInputData(newPayload) - }, [setRunInputData]) - // curl import panel const [isShowCurlPanel, { setTrue: showCurlPanel, @@ -220,16 +170,6 @@ const useConfig = (id: string, payload: HttpNodeType) => { hideAuthorization, setAuthorization, setTimeout, - // single run - isShowSingleRun, - hideSingleRun, - runningStatus, - handleRun, - handleStop, - varInputs, - inputVarValues, - setInputVarValues, - runResult, // curl import isShowCurlPanel, showCurlPanel, diff --git a/web/app/components/workflow/nodes/http/use-single-run-form-params.ts b/web/app/components/workflow/nodes/http/use-single-run-form-params.ts new file mode 100644 index 0000000000..c5d65634c4 --- /dev/null +++ b/web/app/components/workflow/nodes/http/use-single-run-form-params.ts @@ -0,0 +1,74 @@ +import type { MutableRefObject } from 'react' +import type { InputVar, Variable } from '@/app/components/workflow/types' +import { useCallback, useMemo } from 'react' +import useNodeCrud from '../_base/hooks/use-node-crud' +import type { HttpNodeType } from './types' + +type Params = { + id: string, + payload: HttpNodeType, + runInputData: Record + runInputDataRef: MutableRefObject> + getInputVars: (textList: string[]) => InputVar[] + setRunInputData: (data: Record) => void + toVarInputs: (variables: Variable[]) => InputVar[] +} +const useSingleRunFormParams = ({ + id, + payload, + runInputData, + getInputVars, + setRunInputData, +}: Params) => { + const { inputs } = useNodeCrud(id, payload) + + const fileVarInputs = useMemo(() => { + if (!Array.isArray(inputs.body.data)) + return '' + + const res = inputs.body.data + .filter(item => item.file?.length) + .map(item => item.file ? `{{#${item.file.join('.')}#}}` : '') + .join(' ') + return res + }, [inputs.body.data]) + const varInputs = getInputVars([ + inputs.url, + inputs.headers, + inputs.params, + typeof inputs.body.data === 'string' ? inputs.body.data : inputs.body.data?.map(item => item.value).join(''), + fileVarInputs, + ]) + const setInputVarValues = useCallback((newPayload: Record) => { + setRunInputData(newPayload) + }, [setRunInputData]) + const inputVarValues = (() => { + const vars: Record = {} + Object.keys(runInputData) + .forEach((key) => { + vars[key] = runInputData[key] + }) + return vars + })() + + const forms = useMemo(() => { + return [ + { + inputs: varInputs, + values: inputVarValues, + onChange: setInputVarValues, + }, + ] + }, [inputVarValues, setInputVarValues, varInputs]) + + const getDependentVars = () => { + return varInputs.map(item => item.variable.slice(1, -1).split('.')) + } + + return { + forms, + getDependentVars, + } +} + +export default useSingleRunFormParams diff --git a/web/app/components/workflow/nodes/if-else/components/condition-list/condition-operator.tsx b/web/app/components/workflow/nodes/if-else/components/condition-list/condition-operator.tsx index 9036e04d3b..a2b3cb7589 100644 --- a/web/app/components/workflow/nodes/if-else/components/condition-list/condition-operator.tsx +++ b/web/app/components/workflow/nodes/if-else/components/condition-list/condition-operator.tsx @@ -69,7 +69,7 @@ const ConditionOperator = ({ - +
{ options.map(option => ( diff --git a/web/app/components/workflow/nodes/if-else/use-single-run-form-params.ts b/web/app/components/workflow/nodes/if-else/use-single-run-form-params.ts new file mode 100644 index 0000000000..f61f2846c3 --- /dev/null +++ b/web/app/components/workflow/nodes/if-else/use-single-run-form-params.ts @@ -0,0 +1,166 @@ +import type { MutableRefObject } from 'react' +import type { InputVar, ValueSelector, Variable } from '@/app/components/workflow/types' +import { useCallback } from 'react' +import type { CaseItem, Condition, IfElseNodeType } from './types' + +type Params = { + id: string, + payload: IfElseNodeType, + runInputData: Record + runInputDataRef: MutableRefObject> + getInputVars: (textList: string[]) => InputVar[] + setRunInputData: (data: Record) => void + toVarInputs: (variables: Variable[]) => InputVar[] + varSelectorsToVarInputs: (variables: ValueSelector[]) => InputVar[] +} +const useSingleRunFormParams = ({ + payload, + runInputData, + setRunInputData, + getInputVars, + varSelectorsToVarInputs, +}: Params) => { + const setInputVarValues = useCallback((newPayload: Record) => { + setRunInputData(newPayload) + }, [setRunInputData]) + const inputVarValues = (() => { + const vars: Record = {} + Object.keys(runInputData) + .forEach((key) => { + vars[key] = runInputData[key] + }) + return vars + })() + + const getVarSelectorsFromCase = (caseItem: CaseItem): ValueSelector[] => { + const vars: ValueSelector[] = [] + if (caseItem.conditions && caseItem.conditions.length) { + caseItem.conditions.forEach((condition) => { + // eslint-disable-next-line ts/no-use-before-define + const conditionVars = getVarSelectorsFromCondition(condition) + vars.push(...conditionVars) + }) + } + return vars + } + + const getVarSelectorsFromCondition = (condition: Condition) => { + const vars: ValueSelector[] = [] + if (condition.variable_selector) + vars.push(condition.variable_selector) + + if (condition.sub_variable_condition && condition.sub_variable_condition.conditions?.length) + vars.push(...getVarSelectorsFromCase(condition.sub_variable_condition)) + return vars + } + + const getInputVarsFromCase = (caseItem: CaseItem): InputVar[] => { + const vars: InputVar[] = [] + if (caseItem.conditions && caseItem.conditions.length) { + caseItem.conditions.forEach((condition) => { + // eslint-disable-next-line ts/no-use-before-define + const conditionVars = getInputVarsFromConditionValue(condition) + vars.push(...conditionVars) + }) + } + return vars + } + + const getInputVarsFromConditionValue = (condition: Condition): InputVar[] => { + const vars: InputVar[] = [] + if (condition.value && typeof condition.value === 'string') { + const inputVars = getInputVars([condition.value]) + vars.push(...inputVars) + } + + if (condition.sub_variable_condition && condition.sub_variable_condition.conditions?.length) + vars.push(...getInputVarsFromCase(condition.sub_variable_condition)) + + return vars + } + + const forms = (() => { + const allInputs: ValueSelector[] = [] + const inputVarsFromValue: InputVar[] = [] + if (payload.cases && payload.cases.length) { + payload.cases.forEach((caseItem) => { + const caseVars = getVarSelectorsFromCase(caseItem) + allInputs.push(...caseVars) + inputVarsFromValue.push(...getInputVarsFromCase(caseItem)) + }) + } + + if (payload.conditions && payload.conditions.length) { + payload.conditions.forEach((condition) => { + const conditionVars = getVarSelectorsFromCondition(condition) + allInputs.push(...conditionVars) + inputVarsFromValue.push(...getInputVarsFromConditionValue(condition)) + }) + } + + const varInputs = [...varSelectorsToVarInputs(allInputs), ...inputVarsFromValue] + // remove duplicate inputs + const existVarsKey: Record = {} + const uniqueVarInputs: InputVar[] = [] + varInputs.forEach((input) => { + if(!input) + return + if (!existVarsKey[input.variable]) { + existVarsKey[input.variable] = true + uniqueVarInputs.push(input) + } + }) + return [ + { + inputs: uniqueVarInputs, + values: inputVarValues, + onChange: setInputVarValues, + }, + ] + })() + + const getVarFromCaseItem = (caseItem: CaseItem): ValueSelector[] => { + const vars: ValueSelector[] = [] + if (caseItem.conditions && caseItem.conditions.length) { + caseItem.conditions.forEach((condition) => { + // eslint-disable-next-line ts/no-use-before-define + const conditionVars = getVarFromCondition(condition) + vars.push(...conditionVars) + }) + } + return vars + } + const getVarFromCondition = (condition: Condition): ValueSelector[] => { + const vars: ValueSelector[] = [] + if (condition.variable_selector) + vars.push(condition.variable_selector) + + if(condition.sub_variable_condition && condition.sub_variable_condition.conditions?.length) + vars.push(...getVarFromCaseItem(condition.sub_variable_condition)) + return vars + } + + const getDependentVars = () => { + const vars: ValueSelector[] = [] + if (payload.cases && payload.cases.length) { + payload.cases.forEach((caseItem) => { + const caseVars = getVarFromCaseItem(caseItem) + vars.push(...caseVars) + }) + } + + if (payload.conditions && payload.conditions.length) { + payload.conditions.forEach((condition) => { + const conditionVars = getVarFromCondition(condition) + vars.push(...conditionVars) + }) + } + return vars + } + return { + forms, + getDependentVars, + } +} + +export default useSingleRunFormParams diff --git a/web/app/components/workflow/nodes/index.tsx b/web/app/components/workflow/nodes/index.tsx index bebc140414..d120ed8d37 100644 --- a/web/app/components/workflow/nodes/index.tsx +++ b/web/app/components/workflow/nodes/index.tsx @@ -10,7 +10,7 @@ import { PanelComponentMap, } from './constants' import BaseNode from './_base/node' -import BasePanel from './_base/panel' +import BasePanel from './_base/components/workflow-panel' const CustomNode = (props: NodeProps) => { const nodeData = props.data @@ -18,7 +18,7 @@ const CustomNode = (props: NodeProps) => { return ( <> - + diff --git a/web/app/components/workflow/nodes/iteration/panel.tsx b/web/app/components/workflow/nodes/iteration/panel.tsx index 1f29a07946..4b529f0785 100644 --- a/web/app/components/workflow/nodes/iteration/panel.tsx +++ b/web/app/components/workflow/nodes/iteration/panel.tsx @@ -3,20 +3,15 @@ import React from 'react' import { useTranslation } from 'react-i18next' import VarReferencePicker from '../_base/components/variable/var-reference-picker' import Split from '../_base/components/split' -import ResultPanel from '../../run/result-panel' import { MAX_ITERATION_PARALLEL_NUM, MIN_ITERATION_PARALLEL_NUM } from '../../constants' import type { IterationNodeType } from './types' import useConfig from './use-config' -import { ErrorHandleMode, InputVarType, type NodePanelProps } from '@/app/components/workflow/types' +import { ErrorHandleMode, type NodePanelProps } from '@/app/components/workflow/types' import Field from '@/app/components/workflow/nodes/_base/components/field' -import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form' import Switch from '@/app/components/base/switch' import Select from '@/app/components/base/select' import Slider from '@/app/components/base/slider' import Input from '@/app/components/base/input' -import formatTracing from '@/app/components/workflow/run/utils/format-log' - -import { useLogs } from '@/app/components/workflow/run/hooks' const i18nPrefix = 'workflow.nodes.iteration' @@ -47,27 +42,11 @@ const Panel: FC> = ({ childrenNodeVars, iterationChildrenNodes, handleOutputVarChange, - isShowSingleRun, - hideSingleRun, - runningStatus, - handleRun, - handleStop, - runResult, - inputVarValues, - setInputVarValues, - usedOutVars, - iterator, - setIterator, - iteratorInputKey, - iterationRunResult, changeParallel, changeErrorResponseMode, changeParallelNums, } = useConfig(id, data) - const nodeInfo = formatTracing(iterationRunResult, t)[0] - const logsParams = useLogs() - return (
@@ -137,38 +116,6 @@ const Panel: FC> = ({
*/} - {isShowSingleRun && ( - - } - /> - )}
) } diff --git a/web/app/components/workflow/nodes/loop/use-config.ts b/web/app/components/workflow/nodes/loop/use-config.ts index fbd350c229..965fe2b395 100644 --- a/web/app/components/workflow/nodes/loop/use-config.ts +++ b/web/app/components/workflow/nodes/loop/use-config.ts @@ -3,7 +3,6 @@ import { useRef, } from 'react' import produce from 'immer' -import { useBoolean } from 'ahooks' import { v4 as uuid4 } from 'uuid' import { useIsChatMode, @@ -12,10 +11,9 @@ import { useWorkflow, } from '../../hooks' import { ValueType, VarType } from '../../types' -import type { ErrorHandleMode, ValueSelector, Var } from '../../types' +import type { ErrorHandleMode, Var } from '../../types' import useNodeCrud from '../_base/hooks/use-node-crud' -import { getNodeInfoById, getNodeUsedVarPassToServerKey, getNodeUsedVars, isSystemVar, toNodeOutputVars } from '../_base/components/variable/utils' -import useOneStepRun from '../_base/hooks/use-one-step-run' +import { toNodeOutputVars } from '../_base/components/variable/utils' import { getOperators } from './utils' import { LogicalOperator } from './types' import type { HandleAddCondition, HandleAddSubVariableCondition, HandleRemoveCondition, HandleToggleConditionLogicalOperator, HandleToggleSubVariableConditionLogicalOperator, HandleUpdateCondition, HandleUpdateSubVariableCondition, LoopNodeType } from './types' @@ -47,140 +45,12 @@ const useConfig = (id: string, payload: LoopNodeType) => { const canChooseVarNodes = [...beforeNodes, ...loopChildrenNodes] const childrenNodeVars = toNodeOutputVars(loopChildrenNodes, isChatMode, undefined, [], conversationVariables) - // single run - const loopInputKey = `${id}.input_selector` - const { - isShowSingleRun, - showSingleRun, - hideSingleRun, - toVarInputs, - runningStatus, - handleRun: doHandleRun, - handleStop, - runInputData, - setRunInputData, - runResult, - loopRunResult, - } = useOneStepRun({ - id, - data: inputs, - loopInputKey, - defaultRunInputData: { - [loopInputKey]: [''], - }, - }) - - const [isShowLoopDetail, { - setTrue: doShowLoopDetail, - setFalse: doHideLoopDetail, - }] = useBoolean(false) - - const hideLoopDetail = useCallback(() => { - hideSingleRun() - doHideLoopDetail() - }, [doHideLoopDetail, hideSingleRun]) - - const showLoopDetail = useCallback(() => { - doShowLoopDetail() - }, [doShowLoopDetail]) - - const backToSingleRun = useCallback(() => { - hideLoopDetail() - showSingleRun() - }, [hideLoopDetail, showSingleRun]) - const { getIsVarFileAttribute, } = useIsVarFileAttribute({ nodeId: id, }) - const { usedOutVars, allVarObject } = (() => { - const vars: ValueSelector[] = [] - const varObjs: Record = {} - const allVarObject: Record = {} - loopChildrenNodes.forEach((node) => { - const nodeVars = getNodeUsedVars(node).filter(item => item && item.length > 0) - nodeVars.forEach((varSelector) => { - if (varSelector[0] === id) { // skip Loop node itself variable: item, index - return - } - const isInLoop = isNodeInLoop(varSelector[0]) - if (isInLoop) // not pass loop inner variable - return - - const varSectorStr = varSelector.join('.') - if (!varObjs[varSectorStr]) { - varObjs[varSectorStr] = true - vars.push(varSelector) - } - let passToServerKeys = getNodeUsedVarPassToServerKey(node, varSelector) - if (typeof passToServerKeys === 'string') - passToServerKeys = [passToServerKeys] - - passToServerKeys.forEach((key: string, index: number) => { - allVarObject[[varSectorStr, node.id, index].join(DELIMITER)] = { - inSingleRunPassedKey: key, - } - }) - }) - }) - const res = toVarInputs(vars.map((item) => { - const varInfo = getNodeInfoById(canChooseVarNodes, item[0]) - return { - label: { - nodeType: varInfo?.data.type, - nodeName: varInfo?.data.title || canChooseVarNodes[0]?.data.title, // default start node title - variable: isSystemVar(item) ? item.join('.') : item[item.length - 1], - }, - variable: `${item.join('.')}`, - value_selector: item, - } - })) - return { - usedOutVars: res, - allVarObject, - } - })() - - const handleRun = useCallback((data: Record) => { - const formattedData: Record = {} - Object.keys(allVarObject).forEach((key) => { - const [varSectorStr, nodeId] = key.split(DELIMITER) - formattedData[`${nodeId}.${allVarObject[key].inSingleRunPassedKey}`] = data[varSectorStr] - }) - formattedData[loopInputKey] = data[loopInputKey] - doHandleRun(formattedData) - }, [allVarObject, doHandleRun, loopInputKey]) - - const inputVarValues = (() => { - const vars: Record = {} - Object.keys(runInputData) - .filter(key => ![loopInputKey].includes(key)) - .forEach((key) => { - vars[key] = runInputData[key] - }) - return vars - })() - - const setInputVarValues = useCallback((newPayload: Record) => { - const newVars = { - ...newPayload, - [loopInputKey]: runInputData[loopInputKey], - } - setRunInputData(newVars) - }, [loopInputKey, runInputData, setRunInputData]) - - const loop = runInputData[loopInputKey] - const setLoop = useCallback((newLoop: string[]) => { - setRunInputData({ - ...runInputData, - [loopInputKey]: newLoop, - }) - }, [loopInputKey, runInputData, setRunInputData]) - const changeErrorResponseMode = useCallback((item: { value: unknown }) => { const newInputs = produce(inputs, (draft) => { draft.error_handle_mode = item.value as ErrorHandleMode @@ -342,24 +212,6 @@ const useConfig = (id: string, payload: LoopNodeType) => { filterInputVar, childrenNodeVars, loopChildrenNodes, - isShowSingleRun, - showSingleRun, - hideSingleRun, - isShowLoopDetail, - showLoopDetail, - hideLoopDetail, - backToSingleRun, - runningStatus, - handleRun, - handleStop, - runResult, - inputVarValues, - setInputVarValues, - usedOutVars, - loop, - setLoop, - loopInputKey, - loopRunResult, handleAddCondition, handleRemoveCondition, handleUpdateCondition, diff --git a/web/app/components/workflow/nodes/loop/use-single-run-form-params.ts b/web/app/components/workflow/nodes/loop/use-single-run-form-params.ts new file mode 100644 index 0000000000..394ab9b16f --- /dev/null +++ b/web/app/components/workflow/nodes/loop/use-single-run-form-params.ts @@ -0,0 +1,221 @@ +import type { NodeTracing } from '@/types/workflow' +import { useCallback, useMemo } from 'react' +import formatTracing from '@/app/components/workflow/run/utils/format-log' +import { useTranslation } from 'react-i18next' +import { useIsNodeInLoop, useWorkflow } from '../../hooks' +import { getNodeInfoById, getNodeUsedVarPassToServerKey, getNodeUsedVars, isSystemVar } from '../_base/components/variable/utils' +import type { InputVar, ValueSelector, Variable } from '../../types' +import type { CaseItem, Condition, LoopNodeType } from './types' +import { ValueType } from '@/app/components/workflow/types' +import { VALUE_SELECTOR_DELIMITER as DELIMITER } from '@/config' + +type Params = { + id: string + payload: LoopNodeType + runInputData: Record + runResult: NodeTracing + loopRunResult: NodeTracing[] + setRunInputData: (data: Record) => void + toVarInputs: (variables: Variable[]) => InputVar[] + varSelectorsToVarInputs: (variables: ValueSelector[]) => InputVar[] +} + +const useSingleRunFormParams = ({ + id, + payload, + runInputData, + runResult, + loopRunResult, + setRunInputData, + toVarInputs, + varSelectorsToVarInputs, +}: Params) => { + const { t } = useTranslation() + + const { isNodeInLoop } = useIsNodeInLoop(id) + + const { getLoopNodeChildren, getBeforeNodesInSameBranch } = useWorkflow() + const loopChildrenNodes = getLoopNodeChildren(id) + const beforeNodes = getBeforeNodesInSameBranch(id) + const canChooseVarNodes = [...beforeNodes, ...loopChildrenNodes] + + const { usedOutVars, allVarObject } = (() => { + const vars: ValueSelector[] = [] + const varObjs: Record = {} + const allVarObject: Record = {} + loopChildrenNodes.forEach((node) => { + const nodeVars = getNodeUsedVars(node).filter(item => item && item.length > 0) + nodeVars.forEach((varSelector) => { + if (varSelector[0] === id) { // skip loop node itself variable: item, index + return + } + const isInLoop = isNodeInLoop(varSelector[0]) + if (isInLoop) // not pass loop inner variable + return + + const varSectorStr = varSelector.join('.') + if (!varObjs[varSectorStr]) { + varObjs[varSectorStr] = true + vars.push(varSelector) + } + let passToServerKeys = getNodeUsedVarPassToServerKey(node, varSelector) + if (typeof passToServerKeys === 'string') + passToServerKeys = [passToServerKeys] + + passToServerKeys.forEach((key: string, index: number) => { + allVarObject[[varSectorStr, node.id, index].join(DELIMITER)] = { + inSingleRunPassedKey: key, + } + }) + }) + }) + + const res = toVarInputs(vars.map((item) => { + const varInfo = getNodeInfoById(canChooseVarNodes, item[0]) + return { + label: { + nodeType: varInfo?.data.type, + nodeName: varInfo?.data.title || canChooseVarNodes[0]?.data.title, // default start node title + variable: isSystemVar(item) ? item.join('.') : item[item.length - 1], + }, + variable: `${item.join('.')}`, + value_selector: item, + } + })) + return { + usedOutVars: res, + allVarObject, + } + })() + + const nodeInfo = useMemo(() => { + const formattedNodeInfo = formatTracing(loopRunResult, t)[0] + + if (runResult && formattedNodeInfo) { + return { + ...formattedNodeInfo, + execution_metadata: { + ...runResult.execution_metadata, + ...formattedNodeInfo.execution_metadata, + }, + } + } + + return formattedNodeInfo + }, [runResult, loopRunResult, t]) + + const setInputVarValues = useCallback((newPayload: Record) => { + setRunInputData(newPayload) + }, [setRunInputData]) + + const inputVarValues = (() => { + const vars: Record = {} + Object.keys(runInputData) + .forEach((key) => { + vars[key] = runInputData[key] + }) + return vars + })() + + const getVarSelectorsFromCase = (caseItem: CaseItem): ValueSelector[] => { + const vars: ValueSelector[] = [] + if (caseItem.conditions && caseItem.conditions.length) { + caseItem.conditions.forEach((condition) => { + // eslint-disable-next-line ts/no-use-before-define + const conditionVars = getVarSelectorsFromCondition(condition) + vars.push(...conditionVars) + }) + } + return vars + } + + const getVarSelectorsFromCondition = (condition: Condition) => { + const vars: ValueSelector[] = [] + if (condition.variable_selector) + vars.push(condition.variable_selector) + + if (condition.sub_variable_condition && condition.sub_variable_condition.conditions?.length) + vars.push(...getVarSelectorsFromCase(condition.sub_variable_condition)) + return vars + } + + const forms = (() => { + const allInputs: ValueSelector[] = [] + payload.break_conditions?.forEach((condition) => { + const vars = getVarSelectorsFromCondition(condition) + allInputs.push(...vars) + }) + + payload.loop_variables?.forEach((loopVariable) => { + if(loopVariable.value_type === ValueType.variable) + allInputs.push(loopVariable.value) + }) + const inputVarsFromValue: InputVar[] = [] + const varInputs = [...varSelectorsToVarInputs(allInputs), ...inputVarsFromValue] + + const existVarsKey: Record = {} + const uniqueVarInputs: InputVar[] = [] + varInputs.forEach((input) => { + if(!input) + return + if (!existVarsKey[input.variable]) { + existVarsKey[input.variable] = true + uniqueVarInputs.push(input) + } + }) + return [ + { + inputs: [...usedOutVars, ...uniqueVarInputs], + values: inputVarValues, + onChange: setInputVarValues, + }, + ] + })() + + const getVarFromCaseItem = (caseItem: CaseItem): ValueSelector[] => { + const vars: ValueSelector[] = [] + if (caseItem.conditions && caseItem.conditions.length) { + caseItem.conditions.forEach((condition) => { + // eslint-disable-next-line ts/no-use-before-define + const conditionVars = getVarFromCondition(condition) + vars.push(...conditionVars) + }) + } + return vars + } + + const getVarFromCondition = (condition: Condition): ValueSelector[] => { + const vars: ValueSelector[] = [] + if (condition.variable_selector) + vars.push(condition.variable_selector) + + if(condition.sub_variable_condition && condition.sub_variable_condition.conditions?.length) + vars.push(...getVarFromCaseItem(condition.sub_variable_condition)) + return vars + } + + const getDependentVars = () => { + const vars: ValueSelector[] = usedOutVars.map(item => item.variable.split('.')) + payload.break_conditions?.forEach((condition) => { + const conditionVars = getVarFromCondition(condition) + vars.push(...conditionVars) + }) + payload.loop_variables?.forEach((loopVariable) => { + if(loopVariable.value_type === ValueType.variable) + vars.push(loopVariable.value) + }) + const hasFilterLoopVars = vars.filter(item => item[0] !== id) + return hasFilterLoopVars + } + + return { + forms, + nodeInfo, + allVarObject, + getDependentVars, + } +} + +export default useSingleRunFormParams diff --git a/web/app/components/workflow/nodes/parameter-extractor/panel.tsx b/web/app/components/workflow/nodes/parameter-extractor/panel.tsx index d03f1d9ff3..e86a2e3764 100644 --- a/web/app/components/workflow/nodes/parameter-extractor/panel.tsx +++ b/web/app/components/workflow/nodes/parameter-extractor/panel.tsx @@ -4,9 +4,7 @@ import { useTranslation } from 'react-i18next' import MemoryConfig from '../_base/components/memory-config' import VarReferencePicker from '../_base/components/variable/var-reference-picker' import Editor from '../_base/components/prompt/editor' -import ResultPanel from '../../run/result-panel' import ConfigVision from '../_base/components/config-vision' -import { findVariableWhenOnLLMVision } from '../utils' import useConfig from './use-config' import type { ParameterExtractorNodeType } from './types' import ExtractParameter from './components/extract-parameter/list' @@ -17,12 +15,10 @@ import Field from '@/app/components/workflow/nodes/_base/components/field' import Split from '@/app/components/workflow/nodes/_base/components/split' import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' import OutputVars, { VarItem } from '@/app/components/workflow/nodes/_base/components/output-vars' -import { InputVarType, type NodePanelProps } from '@/app/components/workflow/types' +import type { NodePanelProps } from '@/app/components/workflow/types' import Tooltip from '@/app/components/base/tooltip' -import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form' import { VarType } from '@/app/components/workflow/types' import { FieldCollapse } from '@/app/components/workflow/nodes/_base/components/collapse' -import type { Props as FormProps } from '@/app/components/workflow/nodes/_base/components/before-run-form/form' const i18nPrefix = 'workflow.nodes.parameterExtractor' const i18nCommonPrefix = 'workflow.common' @@ -53,63 +49,13 @@ const Panel: FC> = ({ handleReasoningModeChange, availableVars, availableNodesWithParent, - availableVisionVars, - inputVarValues, - varInputs, isVisionModel, handleVisionResolutionChange, handleVisionResolutionEnabledChange, - isShowSingleRun, - hideSingleRun, - runningStatus, - handleRun, - handleStop, - runResult, - setInputVarValues, - visionFiles, - setVisionFiles, } = useConfig(id, data) const model = inputs.model - const singleRunForms = (() => { - const forms: FormProps[] = [] - - forms.push( - { - label: t('workflow.nodes.llm.singleRun.variable')!, - inputs: [{ - label: t(`${i18nPrefix}.inputVar`)!, - variable: 'query', - type: InputVarType.paragraph, - required: true, - }, ...varInputs], - values: inputVarValues, - onChange: setInputVarValues, - }, - ) - - if (isVisionModel && data.vision?.enabled && data.vision?.configs?.variable_selector) { - const currentVariable = findVariableWhenOnLLMVision(data.vision.configs.variable_selector, availableVisionVars) - - forms.push( - { - label: t('workflow.nodes.llm.vision')!, - inputs: [{ - label: currentVariable?.variable as any, - variable: '#files#', - type: currentVariable?.formType as any, - required: false, - }], - values: { '#files#': visionFiles }, - onChange: keyValue => setVisionFiles((keyValue as any)['#files#']), - }, - ) - } - - return forms - })() - return (
@@ -255,17 +201,6 @@ const Panel: FC> = ({
)} - {isShowSingleRun && ( - } - /> - )}
) } diff --git a/web/app/components/workflow/nodes/parameter-extractor/use-config.ts b/web/app/components/workflow/nodes/parameter-extractor/use-config.ts index 045737b230..3fe42b60cf 100644 --- a/web/app/components/workflow/nodes/parameter-extractor/use-config.ts +++ b/web/app/components/workflow/nodes/parameter-extractor/use-config.ts @@ -8,7 +8,6 @@ import { useNodesReadOnly, useWorkflow, } from '../../hooks' -import useOneStepRun from '../_base/hooks/use-one-step-run' import useConfigVision from '../../hooks/use-config-vision' import type { Param, ParameterExtractorNodeType, ReasoningModeType } from './types' import { useModelListAndDefaultModelAndCurrentProviderAndModel, useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' @@ -17,8 +16,13 @@ import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-cr import { checkHasQueryBlock } from '@/app/components/base/prompt-editor/constants' import useAvailableVarList from '@/app/components/workflow/nodes/_base/hooks/use-available-var-list' import { supportFunctionCall } from '@/utils/tool-call' +import useInspectVarsCrud from '../../hooks/use-inspect-vars-crud' const useConfig = (id: string, payload: ParameterExtractorNodeType) => { + const { + deleteNodeInspectorVars, + renameInspectVarName, + } = useInspectVarsCrud() const { nodesReadOnly: readOnly } = useNodesReadOnly() const { handleOutVarRenameChange } = useWorkflow() const isChatMode = useIsChatMode() @@ -59,9 +63,14 @@ const useConfig = (id: string, payload: ParameterExtractorNodeType) => { }) setInputs(newInputs) - if (moreInfo && moreInfo?.type === ChangeType.changeVarName && moreInfo.payload) + if (moreInfo && moreInfo?.type === ChangeType.changeVarName && moreInfo.payload) { handleOutVarRenameChange(id, [id, moreInfo.payload.beforeKey], [id, moreInfo.payload.afterKey!]) - }, [handleOutVarRenameChange, id, inputs, setInputs]) + renameInspectVarName(id, moreInfo.payload.beforeKey, moreInfo.payload.afterKey!) + } + else { + deleteNodeInspectorVars(id) + } + }, [deleteNodeInspectorVars, handleOutVarRenameChange, id, inputs, renameInspectVarName, setInputs]) const addExtractParameter = useCallback((payload: Param) => { const newInputs = produce(inputs, (draft) => { @@ -70,7 +79,8 @@ const useConfig = (id: string, payload: ParameterExtractorNodeType) => { draft.parameters.push(payload) }) setInputs(newInputs) - }, [inputs, setInputs]) + deleteNodeInspectorVars(id) + }, [deleteNodeInspectorVars, id, inputs, setInputs]) // model const model = inputs.model || { @@ -145,7 +155,7 @@ const useConfig = (id: string, payload: ParameterExtractorNodeType) => { return setModelChanged(false) handleVisionConfigAfterModelChanged() - // eslint-disable-next-line react-hooks/exhaustive-deps + // eslint-disable-next-line react-hooks/exhaustive-deps }, [isVisionModel, modelChanged]) const { @@ -163,10 +173,6 @@ const useConfig = (id: string, payload: ParameterExtractorNodeType) => { return [VarType.number, VarType.string].includes(varPayload.type) }, []) - const filterVisionInputVar = useCallback((varPayload: Var) => { - return [VarType.file, VarType.arrayFile].includes(varPayload.type) - }, []) - const { availableVars, availableNodesWithParent, @@ -175,13 +181,6 @@ const useConfig = (id: string, payload: ParameterExtractorNodeType) => { filterVar: filterInputVar, }) - const { - availableVars: availableVisionVars, - } = useAvailableVarList(id, { - onlyLeafNodeVar: false, - filterVar: filterVisionInputVar, - }) - const handleCompletionParamsChange = useCallback((newParams: Record) => { const newInputs = produce(inputs, (draft) => { draft.model.completion_params = newParams @@ -223,49 +222,6 @@ const useConfig = (id: string, payload: ParameterExtractorNodeType) => { setInputs(newInputs) }, [inputs, setInputs]) - // single run - const { - isShowSingleRun, - hideSingleRun, - getInputVars, - runningStatus, - handleRun, - handleStop, - runInputData, - runInputDataRef, - setRunInputData, - runResult, - } = useOneStepRun({ - id, - data: inputs, - defaultRunInputData: { - 'query': '', - '#files#': [], - }, - }) - - const varInputs = getInputVars([inputs.instruction]) - const inputVarValues = (() => { - const vars: Record = {} - Object.keys(runInputData) - .forEach((key) => { - vars[key] = runInputData[key] - }) - return vars - })() - - const setInputVarValues = useCallback((newPayload: Record) => { - setRunInputData(newPayload) - }, [setRunInputData]) - - const visionFiles = runInputData['#files#'] - const setVisionFiles = useCallback((newFiles: any[]) => { - setRunInputData({ - ...runInputDataRef.current, - '#files#': newFiles, - }) - }, [runInputDataRef, setRunInputData]) - return { readOnly, handleInputVarChange, @@ -283,24 +239,12 @@ const useConfig = (id: string, payload: ParameterExtractorNodeType) => { hasSetBlockStatus, availableVars, availableNodesWithParent, - availableVisionVars, isSupportFunctionCall, handleReasoningModeChange, handleMemoryChange, - varInputs, - inputVarValues, isVisionModel, handleVisionResolutionEnabledChange, handleVisionResolutionChange, - isShowSingleRun, - hideSingleRun, - runningStatus, - handleRun, - handleStop, - runResult, - setInputVarValues, - visionFiles, - setVisionFiles, } } diff --git a/web/app/components/workflow/nodes/parameter-extractor/use-single-run-form-params.ts b/web/app/components/workflow/nodes/parameter-extractor/use-single-run-form-params.ts new file mode 100644 index 0000000000..178f9e3ed8 --- /dev/null +++ b/web/app/components/workflow/nodes/parameter-extractor/use-single-run-form-params.ts @@ -0,0 +1,148 @@ +import type { MutableRefObject } from 'react' +import { useTranslation } from 'react-i18next' +import type { Props as FormProps } from '@/app/components/workflow/nodes/_base/components/before-run-form/form' +import type { InputVar, Var, Variable } from '@/app/components/workflow/types' +import { InputVarType, VarType } from '@/app/components/workflow/types' +import type { ParameterExtractorNodeType } from './types' +import useNodeCrud from '../_base/hooks/use-node-crud' +import { useCallback } from 'react' +import useConfigVision from '../../hooks/use-config-vision' +import { noop } from 'lodash-es' +import { findVariableWhenOnLLMVision } from '../utils' +import useAvailableVarList from '../_base/hooks/use-available-var-list' + +const i18nPrefix = 'workflow.nodes.parameterExtractor' + +type Params = { + id: string, + payload: ParameterExtractorNodeType, + runInputData: Record + runInputDataRef: MutableRefObject> + getInputVars: (textList: string[]) => InputVar[] + setRunInputData: (data: Record) => void + toVarInputs: (variables: Variable[]) => InputVar[] +} +const useSingleRunFormParams = ({ + id, + payload, + runInputData, + runInputDataRef, + getInputVars, + setRunInputData, +}: Params) => { + const { t } = useTranslation() + const { inputs } = useNodeCrud(id, payload) + + const model = inputs.model + + const { + isVisionModel, + } = useConfigVision(model, { + payload: inputs.vision, + onChange: noop, + }) + + const visionFiles = runInputData['#files#'] + const setVisionFiles = useCallback((newFiles: any[]) => { + setRunInputData?.({ + ...runInputDataRef.current, + '#files#': newFiles, + }) + }, [runInputDataRef, setRunInputData]) + + const varInputs = getInputVars([inputs.instruction]) + + const inputVarValues = (() => { + const vars: Record = {} + Object.keys(runInputData) + .filter(key => !['#context#', '#files#'].includes(key)) + .forEach((key) => { + vars[key] = runInputData[key] + }) + return vars + })() + + const setInputVarValues = useCallback((newPayload: Record) => { + const newVars = { + ...newPayload, + '#context#': runInputDataRef.current['#context#'], + '#files#': runInputDataRef.current['#files#'], + } + setRunInputData?.(newVars) + }, [runInputDataRef, setRunInputData]) + + const filterVisionInputVar = useCallback((varPayload: Var) => { + return [VarType.file, VarType.arrayFile].includes(varPayload.type) + }, []) + const { + availableVars: availableVisionVars, + } = useAvailableVarList(id, { + onlyLeafNodeVar: false, + filterVar: filterVisionInputVar, + }) + + const forms = (() => { + const forms: FormProps[] = [] + + forms.push( + { + label: t('workflow.nodes.llm.singleRun.variable')!, + inputs: [{ + label: t(`${i18nPrefix}.inputVar`)!, + variable: 'query', + type: InputVarType.paragraph, + required: true, + }, ...varInputs], + values: inputVarValues, + onChange: setInputVarValues, + }, + ) + + if (isVisionModel && payload.vision?.enabled && payload.vision?.configs?.variable_selector) { + const currentVariable = findVariableWhenOnLLMVision(payload.vision.configs.variable_selector, availableVisionVars) + + forms.push( + { + label: t('workflow.nodes.llm.vision')!, + inputs: [{ + label: currentVariable?.variable as any, + variable: '#files#', + type: currentVariable?.formType as any, + required: false, + }], + values: { '#files#': visionFiles }, + onChange: keyValue => setVisionFiles((keyValue as any)['#files#']), + }, + ) + } + + return forms + })() + + const getDependentVars = () => { + const promptVars = varInputs.map(item => item.variable.slice(1, -1).split('.')) + const vars = [payload.query, ...promptVars] + if (isVisionModel && payload.vision?.enabled && payload.vision?.configs?.variable_selector) { + const visionVar = payload.vision.configs.variable_selector + vars.push(visionVar) + } + return vars + } + + const getDependentVar = (variable: string) => { + if(variable === 'query') + return payload.query + if(variable === '#files#') + return payload.vision.configs?.variable_selector + + return false + } + + return { + forms, + getDependentVars, + getDependentVar, + } +} + +export default useSingleRunFormParams diff --git a/web/app/components/workflow/nodes/question-classifier/panel.tsx b/web/app/components/workflow/nodes/question-classifier/panel.tsx index d2e0fb060a..8f6f5eb76d 100644 --- a/web/app/components/workflow/nodes/question-classifier/panel.tsx +++ b/web/app/components/workflow/nodes/question-classifier/panel.tsx @@ -3,20 +3,16 @@ import React from 'react' import { useTranslation } from 'react-i18next' import VarReferencePicker from '../_base/components/variable/var-reference-picker' import ConfigVision from '../_base/components/config-vision' -import { findVariableWhenOnLLMVision } from '../utils' import useConfig from './use-config' import ClassList from './components/class-list' import AdvancedSetting from './components/advanced-setting' import type { QuestionClassifierNodeType } from './types' import Field from '@/app/components/workflow/nodes/_base/components/field' import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' -import { InputVarType, type NodePanelProps } from '@/app/components/workflow/types' -import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form' -import ResultPanel from '@/app/components/workflow/run/result-panel' +import type { NodePanelProps } from '@/app/components/workflow/types' import Split from '@/app/components/workflow/nodes/_base/components/split' import OutputVars, { VarItem } from '@/app/components/workflow/nodes/_base/components/output-vars' import { FieldCollapse } from '@/app/components/workflow/nodes/_base/components/collapse' -import type { Props as FormProps } from '@/app/components/workflow/nodes/_base/components/before-run-form/form' const i18nPrefix = 'workflow.nodes.questionClassifiers' @@ -38,66 +34,16 @@ const Panel: FC> = ({ hasSetBlockStatus, availableVars, availableNodesWithParent, - availableVisionVars, handleInstructionChange, - inputVarValues, - varInputs, - setInputVarValues, handleMemoryChange, isVisionModel, handleVisionResolutionChange, handleVisionResolutionEnabledChange, - isShowSingleRun, - hideSingleRun, - runningStatus, - handleRun, - handleStop, - runResult, filterVar, - visionFiles, - setVisionFiles, } = useConfig(id, data) const model = inputs.model - const singleRunForms = (() => { - const forms: FormProps[] = [] - - forms.push( - { - label: t('workflow.nodes.llm.singleRun.variable')!, - inputs: [{ - label: t(`${i18nPrefix}.inputVars`)!, - variable: 'query', - type: InputVarType.paragraph, - required: true, - }, ...varInputs], - values: inputVarValues, - onChange: setInputVarValues, - }, - ) - - if (isVisionModel && data.vision?.enabled && data.vision?.configs?.variable_selector) { - const currentVariable = findVariableWhenOnLLMVision(data.vision.configs.variable_selector, availableVisionVars) - - forms.push( - { - label: t('workflow.nodes.llm.vision')!, - inputs: [{ - label: currentVariable?.variable as any, - variable: '#files#', - type: currentVariable?.formType as any, - required: false, - }], - values: { '#files#': visionFiles }, - onChange: keyValue => setVisionFiles(keyValue['#files#']), - }, - ) - } - - return forms - })() - return (
@@ -186,17 +132,6 @@ const Panel: FC> = ({
- {isShowSingleRun && ( - } - /> - )}
) } diff --git a/web/app/components/workflow/nodes/question-classifier/use-config.ts b/web/app/components/workflow/nodes/question-classifier/use-config.ts index 7df8293b40..8eacf5b43f 100644 --- a/web/app/components/workflow/nodes/question-classifier/use-config.ts +++ b/web/app/components/workflow/nodes/question-classifier/use-config.ts @@ -11,7 +11,6 @@ import useAvailableVarList from '../_base/hooks/use-available-var-list' import useConfigVision from '../../hooks/use-config-vision' import type { QuestionClassifierNodeType } from './types' import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud' -import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run' import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { checkHasQueryBlock } from '@/app/components/base/prompt-editor/constants' @@ -87,7 +86,7 @@ const useConfig = (id: string, payload: QuestionClassifierNodeType) => { return setModelChanged(false) handleVisionConfigAfterModelChanged() - // eslint-disable-next-line react-hooks/exhaustive-deps + // eslint-disable-next-line react-hooks/exhaustive-deps }, [isVisionModel, modelChanged]) const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => { @@ -109,7 +108,7 @@ const useConfig = (id: string, payload: QuestionClassifierNodeType) => { query_variable_selector: inputs.query_variable_selector.length > 0 ? inputs.query_variable_selector : query_variable_selector, }) } - // eslint-disable-next-line react-hooks/exhaustive-deps + // eslint-disable-next-line react-hooks/exhaustive-deps }, [defaultConfig]) const handleClassesChange = useCallback((newClasses: any) => { @@ -163,59 +162,6 @@ const useConfig = (id: string, payload: QuestionClassifierNodeType) => { setInputs(newInputs) }, [inputs, setInputs]) - // single run - const { - isShowSingleRun, - hideSingleRun, - getInputVars, - runningStatus, - handleRun, - handleStop, - runInputData, - runInputDataRef, - setRunInputData, - runResult, - } = useOneStepRun({ - id, - data: inputs, - defaultRunInputData: { - 'query': '', - '#files#': [], - }, - }) - - const query = runInputData.query - const setQuery = useCallback((newQuery: string) => { - setRunInputData({ - ...runInputData, - query: newQuery, - }) - }, [runInputData, setRunInputData]) - - const varInputs = getInputVars([inputs.instruction]) - const inputVarValues = (() => { - const vars: Record = { - query, - } - Object.keys(runInputData) - .forEach((key) => { - vars[key] = runInputData[key] - }) - return vars - })() - - const setInputVarValues = useCallback((newPayload: Record) => { - setRunInputData(newPayload) - }, [setRunInputData]) - - const visionFiles = runInputData['#files#'] - const setVisionFiles = useCallback((newFiles: any[]) => { - setRunInputData({ - ...runInputDataRef.current, - '#files#': newFiles, - }) - }, [runInputDataRef, setRunInputData]) - const filterVar = useCallback((varPayload: Var) => { return varPayload.type === VarType.string }, []) @@ -235,23 +181,10 @@ const useConfig = (id: string, payload: QuestionClassifierNodeType) => { availableNodesWithParent, availableVisionVars, handleInstructionChange, - varInputs, - inputVarValues, - setInputVarValues, handleMemoryChange, isVisionModel, handleVisionResolutionEnabledChange, handleVisionResolutionChange, - isShowSingleRun, - hideSingleRun, - runningStatus, - handleRun, - handleStop, - query, - setQuery, - runResult, - visionFiles, - setVisionFiles, } } diff --git a/web/app/components/workflow/nodes/question-classifier/use-single-run-form-params.ts b/web/app/components/workflow/nodes/question-classifier/use-single-run-form-params.ts new file mode 100644 index 0000000000..66755abb6e --- /dev/null +++ b/web/app/components/workflow/nodes/question-classifier/use-single-run-form-params.ts @@ -0,0 +1,146 @@ +import type { MutableRefObject } from 'react' +import { useTranslation } from 'react-i18next' +import type { Props as FormProps } from '@/app/components/workflow/nodes/_base/components/before-run-form/form' +import type { InputVar, Var, Variable } from '@/app/components/workflow/types' +import { InputVarType, VarType } from '@/app/components/workflow/types' +import type { QuestionClassifierNodeType } from './types' +import useNodeCrud from '../_base/hooks/use-node-crud' +import { useCallback } from 'react' +import useConfigVision from '../../hooks/use-config-vision' +import { noop } from 'lodash-es' +import { findVariableWhenOnLLMVision } from '../utils' +import useAvailableVarList from '../_base/hooks/use-available-var-list' + +const i18nPrefix = 'workflow.nodes.questionClassifiers' + +type Params = { + id: string, + payload: QuestionClassifierNodeType, + runInputData: Record + runInputDataRef: MutableRefObject> + getInputVars: (textList: string[]) => InputVar[] + setRunInputData: (data: Record) => void + toVarInputs: (variables: Variable[]) => InputVar[] +} +const useSingleRunFormParams = ({ + id, + payload, + runInputData, + runInputDataRef, + getInputVars, + setRunInputData, +}: Params) => { + const { t } = useTranslation() + const { inputs } = useNodeCrud(id, payload) + + const model = inputs.model + + const { + isVisionModel, + } = useConfigVision(model, { + payload: inputs.vision, + onChange: noop, + }) + + const visionFiles = runInputData['#files#'] + const setVisionFiles = useCallback((newFiles: any[]) => { + setRunInputData?.({ + ...runInputDataRef.current, + '#files#': newFiles, + }) + }, [runInputDataRef, setRunInputData]) + + const varInputs = getInputVars([inputs.instruction]) + + const inputVarValues = (() => { + const vars: Record = {} + Object.keys(runInputData) + .filter(key => !['#files#'].includes(key)) + .forEach((key) => { + vars[key] = runInputData[key] + }) + return vars + })() + + const setInputVarValues = useCallback((newPayload: Record) => { + const newVars = { + ...newPayload, + '#files#': runInputDataRef.current['#files#'], + } + setRunInputData?.(newVars) + }, [runInputDataRef, setRunInputData]) + + const filterVisionInputVar = useCallback((varPayload: Var) => { + return [VarType.file, VarType.arrayFile].includes(varPayload.type) + }, []) + const { + availableVars: availableVisionVars, + } = useAvailableVarList(id, { + onlyLeafNodeVar: false, + filterVar: filterVisionInputVar, + }) + + const forms = (() => { + const forms: FormProps[] = [] + + forms.push( + { + label: t('workflow.nodes.llm.singleRun.variable')!, + inputs: [{ + label: t(`${i18nPrefix}.inputVars`)!, + variable: 'query', + type: InputVarType.paragraph, + required: true, + }, ...varInputs], + values: inputVarValues, + onChange: setInputVarValues, + }, + ) + + if (isVisionModel && payload.vision?.enabled && payload.vision?.configs?.variable_selector) { + const currentVariable = findVariableWhenOnLLMVision(payload.vision.configs.variable_selector, availableVisionVars) + + forms.push( + { + label: t('workflow.nodes.llm.vision')!, + inputs: [{ + label: currentVariable?.variable as any, + variable: '#files#', + type: currentVariable?.formType as any, + required: false, + }], + values: { '#files#': visionFiles }, + onChange: keyValue => setVisionFiles(keyValue['#files#']), + }, + ) + } + return forms + })() + + const getDependentVars = () => { + const promptVars = varInputs.map(item => item.variable.slice(1, -1).split('.')) + const vars = [payload.query_variable_selector, ...promptVars] + if (isVisionModel && payload.vision?.enabled && payload.vision?.configs?.variable_selector) { + const visionVar = payload.vision.configs.variable_selector + vars.push(visionVar) + } + return vars + } + + const getDependentVar = (variable: string) => { + if(variable === 'query') + return payload.query_variable_selector + if(variable === '#files#') + return payload.vision.configs?.variable_selector + + return false + } + + return { + forms, + getDependentVars, + getDependentVar, + } +} + +export default useSingleRunFormParams diff --git a/web/app/components/workflow/nodes/start/use-config.ts b/web/app/components/workflow/nodes/start/use-config.ts index e30e8c2838..c0ade614e0 100644 --- a/web/app/components/workflow/nodes/start/use-config.ts +++ b/web/app/components/workflow/nodes/start/use-config.ts @@ -10,6 +10,7 @@ import { useNodesReadOnly, useWorkflow, } from '@/app/components/workflow/hooks' +import useInspectVarsCrud from '../../hooks/use-inspect-vars-crud' const useConfig = (id: string, payload: StartNodeType) => { const { nodesReadOnly: readOnly } = useNodesReadOnly() @@ -18,6 +19,13 @@ const useConfig = (id: string, payload: StartNodeType) => { const { inputs, setInputs } = useNodeCrud(id, payload) + const { + deleteNodeInspectorVars, + renameInspectVarName, + nodesWithInspectVars, + deleteInspectVar, + } = useInspectVarsCrud() + const [isShowAddVarModal, { setTrue: showAddVarModal, setFalse: hideAddVarModal, @@ -31,6 +39,12 @@ const useConfig = (id: string, payload: StartNodeType) => { const [removedIndex, setRemoveIndex] = useState(0) const handleVarListChange = useCallback((newList: InputVar[], moreInfo?: { index: number; payload: MoreInfo }) => { if (moreInfo?.payload?.type === ChangeType.remove) { + const varId = nodesWithInspectVars.find(node => node.nodeId === id)?.vars.find((varItem) => { + return varItem.name === moreInfo?.payload?.payload?.beforeKey + })?.id + if(varId) + deleteInspectVar(id, varId) + if (isVarUsedInNodes([id, moreInfo?.payload?.payload?.beforeKey || ''])) { showRemoveVarConfirm() setRemovedVar([id, moreInfo?.payload?.payload?.beforeKey || '']) @@ -46,8 +60,12 @@ const useConfig = (id: string, payload: StartNodeType) => { if (moreInfo?.payload?.type === ChangeType.changeVarName) { const changedVar = newList[moreInfo.index] handleOutVarRenameChange(id, [id, inputs.variables[moreInfo.index].variable], [id, changedVar.variable]) + renameInspectVarName(id, inputs.variables[moreInfo.index].variable, changedVar.variable) } - }, [handleOutVarRenameChange, id, inputs, isVarUsedInNodes, setInputs, showRemoveVarConfirm]) + else if(moreInfo?.payload?.type !== ChangeType.remove) { // edit var type + deleteNodeInspectorVars(id) + } + }, [deleteInspectVar, deleteNodeInspectorVars, handleOutVarRenameChange, id, inputs, isVarUsedInNodes, nodesWithInspectVars, renameInspectVarName, setInputs, showRemoveVarConfirm]) const removeVarInNode = useCallback(() => { const newInputs = produce(inputs, (draft) => { diff --git a/web/app/components/workflow/nodes/start/use-single-run-form-params.ts b/web/app/components/workflow/nodes/start/use-single-run-form-params.ts new file mode 100644 index 0000000000..38abbf2a63 --- /dev/null +++ b/web/app/components/workflow/nodes/start/use-single-run-form-params.ts @@ -0,0 +1,87 @@ +import type { MutableRefObject } from 'react' +import { useTranslation } from 'react-i18next' +import type { Props as FormProps } from '@/app/components/workflow/nodes/_base/components/before-run-form/form' +import type { ValueSelector } from '@/app/components/workflow/types' +import { type InputVar, InputVarType, type Variable } from '@/app/components/workflow/types' +import type { StartNodeType } from './types' +import { useIsChatMode } from '../../hooks' + +type Params = { + id: string, + payload: StartNodeType, + runInputData: Record + runInputDataRef: MutableRefObject> + getInputVars: (textList: string[]) => InputVar[] + setRunInputData: (data: Record) => void + toVarInputs: (variables: Variable[]) => InputVar[] +} +const useSingleRunFormParams = ({ + id, + payload, + runInputData, + setRunInputData, +}: Params) => { + const { t } = useTranslation() + const isChatMode = useIsChatMode() + + const forms = (() => { + const forms: FormProps[] = [] + const inputs: InputVar[] = payload.variables.map((item) => { + return { + ...item, + getVarValueFromDependent: true, + } + }) + + if (isChatMode) { + inputs.push({ + label: 'sys.query', + variable: '#sys.query#', + type: InputVarType.textInput, + required: true, + }) + } + + inputs.push({ + label: 'sys.files', + variable: '#sys.files#', + type: InputVarType.multiFiles, + required: false, + }) + + forms.push( + { + label: t('workflow.nodes.llm.singleRun.variable')!, + inputs, + values: runInputData, + onChange: setRunInputData, + }, + ) + + return forms + })() + + const getDependentVars = () => { + const inputVars = payload.variables.map((item) => { + return [id, item.variable] + }) + const vars: ValueSelector[] = [...inputVars, ['sys', 'files']] + + if (isChatMode) + vars.push(['sys', 'query']) + + return vars + } + + const getDependentVar = (variable: string) => { + return [id, variable] + } + + return { + forms, + getDependentVars, + getDependentVar, + } +} + +export default useSingleRunFormParams diff --git a/web/app/components/workflow/nodes/template-transform/panel.tsx b/web/app/components/workflow/nodes/template-transform/panel.tsx index e120482925..29c34ee663 100644 --- a/web/app/components/workflow/nodes/template-transform/panel.tsx +++ b/web/app/components/workflow/nodes/template-transform/panel.tsx @@ -14,8 +14,6 @@ import Split from '@/app/components/workflow/nodes/_base/components/split' import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor/editor-support-vars' import OutputVars, { VarItem } from '@/app/components/workflow/nodes/_base/components/output-vars' import type { NodePanelProps } from '@/app/components/workflow/types' -import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form' -import ResultPanel from '@/app/components/workflow/run/result-panel' const i18nPrefix = 'workflow.nodes.templateTransform' @@ -35,16 +33,6 @@ const Panel: FC> = ({ handleAddEmptyVariable, handleCodeChange, filterVar, - // single run - isShowSingleRun, - hideSingleRun, - runningStatus, - handleRun, - handleStop, - varInputs, - inputVarValues, - setInputVarValues, - runResult, } = useConfig(id, data) return ( @@ -106,23 +94,6 @@ const Panel: FC> = ({
- {isShowSingleRun && ( - } - /> - )}
) } diff --git a/web/app/components/workflow/nodes/template-transform/use-config.ts b/web/app/components/workflow/nodes/template-transform/use-config.ts index e0c41ac2d6..8be93abdf8 100644 --- a/web/app/components/workflow/nodes/template-transform/use-config.ts +++ b/web/app/components/workflow/nodes/template-transform/use-config.ts @@ -6,7 +6,6 @@ import { VarType } from '../../types' import { useStore } from '../../store' import type { TemplateTransformNodeType } from './types' import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud' -import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run' import { useNodesReadOnly, } from '@/app/components/workflow/hooks' @@ -66,7 +65,7 @@ const useConfig = (id: string, payload: TemplateTransformNodeType) => { ...defaultConfig, }) } - // eslint-disable-next-line react-hooks/exhaustive-deps + // eslint-disable-next-line react-hooks/exhaustive-deps }, [defaultConfig]) const handleCodeChange = useCallback((template: string) => { @@ -76,37 +75,6 @@ const useConfig = (id: string, payload: TemplateTransformNodeType) => { setInputs(newInputs) }, [setInputs]) - // single run - const { - isShowSingleRun, - hideSingleRun, - toVarInputs, - runningStatus, - handleRun, - handleStop, - runInputData, - setRunInputData, - runResult, - } = useOneStepRun({ - id, - data: inputs, - defaultRunInputData: {}, - }) - const varInputs = toVarInputs(inputs.variables) - - const inputVarValues = (() => { - const vars: Record = {} - Object.keys(runInputData) - .forEach((key) => { - vars[key] = runInputData[key] - }) - return vars - })() - - const setInputVarValues = useCallback((newPayload: Record) => { - setRunInputData(newPayload) - }, [setRunInputData]) - const filterVar = useCallback((varPayload: Var) => { return [VarType.string, VarType.number, VarType.object, VarType.array, VarType.arrayNumber, VarType.arrayString, VarType.arrayObject].includes(varPayload.type) }, []) @@ -121,16 +89,6 @@ const useConfig = (id: string, payload: TemplateTransformNodeType) => { handleAddEmptyVariable, handleCodeChange, filterVar, - // single run - isShowSingleRun, - hideSingleRun, - runningStatus, - handleRun, - handleStop, - varInputs, - inputVarValues, - setInputVarValues, - runResult, } } diff --git a/web/app/components/workflow/nodes/template-transform/use-single-run-form-params.ts b/web/app/components/workflow/nodes/template-transform/use-single-run-form-params.ts new file mode 100644 index 0000000000..ab1cfe731d --- /dev/null +++ b/web/app/components/workflow/nodes/template-transform/use-single-run-form-params.ts @@ -0,0 +1,65 @@ +import type { MutableRefObject } from 'react' +import type { InputVar, Variable } from '@/app/components/workflow/types' +import { useCallback, useMemo } from 'react' +import useNodeCrud from '../_base/hooks/use-node-crud' +import type { TemplateTransformNodeType } from './types' + +type Params = { + id: string, + payload: TemplateTransformNodeType, + runInputData: Record + runInputDataRef: MutableRefObject> + getInputVars: (textList: string[]) => InputVar[] + setRunInputData: (data: Record) => void + toVarInputs: (variables: Variable[]) => InputVar[] +} +const useSingleRunFormParams = ({ + id, + payload, + runInputData, + toVarInputs, + setRunInputData, +}: Params) => { + const { inputs } = useNodeCrud(id, payload) + + const varInputs = toVarInputs(inputs.variables) + const setInputVarValues = useCallback((newPayload: Record) => { + setRunInputData(newPayload) + }, [setRunInputData]) + const inputVarValues = (() => { + const vars: Record = {} + Object.keys(runInputData) + .forEach((key) => { + vars[key] = runInputData[key] + }) + return vars + })() + + const forms = useMemo(() => { + return [ + { + inputs: varInputs, + values: inputVarValues, + onChange: setInputVarValues, + }, + ] + }, [inputVarValues, setInputVarValues, varInputs]) + + const getDependentVars = () => { + return payload.variables.map(v => v.value_selector) + } + + const getDependentVar = (variable: string) => { + const varItem = payload.variables.find(v => v.variable === variable) + if (varItem) + return varItem.value_selector + } + + return { + forms, + getDependentVars, + getDependentVar, + } +} + +export default useSingleRunFormParams diff --git a/web/app/components/workflow/nodes/tool/components/input-var-list.tsx b/web/app/components/workflow/nodes/tool/components/input-var-list.tsx index 1a609c58f5..244e54a5f2 100644 --- a/web/app/components/workflow/nodes/tool/components/input-var-list.tsx +++ b/web/app/components/workflow/nodes/tool/components/input-var-list.tsx @@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next' import type { ToolVarInputs } from '../types' import { VarType as VarKindType } from '../types' import cn from '@/utils/classnames' -import type { ValueSelector, Var } from '@/app/components/workflow/types' +import type { ToolWithProvider, ValueSelector, Var } from '@/app/components/workflow/types' import type { CredentialFormSchema } from '@/app/components/header/account-setting/model-provider-page/declarations' import { FormTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks' @@ -17,6 +17,7 @@ import { VarType } from '@/app/components/workflow/types' import AppSelector from '@/app/components/plugins/plugin-detail-panel/app-selector' import ModelParameterModal from '@/app/components/plugins/plugin-detail-panel/model-selector' import { noop } from 'lodash-es' +import type { Tool } from '@/app/components/tools/types' type Props = { readOnly: boolean @@ -27,6 +28,8 @@ type Props = { onOpen?: (index: number) => void isSupportConstantValue?: boolean filterVar?: (payload: Var, valueSelector: ValueSelector) => boolean + currentTool?: Tool + currentProvider?: ToolWithProvider } const InputVarList: FC = ({ @@ -38,6 +41,8 @@ const InputVarList: FC = ({ onOpen = noop, isSupportConstantValue, filterVar, + currentTool, + currentProvider, }) => { const language = useLanguage() const { t } = useTranslation() @@ -58,6 +63,8 @@ const InputVarList: FC = ({ return 'ModelSelector' else if (type === FormTypeEnum.toolSelector) return 'ToolSelector' + else if (type === FormTypeEnum.dynamicSelect || type === FormTypeEnum.select) + return 'Select' else return 'String' } @@ -149,6 +156,7 @@ const InputVarList: FC = ({ const handleOpen = useCallback((index: number) => { return () => onOpen(index) }, [onOpen]) + return (
{ @@ -163,7 +171,8 @@ const InputVarList: FC = ({ } = schema const varInput = value[variable] const isNumber = type === FormTypeEnum.textNumber - const isSelect = type === FormTypeEnum.select + const isDynamicSelect = type === FormTypeEnum.dynamicSelect + const isSelect = type === FormTypeEnum.select || type === FormTypeEnum.dynamicSelect const isFile = type === FormTypeEnum.file || type === FormTypeEnum.files const isAppSelector = type === FormTypeEnum.appSelector const isModelSelector = type === FormTypeEnum.modelSelector @@ -198,11 +207,13 @@ const InputVarList: FC = ({ value={varInput?.type === VarKindType.constant ? (varInput?.value ?? '') : (varInput?.value ?? [])} onChange={handleNotMixedTypeChange(variable)} onOpen={handleOpen(index)} - defaultVarKindType={varInput?.type || (isNumber ? VarKindType.constant : VarKindType.variable)} + defaultVarKindType={varInput?.type || ((isNumber || isDynamicSelect) ? VarKindType.constant : VarKindType.variable)} isSupportConstantValue={isSupportConstantValue} filterVar={isNumber ? filterVar : undefined} availableVars={isSelect ? availableVars : undefined} schema={schema} + currentTool={currentTool} + currentProvider={currentProvider} /> )} {isFile && ( diff --git a/web/app/components/workflow/nodes/tool/panel.tsx b/web/app/components/workflow/nodes/tool/panel.tsx index 393a11c1e8..038159870e 100644 --- a/web/app/components/workflow/nodes/tool/panel.tsx +++ b/web/app/components/workflow/nodes/tool/panel.tsx @@ -1,5 +1,5 @@ import type { FC } from 'react' -import React, { useMemo } from 'react' +import React from 'react' import { useTranslation } from 'react-i18next' import Split from '../_base/components/split' import type { ToolNodeType } from './types' @@ -11,12 +11,7 @@ import type { NodePanelProps } from '@/app/components/workflow/types' import Form from '@/app/components/header/account-setting/model-provider-page/model-modal/Form' import ConfigCredential from '@/app/components/tools/setting/build-in/config-credentials' import Loading from '@/app/components/base/loading' -import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form' import OutputVars, { VarItem } from '@/app/components/workflow/nodes/_base/components/output-vars' -import ResultPanel from '@/app/components/workflow/run/result-panel' -import { useToolIcon } from '@/app/components/workflow/hooks' -import { useLogs } from '@/app/components/workflow/run/hooks' -import formatToTracingNodeList from '@/app/components/workflow/run/utils/format-log' import StructureOutputItem from '@/app/components/workflow/nodes/_base/components/variable/object-child-tree-panel/show' import { Type } from '../llm/types' @@ -45,23 +40,10 @@ const Panel: FC> = ({ hideSetAuthModal, handleSaveAuth, isLoading, - isShowSingleRun, - hideSingleRun, - singleRunForms, - runningStatus, - handleRun, - handleStop, - runResult, outputSchema, hasObjectOutput, + currTool, } = useConfig(id, data) - const toolIcon = useToolIcon(data) - const logsParams = useLogs() - const nodeInfo = useMemo(() => { - if (!runResult) - return null - return formatToTracingNodeList([runResult], t)[0] - }, [runResult, t]) if (isLoading) { return
@@ -99,6 +81,8 @@ const Panel: FC> = ({ filterVar={filterVar} isSupportConstantValue onOpen={handleOnVarOpen} + currentProvider={currCollection} + currentTool={currTool} /> )} @@ -180,21 +164,6 @@ const Panel: FC> = ({
- - {isShowSingleRun && ( - } - /> - )}
) } diff --git a/web/app/components/workflow/nodes/tool/use-config.ts b/web/app/components/workflow/nodes/tool/use-config.ts index 38ca5b5195..b83ae8a07f 100644 --- a/web/app/components/workflow/nodes/tool/use-config.ts +++ b/web/app/components/workflow/nodes/tool/use-config.ts @@ -3,17 +3,15 @@ import { useTranslation } from 'react-i18next' import produce from 'immer' import { useBoolean } from 'ahooks' import { useStore } from '../../store' -import { type ToolNodeType, type ToolVarInputs, VarType } from './types' +import type { ToolNodeType, ToolVarInputs } from './types' import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks' import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud' import { CollectionType } from '@/app/components/tools/types' import { updateBuiltInToolCredential } from '@/service/tools' import { addDefaultValue, toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema' import Toast from '@/app/components/base/toast' -import type { Props as FormProps } from '@/app/components/workflow/nodes/_base/components/before-run-form/form' import { VarType as VarVarType } from '@/app/components/workflow/types' -import type { InputVar, ValueSelector, Var } from '@/app/components/workflow/types' -import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run' +import type { InputVar, Var } from '@/app/components/workflow/types' import { useFetchToolsData, useNodesReadOnly, @@ -160,39 +158,8 @@ const useConfig = (id: string, payload: ToolNodeType) => { const isLoading = currTool && (isBuiltIn ? !currCollection : false) - // single run - const [inputVarValues, doSetInputVarValues] = useState>({}) - const setInputVarValues = (value: Record) => { - doSetInputVarValues(value) - // eslint-disable-next-line ts/no-use-before-define - setRunInputData(value) - } - // fill single run form variable with constant value first time - const inputVarValuesWithConstantValue = () => { - const res = produce(inputVarValues, (draft) => { - Object.keys(inputs.tool_parameters).forEach((key: string) => { - const { type, value } = inputs.tool_parameters[key] - if (type === VarType.constant && (value === undefined || value === null)) - draft.tool_parameters[key].value = value - }) - }) - return res - } - - const { - isShowSingleRun, - hideSingleRun, - getInputVars, - runningStatus, - setRunInputData, - handleRun: doHandleRun, - handleStop, - runResult, - } = useOneStepRun({ - id, - data: inputs, - defaultRunInputData: {}, - moreDataForCheckValid: { + const getMoreDataForCheckValid = () => { + return { toolInputsSchema: (() => { const formInputs: InputVar[] = [] toolInputVarSchema.forEach((item: any) => { @@ -208,52 +175,7 @@ const useConfig = (id: string, payload: ToolNodeType) => { notAuthed: isShowAuthBtn, toolSettingSchema, language, - }, - }) - - const hadVarParams = Object.keys(inputs.tool_parameters) - .filter(key => inputs.tool_parameters[key].type !== VarType.constant) - .map(k => inputs.tool_parameters[k]) - - const varInputs = getInputVars(hadVarParams.map((p) => { - if (p.type === VarType.variable) { - // handle the old wrong value not crash the page - if (!(p.value as any).join) - return `{{#${p.value}#}}` - - return `{{#${(p.value as ValueSelector).join('.')}#}}` } - - return p.value as string - })) - - const singleRunForms = (() => { - const forms: FormProps[] = [{ - inputs: varInputs, - values: inputVarValuesWithConstantValue(), - onChange: setInputVarValues, - }] - return forms - })() - - const handleRun = (submitData: Record) => { - const varTypeInputKeys = Object.keys(inputs.tool_parameters) - .filter(key => inputs.tool_parameters[key].type === VarType.variable) - const shouldAdd = varTypeInputKeys.length > 0 - if (!shouldAdd) { - doHandleRun(submitData) - return - } - const addMissedVarData = { ...submitData } - Object.keys(submitData).forEach((key) => { - const value = submitData[key] - varTypeInputKeys.forEach((inputKey) => { - const inputValue = inputs.tool_parameters[inputKey].value as ValueSelector - if (`#${inputValue.join('.')}#` === key) - addMissedVarData[inputKey] = value - }) - }) - doHandleRun(addMissedVarData) } const outputSchema = useMemo(() => { @@ -307,18 +229,9 @@ const useConfig = (id: string, payload: ToolNodeType) => { hideSetAuthModal, handleSaveAuth, isLoading, - isShowSingleRun, - hideSingleRun, - inputVarValues, - varInputs, - setInputVarValues, - singleRunForms, - runningStatus, - handleRun, - handleStop, - runResult, outputSchema, hasObjectOutput, + getMoreDataForCheckValid, } } diff --git a/web/app/components/workflow/nodes/tool/use-get-data-for-check-more.ts b/web/app/components/workflow/nodes/tool/use-get-data-for-check-more.ts new file mode 100644 index 0000000000..a68f12fc37 --- /dev/null +++ b/web/app/components/workflow/nodes/tool/use-get-data-for-check-more.ts @@ -0,0 +1,20 @@ +import type { ToolNodeType } from './types' +import useConfig from './use-config' + +type Params = { + id: string + payload: ToolNodeType, +} + +const useGetDataForCheckMore = ({ + id, + payload, +}: Params) => { + const { getMoreDataForCheckValid } = useConfig(id, payload) + + return { + getData: getMoreDataForCheckValid, + } +} + +export default useGetDataForCheckMore diff --git a/web/app/components/workflow/nodes/tool/use-single-run-form-params.ts b/web/app/components/workflow/nodes/tool/use-single-run-form-params.ts new file mode 100644 index 0000000000..295cf02639 --- /dev/null +++ b/web/app/components/workflow/nodes/tool/use-single-run-form-params.ts @@ -0,0 +1,94 @@ +import type { MutableRefObject } from 'react' +import type { InputVar, Variable } from '@/app/components/workflow/types' +import { useCallback, useMemo, useState } from 'react' +import useNodeCrud from '../_base/hooks/use-node-crud' +import { type ToolNodeType, VarType } from './types' +import type { ValueSelector } from '@/app/components/workflow/types' +import type { Props as FormProps } from '@/app/components/workflow/nodes/_base/components/before-run-form/form' +import produce from 'immer' +import type { NodeTracing } from '@/types/workflow' +import { useTranslation } from 'react-i18next' +import formatToTracingNodeList from '@/app/components/workflow/run/utils/format-log' +import { useToolIcon } from '../../hooks' + +type Params = { + id: string, + payload: ToolNodeType, + runInputData: Record + runInputDataRef: MutableRefObject> + getInputVars: (textList: string[]) => InputVar[] + setRunInputData: (data: Record) => void + toVarInputs: (variables: Variable[]) => InputVar[] + runResult: NodeTracing +} +const useSingleRunFormParams = ({ + id, + payload, + getInputVars, + setRunInputData, + runResult, +}: Params) => { + const { t } = useTranslation() + const { inputs } = useNodeCrud(id, payload) + + const hadVarParams = Object.keys(inputs.tool_parameters) + .filter(key => inputs.tool_parameters[key].type !== VarType.constant) + .map(k => inputs.tool_parameters[k]) + const varInputs = getInputVars(hadVarParams.map((p) => { + if (p.type === VarType.variable) { + // handle the old wrong value not crash the page + if (!(p.value as any).join) + return `{{#${p.value}#}}` + + return `{{#${(p.value as ValueSelector).join('.')}#}}` + } + + return p.value as string + })) + const [inputVarValues, doSetInputVarValues] = useState>({}) + const setInputVarValues = useCallback((value: Record) => { + doSetInputVarValues(value) + setRunInputData(value) + }, [setRunInputData]) + + const inputVarValuesWithConstantValue = useCallback(() => { + const res = produce(inputVarValues, (draft) => { + Object.keys(inputs.tool_parameters).forEach((key: string) => { + const { type, value } = inputs.tool_parameters[key] + if (type === VarType.constant && (value === undefined || value === null)) + draft[key] = value + }) + }) + return res + }, [inputs.tool_parameters, inputVarValues]) + + const forms = useMemo(() => { + const forms: FormProps[] = [{ + inputs: varInputs, + values: inputVarValuesWithConstantValue(), + onChange: setInputVarValues, + }] + return forms + }, [inputVarValuesWithConstantValue, setInputVarValues, varInputs]) + + const nodeInfo = useMemo(() => { + if (!runResult) + return null + return formatToTracingNodeList([runResult], t)[0] + }, [runResult, t]) + + const toolIcon = useToolIcon(payload) + + const getDependentVars = () => { + return varInputs.map(item => item.variable.slice(1, -1).split('.')) + } + + return { + forms, + nodeInfo, + toolIcon, + getDependentVars, + } +} + +export default useSingleRunFormParams diff --git a/web/app/components/workflow/nodes/variable-assigner/use-config.ts b/web/app/components/workflow/nodes/variable-assigner/use-config.ts index f5a7a092b3..c65941e32d 100644 --- a/web/app/components/workflow/nodes/variable-assigner/use-config.ts +++ b/web/app/components/workflow/nodes/variable-assigner/use-config.ts @@ -1,6 +1,6 @@ -import { useCallback, useState } from 'react' +import { useCallback, useRef, useState } from 'react' import produce from 'immer' -import { useBoolean } from 'ahooks' +import { useBoolean, useDebounceFn } from 'ahooks' import { v4 as uuid4 } from 'uuid' import type { ValueSelector, Var } from '../../types' import { VarType } from '../../types' @@ -12,8 +12,13 @@ import { useNodesReadOnly, useWorkflow, } from '@/app/components/workflow/hooks' +import useInspectVarsCrud from '../../hooks/use-inspect-vars-crud' const useConfig = (id: string, payload: VariableAssignerNodeType) => { + const { + deleteNodeInspectorVars, + renameInspectVarName, + } = useInspectVarsCrud() const { nodesReadOnly: readOnly } = useNodesReadOnly() const { handleOutVarRenameChange, isVarUsedInNodes, removeUsedVarInNodes } = useWorkflow() @@ -113,7 +118,8 @@ const useConfig = (id: string, payload: VariableAssignerNodeType) => { draft.advanced_settings.group_enabled = enabled }) setInputs(newInputs) - }, [handleOutVarRenameChange, id, inputs, isVarUsedInNodes, setInputs, showRemoveVarConfirm]) + deleteNodeInspectorVars(id) + }, [deleteNodeInspectorVars, handleOutVarRenameChange, id, inputs, isVarUsedInNodes, setInputs, showRemoveVarConfirm]) const handleAddGroup = useCallback(() => { let maxInGroupName = 1 @@ -134,7 +140,22 @@ const useConfig = (id: string, payload: VariableAssignerNodeType) => { }) }) setInputs(newInputs) - }, [inputs, setInputs]) + deleteNodeInspectorVars(id) + }, [deleteNodeInspectorVars, id, inputs, setInputs]) + + // record the first old name value + const oldNameRecord = useRef>({}) + + const { + run: renameInspectNameWithDebounce, + } = useDebounceFn( + (id: string, newName: string) => { + const oldName = oldNameRecord.current[id] + renameInspectVarName(id, oldName, newName) + delete oldNameRecord.current[id] + }, + { wait: 500 }, + ) const handleVarGroupNameChange = useCallback((groupId: string) => { return (name: string) => { @@ -144,8 +165,11 @@ const useConfig = (id: string, payload: VariableAssignerNodeType) => { }) handleOutVarRenameChange(id, [id, inputs.advanced_settings.groups[index].group_name, 'output'], [id, name, 'output']) setInputs(newInputs) + if(!(id in oldNameRecord.current)) + oldNameRecord.current[id] = inputs.advanced_settings.groups[index].group_name + renameInspectNameWithDebounce(id, name) } - }, [handleOutVarRenameChange, id, inputs, setInputs]) + }, [handleOutVarRenameChange, id, inputs, renameInspectNameWithDebounce, setInputs]) const onRemoveVarConfirm = useCallback(() => { removedVars.forEach((v) => { diff --git a/web/app/components/workflow/nodes/variable-assigner/use-single-run-form-params.ts b/web/app/components/workflow/nodes/variable-assigner/use-single-run-form-params.ts new file mode 100644 index 0000000000..0d6d737c21 --- /dev/null +++ b/web/app/components/workflow/nodes/variable-assigner/use-single-run-form-params.ts @@ -0,0 +1,92 @@ +import type { MutableRefObject } from 'react' +import type { InputVar, ValueSelector, Variable } from '@/app/components/workflow/types' +import { useCallback } from 'react' +import type { VariableAssignerNodeType } from './types' + +type Params = { + id: string, + payload: VariableAssignerNodeType, + runInputData: Record + runInputDataRef: MutableRefObject> + getInputVars: (textList: string[]) => InputVar[] + setRunInputData: (data: Record) => void + toVarInputs: (variables: Variable[]) => InputVar[] + varSelectorsToVarInputs: (variables: ValueSelector[]) => InputVar[] +} +const useSingleRunFormParams = ({ + payload, + runInputData, + setRunInputData, + varSelectorsToVarInputs, +}: Params) => { + const setInputVarValues = useCallback((newPayload: Record) => { + setRunInputData(newPayload) + }, [setRunInputData]) + const inputVarValues = (() => { + const vars: Record = {} + Object.keys(runInputData) + .forEach((key) => { + vars[key] = runInputData[key] + }) + return vars + })() + + const forms = (() => { + const allInputs: ValueSelector[] = [] + const isGroupEnabled = !!payload.advanced_settings?.group_enabled + if (!isGroupEnabled && payload.variables && payload.variables.length) { + payload.variables.forEach((varSelector) => { + allInputs.push(varSelector) + }) + } + if (isGroupEnabled && payload.advanced_settings && payload.advanced_settings.groups && payload.advanced_settings.groups.length) { + payload.advanced_settings.groups.forEach((group) => { + group.variables?.forEach((varSelector) => { + allInputs.push(varSelector) + }) + }) + } + + const varInputs = varSelectorsToVarInputs(allInputs) + // remove duplicate inputs + const existVarsKey: Record = {} + const uniqueVarInputs: InputVar[] = [] + varInputs.forEach((input) => { + if(!input) + return + if (!existVarsKey[input.variable]) { + existVarsKey[input.variable] = true + uniqueVarInputs.push({ + ...input, + required: false, // just one of the inputs is required + }) + } + }) + return [ + { + inputs: uniqueVarInputs, + values: inputVarValues, + onChange: setInputVarValues, + }, + ] + })() + + const getDependentVars = () => { + if(payload.advanced_settings?.group_enabled) { + const vars: ValueSelector[][] = [] + payload.advanced_settings.groups.forEach((group) => { + if(group.variables) + vars.push([...group.variables]) + }) + return vars + } + return [payload.variables] + } + + return { + forms, + getDependentVars, + } +} + +export default useSingleRunFormParams diff --git a/web/app/components/workflow/operator/add-block.tsx b/web/app/components/workflow/operator/add-block.tsx index d35a5be8b4..5bc541a45a 100644 --- a/web/app/components/workflow/operator/add-block.tsx +++ b/web/app/components/workflow/operator/add-block.tsx @@ -96,7 +96,7 @@ const AddBlock = ({ onOpenChange={handleOpenChange} disabled={nodesReadOnly} onSelect={handleSelect} - placement='top-start' + placement='right-start' offset={offset ?? { mainAxis: 4, crossAxis: -8, diff --git a/web/app/components/workflow/operator/control.tsx b/web/app/components/workflow/operator/control.tsx index 5f7d19a17f..7967bf0a6c 100644 --- a/web/app/components/workflow/operator/control.tsx +++ b/web/app/components/workflow/operator/control.tsx @@ -4,6 +4,8 @@ import { } from 'react' import { useTranslation } from 'react-i18next' import { + RiAspectRatioFill, + RiAspectRatioLine, RiCursorLine, RiFunctionAddLine, RiHand, @@ -11,6 +13,7 @@ import { } from '@remixicon/react' import { useNodesReadOnly, + useWorkflowCanvasMaximize, useWorkflowMoveMode, useWorkflowOrganize, } from '../hooks' @@ -28,6 +31,7 @@ import cn from '@/utils/classnames' const Control = () => { const { t } = useTranslation() const controlMode = useStore(s => s.controlMode) + const maximizeCanvas = useStore(s => s.maximizeCanvas) const { handleModePointer, handleModeHand } = useWorkflowMoveMode() const { handleLayout } = useWorkflowOrganize() const { handleAddNote } = useOperator() @@ -35,6 +39,7 @@ const Control = () => { nodesReadOnly, getNodesReadOnly, } = useNodesReadOnly() + const { handleToggleMaximizeCanvas } = useWorkflowCanvasMaximize() const addNote = (e: MouseEvent) => { if (getNodesReadOnly()) @@ -45,7 +50,7 @@ const Control = () => { } return ( -
+
{
- +
{
- +
{
+ +
+ {maximizeCanvas && } + {!maximizeCanvas && } +
+
) } diff --git a/web/app/components/workflow/operator/index.tsx b/web/app/components/workflow/operator/index.tsx index 94ea8143e7..4a472a755f 100644 --- a/web/app/components/workflow/operator/index.tsx +++ b/web/app/components/workflow/operator/index.tsx @@ -1,8 +1,10 @@ -import { memo } from 'react' +import { memo, useEffect, useMemo, useRef } from 'react' import { MiniMap } from 'reactflow' import UndoRedo from '../header/undo-redo' import ZoomInOut from './zoom-in-out' -import Control from './control' +import VariableTrigger from '../variable-inspect/trigger' +import VariableInspectPanel from '../variable-inspect' +import { useStore } from '../store' export type OperatorProps = { handleUndo: () => void @@ -10,25 +12,65 @@ export type OperatorProps = { } const Operator = ({ handleUndo, handleRedo }: OperatorProps) => { + const bottomPanelRef = useRef(null) + const workflowCanvasWidth = useStore(s => s.workflowCanvasWidth) + const rightPanelWidth = useStore(s => s.rightPanelWidth) + const setBottomPanelWidth = useStore(s => s.setBottomPanelWidth) + const setBottomPanelHeight = useStore(s => s.setBottomPanelHeight) + + const bottomPanelWidth = useMemo(() => { + if (!workflowCanvasWidth || !rightPanelWidth) + return 'auto' + return Math.max((workflowCanvasWidth - rightPanelWidth), 400) + }, [workflowCanvasWidth, rightPanelWidth]) + + // update bottom panel height + useEffect(() => { + if (bottomPanelRef.current) { + const resizeContainerObserver = new ResizeObserver((entries) => { + for (const entry of entries) { + const { inlineSize, blockSize } = entry.borderBoxSize[0] + setBottomPanelWidth(inlineSize) + setBottomPanelHeight(blockSize) + } + }) + resizeContainerObserver.observe(bottomPanelRef.current) + return () => { + resizeContainerObserver.disconnect() + } + } + }, [setBottomPanelHeight, setBottomPanelWidth]) + return ( - <> - -
- +
+
- + +
+ + +
- + +
) } diff --git a/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx b/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx index d8da0e69a3..3240496b62 100644 --- a/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx +++ b/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx @@ -324,7 +324,7 @@ const ChatVariableModal = ({ {type === ChatVarType.String && ( // Input will remove \n\r, so use Textarea just like description area