diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index f08befefb8..76e5c04deb 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -83,9 +83,15 @@ jobs: compose-file: | docker/docker-compose.middleware.yaml services: | + db + redis sandbox ssrf_proxy + - name: setup test config + run: | + cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env + - name: Run Workflow run: uv run --project api bash dev/pytest/pytest_workflow.sh diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml index 7d0a873ebd..912267094b 100644 --- a/.github/workflows/vdb-tests.yml +++ b/.github/workflows/vdb-tests.yml @@ -84,6 +84,12 @@ jobs: elasticsearch oceanbase + - name: setup test config + run: | + echo $(pwd) + ls -lah . + cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env + - name: Check VDB Ready (TiDB) run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py diff --git a/README.md b/README.md index ca09adec08..1dc7e2dd98 100644 --- a/README.md +++ b/README.md @@ -226,6 +226,15 @@ Deploy Dify to AWS with [CDK](https://aws.amazon.com/cdk/) - [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Using Alibaba Cloud Computing Nest + +Quickly deploy Dify to Alibaba cloud with [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + +#### Using Alibaba Cloud Data Management + +One-Click deploy Dify to Alibaba Cloud with [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) + + ## Contributing For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). diff --git a/README_AR.md b/README_AR.md index df288fd33c..d93bca8646 100644 --- a/README_AR.md +++ b/README_AR.md @@ -209,6 +209,14 @@ docker compose up -d - [AWS CDK بواسطة @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### استخدام Alibaba Cloud للنشر + [بسرعة نشر Dify إلى سحابة علي بابا مع عش الحوسبة السحابية علي بابا](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + +#### استخدام Alibaba Cloud Data Management للنشر + +انشر ​​Dify على علي بابا كلاود بنقرة واحدة باستخدام [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) + + ## المساهمة لأولئك الذين يرغبون في المساهمة، انظر إلى [دليل المساهمة](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) لدينا. diff --git a/README_BN.md b/README_BN.md index 4a5b5f3928..3efee3684d 100644 --- a/README_BN.md +++ b/README_BN.md @@ -225,6 +225,15 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন - [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud ব্যবহার করে ডিপ্লয় + + [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + +#### Alibaba Cloud Data Management ব্যবহার করে ডিপ্লয় + + [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) + + ## Contributing যারা কোড অবদান রাখতে চান, তাদের জন্য আমাদের [অবদান নির্দেশিকা] দেখুন (https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)। diff --git a/README_CN.md b/README_CN.md index ba7ee0006d..21e27429ec 100644 --- a/README_CN.md +++ b/README_CN.md @@ -221,6 +221,15 @@ docker compose up -d ##### AWS - [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### 使用 阿里云计算巢 部署 + +使用 [阿里云计算巢](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) 将 Dify 一键部署到 阿里云 + +#### 使用 阿里云数据管理DMS 部署 + +使用 [阿里云数据管理DMS](https://help.aliyun.com/zh/dms/dify-in-invitational-preview) 将 Dify 一键部署到 阿里云 + + ## Star History [![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) diff --git a/README_DE.md b/README_DE.md index f6023a3935..20c313035e 100644 --- a/README_DE.md +++ b/README_DE.md @@ -221,6 +221,15 @@ Bereitstellung von Dify auf AWS mit [CDK](https://aws.amazon.com/cdk/) ##### AWS - [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + +#### Alibaba Cloud Data Management + +Ein-Klick-Bereitstellung von Dify in der Alibaba Cloud mit [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) + + ## Contributing Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). Gleichzeitig bitten wir Sie, Dify zu unterstützen, indem Sie es in den sozialen Medien teilen und auf Veranstaltungen und Konferenzen präsentieren. diff --git a/README_ES.md b/README_ES.md index 12f2ce8c11..e4b7df6686 100644 --- a/README_ES.md +++ b/README_ES.md @@ -221,6 +221,15 @@ Despliegue Dify en AWS usando [CDK](https://aws.amazon.com/cdk/) ##### AWS - [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + +#### Alibaba Cloud Data Management + +Despliega Dify en Alibaba Cloud con un solo clic con [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) + + ## Contribuir Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). diff --git a/README_FR.md b/README_FR.md index b106615b31..8fd17fb7c3 100644 --- a/README_FR.md +++ b/README_FR.md @@ -219,6 +219,15 @@ Déployez Dify sur AWS en utilisant [CDK](https://aws.amazon.com/cdk/) ##### AWS - [AWS CDK par @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + +#### Alibaba Cloud Data Management + +Déployez Dify en un clic sur Alibaba Cloud avec [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) + + ## Contribuer Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). diff --git a/README_JA.md b/README_JA.md index 26703f3958..a3ee81e1f2 100644 --- a/README_JA.md +++ b/README_JA.md @@ -155,7 +155,7 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ [こちら](https://dify.ai)のDify Cloudサービスを利用して、セットアップ不要で試すことができます。サンドボックスプランには、200回のGPT-4呼び出しが無料で含まれています。 - **Dify Community Editionのセルフホスティング
** -この[スタートガイド](#quick-start)を使用して、ローカル環境でDifyを簡単に実行できます。 +この[スタートガイド](#クイックスタート)を使用して、ローカル環境でDifyを簡単に実行できます。 詳しくは[ドキュメント](https://docs.dify.ai)をご覧ください。 - **企業/組織向けのDify
** @@ -220,6 +220,13 @@ docker compose up -d ##### AWS - [@KevinZhaoによるAWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + +#### Alibaba Cloud Data Management +[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) を利用して、DifyをAlibaba Cloudへワンクリックでデプロイできます + + ## 貢献 コードに貢献したい方は、[Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)を参照してください。 diff --git a/README_KL.md b/README_KL.md index ea91baa5aa..3e5ab1a74f 100644 --- a/README_KL.md +++ b/README_KL.md @@ -219,6 +219,15 @@ wa'logh nIqHom neH ghun deployment toy'wI' [CDK](https://aws.amazon.com/cdk/) lo ##### AWS - [AWS CDK qachlot @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + +#### Alibaba Cloud Data Management + +[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) + + ## Contributing For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). diff --git a/README_KR.md b/README_KR.md index 89301e8b2c..3c504900e1 100644 --- a/README_KR.md +++ b/README_KR.md @@ -213,6 +213,15 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 ##### AWS - [KevinZhao의 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + +#### Alibaba Cloud Data Management + +[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)를 통해 원클릭으로 Dify를 Alibaba Cloud에 배포할 수 있습니다 + + ## 기여 코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요. diff --git a/README_PT.md b/README_PT.md index 157772d528..fb5f3662ae 100644 --- a/README_PT.md +++ b/README_PT.md @@ -218,6 +218,15 @@ Implante o Dify na AWS usando [CDK](https://aws.amazon.com/cdk/) ##### AWS - [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + +#### Alibaba Cloud Data Management + +Implante o Dify na Alibaba Cloud com um clique usando o [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) + + ## Contribuindo Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). diff --git a/README_SI.md b/README_SI.md index 14de1ea792..647069a220 100644 --- a/README_SI.md +++ b/README_SI.md @@ -219,6 +219,15 @@ Uvedite Dify v AWS z uporabo [CDK](https://aws.amazon.com/cdk/) ##### AWS - [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + +#### Alibaba Cloud Data Management + +Z enim klikom namestite Dify na Alibaba Cloud z [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) + + ## Prispevam Za tiste, ki bi radi prispevali kodo, si oglejte naš vodnik za prispevke . Hkrati vas prosimo, da podprete Dify tako, da ga delite na družbenih medijih ter na dogodkih in konferencah. diff --git a/README_TR.md b/README_TR.md index 563a05af3c..f52335646a 100644 --- a/README_TR.md +++ b/README_TR.md @@ -212,6 +212,15 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter ##### AWS - [AWS CDK tarafından @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + +#### Alibaba Cloud Data Management + +[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) kullanarak Dify'ı tek tıkla Alibaba Cloud'a dağıtın + + ## Katkıda Bulunma Kod katkısında bulunmak isteyenler için [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) bakabilirsiniz. diff --git a/README_TW.md b/README_TW.md index f4a76ac109..71082ff893 100644 --- a/README_TW.md +++ b/README_TW.md @@ -224,6 +224,15 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify - [由 @KevinZhao 提供的 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### 使用 阿里云计算巢進行部署 + +[阿里云](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + +#### 使用 阿里雲數據管理DMS 進行部署 + +透過 [阿里雲數據管理DMS](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/),一鍵將 Dify 部署至阿里雲 + + ## 貢獻 對於想要貢獻程式碼的開發者,請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。 diff --git a/README_VI.md b/README_VI.md index 4e1e05cbf3..58d8434fff 100644 --- a/README_VI.md +++ b/README_VI.md @@ -214,6 +214,16 @@ Triển khai Dify trên AWS bằng [CDK](https://aws.amazon.com/cdk/) ##### AWS - [AWS CDK bởi @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) + +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + +#### Alibaba Cloud Data Management + +Triển khai Dify lên Alibaba Cloud chỉ với một cú nhấp chuột bằng [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) + + ## Đóng góp Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) của chúng tôi. diff --git a/api/.ruff.toml b/api/.ruff.toml index facb0d5419..0169613bf8 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -1,6 +1,4 @@ -exclude = [ - "migrations/*", -] +exclude = ["migrations/*"] line-length = 120 [format] @@ -9,14 +7,14 @@ quote-style = "double" [lint] preview = false select = [ - "B", # flake8-bugbear rules - "C4", # flake8-comprehensions - "E", # pycodestyle E rules - "F", # pyflakes rules - "FURB", # refurb rules - "I", # isort rules - "N", # pep8-naming - "PT", # flake8-pytest-style rules + "B", # flake8-bugbear rules + "C4", # flake8-comprehensions + "E", # pycodestyle E rules + "F", # pyflakes rules + "FURB", # refurb rules + "I", # isort rules + "N", # pep8-naming + "PT", # flake8-pytest-style rules "PLC0208", # iteration-over-set "PLC0414", # useless-import-alias "PLE0604", # invalid-all-object @@ -24,19 +22,19 @@ select = [ "PLR0402", # manual-from-import "PLR1711", # useless-return "PLR1714", # repeated-equality-comparison - "RUF013", # implicit-optional - "RUF019", # unnecessary-key-check - "RUF100", # unused-noqa - "RUF101", # redirected-noqa - "RUF200", # invalid-pyproject-toml - "RUF022", # unsorted-dunder-all - "S506", # unsafe-yaml-load - "SIM", # flake8-simplify rules - "TRY400", # error-instead-of-exception - "TRY401", # verbose-log-message - "UP", # pyupgrade rules - "W191", # tab-indentation - "W605", # invalid-escape-sequence + "RUF013", # implicit-optional + "RUF019", # unnecessary-key-check + "RUF100", # unused-noqa + "RUF101", # redirected-noqa + "RUF200", # invalid-pyproject-toml + "RUF022", # unsorted-dunder-all + "S506", # unsafe-yaml-load + "SIM", # flake8-simplify rules + "TRY400", # error-instead-of-exception + "TRY401", # verbose-log-message + "UP", # pyupgrade rules + "W191", # tab-indentation + "W605", # invalid-escape-sequence # security related linting rules # RCE proctection (sort of) "S102", # exec-builtin, disallow use of `exec` @@ -47,36 +45,37 @@ select = [ ] ignore = [ - "E402", # module-import-not-at-top-of-file - "E711", # none-comparison - "E712", # true-false-comparison - "E721", # type-comparison - "E722", # bare-except - "F821", # undefined-name - "F841", # unused-variable + "E402", # module-import-not-at-top-of-file + "E711", # none-comparison + "E712", # true-false-comparison + "E721", # type-comparison + "E722", # bare-except + "F821", # undefined-name + "F841", # unused-variable "FURB113", # repeated-append "FURB152", # math-constant - "UP007", # non-pep604-annotation - "UP032", # f-string - "UP045", # non-pep604-annotation-optional - "B005", # strip-with-multi-characters - "B006", # mutable-argument-default - "B007", # unused-loop-control-variable - "B026", # star-arg-unpacking-after-keyword-arg - "B903", # class-as-data-structure - "B904", # raise-without-from-inside-except - "B905", # zip-without-explicit-strict - "N806", # non-lowercase-variable-in-function - "N815", # mixed-case-variable-in-class-scope - "PT011", # pytest-raises-too-broad - "SIM102", # collapsible-if - "SIM103", # needless-bool - "SIM105", # suppressible-exception - "SIM107", # return-in-try-except-finally - "SIM108", # if-else-block-instead-of-if-exp - "SIM113", # enumerate-for-loop - "SIM117", # multiple-with-statements - "SIM210", # if-expr-with-true-false + "UP007", # non-pep604-annotation + "UP032", # f-string + "UP045", # non-pep604-annotation-optional + "B005", # strip-with-multi-characters + "B006", # mutable-argument-default + "B007", # unused-loop-control-variable + "B026", # star-arg-unpacking-after-keyword-arg + "B903", # class-as-data-structure + "B904", # raise-without-from-inside-except + "B905", # zip-without-explicit-strict + "N806", # non-lowercase-variable-in-function + "N815", # mixed-case-variable-in-class-scope + "PT011", # pytest-raises-too-broad + "SIM102", # collapsible-if + "SIM103", # needless-bool + "SIM105", # suppressible-exception + "SIM107", # return-in-try-except-finally + "SIM108", # if-else-block-instead-of-if-exp + "SIM113", # enumerate-for-loop + "SIM117", # multiple-with-statements + "SIM210", # if-expr-with-true-false + "UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/ ] [lint.per-file-ignores] diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index f17c28dcd4..28ee7395d6 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -63,6 +63,7 @@ from .app import ( statistic, workflow, workflow_app_log, + workflow_draft_variable, workflow_run, workflow_statistic, ) diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 8cb7ad9f5b..f5257fae79 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -56,8 +56,7 @@ class InsertExploreAppListApi(Resource): parser.add_argument("position", type=int, required=True, nullable=False, location="json") args = parser.parse_args() - with Session(db.engine) as session: - app = session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none() + app = db.session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none() if not app: raise NotFound(f"App '{args['app_id']}' is not found") @@ -78,38 +77,38 @@ class InsertExploreAppListApi(Resource): select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]) ).scalar_one_or_none() - if not recommended_app: - recommended_app = RecommendedApp( - app_id=app.id, - description=desc, - copyright=copy_right, - privacy_policy=privacy_policy, - custom_disclaimer=custom_disclaimer, - language=args["language"], - category=args["category"], - position=args["position"], - ) + if not recommended_app: + recommended_app = RecommendedApp( + app_id=app.id, + description=desc, + copyright=copy_right, + privacy_policy=privacy_policy, + custom_disclaimer=custom_disclaimer, + language=args["language"], + category=args["category"], + position=args["position"], + ) - db.session.add(recommended_app) + db.session.add(recommended_app) - app.is_public = True - db.session.commit() + app.is_public = True + db.session.commit() - return {"result": "success"}, 201 - else: - recommended_app.description = desc - recommended_app.copyright = copy_right - recommended_app.privacy_policy = privacy_policy - recommended_app.custom_disclaimer = custom_disclaimer - recommended_app.language = args["language"] - recommended_app.category = args["category"] - recommended_app.position = args["position"] + return {"result": "success"}, 201 + else: + recommended_app.description = desc + recommended_app.copyright = copy_right + recommended_app.privacy_policy = privacy_policy + recommended_app.custom_disclaimer = custom_disclaimer + recommended_app.language = args["language"] + recommended_app.category = args["category"] + recommended_app.position = args["position"] - app.is_public = True + app.is_public = True - db.session.commit() + db.session.commit() - return {"result": "success"}, 200 + return {"result": "success"}, 200 class InsertExploreAppApi(Resource): diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index 5dc6515ce0..9ffb94e9f9 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -17,6 +17,8 @@ from libs.login import login_required from models import Account from models.model import App from services.app_dsl_service import AppDslService, ImportStatus +from services.enterprise.enterprise_service import EnterpriseService +from services.feature_service import FeatureService class AppImportApi(Resource): @@ -60,7 +62,9 @@ class AppImportApi(Resource): app_id=args.get("app_id"), ) session.commit() - + if result.app_id and FeatureService.get_system_features().webapp_auth.enabled: + # update web app setting as private + EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private") # Return appropriate status code based on result status = result.status if status == ImportStatus.FAILED.value: diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index cbbdd324ba..a9f088a276 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -1,5 +1,6 @@ import json import logging +from collections.abc import Sequence from typing import cast from flask import abort, request @@ -18,10 +19,12 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom +from core.file.models import File from extensions.ext_database import db -from factories import variable_factory +from factories import file_factory, variable_factory from fields.workflow_fields import workflow_fields, workflow_pagination_fields from fields.workflow_run_fields import workflow_run_node_execution_fields from libs import helper @@ -30,6 +33,7 @@ from libs.login import current_user, login_required from models import App from models.account import Account from models.model import AppMode +from models.workflow import Workflow from services.app_generate_service import AppGenerateService from services.errors.app import WorkflowHashNotEqualError from services.errors.llm import InvokeRateLimitError @@ -38,6 +42,24 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE logger = logging.getLogger(__name__) +# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing +# at the controller level rather than in the workflow logic. This would improve separation +# of concerns and make the code more maintainable. +def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence[File]: + files = files or [] + + file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) + file_objs: Sequence[File] = [] + if file_extra_config is None: + return file_objs + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=workflow.tenant_id, + config=file_extra_config, + ) + return file_objs + + class DraftWorkflowApi(Resource): @setup_required @login_required @@ -402,15 +424,30 @@ class DraftWorkflowNodeRunApi(Resource): parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("query", type=str, required=False, location="json", default="") + parser.add_argument("files", type=list, location="json", default=[]) args = parser.parse_args() - inputs = args.get("inputs") - if inputs == None: + user_inputs = args.get("inputs") + if user_inputs is None: raise ValueError("missing inputs") + workflow_srv = WorkflowService() + # fetch draft workflow by app_model + draft_workflow = workflow_srv.get_draft_workflow(app_model=app_model) + if not draft_workflow: + raise ValueError("Workflow not initialized") + files = _parse_file(draft_workflow, args.get("files")) workflow_service = WorkflowService() + workflow_node_execution = workflow_service.run_draft_workflow_node( - app_model=app_model, node_id=node_id, user_inputs=inputs, account=current_user + app_model=app_model, + draft_workflow=draft_workflow, + node_id=node_id, + user_inputs=user_inputs, + account=current_user, + query=args.get("query", ""), + files=files, ) return workflow_node_execution @@ -731,6 +768,27 @@ class WorkflowByIdApi(Resource): return None, 204 +class DraftWorkflowNodeLastRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @marshal_with(workflow_run_node_execution_fields) + def get(self, app_model: App, node_id: str): + srv = WorkflowService() + workflow = srv.get_draft_workflow(app_model) + if not workflow: + raise NotFound("Workflow not found") + node_exec = srv.get_node_last_run( + app_model=app_model, + workflow=workflow, + node_id=node_id, + ) + if node_exec is None: + raise NotFound("last run not found") + return node_exec + + api.add_resource( DraftWorkflowApi, "/apps//workflows/draft", @@ -795,3 +853,7 @@ api.add_resource( WorkflowByIdApi, "/apps//workflows/", ) +api.add_resource( + DraftWorkflowNodeLastRunApi, + "/apps//workflows/draft/nodes//last-run", +) diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py new file mode 100644 index 0000000000..00d6fa3cbf --- /dev/null +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -0,0 +1,421 @@ +import logging +from typing import Any, NoReturn + +from flask import Response +from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden + +from controllers.console import api +from controllers.console.app.error import ( + DraftWorkflowNotExist, +) +from controllers.console.app.wraps import get_app_model +from controllers.console.wraps import account_initialization_required, setup_required +from controllers.web.error import InvalidArgumentError, NotFoundError +from core.variables.segment_group import SegmentGroup +from core.variables.segments import ArrayFileSegment, FileSegment, Segment +from core.variables.types import SegmentType +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from factories.file_factory import build_from_mapping, build_from_mappings +from factories.variable_factory import build_segment_with_type +from libs.login import current_user, login_required +from models import App, AppMode, db +from models.workflow import WorkflowDraftVariable +from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService +from services.workflow_service import WorkflowService + +logger = logging.getLogger(__name__) + + +def _convert_values_to_json_serializable_object(value: Segment) -> Any: + if isinstance(value, FileSegment): + return value.value.model_dump() + elif isinstance(value, ArrayFileSegment): + return [i.model_dump() for i in value.value] + elif isinstance(value, SegmentGroup): + return [_convert_values_to_json_serializable_object(i) for i in value.value] + else: + return value.value + + +def _serialize_var_value(variable: WorkflowDraftVariable) -> Any: + value = variable.get_value() + # create a copy of the value to avoid affecting the model cache. + value = value.model_copy(deep=True) + # Refresh the url signature before returning it to client. + if isinstance(value, FileSegment): + file = value.value + file.remote_url = file.generate_url() + elif isinstance(value, ArrayFileSegment): + files = value.value + for file in files: + file.remote_url = file.generate_url() + return _convert_values_to_json_serializable_object(value) + + +def _create_pagination_parser(): + parser = reqparse.RequestParser() + parser.add_argument( + "page", + type=inputs.int_range(1, 100_000), + required=False, + default=1, + location="args", + help="the page of data requested", + ) + parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") + return parser + + +_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = { + "id": fields.String, + "type": fields.String(attribute=lambda model: model.get_variable_type()), + "name": fields.String, + "description": fields.String, + "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()), + "value_type": fields.String, + "edited": fields.Boolean(attribute=lambda model: model.edited), + "visible": fields.Boolean, +} + +_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict( + _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, + value=fields.Raw(attribute=_serialize_var_value), +) + +_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = { + "id": fields.String, + "type": fields.String(attribute=lambda _: "env"), + "name": fields.String, + "description": fields.String, + "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()), + "value_type": fields.String, + "edited": fields.Boolean(attribute=lambda model: model.edited), + "visible": fields.Boolean, +} + +_WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = { + "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)), +} + + +def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]: + return var_list.variables + + +_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = { + "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items), + "total": fields.Raw(), +} + +_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = { + "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items), +} + + +def _api_prerequisite(f): + """Common prerequisites for all draft workflow variable APIs. + + It ensures the following conditions are satisfied: + + - Dify has been property setup. + - The request user has logged in and initialized. + - The requested app is a workflow or a chat flow. + - The request user has the edit permission for the app. + """ + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def wrapper(*args, **kwargs): + if not current_user.is_editor: + raise Forbidden() + return f(*args, **kwargs) + + return wrapper + + +class WorkflowVariableCollectionApi(Resource): + @_api_prerequisite + @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) + def get(self, app_model: App): + """ + Get draft workflow + """ + parser = _create_pagination_parser() + args = parser.parse_args() + + # fetch draft workflow by app_model + workflow_service = WorkflowService() + workflow_exist = workflow_service.is_workflow_exist(app_model=app_model) + if not workflow_exist: + raise DraftWorkflowNotExist() + + # fetch draft workflow by app_model + with Session(bind=db.engine, expire_on_commit=False) as session: + draft_var_srv = WorkflowDraftVariableService( + session=session, + ) + workflow_vars = draft_var_srv.list_variables_without_values( + app_id=app_model.id, + page=args.page, + limit=args.limit, + ) + + return workflow_vars + + @_api_prerequisite + def delete(self, app_model: App): + draft_var_srv = WorkflowDraftVariableService( + session=db.session(), + ) + draft_var_srv.delete_workflow_variables(app_model.id) + db.session.commit() + return Response("", 204) + + +def validate_node_id(node_id: str) -> NoReturn | None: + if node_id in [ + CONVERSATION_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, + ]: + # NOTE(QuantumGhost): While we store the system and conversation variables as node variables + # with specific `node_id` in database, we still want to make the API separated. By disallowing + # accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`, + # we mitigate the risk that user of the API depending on the implementation detail of the API. + # + # ref: [Hyrum's Law](https://www.hyrumslaw.com/) + + raise InvalidArgumentError( + f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}", + ) + return None + + +class NodeVariableCollectionApi(Resource): + @_api_prerequisite + @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) + def get(self, app_model: App, node_id: str): + validate_node_id(node_id) + with Session(bind=db.engine, expire_on_commit=False) as session: + draft_var_srv = WorkflowDraftVariableService( + session=session, + ) + node_vars = draft_var_srv.list_node_variables(app_model.id, node_id) + + return node_vars + + @_api_prerequisite + def delete(self, app_model: App, node_id: str): + validate_node_id(node_id) + srv = WorkflowDraftVariableService(db.session()) + srv.delete_node_variables(app_model.id, node_id) + db.session.commit() + return Response("", 204) + + +class VariableApi(Resource): + _PATCH_NAME_FIELD = "name" + _PATCH_VALUE_FIELD = "value" + + @_api_prerequisite + @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) + def get(self, app_model: App, variable_id: str): + draft_var_srv = WorkflowDraftVariableService( + session=db.session(), + ) + variable = draft_var_srv.get_variable(variable_id=variable_id) + if variable is None: + raise NotFoundError(description=f"variable not found, id={variable_id}") + if variable.app_id != app_model.id: + raise NotFoundError(description=f"variable not found, id={variable_id}") + return variable + + @_api_prerequisite + @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) + def patch(self, app_model: App, variable_id: str): + # Request payload for file types: + # + # Local File: + # + # { + # "type": "image", + # "transfer_method": "local_file", + # "url": "", + # "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190" + # } + # + # Remote File: + # + # + # { + # "type": "image", + # "transfer_method": "remote_url", + # "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=", + # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4" + # } + + parser = reqparse.RequestParser() + parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json") + # Parse 'value' field as-is to maintain its original data structure + parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json") + + draft_var_srv = WorkflowDraftVariableService( + session=db.session(), + ) + args = parser.parse_args(strict=True) + + variable = draft_var_srv.get_variable(variable_id=variable_id) + if variable is None: + raise NotFoundError(description=f"variable not found, id={variable_id}") + if variable.app_id != app_model.id: + raise NotFoundError(description=f"variable not found, id={variable_id}") + + new_name = args.get(self._PATCH_NAME_FIELD, None) + raw_value = args.get(self._PATCH_VALUE_FIELD, None) + if new_name is None and raw_value is None: + return variable + + new_value = None + if raw_value is not None: + if variable.value_type == SegmentType.FILE: + if not isinstance(raw_value, dict): + raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") + raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id) + elif variable.value_type == SegmentType.ARRAY_FILE: + if not isinstance(raw_value, list): + raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") + if len(raw_value) > 0 and not isinstance(raw_value[0], dict): + raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") + raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id) + new_value = build_segment_with_type(variable.value_type, raw_value) + draft_var_srv.update_variable(variable, name=new_name, value=new_value) + db.session.commit() + return variable + + @_api_prerequisite + def delete(self, app_model: App, variable_id: str): + draft_var_srv = WorkflowDraftVariableService( + session=db.session(), + ) + variable = draft_var_srv.get_variable(variable_id=variable_id) + if variable is None: + raise NotFoundError(description=f"variable not found, id={variable_id}") + if variable.app_id != app_model.id: + raise NotFoundError(description=f"variable not found, id={variable_id}") + draft_var_srv.delete_variable(variable) + db.session.commit() + return Response("", 204) + + +class VariableResetApi(Resource): + @_api_prerequisite + def put(self, app_model: App, variable_id: str): + draft_var_srv = WorkflowDraftVariableService( + session=db.session(), + ) + + workflow_srv = WorkflowService() + draft_workflow = workflow_srv.get_draft_workflow(app_model) + if draft_workflow is None: + raise NotFoundError( + f"Draft workflow not found, app_id={app_model.id}", + ) + variable = draft_var_srv.get_variable(variable_id=variable_id) + if variable is None: + raise NotFoundError(description=f"variable not found, id={variable_id}") + if variable.app_id != app_model.id: + raise NotFoundError(description=f"variable not found, id={variable_id}") + + resetted = draft_var_srv.reset_variable(draft_workflow, variable) + db.session.commit() + if resetted is None: + return Response("", 204) + else: + return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS) + + +def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList: + with Session(bind=db.engine, expire_on_commit=False) as session: + draft_var_srv = WorkflowDraftVariableService( + session=session, + ) + if node_id == CONVERSATION_VARIABLE_NODE_ID: + draft_vars = draft_var_srv.list_conversation_variables(app_model.id) + elif node_id == SYSTEM_VARIABLE_NODE_ID: + draft_vars = draft_var_srv.list_system_variables(app_model.id) + else: + draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id) + return draft_vars + + +class ConversationVariableCollectionApi(Resource): + @_api_prerequisite + @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) + def get(self, app_model: App): + # NOTE(QuantumGhost): Prefill conversation variables into the draft variables table + # so their IDs can be returned to the caller. + workflow_srv = WorkflowService() + draft_workflow = workflow_srv.get_draft_workflow(app_model) + if draft_workflow is None: + raise NotFoundError(description=f"draft workflow not found, id={app_model.id}") + draft_var_srv = WorkflowDraftVariableService(db.session()) + draft_var_srv.prefill_conversation_variable_default_values(draft_workflow) + db.session.commit() + return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID) + + +class SystemVariableCollectionApi(Resource): + @_api_prerequisite + @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) + def get(self, app_model: App): + return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID) + + +class EnvironmentVariableCollectionApi(Resource): + @_api_prerequisite + def get(self, app_model: App): + """ + Get draft workflow + """ + # fetch draft workflow by app_model + workflow_service = WorkflowService() + workflow = workflow_service.get_draft_workflow(app_model=app_model) + if workflow is None: + raise DraftWorkflowNotExist() + + env_vars = workflow.environment_variables + env_vars_list = [] + for v in env_vars: + env_vars_list.append( + { + "id": v.id, + "type": "env", + "name": v.name, + "description": v.description, + "selector": v.selector, + "value_type": v.value_type.value, + "value": v.value, + # Do not track edited for env vars. + "edited": False, + "visible": True, + "editable": True, + } + ) + + return {"items": env_vars_list} + + +api.add_resource( + WorkflowVariableCollectionApi, + "/apps//workflows/draft/variables", +) +api.add_resource(NodeVariableCollectionApi, "/apps//workflows/draft/nodes//variables") +api.add_resource(VariableApi, "/apps//workflows/draft/variables/") +api.add_resource(VariableResetApi, "/apps//workflows/draft/variables//reset") + +api.add_resource(ConversationVariableCollectionApi, "/apps//workflows/draft/conversation-variables") +api.add_resource(SystemVariableCollectionApi, "/apps//workflows/draft/system-variables") +api.add_resource(EnvironmentVariableCollectionApi, "/apps//workflows/draft/environment-variables") diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index 9ad8c15847..03b60610aa 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -8,6 +8,15 @@ from libs.login import current_user from models import App, AppMode +def _load_app_model(app_id: str) -> Optional[App]: + app_model = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) + return app_model + + def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None): def decorator(view_func): @wraps(view_func) @@ -20,11 +29,7 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[ del kwargs["app_id"] - app_model = ( - db.session.query(App) - .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") - .first() - ) + app_model = _load_app_model(app_id) if not app_model: raise AppNotFoundError() diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 8d302ca05e..42592c6c9a 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -6,7 +6,7 @@ from typing import cast from flask import request from flask_login import current_user -from flask_restful import Resource, fields, marshal, marshal_with, reqparse +from flask_restful import Resource, marshal, marshal_with, reqparse from sqlalchemy import asc, desc, select from werkzeug.exceptions import Forbidden, NotFound @@ -44,7 +44,6 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.plugin.impl.exc import PluginDaemonClientSideError from core.rag.extractor.entity.extract_setting import ExtractSetting from extensions.ext_database import db -from extensions.ext_redis import redis_client from fields.document_fields import ( dataset_and_document_fields, document_fields, @@ -56,8 +55,6 @@ from models import Dataset, DatasetProcessRule, Document, DocumentSegment, Uploa from models.dataset import DocumentPipelineExecutionLog from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig -from tasks.add_document_to_index_task import add_document_to_index_task -from tasks.remove_document_from_index_task import remove_document_from_index_task class DocumentResource(Resource): @@ -244,12 +241,10 @@ class DatasetDocumentListApi(Resource): return response - documents_and_batch_fields = {"documents": fields.List(fields.Nested(document_fields)), "batch": fields.String} - @setup_required @login_required @account_initialization_required - @marshal_with(documents_and_batch_fields) + @marshal_with(dataset_and_document_fields) @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") def post(self, dataset_id): @@ -295,6 +290,8 @@ class DatasetDocumentListApi(Resource): try: documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user) + dataset = DatasetService.get_dataset(dataset_id) + except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except QuotaExceededError: @@ -302,7 +299,7 @@ class DatasetDocumentListApi(Resource): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - return {"documents": documents, "batch": batch} + return {"dataset": dataset, "documents": documents, "batch": batch} @setup_required @login_required @@ -864,77 +861,16 @@ class DocumentStatusApi(DocumentResource): DatasetService.check_dataset_permission(dataset, current_user) document_ids = request.args.getlist("document_id") - for document_id in document_ids: - document = self.get_document(dataset_id, document_id) - indexing_cache_key = "document_{}_indexing".format(document.id) - cache_result = redis_client.get(indexing_cache_key) - if cache_result is not None: - raise InvalidActionError(f"Document:{document.name} is being indexed, please try again later") + try: + DocumentService.batch_update_document_status(dataset, document_ids, action, current_user) + except services.errors.document.DocumentIndexingError as e: + raise InvalidActionError(str(e)) + except ValueError as e: + raise InvalidActionError(str(e)) + except NotFound as e: + raise NotFound(str(e)) - if action == "enable": - if document.enabled: - continue - document.enabled = True - document.disabled_at = None - document.disabled_by = None - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() - - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) - - add_document_to_index_task.delay(document_id) - - elif action == "disable": - if not document.completed_at or document.indexing_status != "completed": - raise InvalidActionError(f"Document: {document.name} is not completed.") - if not document.enabled: - continue - - document.enabled = False - document.disabled_at = datetime.now(UTC).replace(tzinfo=None) - document.disabled_by = current_user.id - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() - - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) - - remove_document_from_index_task.delay(document_id) - - elif action == "archive": - if document.archived: - continue - - document.archived = True - document.archived_at = datetime.now(UTC).replace(tzinfo=None) - document.archived_by = current_user.id - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() - - if document.enabled: - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) - - remove_document_from_index_task.delay(document_id) - - elif action == "un_archive": - if not document.archived: - continue - document.archived = False - document.archived_at = None - document.archived_by = None - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() - - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) - - add_document_to_index_task.delay(document_id) - - else: - raise InvalidActionError() return {"result": "success"}, 200 diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index ba74e2c074..b4eb5e246b 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -15,7 +15,7 @@ class LoadBalancingCredentialsValidateApi(Resource): @login_required @account_initialization_required def post(self, provider: str): - if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role): + if not TenantAccountRole.is_privileged_role(current_user.current_role): raise Forbidden() tenant_id = current_user.current_tenant_id @@ -64,7 +64,7 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): @login_required @account_initialization_required def post(self, provider: str, config_id: str): - if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role): + if not TenantAccountRole.is_privileged_role(current_user.current_role): raise Forbidden() tenant_id = current_user.current_tenant_id diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 1467dfb6b3..839afdb9fd 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -4,7 +4,7 @@ from werkzeug.exceptions import Forbidden, NotFound import services.dataset_service from controllers.service_api import api -from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError +from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError from controllers.service_api.wraps import ( DatasetApiResource, cloud_edition_billing_rate_limit_check, @@ -17,7 +17,7 @@ from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import tag_fields from libs.login import current_user from models.dataset import Dataset, DatasetPermissionEnum -from services.dataset_service import DatasetPermissionService, DatasetService +from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import RetrievalModel from services.tag_service import TagService @@ -329,6 +329,56 @@ class DatasetApi(DatasetApiResource): raise DatasetInUseError() +class DocumentStatusApi(DatasetApiResource): + """Resource for batch document status operations.""" + + def patch(self, tenant_id, dataset_id, action): + """ + Batch update document status. + + Args: + tenant_id: tenant id + dataset_id: dataset id + action: action to perform (enable, disable, archive, un_archive) + + Returns: + dict: A dictionary with a key 'result' and a value 'success' + int: HTTP status code 200 indicating that the operation was successful. + + Raises: + NotFound: If the dataset with the given ID does not exist. + Forbidden: If the user does not have permission. + InvalidActionError: If the action is invalid or cannot be performed. + """ + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) + + if dataset is None: + raise NotFound("Dataset not found.") + + # Check user's permission + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + # Check dataset model setting + DatasetService.check_dataset_model_setting(dataset) + + # Get document IDs from request body + data = request.get_json() + document_ids = data.get("document_ids", []) + + try: + DocumentService.batch_update_document_status(dataset, document_ids, action, current_user) + except services.errors.document.DocumentIndexingError as e: + raise InvalidActionError(str(e)) + except ValueError as e: + raise InvalidActionError(str(e)) + + return {"result": "success"}, 200 + + class DatasetTagsApi(DatasetApiResource): @validate_dataset_token @marshal_with(tag_fields) @@ -457,6 +507,7 @@ class DatasetTagsBindingStatusApi(DatasetApiResource): api.add_resource(DatasetListApi, "/datasets") api.add_resource(DatasetApi, "/datasets/") +api.add_resource(DocumentStatusApi, "/datasets//documents/status/") api.add_resource(DatasetTagsApi, "/datasets/tags") api.add_resource(DatasetTagBindingApi, "/datasets/tags/binding") api.add_resource(DatasetTagUnbindingApi, "/datasets/tags/unbinding") diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py index 4371e679db..036e11d5c5 100644 --- a/api/controllers/web/error.py +++ b/api/controllers/web/error.py @@ -139,3 +139,13 @@ class InvokeRateLimitError(BaseHTTPException): error_code = "rate_limit_error" description = "Rate Limit Error" code = 429 + + +class NotFoundError(BaseHTTPException): + error_code = "not_found" + code = 404 + + +class InvalidArgumentError(BaseHTTPException): + error_code = "invalid_param" + code = 400 diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index d317af15a4..5308339871 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -104,6 +104,7 @@ class VariableEntity(BaseModel): Variable Entity. """ + # `variable` records the name of the variable in user inputs. variable: str label: str description: str = "" diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index a8848b9534..afecd99978 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -29,6 +29,7 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts @@ -36,6 +37,7 @@ from models import Account, App, Conversation, EndUser, Message, Workflow, Workf from models.enums import WorkflowRunTriggeredFrom from services.conversation_service import ConversationService from services.errors.message import MessageNotExistsError +from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService logger = logging.getLogger(__name__) @@ -116,6 +118,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): ) # parse files + # TODO(QuantumGhost): Move file parsing logic to the API controller layer + # for better separation of concerns. + # + # For implementation reference, see the `_parse_file` function and + # `DraftWorkflowNodeRunApi` class which handle this properly. files = args["files"] if args.get("files") else [] file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) if file_extra_config: @@ -261,6 +268,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) + var_loader = DraftVarLoader( + engine=db.engine, + app_id=application_generate_entity.app_config.app_id, + tenant_id=application_generate_entity.app_config.tenant_id, + ) + draft_var_srv = WorkflowDraftVariableService(db.session()) + draft_var_srv.prefill_conversation_variable_default_values(workflow) return self._generate( workflow=workflow, @@ -271,6 +285,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, conversation=None, stream=streaming, + variable_loader=var_loader, ) def single_loop_generate( @@ -336,6 +351,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) + var_loader = DraftVarLoader( + engine=db.engine, + app_id=application_generate_entity.app_config.app_id, + tenant_id=application_generate_entity.app_config.tenant_id, + ) + draft_var_srv = WorkflowDraftVariableService(db.session()) + draft_var_srv.prefill_conversation_variable_default_values(workflow) return self._generate( workflow=workflow, @@ -346,6 +368,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, conversation=None, stream=streaming, + variable_loader=var_loader, ) def _generate( @@ -359,6 +382,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_node_execution_repository: WorkflowNodeExecutionRepository, conversation: Optional[Conversation] = None, stream: bool = True, + variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: """ Generate App response. @@ -410,6 +434,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): "conversation_id": conversation.id, "message_id": message.id, "context": context, + "variable_loader": variable_loader, }, ) @@ -438,6 +463,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation_id: str, message_id: str, context: contextvars.Context, + variable_loader: VariableLoader, ) -> None: """ Generate worker in a new thread. @@ -464,6 +490,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation=conversation, message=message, dialogue_count=self._dialogue_count, + variable_loader=variable_loader, ) runner.run() diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index d9b3833862..840a3c9d3b 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -19,6 +19,7 @@ from core.moderation.base import ModerationError from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey +from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.enums import UserFrom @@ -40,14 +41,17 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): conversation: Conversation, message: Message, dialogue_count: int, + variable_loader: VariableLoader, ) -> None: - super().__init__(queue_manager) - + super().__init__(queue_manager, variable_loader) self.application_generate_entity = application_generate_entity self.conversation = conversation self.message = message self._dialogue_count = dialogue_count + def _get_app_id(self) -> str: + return self.application_generate_entity.app_config.app_id + def run(self) -> None: app_config = self.application_generate_entity.app_config app_config = cast(AdvancedChatAppConfig, app_config) diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index a448bf8a94..75a0b00424 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -124,6 +124,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): override_model_config_dict["retriever_resource"] = {"enabled": True} # parse files + # TODO(QuantumGhost): Move file parsing logic to the API controller layer + # for better separation of concerns. + # + # For implementation reference, see the `_parse_file` function and + # `DraftWorkflowNodeRunApi` class which handle this properly. files = args.get("files") or [] file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index a1329cb938..76fae879f2 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -115,6 +115,11 @@ class ChatAppGenerator(MessageBasedAppGenerator): override_model_config_dict["retriever_resource"] = {"enabled": True} # parse files + # TODO(QuantumGhost): Move file parsing logic to the API controller layer + # for better separation of concerns. + # + # For implementation reference, see the `_parse_file` function and + # `DraftWorkflowNodeRunApi` class which handle this properly. files = args["files"] if args.get("files") else [] file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index aa74f8c318..5bdf937767 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -50,6 +50,7 @@ from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution from core.workflow.nodes import NodeType from core.workflow.nodes.datasource.entities import DatasourceNodeData from core.workflow.nodes.tool.entities import ToolNodeData +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import ( Account, CreatorUserRole, @@ -127,7 +128,7 @@ class WorkflowResponseConverter: id=workflow_execution.id_, workflow_id=workflow_execution.workflow_id, status=workflow_execution.status, - outputs=workflow_execution.outputs, + outputs=WorkflowRuntimeTypeConverter().to_json_encodable(workflow_execution.outputs), error=workflow_execution.error_message, elapsed_time=workflow_execution.elapsed_time, total_tokens=workflow_execution.total_tokens, @@ -212,6 +213,8 @@ class WorkflowResponseConverter: if not workflow_node_execution.finished_at: return None + json_converter = WorkflowRuntimeTypeConverter() + return NodeFinishStreamResponse( task_id=task_id, workflow_run_id=workflow_node_execution.workflow_execution_id, @@ -224,7 +227,7 @@ class WorkflowResponseConverter: predecessor_node_id=workflow_node_execution.predecessor_node_id, inputs=workflow_node_execution.inputs, process_data=workflow_node_execution.process_data, - outputs=workflow_node_execution.outputs, + outputs=json_converter.to_json_encodable(workflow_node_execution.outputs), status=workflow_node_execution.status, error=workflow_node_execution.error, elapsed_time=workflow_node_execution.elapsed_time, @@ -255,6 +258,8 @@ class WorkflowResponseConverter: if not workflow_node_execution.finished_at: return None + json_converter = WorkflowRuntimeTypeConverter() + return NodeRetryStreamResponse( task_id=task_id, workflow_run_id=workflow_node_execution.workflow_execution_id, @@ -267,7 +272,7 @@ class WorkflowResponseConverter: predecessor_node_id=workflow_node_execution.predecessor_node_id, inputs=workflow_node_execution.inputs, process_data=workflow_node_execution.process_data, - outputs=workflow_node_execution.outputs, + outputs=json_converter.to_json_encodable(workflow_node_execution.outputs), status=workflow_node_execution.status, error=workflow_node_execution.error, elapsed_time=workflow_node_execution.elapsed_time, @@ -386,6 +391,7 @@ class WorkflowResponseConverter: workflow_execution_id: str, event: QueueIterationCompletedEvent, ) -> IterationNodeCompletedStreamResponse: + json_converter = WorkflowRuntimeTypeConverter() return IterationNodeCompletedStreamResponse( task_id=task_id, workflow_run_id=workflow_execution_id, @@ -394,7 +400,7 @@ class WorkflowResponseConverter: node_id=event.node_id, node_type=event.node_type.value, title=event.node_data.title, - outputs=event.outputs, + outputs=json_converter.to_json_encodable(event.outputs), created_at=int(time.time()), extras={}, inputs=event.inputs or {}, @@ -473,7 +479,7 @@ class WorkflowResponseConverter: node_id=event.node_id, node_type=event.node_type.value, title=event.node_data.title, - outputs=event.outputs, + outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs), created_at=int(time.time()), extras={}, inputs=event.inputs or {}, diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index adcbaad3ec..7bc4a0a5c0 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -101,6 +101,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator): ) # parse files + # TODO(QuantumGhost): Move file parsing logic to the API controller layer + # for better separation of concerns. + # + # For implementation reference, see the `_parse_file` function and + # `DraftWorkflowNodeRunApi` class which handle this properly. files = args["files"] if args.get("files") else [] file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index fd15bd9f50..369fa0e48c 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -27,11 +27,13 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.enums import WorkflowRunTriggeredFrom +from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService logger = logging.getLogger(__name__) @@ -94,6 +96,11 @@ class WorkflowAppGenerator(BaseAppGenerator): files: Sequence[Mapping[str, Any]] = args.get("files") or [] # parse files + # TODO(QuantumGhost): Move file parsing logic to the API controller layer + # for better separation of concerns. + # + # For implementation reference, see the `_parse_file` function and + # `DraftWorkflowNodeRunApi` class which handle this properly. file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) system_files = file_factory.build_from_mappings( mappings=files, @@ -186,6 +193,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_node_execution_repository: WorkflowNodeExecutionRepository, streaming: bool = True, workflow_thread_pool_id: Optional[str] = None, + variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: """ Generate App response. @@ -219,6 +227,7 @@ class WorkflowAppGenerator(BaseAppGenerator): "queue_manager": queue_manager, "context": context, "workflow_thread_pool_id": workflow_thread_pool_id, + "variable_loader": variable_loader, }, ) @@ -303,6 +312,13 @@ class WorkflowAppGenerator(BaseAppGenerator): app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) + draft_var_srv = WorkflowDraftVariableService(db.session()) + draft_var_srv.prefill_conversation_variable_default_values(workflow) + var_loader = DraftVarLoader( + engine=db.engine, + app_id=application_generate_entity.app_config.app_id, + tenant_id=application_generate_entity.app_config.tenant_id, + ) return self._generate( app_model=app_model, @@ -313,6 +329,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, + variable_loader=var_loader, ) def single_loop_generate( @@ -379,7 +396,13 @@ class WorkflowAppGenerator(BaseAppGenerator): app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) - + draft_var_srv = WorkflowDraftVariableService(db.session()) + draft_var_srv.prefill_conversation_variable_default_values(workflow) + var_loader = DraftVarLoader( + engine=db.engine, + app_id=application_generate_entity.app_config.app_id, + tenant_id=application_generate_entity.app_config.tenant_id, + ) return self._generate( app_model=app_model, workflow=workflow, @@ -389,6 +412,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, + variable_loader=var_loader, ) def _generate_worker( @@ -397,6 +421,7 @@ class WorkflowAppGenerator(BaseAppGenerator): application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager, context: contextvars.Context, + variable_loader: VariableLoader, workflow_thread_pool_id: Optional[str] = None, ) -> None: """ @@ -415,6 +440,7 @@ class WorkflowAppGenerator(BaseAppGenerator): application_generate_entity=application_generate_entity, queue_manager=queue_manager, workflow_thread_pool_id=workflow_thread_pool_id, + variable_loader=variable_loader, ) runner.run() diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index b59e34e222..07aeb57fa3 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -12,6 +12,7 @@ from core.app.entities.app_invoke_entities import ( from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey +from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.enums import UserFrom @@ -30,6 +31,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager, + variable_loader: VariableLoader, workflow_thread_pool_id: Optional[str] = None, ) -> None: """ @@ -37,10 +39,13 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): :param queue_manager: application queue manager :param workflow_thread_pool_id: workflow thread pool id """ + super().__init__(queue_manager, variable_loader) self.application_generate_entity = application_generate_entity - self.queue_manager = queue_manager self.workflow_thread_pool_id = workflow_thread_pool_id + def _get_app_id(self) -> str: + return self.application_generate_entity.app_config.app_id + def run(self) -> None: """ Run application diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index facc24b4ca..dc6c381e86 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -1,6 +1,8 @@ from collections.abc import Mapping from typing import Any, Optional, cast +from sqlalchemy.orm import Session + from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.queue_entities import ( @@ -33,6 +35,7 @@ from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.graph_engine.entities.event import ( AgentLogEvent, + BaseNodeEvent, GraphEngineEvent, GraphRunFailedEvent, GraphRunPartialSucceededEvent, @@ -62,15 +65,23 @@ from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.graph import Graph from core.workflow.nodes import NodeType from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.model import App from models.workflow import Workflow +from services.workflow_draft_variable_service import ( + DraftVariableSaver, +) class WorkflowBasedAppRunner(AppRunner): - def __init__(self, queue_manager: AppQueueManager): + def __init__(self, queue_manager: AppQueueManager, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER) -> None: self.queue_manager = queue_manager + self._variable_loader = variable_loader + + def _get_app_id(self) -> str: + raise NotImplementedError("not implemented") def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph: """ @@ -173,6 +184,13 @@ class WorkflowBasedAppRunner(AppRunner): except NotImplementedError: variable_mapping = {} + load_into_variable_pool( + variable_loader=self._variable_loader, + variable_pool=variable_pool, + variable_mapping=variable_mapping, + user_inputs=user_inputs, + ) + WorkflowEntry.mapping_user_inputs_to_variable_pool( variable_mapping=variable_mapping, user_inputs=user_inputs, @@ -262,6 +280,12 @@ class WorkflowBasedAppRunner(AppRunner): ) except NotImplementedError: variable_mapping = {} + load_into_variable_pool( + self._variable_loader, + variable_pool=variable_pool, + variable_mapping=variable_mapping, + user_inputs=user_inputs, + ) WorkflowEntry.mapping_user_inputs_to_variable_pool( variable_mapping=variable_mapping, @@ -376,6 +400,8 @@ class WorkflowBasedAppRunner(AppRunner): in_loop_id=event.in_loop_id, ) ) + self._save_draft_var_for_event(event) + elif isinstance(event, NodeRunFailedEvent): self._publish_event( QueueNodeFailedEvent( @@ -438,6 +464,8 @@ class WorkflowBasedAppRunner(AppRunner): in_loop_id=event.in_loop_id, ) ) + self._save_draft_var_for_event(event) + elif isinstance(event, NodeInIterationFailedEvent): self._publish_event( QueueNodeInIterationFailedEvent( @@ -690,3 +718,30 @@ class WorkflowBasedAppRunner(AppRunner): def _publish_event(self, event: AppQueueEvent) -> None: self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) + + def _save_draft_var_for_event(self, event: BaseNodeEvent): + run_result = event.route_node_state.node_run_result + if run_result is None: + return + process_data = run_result.process_data + outputs = run_result.outputs + with Session(bind=db.engine) as session, session.begin(): + draft_var_saver = DraftVariableSaver( + session=session, + app_id=self._get_app_id(), + node_id=event.node_id, + node_type=event.node_type, + # FIXME(QuantumGhost): rely on private state of queue_manager is not ideal. + invoke_from=self.queue_manager._invoke_from, + node_execution_id=event.id, + enclosing_node_id=event.in_loop_id or event.in_iteration_id or None, + ) + draft_var_saver.save(process_data=process_data, outputs=outputs) + + +def _remove_first_element_from_variable_string(key: str) -> str: + """ + Remove the first element from the prefix. + """ + prefix, remaining = key.split(".", maxsplit=1) + return remaining diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 75693be5ea..4947861ef0 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -17,9 +17,24 @@ class InvokeFrom(Enum): Invoke From. """ + # SERVICE_API indicates that this invocation is from an API call to Dify app. + # + # Description of service api in Dify docs: + # https://docs.dify.ai/en/guides/application-publishing/developing-with-apis SERVICE_API = "service-api" + + # WEB_APP indicates that this invocation is from + # the web app of the workflow (or chatflow). + # + # Description of web app in Dify docs: + # https://docs.dify.ai/en/guides/application-publishing/launch-your-webapp-quickly/README WEB_APP = "web-app" + + # EXPLORE indicates that this invocation is from + # the workflow (or chatflow) explore page. EXPLORE = "explore" + # DEBUGGER indicates that this invocation is from + # the workflow (or chatflow) edit page. DEBUGGER = "debugger" PUBLISHED = "published" diff --git a/api/core/file/constants.py b/api/core/file/constants.py index ce1d238e93..0665ed7e0d 100644 --- a/api/core/file/constants.py +++ b/api/core/file/constants.py @@ -1 +1,11 @@ +from typing import Any + +# TODO(QuantumGhost): Refactor variable type identification. Instead of directly +# comparing `dify_model_identity` with constants throughout the codebase, extract +# this logic into a dedicated function. This would encapsulate the implementation +# details of how different variable types are identified. FILE_MODEL_IDENTITY = "__dify__file__" + + +def maybe_file_object(o: Any) -> bool: + return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY diff --git a/api/core/plugin/impl/oauth.py b/api/core/plugin/impl/oauth.py index 7b984da922..b006bf1d4b 100644 --- a/api/core/plugin/impl/oauth.py +++ b/api/core/plugin/impl/oauth.py @@ -37,7 +37,6 @@ class OAuthHandler(BasePluginClient): return resp raise ValueError("No response received from plugin daemon for authorization URL request.") - def get_credentials( self, tenant_id: str, @@ -76,7 +75,6 @@ class OAuthHandler(BasePluginClient): return resp raise ValueError("No response received from plugin daemon for authorization URL request.") - def _convert_request_to_raw_data(self, request: Request) -> bytes: """ Convert a Request object to raw HTTP data. @@ -105,4 +103,4 @@ class OAuthHandler(BasePluginClient): if body: raw_data += body - return raw_data \ No newline at end of file + return raw_data diff --git a/api/core/rag/extractor/markdown_extractor.py b/api/core/rag/extractor/markdown_extractor.py index 849852ac23..c97765b1dc 100644 --- a/api/core/rag/extractor/markdown_extractor.py +++ b/api/core/rag/extractor/markdown_extractor.py @@ -68,22 +68,17 @@ class MarkdownExtractor(BaseExtractor): continue header_match = re.match(r"^#+\s", line) if header_match: - if current_header is not None: - markdown_tups.append((current_header, current_text)) - + markdown_tups.append((current_header, current_text)) current_header = line current_text = "" else: current_text += line + "\n" markdown_tups.append((current_header, current_text)) - if current_header is not None: - # pass linting, assert keys are defined - markdown_tups = [ - (re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value)) for key, value in markdown_tups - ] - else: - markdown_tups = [(key, re.sub("\n", "", value)) for key, value in markdown_tups] + markdown_tups = [ + (re.sub(r"#", "", cast(str, key)).strip() if key else None, re.sub(r"<.*?>", "", value)) + for key, value in markdown_tups + ] return markdown_tups diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index e5ead9dc56..cdec92aee7 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -6,7 +6,7 @@ import json import logging from typing import Optional, Union -from sqlalchemy import func, select +from sqlalchemy import select from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -16,6 +16,7 @@ from core.workflow.entities.workflow_execution import ( WorkflowType, ) from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import ( Account, CreatorUserRole, @@ -146,26 +147,17 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): db_model.workflow_id = domain_model.workflow_id db_model.triggered_from = self._triggered_from - # Check if this is a new record - with self._session_factory() as session: - existing = session.scalar(select(WorkflowRun).where(WorkflowRun.id == domain_model.id_)) - if not existing: - # For new records, get the next sequence number - stmt = select(func.max(WorkflowRun.sequence_number)).where( - WorkflowRun.app_id == self._app_id, - WorkflowRun.tenant_id == self._tenant_id, - ) - max_sequence = session.scalar(stmt) - db_model.sequence_number = (max_sequence or 0) + 1 - else: - # For updates, keep the existing sequence number - db_model.sequence_number = existing.sequence_number + # No sequence number generation needed anymore db_model.type = domain_model.workflow_type db_model.version = domain_model.workflow_version db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None - db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None + db_model.outputs = ( + json.dumps(WorkflowRuntimeTypeConverter().to_json_encodable(domain_model.outputs)) + if domain_model.outputs + else None + ) db_model.status = domain_model.status db_model.error = domain_model.error_message if domain_model.error_message else None db_model.total_tokens = domain_model.total_tokens diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 16b27591b0..46ff9e63a4 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -19,6 +19,7 @@ from core.workflow.entities.workflow_node_execution import ( ) from core.workflow.nodes.enums import NodeType from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import ( Account, CreatorUserRole, @@ -146,6 +147,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) if not self._creator_user_role: raise ValueError("created_by_role is required in repository constructor") + json_converter = WorkflowRuntimeTypeConverter() db_model = WorkflowNodeExecutionModel() db_model.id = domain_model.id db_model.tenant_id = self._tenant_id @@ -160,9 +162,17 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) db_model.node_id = domain_model.node_id db_model.node_type = domain_model.node_type db_model.title = domain_model.title - db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None - db_model.process_data = json.dumps(domain_model.process_data) if domain_model.process_data else None - db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None + db_model.inputs = ( + json.dumps(json_converter.to_json_encodable(domain_model.inputs)) if domain_model.inputs else None + ) + db_model.process_data = ( + json.dumps(json_converter.to_json_encodable(domain_model.process_data)) + if domain_model.process_data + else None + ) + db_model.outputs = ( + json.dumps(json_converter.to_json_encodable(domain_model.outputs)) if domain_model.outputs else None + ) db_model.status = domain_model.status db_model.error = domain_model.error db_model.elapsed_time = domain_model.elapsed_time diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 64ba16c367..6cf09e0372 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -75,6 +75,20 @@ class StringSegment(Segment): class FloatSegment(Segment): value_type: SegmentType = SegmentType.NUMBER value: float + # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. + # The following tests cannot pass. + # + # def test_float_segment_and_nan(): + # nan = float("nan") + # assert nan != nan + # + # f1 = FloatSegment(value=float("nan")) + # f2 = FloatSegment(value=float("nan")) + # assert f1 != f2 + # + # f3 = FloatSegment(value=nan) + # f4 = FloatSegment(value=nan) + # assert f3 != f4 class IntegerSegment(Segment): diff --git a/api/core/variables/types.py b/api/core/variables/types.py index 4387e9693e..68d3d82883 100644 --- a/api/core/variables/types.py +++ b/api/core/variables/types.py @@ -18,3 +18,17 @@ class SegmentType(StrEnum): NONE = "none" GROUP = "group" + + def is_array_type(self): + return self in _ARRAY_TYPES + + +_ARRAY_TYPES = frozenset( + [ + SegmentType.ARRAY_ANY, + SegmentType.ARRAY_STRING, + SegmentType.ARRAY_NUMBER, + SegmentType.ARRAY_OBJECT, + SegmentType.ARRAY_FILE, + ] +) diff --git a/api/core/variables/utils.py b/api/core/variables/utils.py index e5d222af7d..692db3502e 100644 --- a/api/core/variables/utils.py +++ b/api/core/variables/utils.py @@ -1,8 +1,26 @@ +import json from collections.abc import Iterable, Sequence +from .segment_group import SegmentGroup +from .segments import ArrayFileSegment, FileSegment, Segment + def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]: selectors = [node_id, name] if paths: selectors.extend(paths) return selectors + + +class SegmentJSONEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, ArrayFileSegment): + return [v.model_dump() for v in o.value] + elif isinstance(o, FileSegment): + return o.value.model_dump() + elif isinstance(o, SegmentGroup): + return [self.default(seg) for seg in o.value] + elif isinstance(o, Segment): + return o.value + else: + super().default(o) diff --git a/api/core/workflow/conversation_variable_updater.py b/api/core/workflow/conversation_variable_updater.py new file mode 100644 index 0000000000..84e99bb582 --- /dev/null +++ b/api/core/workflow/conversation_variable_updater.py @@ -0,0 +1,39 @@ +import abc +from typing import Protocol + +from core.variables import Variable + + +class ConversationVariableUpdater(Protocol): + """ + ConversationVariableUpdater defines an abstraction for updating conversation variable values. + + It is intended for use by `v1.VariableAssignerNode` and `v2.VariableAssignerNode` when updating + conversation variables. + + Implementations may choose to batch updates. If batching is used, the `flush` method + should be implemented to persist buffered changes, and `update` + should handle buffering accordingly. + + Note: Since implementations may buffer updates, instances of ConversationVariableUpdater + are not thread-safe. Each VariableAssignerNode should create its own instance during execution. + """ + + @abc.abstractmethod + def update(self, conversation_id: str, variable: "Variable") -> None: + """ + Updates the value of the specified conversation variable in the underlying storage. + + :param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`. + :param variable: The `Variable` instance containing the updated value. + """ + pass + + @abc.abstractmethod + def flush(self): + """ + Flushes all pending updates to the underlying storage system. + + If the implementation does not buffer updates, this method can be a no-op. + """ + pass diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 21ea26862a..e6196f48fe 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -7,17 +7,12 @@ from pydantic import BaseModel, Field from core.file import File, FileAttribute, file_manager from core.variables import Segment, SegmentGroup, Variable +from core.variables.consts import MIN_SELECTORS_LENGTH from core.variables.segments import FileSegment, NoneSegment +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.workflow.enums import SystemVariableKey from factories import variable_factory -from ..constants import ( - CONVERSATION_VARIABLE_NODE_ID, - ENVIRONMENT_VARIABLE_NODE_ID, - RAG_PIPELINE_VARIABLE_NODE_ID, - SYSTEM_VARIABLE_NODE_ID, -) -from ..enums import SystemVariableKey - VariableValue = Union[str, int, float, dict, list, File] VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") @@ -35,9 +30,11 @@ class VariablePool(BaseModel): # TODO: This user inputs is not used for pool. user_inputs: Mapping[str, Any] = Field( description="User inputs", + default_factory=dict, ) system_variables: Mapping[SystemVariableKey, Any] = Field( description="System variables", + default_factory=dict, ) environment_variables: Sequence[Variable] = Field( description="Environment variables.", @@ -52,31 +49,7 @@ class VariablePool(BaseModel): default_factory=dict, ) - def __init__( - self, - *, - system_variables: Mapping[SystemVariableKey, Any] | None = None, - user_inputs: Mapping[str, Any] | None = None, - environment_variables: Sequence[Variable] | None = None, - conversation_variables: Sequence[Variable] | None = None, - rag_pipeline_variables: Mapping[str, Any] | None = None, - **kwargs, - ): - environment_variables = environment_variables or [] - conversation_variables = conversation_variables or [] - user_inputs = user_inputs or {} - system_variables = system_variables or {} - rag_pipeline_variables = rag_pipeline_variables or {} - - super().__init__( - system_variables=system_variables, - user_inputs=user_inputs, - environment_variables=environment_variables, - conversation_variables=conversation_variables, - rag_pipeline_variables=rag_pipeline_variables, - **kwargs, - ) - + def model_post_init(self, context: Any, /) -> None: for key, value in self.system_variables.items(): self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value) # Add environment variables to the variable pool @@ -106,12 +79,12 @@ class VariablePool(BaseModel): Returns: None """ - if len(selector) < 2: + if len(selector) < MIN_SELECTORS_LENGTH: raise ValueError("Invalid selector") if isinstance(value, Variable): variable = value - if isinstance(value, Segment): + elif isinstance(value, Segment): variable = variable_factory.segment_to_variable(segment=value, selector=selector) else: segment = variable_factory.build_segment(value) @@ -133,7 +106,7 @@ class VariablePool(BaseModel): Raises: ValueError: If the selector is invalid. """ - if len(selector) < 2: + if len(selector) < MIN_SELECTORS_LENGTH: return None hash_key = hash(tuple(selector[1:])) diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 063216dd49..89149c91db 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -66,6 +66,8 @@ class BaseNodeEvent(GraphEngineEvent): """iteration id if node is in iteration""" in_loop_id: Optional[str] = None """loop id if node is in loop""" + # The version of the node, or "1" if not specified. + node_version: str = "1" class NodeRunStartedEvent(BaseNodeEvent): diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index e1771b7ed9..eaa558b02c 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -53,6 +53,7 @@ from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.utils import variable_utils from libs.flask_utils import preserve_flask_contexts from models.enums import UserFrom from models.workflow import WorkflowType @@ -314,6 +315,7 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + node_version=node_instance.version(), ) raise e @@ -627,6 +629,7 @@ class GraphEngine: parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, agent_strategy=agent_strategy, + node_version=node_instance.version(), ) max_retries = node_instance.node_data.retry_config.max_retries @@ -677,6 +680,7 @@ class GraphEngine: error=run_result.error or "Unknown error", retry_index=retries, start_at=retry_start_at, + node_version=node_instance.version(), ) time.sleep(retry_interval) break @@ -712,6 +716,7 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + node_version=node_instance.version(), ) should_continue_retry = False else: @@ -726,6 +731,7 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + node_version=node_instance.version(), ) should_continue_retry = False elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: @@ -786,6 +792,7 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + node_version=node_instance.version(), ) should_continue_retry = False @@ -803,6 +810,7 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + node_version=node_instance.version(), ) elif isinstance(event, RunRetrieverResourceEvent): yield NodeRunRetrieverResourceEvent( @@ -817,6 +825,7 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + node_version=node_instance.version(), ) except GenerateTaskStoppedError: # trigger node run failed event @@ -833,6 +842,7 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + node_version=node_instance.version(), ) return except Exception as e: @@ -847,16 +857,12 @@ class GraphEngine: :param variable_value: variable value :return: """ - self.graph_runtime_state.variable_pool.add([node_id] + variable_key_list, variable_value) - - # if variable_value is a dict, then recursively append variables - if isinstance(variable_value, dict): - for key, value in variable_value.items(): - # construct new key list - new_key_list = variable_key_list + [key] - self._append_variables_recursively( - node_id=node_id, variable_key_list=new_key_list, variable_value=value - ) + variable_utils.append_variables_recursively( + self.graph_runtime_state.variable_pool, + node_id, + variable_key_list, + variable_value, + ) def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: """ diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 22c564c1fc..2f28363955 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -39,6 +39,10 @@ class AgentNode(ToolNode): _node_data_cls = AgentNodeData # type: ignore _node_type = NodeType.AGENT + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> Generator: """ Run the agent node diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index aa030870e2..38c2bcbdf5 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -18,7 +18,11 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser class AnswerNode(BaseNode[AnswerNodeData]): _node_data_cls = AnswerNodeData - _node_type: NodeType = NodeType.ANSWER + _node_type = NodeType.ANSWER + + @classmethod + def version(cls) -> str: + return "1" def _run(self) -> NodeRunResult: """ @@ -45,7 +49,10 @@ class AnswerNode(BaseNode[AnswerNodeData]): part = cast(TextGenerateRouteChunk, part) answer += part.text - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer, "files": files}) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={"answer": answer, "files": ArrayFileSegment(value=files)}, + ) @classmethod def _extract_variable_selector_to_variable_mapping( diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index ba6ba16e36..f3e4a62ade 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -109,6 +109,7 @@ class AnswerStreamProcessor(StreamProcessor): parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, from_variable_selector=[answer_node_id, "answer"], + node_version=event.node_version, ) else: route_chunk = cast(VarGenerateRouteChunk, route_chunk) @@ -134,6 +135,7 @@ class AnswerStreamProcessor(StreamProcessor): route_node_state=event.route_node_state, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, + node_version=event.node_version, ) self.route_position[answer_node_id] += 1 diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 7da0c19740..6973401429 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -1,7 +1,7 @@ import logging from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -23,7 +23,7 @@ GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData) class BaseNode(Generic[GenericNodeData]): _node_data_cls: type[GenericNodeData] - _node_type: NodeType + _node_type: ClassVar[NodeType] def __init__( self, @@ -90,8 +90,38 @@ class BaseNode(Generic[GenericNodeData]): graph_config: Mapping[str, Any], config: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping + """Extracts references variable selectors from node configuration. + + The `config` parameter represents the configuration for a specific node type and corresponds + to the `data` field in the node definition object. + + The returned mapping has the following structure: + + {'1747829548239.#1747829667553.result#': ['1747829667553', 'result']} + + For loop and iteration nodes, the mapping may look like this: + + { + "1748332301644.input_selector": ["1748332363630", "result"], + "1748332325079.1748332325079.#sys.workflow_id#": ["sys", "workflow_id"], + } + + where `1748332301644` is the ID of the loop / iteration node, + and `1748332325079` is the ID of the node inside the loop or iteration node. + + Here, the key consists of two parts: the current node ID (provided as the `node_id` + parameter to `_extract_variable_selector_to_variable_mapping`) and the variable selector, + enclosed in `#` symbols. These two parts are separated by a dot (`.`). + + The value is a list of string representing the variable selector, where the first element is the node ID + of the referenced variable, and the second element is the variable name within that node. + + The meaning of the above response is: + + The node with ID `1747829548239` references the variable `result` from the node with + ID `1747829667553`. For example, if `1747829548239` is a LLM node, its prompt may contain a + reference to the `result` output variable of node `1747829667553`. + :param graph_config: graph config :param config: node config :return: @@ -101,9 +131,10 @@ class BaseNode(Generic[GenericNodeData]): raise ValueError("Node ID is required when extracting variable selector to variable mapping.") node_data = cls._node_data_cls(**config.get("data", {})) - return cls._extract_variable_selector_to_variable_mapping( + data = cls._extract_variable_selector_to_variable_mapping( graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data) ) + return data @classmethod def _extract_variable_selector_to_variable_mapping( @@ -139,6 +170,16 @@ class BaseNode(Generic[GenericNodeData]): """ return self._node_type + @classmethod + @abstractmethod + def version(cls) -> str: + """`node_version` returns the version of current node type.""" + # NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`. + # + # If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING` + # in `api/core/workflow/nodes/__init__.py`. + raise NotImplementedError("subclasses of BaseNode must implement `version` method.") + @property def should_continue_on_error(self) -> bool: """judge if should continue on error diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 61c08a7d71..22ed9e2651 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -40,6 +40,10 @@ class CodeNode(BaseNode[CodeNodeData]): return code_provider.get_default_config() + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: # Get code language code_language = self.node_data.code_language @@ -126,6 +130,9 @@ class CodeNode(BaseNode[CodeNodeData]): prefix: str = "", depth: int = 1, ): + # TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes. + # Note that `_transform_result` may produce lists containing `None` values, + # which don't conform to the type requirements of `Array*Segment` classes. if depth > dify_config.CODE_MAX_DEPTH: raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.") diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 429fed2d04..8e6150f9cc 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -24,7 +24,7 @@ from configs import dify_config from core.file import File, FileTransferMethod, file_manager from core.helper import ssrf_proxy from core.variables import ArrayFileSegment -from core.variables.segments import FileSegment +from core.variables.segments import ArrayStringSegment, FileSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode @@ -45,6 +45,10 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): _node_data_cls = DocumentExtractorNodeData _node_type = NodeType.DOCUMENT_EXTRACTOR + @classmethod + def version(cls) -> str: + return "1" + def _run(self): variable_selector = self.node_data.variable_selector variable = self.graph_runtime_state.variable_pool.get(variable_selector) @@ -67,7 +71,7 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, process_data=process_data, - outputs={"text": extracted_text_list}, + outputs={"text": ArrayStringSegment(value=extracted_text_list)}, ) elif isinstance(value, File): extracted_text = _extract_text_from_file(value) @@ -447,7 +451,7 @@ def _extract_text_from_excel(file_content: bytes) -> str: df = df.applymap(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x) # type: ignore # Combine multi-line text in column names into a single line - df.columns = pd.Index([" ".join(col.splitlines()) for col in df.columns]) + df.columns = pd.Index([" ".join(str(col).splitlines()) for col in df.columns]) # Manually construct the Markdown table markdown_table += _construct_markdown_table(df) + "\n\n" diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 0e9756b243..17a0b3adeb 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -9,6 +9,10 @@ class EndNode(BaseNode[EndNodeData]): _node_data_cls = EndNodeData _node_type = NodeType.END + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: """ Run node diff --git a/api/core/workflow/nodes/end/end_stream_processor.py b/api/core/workflow/nodes/end/end_stream_processor.py index 3ae5af7137..a6fb2ffc18 100644 --- a/api/core/workflow/nodes/end/end_stream_processor.py +++ b/api/core/workflow/nodes/end/end_stream_processor.py @@ -139,6 +139,7 @@ class EndStreamProcessor(StreamProcessor): route_node_state=event.route_node_state, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, + node_version=event.node_version, ) self.route_position[end_node_id] += 1 diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 6b1ac57c06..971e0f73e7 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -6,6 +6,7 @@ from typing import Any, Optional from configs import dify_config from core.file import File, FileTransferMethod from core.tools.tool_file_manager import ToolFileManager +from core.variables.segments import ArrayFileSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -60,6 +61,10 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): }, } + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: process_data = {} try: @@ -92,7 +97,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={ "status_code": response.status_code, - "body": response.text if not files else "", + "body": response.text if not files.value else "", "headers": response.headers, "files": files, }, @@ -166,7 +171,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): return mapping - def extract_files(self, url: str, response: Response) -> list[File]: + def extract_files(self, url: str, response: Response) -> ArrayFileSegment: """ Extract files from response by checking both Content-Type header and URL """ @@ -178,7 +183,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): content_disposition_type = None if not is_file: - return files + return ArrayFileSegment(value=[]) if parsed_content_disposition: content_disposition_filename = parsed_content_disposition.get_filename() @@ -211,4 +216,4 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): ) files.append(file) - return files + return ArrayFileSegment(value=files) diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 976922f75d..22b748030c 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -1,4 +1,5 @@ -from typing import Literal +from collections.abc import Mapping, Sequence +from typing import Any, Literal from typing_extensions import deprecated @@ -16,6 +17,10 @@ class IfElseNode(BaseNode[IfElseNodeData]): _node_data_cls = IfElseNodeData _node_type = NodeType.IF_ELSE + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: """ Run node @@ -87,6 +92,22 @@ class IfElseNode(BaseNode[IfElseNodeData]): return data + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: IfElseNodeData, + ) -> Mapping[str, Sequence[str]]: + var_mapping: dict[str, list[str]] = {} + for case in node_data.cases or []: + for condition in case.conditions: + key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector)) + var_mapping[key] = condition.variable_selector + + return var_mapping + @deprecated("This function is deprecated. You should use the new cases structure.") def _should_not_use_old_function( diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 42b6795fb0..151efc28ec 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -11,6 +11,7 @@ from flask import Flask, current_app from configs import dify_config from core.variables import ArrayVariable, IntegerVariable, NoneVariable +from core.variables.segments import ArrayAnySegment, ArraySegment from core.workflow.entities.node_entities import ( NodeRunResult, ) @@ -37,6 +38,7 @@ from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData +from factories.variable_factory import build_segment from libs.flask_utils import preserve_flask_contexts from .exc import ( @@ -72,6 +74,10 @@ class IterationNode(BaseNode[IterationNodeData]): }, } + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: """ Run the node. @@ -85,10 +91,17 @@ class IterationNode(BaseNode[IterationNodeData]): raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") if isinstance(variable, NoneVariable) or len(variable.value) == 0: + # Try our best to preserve the type informat. + if isinstance(variable, ArraySegment): + output = variable.model_copy(update={"value": []}) + else: + output = ArrayAnySegment(value=[]) yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"output": []}, + # TODO(QuantumGhost): is it possible to compute the type of `output` + # from graph definition? + outputs={"output": output}, ) ) return @@ -231,6 +244,7 @@ class IterationNode(BaseNode[IterationNodeData]): # Flatten the list of lists if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs): outputs = [item for sublist in outputs for item in sublist] + output_segment = build_segment(outputs) yield IterationRunSucceededEvent( iteration_id=self.id, @@ -247,7 +261,7 @@ class IterationNode(BaseNode[IterationNodeData]): yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"output": outputs}, + outputs={"output": output_segment}, metadata={ WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py index bee481ebdb..9900aa225d 100644 --- a/api/core/workflow/nodes/iteration/iteration_start_node.py +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -13,6 +13,10 @@ class IterationStartNode(BaseNode[IterationStartNodeData]): _node_data_cls = IterationStartNodeData _node_type = NodeType.ITERATION_START + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: """ Run the node. diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 5cf5848d54..0b9e98f28a 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -24,6 +24,7 @@ from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.variables import StringSegment +from core.variables.segments import ArrayObjectSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.enums import NodeType @@ -70,6 +71,10 @@ class KnowledgeRetrievalNode(LLMNode): _node_data_cls = KnowledgeRetrievalNodeData # type: ignore _node_type = NodeType.KNOWLEDGE_RETRIEVAL + @classmethod + def version(cls): + return "1" + def _run(self) -> NodeRunResult: # type: ignore node_data = cast(KnowledgeRetrievalNodeData, self.node_data) # extract variables @@ -115,9 +120,12 @@ class KnowledgeRetrievalNode(LLMNode): # retrieve knowledge try: results = self._fetch_dataset_retriever(node_data=node_data, query=query) - outputs = {"result": results} + outputs = {"result": ArrayObjectSegment(value=results)} return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variables, + process_data=None, + outputs=outputs, # type: ignore ) except KnowledgeRetrievalNodeError as e: diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index e698d3f5d8..3c9ba44cf1 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -3,6 +3,7 @@ from typing import Any, Literal, Union from core.file import File from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment +from core.variables.segments import ArrayAnySegment, ArraySegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode @@ -16,6 +17,10 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): _node_data_cls = ListOperatorNodeData _node_type = NodeType.LIST_OPERATOR + @classmethod + def version(cls) -> str: + return "1" + def _run(self): inputs: dict[str, list] = {} process_data: dict[str, list] = {} @@ -30,7 +35,11 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): if not variable.value: inputs = {"variable": []} process_data = {"variable": []} - outputs = {"result": [], "first_record": None, "last_record": None} + if isinstance(variable, ArraySegment): + result = variable.model_copy(update={"value": []}) + else: + result = ArrayAnySegment(value=[]) + outputs = {"result": result, "first_record": None, "last_record": None} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, @@ -71,7 +80,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): variable = self._apply_slice(variable) outputs = { - "result": variable.value, + "result": variable, "first_record": variable.value[0] if variable.value else None, "last_record": variable.value[-1] if variable.value else None, } diff --git a/api/core/workflow/nodes/llm/file_saver.py b/api/core/workflow/nodes/llm/file_saver.py index c85baade03..a4b45ce652 100644 --- a/api/core/workflow/nodes/llm/file_saver.py +++ b/api/core/workflow/nodes/llm/file_saver.py @@ -119,9 +119,6 @@ class FileSaverImpl(LLMFileSaver): size=len(data), related_id=tool_file.id, url=url, - # TODO(QuantumGhost): how should I set the following key? - # What's the difference between `remote_url` and `url`? - # What's the purpose of `storage_key` and `dify_model_identity`? storage_key=tool_file.file_key, ) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index d27124d62c..124ae6d75d 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -138,6 +138,10 @@ class LLMNode(BaseNode[LLMNodeData]): ) self._llm_file_saver = llm_file_saver + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: def process_structured_output(text: str) -> Optional[dict[str, Any]]: """Process structured output if enabled""" @@ -255,7 +259,7 @@ class LLMNode(BaseNode[LLMNodeData]): if structured_output: outputs["structured_output"] = structured_output if self._file_outputs is not None: - outputs["files"] = self._file_outputs + outputs["files"] = ArrayFileSegment(value=self._file_outputs) yield RunCompletedEvent( run_result=NodeRunResult( diff --git a/api/core/workflow/nodes/loop/loop_end_node.py b/api/core/workflow/nodes/loop/loop_end_node.py index 327b9e234b..b144021bab 100644 --- a/api/core/workflow/nodes/loop/loop_end_node.py +++ b/api/core/workflow/nodes/loop/loop_end_node.py @@ -13,6 +13,10 @@ class LoopEndNode(BaseNode[LoopEndNodeData]): _node_data_cls = LoopEndNodeData _node_type = NodeType.LOOP_END + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: """ Run the node. diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index fafa205386..368d662a75 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -54,6 +54,10 @@ class LoopNode(BaseNode[LoopNodeData]): _node_data_cls = LoopNodeData _node_type = NodeType.LOOP + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: """Run the node.""" # Get inputs @@ -482,6 +486,13 @@ class LoopNode(BaseNode[LoopNodeData]): variable_mapping.update(sub_node_variable_mapping) + for loop_variable in node_data.loop_variables or []: + if loop_variable.value_type == "variable": + assert loop_variable.value is not None, "Loop variable value must be provided for variable type" + # add loop variable to variable mapping + selector = loop_variable.value + variable_mapping[f"{node_id}.{loop_variable.label}"] = selector + # remove variable out from loop variable_mapping = { key: value for key, value in variable_mapping.items() if value[0] not in loop_graph.node_ids diff --git a/api/core/workflow/nodes/loop/loop_start_node.py b/api/core/workflow/nodes/loop/loop_start_node.py index 5a15f36044..f5e38b7516 100644 --- a/api/core/workflow/nodes/loop/loop_start_node.py +++ b/api/core/workflow/nodes/loop/loop_start_node.py @@ -13,6 +13,10 @@ class LoopStartNode(BaseNode[LoopStartNodeData]): _node_data_cls = LoopStartNodeData _node_type = NodeType.LOOP_START + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: """ Run the node. diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py index e328c20096..f7ec8fe737 100644 --- a/api/core/workflow/nodes/node_mapping.py +++ b/api/core/workflow/nodes/node_mapping.py @@ -27,6 +27,11 @@ from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as Var LATEST_VERSION = "latest" +# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode. +# Specifically, if you have introduced new node types, you should add them here. +# +# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__` +# hook. Try to avoid duplication of node information. NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = { NodeType.START: { LATEST_VERSION: StartNode, diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index 369eb13b04..916778d167 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -7,6 +7,10 @@ from core.workflow.nodes.base import BaseNodeData from core.workflow.nodes.llm import ModelConfig, VisionConfig +class _ParameterConfigError(Exception): + pass + + class ParameterConfig(BaseModel): """ Parameter Config. @@ -27,6 +31,19 @@ class ParameterConfig(BaseModel): raise ValueError("Invalid parameter name, __reason and __is_success are reserved") return str(value) + def is_array_type(self) -> bool: + return self.type in ("array[string]", "array[number]", "array[object]") + + def element_type(self) -> Literal["string", "number", "object"]: + if self.type == "array[number]": + return "number" + elif self.type == "array[string]": + return "string" + elif self.type == "array[object]": + return "object" + else: + raise _ParameterConfigError(f"{self.type} is not array type.") + class ParameterExtractorNodeData(BaseNodeData): """ diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 2552784762..8d6c2d0a5c 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -25,6 +25,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.variables.types import SegmentType from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -32,6 +33,7 @@ from core.workflow.nodes.base.node import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.llm import ModelConfig, llm_utils from core.workflow.utils import variable_template_parser +from factories.variable_factory import build_segment_with_type from .entities import ParameterExtractorNodeData from .exc import ( @@ -109,6 +111,10 @@ class ParameterExtractorNode(BaseNode): } } + @classmethod + def version(cls) -> str: + return "1" + def _run(self): """ Run the node. @@ -584,28 +590,30 @@ class ParameterExtractorNode(BaseNode): elif parameter.type in {"string", "select"}: if isinstance(result[parameter.name], str): transformed_result[parameter.name] = result[parameter.name] - elif parameter.type.startswith("array"): + elif parameter.is_array_type(): if isinstance(result[parameter.name], list): - nested_type = parameter.type[6:-1] - transformed_result[parameter.name] = [] + nested_type = parameter.element_type() + assert nested_type is not None + segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[]) + transformed_result[parameter.name] = segment_value for item in result[parameter.name]: if nested_type == "number": if isinstance(item, int | float): - transformed_result[parameter.name].append(item) + segment_value.value.append(item) elif isinstance(item, str): try: if "." in item: - transformed_result[parameter.name].append(float(item)) + segment_value.value.append(float(item)) else: - transformed_result[parameter.name].append(int(item)) + segment_value.value.append(int(item)) except ValueError: pass elif nested_type == "string": if isinstance(item, str): - transformed_result[parameter.name].append(item) + segment_value.value.append(item) elif nested_type == "object": if isinstance(item, dict): - transformed_result[parameter.name].append(item) + segment_value.value.append(item) if parameter.name not in transformed_result: if parameter.type == "number": @@ -615,7 +623,9 @@ class ParameterExtractorNode(BaseNode): elif parameter.type in {"string", "select"}: transformed_result[parameter.name] = "" elif parameter.type.startswith("array"): - transformed_result[parameter.name] = [] + transformed_result[parameter.name] = build_segment_with_type( + segment_type=SegmentType(parameter.type), value=[] + ) return transformed_result diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 1f50700c7e..a518167cc6 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -40,6 +40,10 @@ class QuestionClassifierNode(LLMNode): _node_data_cls = QuestionClassifierNodeData # type: ignore _node_type = NodeType.QUESTION_CLASSIFIER + @classmethod + def version(cls): + return "1" + def _run(self): node_data = cast(QuestionClassifierNodeData, self.node_data) variable_pool = self.graph_runtime_state.variable_pool diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 8839aec9d6..5ee9bc331f 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -10,6 +10,10 @@ class StartNode(BaseNode[StartNodeData]): _node_data_cls = StartNodeData _node_type = NodeType.START + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) system_inputs = self.graph_runtime_state.variable_pool.system_variables @@ -18,5 +22,6 @@ class StartNode(BaseNode[StartNodeData]): # Set system variables as node outputs. for var in system_inputs: node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] + outputs = dict(node_inputs) - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=node_inputs) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs) diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 476cf7eee4..ba573074c3 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -28,6 +28,10 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]): "config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"}, } + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: # Get variables variables = {} diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index aaecc7b989..aa15d69931 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -12,7 +12,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.errors import ToolInvokeError from core.tools.tool_engine import ToolEngine from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.variables.segments import ArrayAnySegment +from core.variables.segments import ArrayAnySegment, ArrayFileSegment from core.variables.variables import ArrayAnyVariable from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool @@ -44,6 +44,10 @@ class ToolNode(BaseNode[ToolNodeData]): _node_data_cls = ToolNodeData _node_type = NodeType.TOOL + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> Generator: """ Run the tool node @@ -300,6 +304,7 @@ class ToolNode(BaseNode[ToolNodeData]): variables[variable_name] = variable_value elif message.type == ToolInvokeMessage.MessageType.FILE: assert message.meta is not None + assert isinstance(message.meta, File) files.append(message.meta["file"]) elif message.type == ToolInvokeMessage.MessageType.LOG: assert isinstance(message.message, ToolInvokeMessage.LogMessage) @@ -363,7 +368,7 @@ class ToolNode(BaseNode[ToolNodeData]): yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"text": text, "files": files, "json": json, **variables}, + outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json, **variables}, metadata={ **agent_execution_metadata, WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index db3e25b015..96bb3e793a 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,3 +1,6 @@ +from collections.abc import Mapping + +from core.variables.segments import Segment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode @@ -9,16 +12,20 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): _node_data_cls = VariableAssignerNodeData _node_type = NodeType.VARIABLE_AGGREGATOR + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> NodeRunResult: # Get variables - outputs = {} + outputs: dict[str, Segment | Mapping[str, Segment]] = {} inputs = {} if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled: for selector in self.node_data.variables: variable = self.graph_runtime_state.variable_pool.get(selector) if variable is not None: - outputs = {"output": variable.to_object()} + outputs = {"output": variable} inputs = {".".join(selector[1:]): variable.to_object()} break @@ -28,7 +35,7 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): variable = self.graph_runtime_state.variable_pool.get(selector) if variable is not None: - outputs[group.group_name] = {"output": variable.to_object()} + outputs[group.group_name] = {"output": variable} inputs[".".join(selector[1:])] = variable.to_object() break diff --git a/api/core/workflow/nodes/variable_assigner/common/helpers.py b/api/core/workflow/nodes/variable_assigner/common/helpers.py index 8031b57fa8..0d2822233e 100644 --- a/api/core/workflow/nodes/variable_assigner/common/helpers.py +++ b/api/core/workflow/nodes/variable_assigner/common/helpers.py @@ -1,19 +1,55 @@ -from sqlalchemy import select -from sqlalchemy.orm import Session +from collections.abc import Mapping, MutableMapping, Sequence +from typing import Any, TypeVar -from core.variables import Variable -from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from extensions.ext_database import db -from models import ConversationVariable +from pydantic import BaseModel + +from core.variables import Segment +from core.variables.consts import MIN_SELECTORS_LENGTH +from core.variables.types import SegmentType + +# Use double underscore (`__`) prefix for internal variables +# to minimize risk of collision with user-defined variable names. +_UPDATED_VARIABLES_KEY = "__updated_variables" -def update_conversation_variable(conversation_id: str, variable: Variable): - stmt = select(ConversationVariable).where( - ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id +class UpdatedVariable(BaseModel): + name: str + selector: Sequence[str] + value_type: SegmentType + new_value: Any + + +_T = TypeVar("_T", bound=MutableMapping[str, Any]) + + +def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable: + if len(selector) < MIN_SELECTORS_LENGTH: + raise Exception("selector too short") + node_id, var_name = selector[:2] + return UpdatedVariable( + name=var_name, + selector=list(selector[:2]), + value_type=seg.value_type, + new_value=seg.value, ) - with Session(db.engine) as session: - row = session.scalar(stmt) - if not row: - raise VariableOperatorNodeError("conversation variable not found in the database") - row.data = variable.model_dump_json() - session.commit() + + +def set_updated_variables(m: _T, updates: Sequence[UpdatedVariable]) -> _T: + m[_UPDATED_VARIABLES_KEY] = updates + return m + + +def get_updated_variables(m: Mapping[str, Any]) -> Sequence[UpdatedVariable] | None: + updated_values = m.get(_UPDATED_VARIABLES_KEY, None) + if updated_values is None: + return None + result = [] + for items in updated_values: + if isinstance(items, UpdatedVariable): + result.append(items) + elif isinstance(items, dict): + items = UpdatedVariable.model_validate(items) + result.append(items) + else: + raise TypeError(f"Invalid updated variable: {items}, type={type(items)}") + return result diff --git a/api/core/workflow/nodes/variable_assigner/common/impl.py b/api/core/workflow/nodes/variable_assigner/common/impl.py new file mode 100644 index 0000000000..8f7a44bb62 --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/common/impl.py @@ -0,0 +1,38 @@ +from sqlalchemy import Engine, select +from sqlalchemy.orm import Session + +from core.variables.variables import Variable +from models.engine import db +from models.workflow import ConversationVariable + +from .exc import VariableOperatorNodeError + + +class ConversationVariableUpdaterImpl: + _engine: Engine | None + + def __init__(self, engine: Engine | None = None) -> None: + self._engine = engine + + def _get_engine(self) -> Engine: + if self._engine: + return self._engine + return db.engine + + def update(self, conversation_id: str, variable: Variable): + stmt = select(ConversationVariable).where( + ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id + ) + with Session(self._get_engine()) as session: + row = session.scalar(stmt) + if not row: + raise VariableOperatorNodeError("conversation variable not found in the database") + row.data = variable.model_dump_json() + session.commit() + + def flush(self): + pass + + +def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl: + return ConversationVariableUpdaterImpl() diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index 835e1d77b5..be5083c9c1 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -1,4 +1,9 @@ +from collections.abc import Callable, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Optional, TypeAlias + from core.variables import SegmentType, Variable +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID +from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode @@ -7,16 +12,71 @@ from core.workflow.nodes.variable_assigner.common import helpers as common_helpe from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError from factories import variable_factory +from ..common.impl import conversation_variable_updater_factory from .node_data import VariableAssignerData, WriteMode +if TYPE_CHECKING: + from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState + + +_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] + class VariableAssignerNode(BaseNode[VariableAssignerData]): _node_data_cls = VariableAssignerData _node_type = NodeType.VARIABLE_ASSIGNER + _conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY + + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph: "Graph", + graph_runtime_state: "GraphRuntimeState", + previous_node_id: Optional[str] = None, + thread_pool_id: Optional[str] = None, + conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph=graph, + graph_runtime_state=graph_runtime_state, + previous_node_id=previous_node_id, + thread_pool_id=thread_pool_id, + ) + self._conv_var_updater_factory = conv_var_updater_factory + + @classmethod + def version(cls) -> str: + return "1" + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: VariableAssignerData, + ) -> Mapping[str, Sequence[str]]: + mapping = {} + assigned_variable_node_id = node_data.assigned_variable_selector[0] + if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID: + selector_key = ".".join(node_data.assigned_variable_selector) + key = f"{node_id}.#{selector_key}#" + mapping[key] = node_data.assigned_variable_selector + + selector_key = ".".join(node_data.input_variable_selector) + key = f"{node_id}.#{selector_key}#" + mapping[key] = node_data.input_variable_selector + return mapping def _run(self) -> NodeRunResult: + assigned_variable_selector = self.node_data.assigned_variable_selector # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject - original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector) + original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector) if not isinstance(original_variable, Variable): raise VariableOperatorNodeError("assigned variable not found") @@ -44,20 +104,28 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]): raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}") # Over write the variable. - self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable) + self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable) # TODO: Move database operation to the pipeline. # Update conversation variable. conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"]) if not conversation_id: raise VariableOperatorNodeError("conversation_id not found") - common_helpers.update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) + conv_var_updater = self._conv_var_updater_factory() + conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable) + conv_var_updater.flush() + updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)] return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={ "value": income_value.to_object(), }, + # NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`, + # we still set `output_variables` as a list to ensure the schema of output is + # compatible with `v2.VariableAssignerNode`. + process_data=common_helpers.set_updated_variables({}, updated_variables), + outputs={}, ) diff --git a/api/core/workflow/nodes/variable_assigner/v2/constants.py b/api/core/workflow/nodes/variable_assigner/v2/constants.py index 7d7922abd4..3797bfa77a 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/constants.py +++ b/api/core/workflow/nodes/variable_assigner/v2/constants.py @@ -8,5 +8,4 @@ EMPTY_VALUE_MAPPING = { SegmentType.ARRAY_STRING: [], SegmentType.ARRAY_NUMBER: [], SegmentType.ARRAY_OBJECT: [], - SegmentType.ARRAY_FILE: [], } diff --git a/api/core/workflow/nodes/variable_assigner/v2/entities.py b/api/core/workflow/nodes/variable_assigner/v2/entities.py index 01df33b6d4..d93affcd15 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/entities.py +++ b/api/core/workflow/nodes/variable_assigner/v2/entities.py @@ -12,6 +12,12 @@ class VariableOperationItem(BaseModel): variable_selector: Sequence[str] input_type: InputType operation: Operation + # NOTE(QuantumGhost): The `value` field serves multiple purposes depending on context: + # + # 1. For CONSTANT input_type: Contains the literal value to be used in the operation. + # 2. For VARIABLE input_type: Initially contains the selector of the source variable. + # 3. During the variable updating procedure: The `value` field is reassigned to hold + # the resolved actual value that will be applied to the target variable. value: Any | None = None diff --git a/api/core/workflow/nodes/variable_assigner/v2/exc.py b/api/core/workflow/nodes/variable_assigner/v2/exc.py index b67af6d73c..fd6c304a9a 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/exc.py +++ b/api/core/workflow/nodes/variable_assigner/v2/exc.py @@ -29,3 +29,8 @@ class InvalidInputValueError(VariableOperatorNodeError): class ConversationIDNotFoundError(VariableOperatorNodeError): def __init__(self): super().__init__("conversation_id not found") + + +class InvalidDataError(VariableOperatorNodeError): + def __init__(self, message: str) -> None: + super().__init__(message) diff --git a/api/core/workflow/nodes/variable_assigner/v2/helpers.py b/api/core/workflow/nodes/variable_assigner/v2/helpers.py index f33f406145..8fb2a27388 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/helpers.py +++ b/api/core/workflow/nodes/variable_assigner/v2/helpers.py @@ -1,6 +1,5 @@ from typing import Any -from core.file import File from core.variables import SegmentType from .enums import Operation @@ -86,8 +85,6 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va return isinstance(value, int | float) case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND: return isinstance(value, dict) - case SegmentType.ARRAY_FILE if operation == Operation.APPEND: - return isinstance(value, File) # Array & Extend / Overwrite case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}: @@ -98,8 +95,6 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va return isinstance(value, list) and all(isinstance(item, int | float) for item in value) case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}: return isinstance(value, list) and all(isinstance(item, dict) for item in value) - case SegmentType.ARRAY_FILE if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, File) for item in value) case _: return False diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index 8759a55b34..9292da6f1c 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -1,34 +1,84 @@ import json -from collections.abc import Sequence -from typing import Any, cast +from collections.abc import Callable, Mapping, MutableMapping, Sequence +from typing import Any, TypeAlias, cast from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import SegmentType, Variable +from core.variables.consts import MIN_SELECTORS_LENGTH from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID +from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError +from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory from . import helpers from .constants import EMPTY_VALUE_MAPPING -from .entities import VariableAssignerNodeData +from .entities import VariableAssignerNodeData, VariableOperationItem from .enums import InputType, Operation from .exc import ( ConversationIDNotFoundError, InputTypeNotSupportedError, + InvalidDataError, InvalidInputValueError, OperationNotSupportedError, VariableNotFoundError, ) +_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] + + +def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): + selector_node_id = item.variable_selector[0] + if selector_node_id != CONVERSATION_VARIABLE_NODE_ID: + return + selector_str = ".".join(item.variable_selector) + key = f"{node_id}.#{selector_str}#" + mapping[key] = item.variable_selector + + +def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): + # Keep this in sync with the logic in _run methods... + if item.input_type != InputType.VARIABLE: + return + selector = item.value + if not isinstance(selector, list): + raise InvalidDataError(f"selector is not a list, {node_id=}, {item=}") + if len(selector) < MIN_SELECTORS_LENGTH: + raise InvalidDataError(f"selector too short, {node_id=}, {item=}") + selector_str = ".".join(selector) + key = f"{node_id}.#{selector_str}#" + mapping[key] = selector + class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): _node_data_cls = VariableAssignerNodeData _node_type = NodeType.VARIABLE_ASSIGNER + def _conv_var_updater_factory(self) -> ConversationVariableUpdater: + return conversation_variable_updater_factory() + + @classmethod + def version(cls) -> str: + return "2" + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: VariableAssignerNodeData, + ) -> Mapping[str, Sequence[str]]: + var_mapping: dict[str, Sequence[str]] = {} + for item in node_data.items: + _target_mapping_from_item(var_mapping, node_id, item) + _source_mapping_from_item(var_mapping, node_id, item) + return var_mapping + def _run(self) -> NodeRunResult: inputs = self.node_data.model_dump() process_data: dict[str, Any] = {} @@ -114,6 +164,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): # remove the duplicated items first. updated_variable_selectors = list(set(map(tuple, updated_variable_selectors))) + conv_var_updater = self._conv_var_updater_factory() # Update variables for selector in updated_variable_selectors: variable = self.graph_runtime_state.variable_pool.get(selector) @@ -128,15 +179,23 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): raise ConversationIDNotFoundError else: conversation_id = conversation_id.value - common_helpers.update_conversation_variable( + conv_var_updater.update( conversation_id=cast(str, conversation_id), variable=variable, ) + conv_var_updater.flush() + updated_variables = [ + common_helpers.variable_to_processed_data(selector, seg) + for selector in updated_variable_selectors + if (seg := self.graph_runtime_state.variable_pool.get(selector)) is not None + ] + process_data = common_helpers.set_updated_variables(process_data, updated_variables) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, process_data=process_data, + outputs={}, ) def _handle_item( diff --git a/api/core/workflow/utils/variable_utils.py b/api/core/workflow/utils/variable_utils.py new file mode 100644 index 0000000000..e68d990b60 --- /dev/null +++ b/api/core/workflow/utils/variable_utils.py @@ -0,0 +1,28 @@ +from core.variables.segments import ObjectSegment, Segment +from core.workflow.entities.variable_pool import VariablePool, VariableValue + + +def append_variables_recursively( + pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue | Segment +): + """ + Append variables recursively + :param node_id: node id + :param variable_key_list: variable key list + :param variable_value: variable value + :return: + """ + pool.add([node_id] + variable_key_list, variable_value) + + # if variable_value is a dict, then recursively append variables + if isinstance(variable_value, ObjectSegment): + variable_dict = variable_value.value + elif isinstance(variable_value, dict): + variable_dict = variable_value + else: + return + + for key, value in variable_dict.items(): + # construct new key list + new_key_list = variable_key_list + [key] + append_variables_recursively(pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value) diff --git a/api/core/workflow/variable_loader.py b/api/core/workflow/variable_loader.py new file mode 100644 index 0000000000..1e13871d0a --- /dev/null +++ b/api/core/workflow/variable_loader.py @@ -0,0 +1,84 @@ +import abc +from collections.abc import Mapping, Sequence +from typing import Any, Protocol + +from core.variables import Variable +from core.variables.consts import MIN_SELECTORS_LENGTH +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.utils import variable_utils + + +class VariableLoader(Protocol): + """Interface for loading variables based on selectors. + + A `VariableLoader` is responsible for retrieving additional variables required during the execution + of a single node, which are not provided as user inputs. + + NOTE(QuantumGhost): Typically, all variables loaded by a `VariableLoader` should belong to the same + application and share the same `app_id`. However, this interface does not enforce that constraint, + and the `app_id` parameter is intentionally omitted from `load_variables` to achieve separation of + concern and allow for flexible implementations. + + Implementations of `VariableLoader` should almost always have an `app_id` parameter in + their constructor. + + TODO(QuantumGhost): this is a temporally workaround. If we can move the creation of node instance into + `WorkflowService.single_step_run`, we may get rid of this interface. + """ + + @abc.abstractmethod + def load_variables(self, selectors: list[list[str]]) -> list[Variable]: + """Load variables based on the provided selectors. If the selectors are empty, + this method should return an empty list. + + The order of the returned variables is not guaranteed. If the caller wants to ensure + a specific order, they should sort the returned list themselves. + + :param: selectors: a list of string list, each inner list should have at least two elements: + - the first element is the node ID, + - the second element is the variable name. + :return: a list of Variable objects that match the provided selectors. + """ + pass + + +class _DummyVariableLoader(VariableLoader): + """A dummy implementation of VariableLoader that does not load any variables. + Serves as a placeholder when no variable loading is needed. + """ + + def load_variables(self, selectors: list[list[str]]) -> list[Variable]: + return [] + + +DUMMY_VARIABLE_LOADER = _DummyVariableLoader() + + +def load_into_variable_pool( + variable_loader: VariableLoader, + variable_pool: VariablePool, + variable_mapping: Mapping[str, Sequence[str]], + user_inputs: Mapping[str, Any], +): + # Loading missing variable from draft var here, and set it into + # variable_pool. + variables_to_load: list[list[str]] = [] + for key, selector in variable_mapping.items(): + # NOTE(QuantumGhost): this logic needs to be in sync with + # `WorkflowEntry.mapping_user_inputs_to_variable_pool`. + node_variable_list = key.split(".") + if len(node_variable_list) < 1: + raise ValueError(f"Invalid variable key: {key}. It should have at least one element.") + if key in user_inputs: + continue + node_variable_key = ".".join(node_variable_list[1:]) + if node_variable_key in user_inputs: + continue + if variable_pool.get(selector) is None: + variables_to_load.append(list(selector)) + loaded = variable_loader.load_variables(variables_to_load) + for var in loaded: + assert len(var.selector) >= MIN_SELECTORS_LENGTH, f"Invalid variable {var}" + variable_utils.append_variables_recursively( + variable_pool, node_id=var.selector[0], variable_key_list=list(var.selector[1:]), variable_value=var + ) diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index b88f9edd03..6ee562fc8d 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -92,7 +92,7 @@ class WorkflowCycleManager: ) -> WorkflowExecution: workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) - outputs = WorkflowEntry.handle_special_values(outputs) + # outputs = WorkflowEntry.handle_special_values(outputs) workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED workflow_execution.outputs = outputs or {} @@ -125,7 +125,7 @@ class WorkflowCycleManager: trace_manager: Optional[TraceQueueManager] = None, ) -> WorkflowExecution: execution = self._get_workflow_execution_or_raise_error(workflow_run_id) - outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) + # outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED execution.outputs = outputs or {} @@ -242,9 +242,9 @@ class WorkflowCycleManager: raise ValueError(f"Domain node execution not found: {event.node_execution_id}") # Process data - inputs = WorkflowEntry.handle_special_values(event.inputs) - process_data = WorkflowEntry.handle_special_values(event.process_data) - outputs = WorkflowEntry.handle_special_values(event.outputs) + inputs = event.inputs + process_data = event.process_data + outputs = event.outputs # Convert metadata keys to strings execution_metadata_dict = {} @@ -289,7 +289,7 @@ class WorkflowCycleManager: # Process data inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) - outputs = WorkflowEntry.handle_special_values(event.outputs) + outputs = event.outputs # Convert metadata keys to strings execution_metadata_dict = {} @@ -326,7 +326,7 @@ class WorkflowCycleManager: finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - created_at).total_seconds() inputs = WorkflowEntry.handle_special_values(event.inputs) - outputs = WorkflowEntry.handle_special_values(event.outputs) + outputs = event.outputs # Convert metadata keys to strings origin_metadata = { diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 7648947fca..182c54fa77 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -21,6 +21,7 @@ from core.workflow.nodes import NodeType from core.workflow.nodes.base import BaseNode from core.workflow.nodes.event import NodeEvent from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from factories import file_factory from models.enums import UserFrom from models.workflow import ( @@ -119,7 +120,9 @@ class WorkflowEntry: workflow: Workflow, node_id: str, user_id: str, - user_inputs: dict, + user_inputs: Mapping[str, Any], + variable_pool: VariablePool, + variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, ) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]: """ Single step run workflow node @@ -129,29 +132,14 @@ class WorkflowEntry: :param user_inputs: user inputs :return: """ - # fetch node info from workflow graph - workflow_graph = workflow.graph_dict - if not workflow_graph: - raise ValueError("workflow graph not found") - - nodes = workflow_graph.get("nodes") - if not nodes: - raise ValueError("nodes not found in workflow graph") - - # fetch node config from node id - try: - node_config = next(filter(lambda node: node["id"] == node_id, nodes)) - except StopIteration: - raise ValueError("node id not found in workflow graph") + node_config = workflow.get_node_config_by_id(node_id) + node_config_data = node_config.get("data", {}) # Get node class - node_type = NodeType(node_config.get("data", {}).get("type")) - node_version = node_config.get("data", {}).get("version", "1") + node_type = NodeType(node_config_data.get("type")) + node_version = node_config_data.get("version", "1") node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] - # init variable pool - variable_pool = VariablePool(environment_variables=workflow.environment_variables) - # init graph graph = Graph.init(graph_config=workflow.graph_dict) @@ -182,16 +170,33 @@ class WorkflowEntry: except NotImplementedError: variable_mapping = {} + # Loading missing variable from draft var here, and set it into + # variable_pool. + load_into_variable_pool( + variable_loader=variable_loader, + variable_pool=variable_pool, + variable_mapping=variable_mapping, + user_inputs=user_inputs, + ) + cls.mapping_user_inputs_to_variable_pool( variable_mapping=variable_mapping, user_inputs=user_inputs, variable_pool=variable_pool, tenant_id=workflow.tenant_id, ) + try: # run node generator = node_instance.run() except Exception as e: + logger.exception( + "error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s", + workflow.id, + node_instance.id, + node_instance.node_type, + node_instance.version(), + ) raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) return node_instance, generator @@ -294,10 +299,20 @@ class WorkflowEntry: return node_instance, generator except Exception as e: + logger.exception( + "error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s", + node_instance.id, + node_instance.node_type, + node_instance.version(), + ) raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) @staticmethod def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: + # NOTE(QuantumGhost): Avoid using this function in new code. + # Keep values structured as long as possible and only convert to dict + # immediately before serialization (e.g., JSON serialization) to maintain + # data integrity and type information. result = WorkflowEntry._handle_special_values(value) return result if isinstance(result, Mapping) or result is None else dict(result) @@ -324,10 +339,17 @@ class WorkflowEntry: cls, *, variable_mapping: Mapping[str, Sequence[str]], - user_inputs: dict, + user_inputs: Mapping[str, Any], variable_pool: VariablePool, tenant_id: str, ) -> None: + # NOTE(QuantumGhost): This logic should remain synchronized with + # the implementation of `load_into_variable_pool`, specifically the logic about + # variable existence checking. + + # WARNING(QuantumGhost): The semantics of this method are not clearly defined, + # and multiple parts of the codebase depend on its current behavior. + # Modify with caution. for node_variable, variable_selector in variable_mapping.items(): # fetch node id and variable key from node_variable node_variable_list = node_variable.split(".") diff --git a/api/core/workflow/workflow_type_encoder.py b/api/core/workflow/workflow_type_encoder.py new file mode 100644 index 0000000000..0123fdac18 --- /dev/null +++ b/api/core/workflow/workflow_type_encoder.py @@ -0,0 +1,49 @@ +import json +from collections.abc import Mapping +from typing import Any + +from pydantic import BaseModel + +from core.file.models import File +from core.variables import Segment + + +class WorkflowRuntimeTypeEncoder(json.JSONEncoder): + def default(self, o: Any): + if isinstance(o, Segment): + return o.value + elif isinstance(o, File): + return o.to_dict() + elif isinstance(o, BaseModel): + return o.model_dump(mode="json") + else: + return super().default(o) + + +class WorkflowRuntimeTypeConverter: + def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None: + result = self._to_json_encodable_recursive(value) + return result if isinstance(result, Mapping) or result is None else dict(result) + + def _to_json_encodable_recursive(self, value: Any) -> Any: + if value is None: + return value + if isinstance(value, (bool, int, str, float)): + return value + if isinstance(value, Segment): + return self._to_json_encodable_recursive(value.value) + if isinstance(value, File): + return value.to_dict() + if isinstance(value, BaseModel): + return value.model_dump(mode="json") + if isinstance(value, dict): + res = {} + for k, v in value.items(): + res[k] = self._to_json_encodable_recursive(v) + return res + if isinstance(value, list): + res_list = [] + for item in value: + res_list.append(self._to_json_encodable_recursive(item)) + return res_list + return value diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 81606594e0..00c9a7fff9 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -5,6 +5,7 @@ from typing import Any, cast import httpx from sqlalchemy import select +from sqlalchemy.orm import Session from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers @@ -92,6 +93,8 @@ def build_from_mappings( tenant_id: str, strict_type_validation: bool = False, ) -> Sequence[File]: + # TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query. + # Implement batch processing to reduce database load when handling multiple files. files = [ build_from_mapping( mapping=mapping, @@ -425,3 +428,75 @@ def _get_file_type_by_mimetype(mime_type: str) -> FileType | None: def get_file_type_by_mime_type(mime_type: str) -> FileType: return _get_file_type_by_mimetype(mime_type) or FileType.CUSTOM + + +class StorageKeyLoader: + """FileKeyLoader load the storage key from database for a list of files. + This loader is batched, the + """ + + def __init__(self, session: Session, tenant_id: str) -> None: + self._session = session + self._tenant_id = tenant_id + + def _load_upload_files(self, upload_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, UploadFile]: + stmt = select(UploadFile).where( + UploadFile.id.in_(upload_file_ids), + UploadFile.tenant_id == self._tenant_id, + ) + + return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)} + + def _load_tool_files(self, tool_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, ToolFile]: + stmt = select(ToolFile).where( + ToolFile.id.in_(tool_file_ids), + ToolFile.tenant_id == self._tenant_id, + ) + return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)} + + def load_storage_keys(self, files: Sequence[File]): + """Loads storage keys for a sequence of files by retrieving the corresponding + `UploadFile` or `ToolFile` records from the database based on their transfer method. + + This method doesn't modify the input sequence structure but updates the `_storage_key` + property of each file object by extracting the relevant key from its database record. + + Performance note: This is a batched operation where database query count remains constant + regardless of input size. However, for optimal performance, input sequences should contain + fewer than 1000 files. For larger collections, split into smaller batches and process each + batch separately. + """ + + upload_file_ids: list[uuid.UUID] = [] + tool_file_ids: list[uuid.UUID] = [] + for file in files: + related_model_id = file.related_id + if file.related_id is None: + raise ValueError("file id should not be None.") + if file.tenant_id != self._tenant_id: + err_msg = ( + f"invalid file, expected tenant_id={self._tenant_id}, " + f"got tenant_id={file.tenant_id}, file_id={file.id}, related_model_id={related_model_id}" + ) + raise ValueError(err_msg) + model_id = uuid.UUID(related_model_id) + + if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): + upload_file_ids.append(model_id) + elif file.transfer_method == FileTransferMethod.TOOL_FILE: + tool_file_ids.append(model_id) + + tool_files = self._load_tool_files(tool_file_ids) + upload_files = self._load_upload_files(upload_file_ids) + for file in files: + model_id = uuid.UUID(file.related_id) + if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): + upload_file_row = upload_files.get(model_id) + if upload_file_row is None: + raise ValueError(...) + file._storage_key = upload_file_row.key + elif file.transfer_method == FileTransferMethod.TOOL_FILE: + tool_file_row = tool_files.get(model_id) + if tool_file_row is None: + raise ValueError(...) + file._storage_key = tool_file_row.file_key diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 202f10044b..8dd4ec2e4a 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -46,6 +46,10 @@ class UnsupportedSegmentTypeError(Exception): pass +class TypeMismatchError(Exception): + pass + + # Define the constant SEGMENT_TO_VARIABLE_MAP = { StringSegment: StringVariable, @@ -110,8 +114,6 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen result = ArrayNumberVariable.model_validate(mapping) case SegmentType.ARRAY_OBJECT if isinstance(value, list): result = ArrayObjectVariable.model_validate(mapping) - case SegmentType.ARRAY_FILE if isinstance(value, list): - result = ArrayFileVariable.model_validate(mapping) case _: raise VariableError(f"not supported value type {value_type}") if result.size > dify_config.MAX_VARIABLE_SIZE: @@ -121,6 +123,10 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen return cast(Variable, result) +def infer_segment_type_from_value(value: Any, /) -> SegmentType: + return build_segment(value).value_type + + def build_segment(value: Any, /) -> Segment: if value is None: return NoneSegment() @@ -151,10 +157,80 @@ def build_segment(value: Any, /) -> Segment: case SegmentType.NONE: return ArrayAnySegment(value=value) case _: + # This should be unreachable. raise ValueError(f"not supported value {value}") raise ValueError(f"not supported value {value}") +def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: + """ + Build a segment with explicit type checking. + + This function creates a segment from a value while enforcing type compatibility + with the specified segment_type. It provides stricter type validation compared + to the standard build_segment function. + + Args: + segment_type: The expected SegmentType for the resulting segment + value: The value to be converted into a segment + + Returns: + Segment: A segment instance of the appropriate type + + Raises: + TypeMismatchError: If the value type doesn't match the expected segment_type + + Special Cases: + - For empty list [] values, if segment_type is array[*], returns the corresponding array type + - Type validation is performed before segment creation + + Examples: + >>> build_segment_with_type(SegmentType.STRING, "hello") + StringSegment(value="hello") + + >>> build_segment_with_type(SegmentType.ARRAY_STRING, []) + ArrayStringSegment(value=[]) + + >>> build_segment_with_type(SegmentType.STRING, 123) + # Raises TypeMismatchError + """ + # Handle None values + if value is None: + if segment_type == SegmentType.NONE: + return NoneSegment() + else: + raise TypeMismatchError(f"Expected {segment_type}, but got None") + + # Handle empty list special case for array types + if isinstance(value, list) and len(value) == 0: + if segment_type == SegmentType.ARRAY_ANY: + return ArrayAnySegment(value=value) + elif segment_type == SegmentType.ARRAY_STRING: + return ArrayStringSegment(value=value) + elif segment_type == SegmentType.ARRAY_NUMBER: + return ArrayNumberSegment(value=value) + elif segment_type == SegmentType.ARRAY_OBJECT: + return ArrayObjectSegment(value=value) + elif segment_type == SegmentType.ARRAY_FILE: + return ArrayFileSegment(value=value) + else: + raise TypeMismatchError(f"Expected {segment_type}, but got empty list") + + # Build segment using existing logic to infer actual type + inferred_segment = build_segment(value) + inferred_type = inferred_segment.value_type + + # Type compatibility checking + if inferred_type == segment_type: + return inferred_segment + + # Type mismatch - raise error with descriptive message + raise TypeMismatchError( + f"Type mismatch: expected {segment_type}, but value '{value}' " + f"(type: {type(value).__name__}) corresponds to {inferred_type}" + ) + + def segment_to_variable( *, segment: Segment, diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 74fdf8bd97..a106728e9c 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -19,7 +19,6 @@ workflow_run_for_log_fields = { workflow_run_for_list_fields = { "id": fields.String, - "sequence_number": fields.Integer, "version": fields.String, "status": fields.String, "elapsed_time": fields.Float, @@ -36,7 +35,6 @@ advanced_chat_workflow_run_for_list_fields = { "id": fields.String, "conversation_id": fields.String, "message_id": fields.String, - "sequence_number": fields.Integer, "version": fields.String, "status": fields.String, "elapsed_time": fields.Float, @@ -63,7 +61,6 @@ workflow_run_pagination_fields = { workflow_run_detail_fields = { "id": fields.String, - "sequence_number": fields.Integer, "version": fields.String, "graph": fields.Raw(attribute="graph_dict"), "inputs": fields.Raw(attribute="inputs_dict"), diff --git a/api/libs/datetime_utils.py b/api/libs/datetime_utils.py new file mode 100644 index 0000000000..e576a34629 --- /dev/null +++ b/api/libs/datetime_utils.py @@ -0,0 +1,22 @@ +import abc +import datetime +from typing import Protocol + + +class _NowFunction(Protocol): + @abc.abstractmethod + def __call__(self, tz: datetime.timezone | None) -> datetime.datetime: + pass + + +# _now_func is a callable with the _NowFunction signature. +# Its sole purpose is to abstract time retrieval, enabling +# developers to mock this behavior in tests and time-dependent scenarios. +_now_func: _NowFunction = datetime.datetime.now + + +def naive_utc_now() -> datetime.datetime: + """Return a naive datetime object (without timezone information) + representing current UTC time. + """ + return _now_func(datetime.UTC).replace(tzinfo=None) diff --git a/api/libs/jsonutil.py b/api/libs/jsonutil.py new file mode 100644 index 0000000000..fa29671034 --- /dev/null +++ b/api/libs/jsonutil.py @@ -0,0 +1,11 @@ +import json + +from pydantic import BaseModel + + +class PydanticModelEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, BaseModel): + return o.model_dump() + else: + super().default(o) diff --git a/api/libs/sendgrid.py b/api/libs/sendgrid.py index 6217e9f4a6..5409e3eeeb 100644 --- a/api/libs/sendgrid.py +++ b/api/libs/sendgrid.py @@ -35,7 +35,10 @@ class SendGridClient: logging.exception("SendGridClient Timeout occurred while sending email") raise except (UnauthorizedError, ForbiddenError) as e: - logging.exception("SendGridClient Authentication failed. Verify that your credentials and the 'from") + logging.exception( + "SendGridClient Authentication failed. " + "Verify that your credentials and the 'from' email address are correct" + ) raise except Exception as e: logging.exception(f"SendGridClient Unexpected error occurred while sending email to {_to}") diff --git a/api/libs/smtp.py b/api/libs/smtp.py index 35561f071c..b94386660e 100644 --- a/api/libs/smtp.py +++ b/api/libs/smtp.py @@ -22,7 +22,11 @@ class SMTPClient: if self.use_tls: if self.opportunistic_tls: smtp = smtplib.SMTP(self.server, self.port, timeout=10) + # Send EHLO command with the HELO domain name as the server address + smtp.ehlo(self.server) smtp.starttls() + # Resend EHLO command to identify the TLS session + smtp.ehlo(self.server) else: smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10) else: diff --git a/api/migrations/versions/2025_06_19_1633-0ab65e1cc7fa_remove_sequence_number_from_workflow_.py b/api/migrations/versions/2025_06_19_1633-0ab65e1cc7fa_remove_sequence_number_from_workflow_.py new file mode 100644 index 0000000000..29fef77798 --- /dev/null +++ b/api/migrations/versions/2025_06_19_1633-0ab65e1cc7fa_remove_sequence_number_from_workflow_.py @@ -0,0 +1,66 @@ +"""remove sequence_number from workflow_runs + +Revision ID: 0ab65e1cc7fa +Revises: 4474872b0ee6 +Create Date: 2025-06-19 16:33:13.377215 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '0ab65e1cc7fa' +down_revision = '4474872b0ee6' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('workflow_run_tenant_app_sequence_idx')) + batch_op.drop_column('sequence_number') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + # WARNING: This downgrade CANNOT recover the original sequence_number values! + # The original sequence numbers are permanently lost after the upgrade. + # This downgrade will regenerate sequence numbers based on created_at order, + # which may result in different values than the original sequence numbers. + # + # If you need to preserve original sequence numbers, use the alternative + # migration approach that creates a backup table before removal. + + # Step 1: Add sequence_number column as nullable first + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.add_column(sa.Column('sequence_number', sa.INTEGER(), autoincrement=False, nullable=True)) + + # Step 2: Populate sequence_number values based on created_at order within each app + # NOTE: This recreates sequence numbering logic but values will be different + # from the original sequence numbers that were removed in the upgrade + connection = op.get_bind() + connection.execute(sa.text(""" + UPDATE workflow_runs + SET sequence_number = subquery.row_num + FROM ( + SELECT id, ROW_NUMBER() OVER ( + PARTITION BY tenant_id, app_id + ORDER BY created_at, id + ) as row_num + FROM workflow_runs + ) subquery + WHERE workflow_runs.id = subquery.id + """)) + + # Step 3: Make the column NOT NULL and add the index + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.alter_column('sequence_number', nullable=False) + batch_op.create_index(batch_op.f('workflow_run_tenant_app_sequence_idx'), ['tenant_id', 'app_id', 'sequence_number'], unique=False) + + # ### end Alembic commands ### diff --git a/api/models/_workflow_exc.py b/api/models/_workflow_exc.py new file mode 100644 index 0000000000..f6271bda47 --- /dev/null +++ b/api/models/_workflow_exc.py @@ -0,0 +1,20 @@ +"""All these exceptions are not meant to be caught by callers.""" + + +class WorkflowDataError(Exception): + """Base class for all workflow data related exceptions. + + This should be used to indicate issues with workflow data integrity, such as + no `graph` configuration, missing `nodes` field in `graph` configuration, or + similar issues. + """ + + pass + + +class NodeNotFoundError(WorkflowDataError): + """Raised when a node with the specified ID is not found in the workflow.""" + + def __init__(self, node_id: str): + super().__init__(f"Node with ID '{node_id}' not found in the workflow.") + self.node_id = node_id diff --git a/api/models/model.py b/api/models/model.py index 4ab9faa69e..1c034f8867 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -612,6 +612,14 @@ class InstalledApp(Base): return tenant +class ConversationSource(StrEnum): + """This enumeration is designed for use with `Conversation.from_source`.""" + + # NOTE(QuantumGhost): The enumeration members may not cover all possible cases. + API = "api" + CONSOLE = "console" + + class Conversation(Base): __tablename__ = "conversations" __table_args__ = ( @@ -633,7 +641,14 @@ class Conversation(Base): system_instruction = db.Column(db.Text) system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) status = db.Column(db.String(255), nullable=False) + + # The `invoke_from` records how the conversation is created. + # + # Its value corresponds to the members of `InvokeFrom`. + # (api/core/app/entities/app_invoke_entities.py) invoke_from = db.Column(db.String(255), nullable=True) + + # ref: ConversationSource. from_source = db.Column(db.String(255), nullable=False) from_end_user_id = db.Column(StringUUID) from_account_id = db.Column(StringUUID) @@ -818,7 +833,12 @@ class Conversation(Base): @property def first_message(self): - return db.session.query(Message).filter(Message.conversation_id == self.id).first() + return ( + db.session.query(Message) + .filter(Message.conversation_id == self.id) + .order_by(Message.created_at.asc()) + .first() + ) @property def app(self): diff --git a/api/models/workflow.py b/api/models/workflow.py index 741422db06..645089ae7f 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -7,10 +7,16 @@ from typing import TYPE_CHECKING, Any, Optional, Union from uuid import uuid4 from flask_login import current_user +from sqlalchemy import orm +from core.file.constants import maybe_file_object +from core.file.models import File from core.variables import utils as variable_utils from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from factories.variable_factory import build_segment +from core.workflow.nodes.enums import NodeType +from factories.variable_factory import TypeMismatchError, build_segment_with_type + +from ._workflow_exc import NodeNotFoundError, WorkflowDataError if TYPE_CHECKING: from models.model import AppMode @@ -73,6 +79,10 @@ class WorkflowType(Enum): return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT +class _InvalidGraphDefinitionError(Exception): + pass + + class Workflow(Base): """ Workflow, for `Workflow App` and `Chat App workflow mode`. @@ -140,6 +150,8 @@ class Workflow(Base): "rag_pipeline_variables", db.Text, nullable=False, server_default="{}" ) + VERSION_DRAFT = "draft" + @classmethod def new( cls, @@ -185,8 +197,72 @@ class Workflow(Base): @property def graph_dict(self) -> Mapping[str, Any]: + # TODO(QuantumGhost): Consider caching `graph_dict` to avoid repeated JSON decoding. + # + # Using `functools.cached_property` could help, but some code in the codebase may + # modify the returned dict, which can cause issues elsewhere. + # + # For example, changing this property to a cached property led to errors like the + # following when single stepping an `Iteration` node: + # + # Root node id 1748401971780start not found in the graph + # + # There is currently no standard way to make a dict deeply immutable in Python, + # and tracking modifications to the returned dict is difficult. For now, we leave + # the code as-is to avoid these issues. + # + # Currently, the following functions / methods would mutate the returned dict: + # + # - `_get_graph_and_variable_pool_of_single_iteration`. + # - `_get_graph_and_variable_pool_of_single_loop`. return json.loads(self.graph) if self.graph else {} + def get_node_config_by_id(self, node_id: str) -> Mapping[str, Any]: + """Extract a node configuration from the workflow graph by node ID. + A node configuration is a dictionary containing the node's properties, including + the node's id, title, and its data as a dict. + """ + workflow_graph = self.graph_dict + + if not workflow_graph: + raise WorkflowDataError(f"workflow graph not found, workflow_id={self.id}") + + nodes = workflow_graph.get("nodes") + if not nodes: + raise WorkflowDataError("nodes not found in workflow graph") + + try: + node_config = next(filter(lambda node: node["id"] == node_id, nodes)) + except StopIteration: + raise NodeNotFoundError(node_id) + assert isinstance(node_config, dict) + return node_config + + @staticmethod + def get_node_type_from_node_config(node_config: Mapping[str, Any]) -> NodeType: + """Extract type of a node from the node configuration returned by `get_node_config_by_id`.""" + node_config_data = node_config.get("data", {}) + # Get node class + node_type = NodeType(node_config_data.get("type")) + return node_type + + @staticmethod + def get_enclosing_node_type_and_id(node_config: Mapping[str, Any]) -> tuple[NodeType, str] | None: + in_loop = node_config.get("isInLoop", False) + in_iteration = node_config.get("isInIteration", False) + if in_loop: + loop_id = node_config.get("loop_id") + if loop_id is None: + raise _InvalidGraphDefinitionError("invalid graph") + return NodeType.LOOP, loop_id + elif in_iteration: + iteration_id = node_config.get("iteration_id") + if iteration_id is None: + raise _InvalidGraphDefinitionError("invalid graph") + return NodeType.ITERATION, iteration_id + else: + return None + @property def features(self) -> str: """ @@ -400,6 +476,10 @@ class Workflow(Base): ensure_ascii=False, ) + @staticmethod + def version_from_datetime(d: datetime) -> str: + return str(d) + class WorkflowRun(Base): """ @@ -410,7 +490,7 @@ class WorkflowRun(Base): - id (uuid) Run ID - tenant_id (uuid) Workspace ID - app_id (uuid) App ID - - sequence_number (int) Auto-increment sequence number, incremented within the App, starting from 1 + - workflow_id (uuid) Workflow ID - type (string) Workflow type - triggered_from (string) Trigger source @@ -443,13 +523,12 @@ class WorkflowRun(Base): __table_args__ = ( db.PrimaryKeyConstraint("id", name="workflow_run_pkey"), db.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"), - db.Index("workflow_run_tenant_app_sequence_idx", "tenant_id", "app_id", "sequence_number"), ) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID) - sequence_number: Mapped[int] = mapped_column() + workflow_id: Mapped[str] = mapped_column(StringUUID) type: Mapped[str] = mapped_column(db.String(255)) triggered_from: Mapped[str] = mapped_column(db.String(255)) @@ -509,7 +588,6 @@ class WorkflowRun(Base): "id": self.id, "tenant_id": self.tenant_id, "app_id": self.app_id, - "sequence_number": self.sequence_number, "workflow_id": self.workflow_id, "type": self.type, "triggered_from": self.triggered_from, @@ -535,7 +613,6 @@ class WorkflowRun(Base): id=data.get("id"), tenant_id=data.get("tenant_id"), app_id=data.get("app_id"), - sequence_number=data.get("sequence_number"), workflow_id=data.get("workflow_id"), type=data.get("type"), triggered_from=data.get("triggered_from"), @@ -863,8 +940,18 @@ def _naive_utc_datetime(): class WorkflowDraftVariable(Base): + """`WorkflowDraftVariable` record variables and outputs generated during + debugging worfklow or chatflow. + + IMPORTANT: This model maintains multiple invariant rules that must be preserved. + Do not instantiate this class directly with the constructor. + + Instead, use the factory methods (`new_conversation_variable`, `new_sys_variable`, + `new_node_variable`) defined below to ensure all invariants are properly maintained. + """ + @staticmethod - def unique_columns() -> list[str]: + def unique_app_id_node_id_name() -> list[str]: return [ "app_id", "node_id", @@ -872,7 +959,9 @@ class WorkflowDraftVariable(Base): ] __tablename__ = "workflow_draft_variables" - __table_args__ = (UniqueConstraint(*unique_columns()),) + __table_args__ = (UniqueConstraint(*unique_app_id_node_id_name()),) + # Required for instance variable annotation. + __allow_unmapped__ = True # id is the unique identifier of a draft variable. id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) @@ -953,6 +1042,36 @@ class WorkflowDraftVariable(Base): default=None, ) + # Cache for deserialized value + # + # NOTE(QuantumGhost): This field serves two purposes: + # + # 1. Caches deserialized values to reduce repeated parsing costs + # 2. Allows modification of the deserialized value after retrieval, + # particularly important for `File`` variables which require database + # lookups to obtain storage_key and other metadata + # + # Use double underscore prefix for better encapsulation, + # making this attribute harder to access from outside the class. + __value: Segment | None + + def __init__(self, *args, **kwargs): + """ + The constructor of `WorkflowDraftVariable` is not intended for + direct use outside this file. Its solo purpose is setup private state + used by the model instance. + + Please use the factory methods + (`new_conversation_variable`, `new_sys_variable`, `new_node_variable`) + defined below to create instances of this class. + """ + super().__init__(*args, **kwargs) + self.__value = None + + @orm.reconstructor + def _init_on_load(self): + self.__value = None + def get_selector(self) -> list[str]: selector = json.loads(self.selector) if not isinstance(selector, list): @@ -967,15 +1086,92 @@ class WorkflowDraftVariable(Base): def _set_selector(self, value: list[str]): self.selector = json.dumps(value) - def get_value(self) -> Segment | None: - return build_segment(json.loads(self.value)) + def _loads_value(self) -> Segment: + value = json.loads(self.value) + return self.build_segment_with_type(self.value_type, value) + + @staticmethod + def rebuild_file_types(value: Any) -> Any: + # NOTE(QuantumGhost): Temporary workaround for structured data handling. + # By this point, `output` has been converted to dict by + # `WorkflowEntry.handle_special_values`, so we need to + # reconstruct File objects from their serialized form + # to maintain proper variable saving behavior. + # + # Ideally, we should work with structured data objects directly + # rather than their serialized forms. + # However, multiple components in the codebase depend on + # `WorkflowEntry.handle_special_values`, making a comprehensive migration challenging. + if isinstance(value, dict): + if not maybe_file_object(value): + return value + return File.model_validate(value) + elif isinstance(value, list) and value: + first = value[0] + if not maybe_file_object(first): + return value + return [File.model_validate(i) for i in value] + else: + return value + + @classmethod + def build_segment_with_type(cls, segment_type: SegmentType, value: Any) -> Segment: + # Extends `variable_factory.build_segment_with_type` functionality by + # reconstructing `FileSegment`` or `ArrayFileSegment`` objects from + # their serialized dictionary or list representations, respectively. + if segment_type == SegmentType.FILE: + if isinstance(value, File): + return build_segment_with_type(segment_type, value) + elif isinstance(value, dict): + file = cls.rebuild_file_types(value) + return build_segment_with_type(segment_type, file) + else: + raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}") + if segment_type == SegmentType.ARRAY_FILE: + if not isinstance(value, list): + raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}") + file_list = cls.rebuild_file_types(value) + return build_segment_with_type(segment_type=segment_type, value=file_list) + + return build_segment_with_type(segment_type=segment_type, value=value) + + def get_value(self) -> Segment: + """Decode the serialized value into its corresponding `Segment` object. + + This method caches the result, so repeated calls will return the same + object instance without re-parsing the serialized data. + + If you need to modify the returned `Segment`, use `value.model_copy()` + to create a copy first to avoid affecting the cached instance. + + For more information about the caching mechanism, see the documentation + of the `__value` field. + + Returns: + Segment: The deserialized value as a Segment object. + """ + + if self.__value is not None: + return self.__value + value = self._loads_value() + self.__value = value + return value def set_name(self, name: str): self.name = name self._set_selector([self.node_id, name]) def set_value(self, value: Segment): - self.value = json.dumps(value.value) + """Updates the `value` and corresponding `value_type` fields in the database model. + + This method also stores the provided Segment object in the deserialized cache + without creating a copy, allowing for efficient value access. + + Args: + value: The Segment object to store as the variable's value. + """ + self.__value = value + self.value = json.dumps(value, cls=variable_utils.SegmentJSONEncoder) self.value_type = value.value_type def get_node_id(self) -> str | None: @@ -1001,6 +1197,7 @@ class WorkflowDraftVariable(Base): node_id: str, name: str, value: Segment, + node_execution_id: str | None, description: str = "", ) -> "WorkflowDraftVariable": variable = WorkflowDraftVariable() @@ -1012,6 +1209,7 @@ class WorkflowDraftVariable(Base): variable.name = name variable.set_value(value) variable._set_selector(list(variable_utils.to_selector(node_id, name))) + variable.node_execution_id = node_execution_id return variable @classmethod @@ -1021,13 +1219,17 @@ class WorkflowDraftVariable(Base): app_id: str, name: str, value: Segment, + description: str = "", ) -> "WorkflowDraftVariable": variable = cls._new( app_id=app_id, node_id=CONVERSATION_VARIABLE_NODE_ID, name=name, value=value, + description=description, + node_execution_id=None, ) + variable.editable = True return variable @classmethod @@ -1037,9 +1239,16 @@ class WorkflowDraftVariable(Base): app_id: str, name: str, value: Segment, + node_execution_id: str, editable: bool = False, ) -> "WorkflowDraftVariable": - variable = cls._new(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name, value=value) + variable = cls._new( + app_id=app_id, + node_id=SYSTEM_VARIABLE_NODE_ID, + name=name, + node_execution_id=node_execution_id, + value=value, + ) variable.editable = editable return variable @@ -1051,11 +1260,19 @@ class WorkflowDraftVariable(Base): node_id: str, name: str, value: Segment, + node_execution_id: str, visible: bool = True, + editable: bool = True, ) -> "WorkflowDraftVariable": - variable = cls._new(app_id=app_id, node_id=node_id, name=name, value=value) + variable = cls._new( + app_id=app_id, + node_id=node_id, + name=name, + node_execution_id=node_execution_id, + value=value, + ) variable.visible = visible - variable.editable = True + variable.editable = editable return variable @property diff --git a/api/pyproject.toml b/api/pyproject.toml index 38cc9ae75d..fed0128b90 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -149,6 +149,7 @@ dev = [ "types-ujson>=5.10.0", "boto3-stubs>=1.38.20", "types-jmespath>=1.0.2.20240106", + "hypothesis>=6.131.15", "types_pyOpenSSL>=24.1.0", "types_cffi>=1.17.0", "types_setuptools>=80.9.0", diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 1b026acfd6..20257fa345 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -32,6 +32,7 @@ from models import Account, App, AppMode from models.model import AppModelConfig from models.workflow import Workflow from services.plugin.dependencies_analysis import DependenciesAnalysisService +from services.workflow_draft_variable_service import WorkflowDraftVariableService from services.workflow_service import WorkflowService logger = logging.getLogger(__name__) @@ -292,6 +293,8 @@ class AppDslService: dependencies=check_dependencies_pending_data, ) + draft_var_srv = WorkflowDraftVariableService(session=self._session) + draft_var_srv.delete_workflow_variables(app_id=app.id) return Import( id=import_id, status=status, diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 6f23f98e67..f83359d456 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -64,6 +64,7 @@ from services.external_knowledge_service import ExternalDatasetService from services.feature_service import FeatureModel, FeatureService from services.tag_service import TagService from services.vector_service import VectorService +from tasks.add_document_to_index_task import add_document_to_index_task from tasks.batch_clean_document_task import batch_clean_document_task from tasks.clean_notion_document_task import clean_notion_document_task from tasks.deal_dataset_index_update_task import deal_dataset_index_update_task @@ -76,6 +77,7 @@ from tasks.document_indexing_update_task import document_indexing_update_task from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task from tasks.enable_segments_to_index_task import enable_segments_to_index_task from tasks.recover_document_indexing_task import recover_document_indexing_task +from tasks.remove_document_from_index_task import remove_document_from_index_task from tasks.retry_document_indexing_task import retry_document_indexing_task from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task @@ -323,188 +325,343 @@ class DatasetService: @staticmethod def update_dataset(dataset_id, data, user): + """ + Update dataset configuration and settings. + + Args: + dataset_id: The unique identifier of the dataset to update + data: Dictionary containing the update data + user: The user performing the update operation + + Returns: + Dataset: The updated dataset object + + Raises: + ValueError: If dataset not found or validation fails + NoPermissionError: If user lacks permission to update the dataset + """ + # Retrieve and validate dataset existence dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise ValueError("Dataset not found") - DatasetService.check_dataset_permission(dataset, user) - if dataset.provider == "external": - external_retrieval_model = data.get("external_retrieval_model", None) - if external_retrieval_model: - dataset.retrieval_model = external_retrieval_model - dataset.name = data.get("name", dataset.name) - # check if dataset name is exists - if ( - db.session.query(Dataset) + # check if dataset name is exists + if ( + db.session.query(Dataset) .filter( - Dataset.id != dataset_id, - Dataset.name == dataset.name, - Dataset.tenant_id == dataset.tenant_id, - ) - .first() - ): - raise ValueError("Dataset name already exists") - dataset.description = data.get("description", "") - permission = data.get("permission") - if permission: - dataset.permission = permission - external_knowledge_id = data.get("external_knowledge_id", None) - db.session.add(dataset) - if not external_knowledge_id: - raise ValueError("External knowledge id is required.") - external_knowledge_api_id = data.get("external_knowledge_api_id", None) - if not external_knowledge_api_id: - raise ValueError("External knowledge api id is required.") + Dataset.id != dataset_id, + Dataset.name == data.get("name", dataset.name), + Dataset.tenant_id == dataset.tenant_id, + ) + .first() + ): + raise ValueError("Dataset name already exists") - with Session(db.engine) as session: - external_knowledge_binding = ( - session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first() - ) + # Verify user has permission to update this dataset + DatasetService.check_dataset_permission(dataset, user) - if not external_knowledge_binding: - raise ValueError("External knowledge binding not found.") - - if ( - external_knowledge_binding.external_knowledge_id != external_knowledge_id - or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id - ): - external_knowledge_binding.external_knowledge_id = external_knowledge_id - external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id - db.session.add(external_knowledge_binding) - db.session.commit() + # Handle external dataset updates + if dataset.provider == "external": + return DatasetService._update_external_dataset(dataset, data, user) else: - data.pop("partial_member_list", None) - data.pop("external_knowledge_api_id", None) - data.pop("external_knowledge_id", None) - data.pop("external_retrieval_model", None) - filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"} - action = None - if dataset.indexing_technique != data["indexing_technique"]: - # if update indexing_technique - if data["indexing_technique"] == "economy": - action = "remove" - filtered_data["embedding_model"] = None - filtered_data["embedding_model_provider"] = None - filtered_data["collection_binding_id"] = None - elif data["indexing_technique"] == "high_quality": - action = "add" - # get embedding model setting - try: - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, - provider=data["embedding_model_provider"], - model_type=ModelType.TEXT_EMBEDDING, - model=data["embedding_model"], - ) - filtered_data["embedding_model"] = embedding_model.model - filtered_data["embedding_model_provider"] = embedding_model.provider - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, embedding_model.model - ) - filtered_data["collection_binding_id"] = dataset_collection_binding.id - except LLMBadRequestError: - raise ValueError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider." - ) - except ProviderTokenNotInitError as ex: - raise ValueError(ex.description) - else: - # add default plugin id to both setting sets, to make sure the plugin model provider is consistent - # Skip embedding model checks if not provided in the update request - if ( - "embedding_model_provider" not in data - or "embedding_model" not in data - or not data.get("embedding_model_provider") - or not data.get("embedding_model") - ): - # If the dataset already has embedding model settings, use those - if dataset.embedding_model_provider and dataset.embedding_model: - # Keep existing values - filtered_data["embedding_model_provider"] = dataset.embedding_model_provider - filtered_data["embedding_model"] = dataset.embedding_model - # If collection_binding_id exists, keep it too - if dataset.collection_binding_id: - filtered_data["collection_binding_id"] = dataset.collection_binding_id - # Otherwise, don't try to update embedding model settings at all - # Remove these fields from filtered_data if they exist but are None/empty - if "embedding_model_provider" in filtered_data and not filtered_data["embedding_model_provider"]: - del filtered_data["embedding_model_provider"] - if "embedding_model" in filtered_data and not filtered_data["embedding_model"]: - del filtered_data["embedding_model"] - else: - skip_embedding_update = False - try: - # Handle existing model provider - plugin_model_provider = dataset.embedding_model_provider - plugin_model_provider_str = None - if plugin_model_provider: - plugin_model_provider_str = str(ModelProviderID(plugin_model_provider)) + return DatasetService._update_internal_dataset(dataset, data, user) - # Handle new model provider from request - new_plugin_model_provider = data["embedding_model_provider"] - new_plugin_model_provider_str = None - if new_plugin_model_provider: - new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider)) + @staticmethod + def _update_external_dataset(dataset, data, user): + """ + Update external dataset configuration. - # Only update embedding model if both values are provided and different from current - if ( - plugin_model_provider_str != new_plugin_model_provider_str - or data["embedding_model"] != dataset.embedding_model - ): - action = "update" - model_manager = ModelManager() - try: - embedding_model = model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, - provider=data["embedding_model_provider"], - model_type=ModelType.TEXT_EMBEDDING, - model=data["embedding_model"], - ) - except ProviderTokenNotInitError: - # If we can't get the embedding model, skip updating it - # and keep the existing settings if available - if dataset.embedding_model_provider and dataset.embedding_model: - filtered_data["embedding_model_provider"] = dataset.embedding_model_provider - filtered_data["embedding_model"] = dataset.embedding_model - if dataset.collection_binding_id: - filtered_data["collection_binding_id"] = dataset.collection_binding_id - # Skip the rest of the embedding model update - skip_embedding_update = True - if not skip_embedding_update: - filtered_data["embedding_model"] = embedding_model.model - filtered_data["embedding_model_provider"] = embedding_model.provider - dataset_collection_binding = ( - DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, embedding_model.model - ) - ) - filtered_data["collection_binding_id"] = dataset_collection_binding.id - except LLMBadRequestError: - raise ValueError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider." - ) - except ProviderTokenNotInitError as ex: - raise ValueError(ex.description) + Args: + dataset: The dataset object to update + data: Update data dictionary + user: User performing the update - filtered_data["updated_by"] = user.id - filtered_data["updated_at"] = datetime.datetime.now() + Returns: + Dataset: Updated dataset object + """ + # Update retrieval model if provided + external_retrieval_model = data.get("external_retrieval_model", None) + if external_retrieval_model: + dataset.retrieval_model = external_retrieval_model - # update Retrieval model - filtered_data["retrieval_model"] = data["retrieval_model"] + # Update basic dataset properties + dataset.name = data.get("name", dataset.name) + dataset.description = data.get("description", dataset.description) - # update icon info - if data.get("icon_info"): - filtered_data["icon_info"] = data.get("icon_info") - db.session.query(Dataset).filter_by(id=dataset_id).update(filtered_data) + # Update permission if provided + permission = data.get("permission") + if permission: + dataset.permission = permission + + # Validate and update external knowledge configuration + external_knowledge_id = data.get("external_knowledge_id", None) + external_knowledge_api_id = data.get("external_knowledge_api_id", None) + + if not external_knowledge_id: + raise ValueError("External knowledge id is required.") + if not external_knowledge_api_id: + raise ValueError("External knowledge api id is required.") + # Update metadata fields + dataset.updated_by = user.id if user else None + dataset.updated_at = datetime.datetime.utcnow() + db.session.add(dataset) + + # Update external knowledge binding + DatasetService._update_external_knowledge_binding(dataset.id, external_knowledge_id, external_knowledge_api_id) + + # Commit changes to database + db.session.commit() - db.session.commit() - if action: - deal_dataset_vector_index_task.delay(dataset_id, action) return dataset + @staticmethod + def _update_external_knowledge_binding(dataset_id, external_knowledge_id, external_knowledge_api_id): + """ + Update external knowledge binding configuration. + + Args: + dataset_id: Dataset identifier + external_knowledge_id: External knowledge identifier + external_knowledge_api_id: External knowledge API identifier + """ + with Session(db.engine) as session: + external_knowledge_binding = ( + session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first() + ) + + if not external_knowledge_binding: + raise ValueError("External knowledge binding not found.") + + # Update binding if values have changed + if ( + external_knowledge_binding.external_knowledge_id != external_knowledge_id + or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id + ): + external_knowledge_binding.external_knowledge_id = external_knowledge_id + external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id + db.session.add(external_knowledge_binding) + + @staticmethod + def _update_internal_dataset(dataset, data, user): + """ + Update internal dataset configuration. + + Args: + dataset: The dataset object to update + data: Update data dictionary + user: User performing the update + + Returns: + Dataset: Updated dataset object + """ + # Remove external-specific fields from update data + data.pop("partial_member_list", None) + data.pop("external_knowledge_api_id", None) + data.pop("external_knowledge_id", None) + data.pop("external_retrieval_model", None) + + # Filter out None values except for description field + filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"} + + # Handle indexing technique changes and embedding model updates + action = DatasetService._handle_indexing_technique_change(dataset, data, filtered_data) + + # Add metadata fields + filtered_data["updated_by"] = user.id + filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + # update Retrieval model + filtered_data["retrieval_model"] = data["retrieval_model"] + # update icon info + if data.get("icon_info"): + filtered_data["icon_info"] = data.get("icon_info") + + # Update dataset in database + db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data) + db.session.commit() + + # Trigger vector index task if indexing technique changed + if action: + deal_dataset_vector_index_task.delay(dataset.id, action) + + return dataset + + @staticmethod + def _handle_indexing_technique_change(dataset, data, filtered_data): + """ + Handle changes in indexing technique and configure embedding models accordingly. + + Args: + dataset: Current dataset object + data: Update data dictionary + filtered_data: Filtered update data + + Returns: + str: Action to perform ('add', 'remove', 'update', or None) + """ + if dataset.indexing_technique != data["indexing_technique"]: + if data["indexing_technique"] == "economy": + # Remove embedding model configuration for economy mode + filtered_data["embedding_model"] = None + filtered_data["embedding_model_provider"] = None + filtered_data["collection_binding_id"] = None + return "remove" + elif data["indexing_technique"] == "high_quality": + # Configure embedding model for high quality mode + DatasetService._configure_embedding_model_for_high_quality(data, filtered_data) + return "add" + else: + # Handle embedding model updates when indexing technique remains the same + return DatasetService._handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data) + return None + + @staticmethod + def _configure_embedding_model_for_high_quality(data, filtered_data): + """ + Configure embedding model settings for high quality indexing. + + Args: + data: Update data dictionary + filtered_data: Filtered update data to modify + """ + try: + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=data["embedding_model_provider"], + model_type=ModelType.TEXT_EMBEDDING, + model=data["embedding_model"], + ) + filtered_data["embedding_model"] = embedding_model.model + filtered_data["embedding_model_provider"] = embedding_model.provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.provider, embedding_model.model + ) + filtered_data["collection_binding_id"] = dataset_collection_binding.id + except LLMBadRequestError: + raise ValueError( + "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(ex.description) + + @staticmethod + def _handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data): + """ + Handle embedding model updates when indexing technique remains the same. + + Args: + dataset: Current dataset object + data: Update data dictionary + filtered_data: Filtered update data to modify + + Returns: + str: Action to perform ('update' or None) + """ + # Skip embedding model checks if not provided in the update request + if ( + "embedding_model_provider" not in data + or "embedding_model" not in data + or not data.get("embedding_model_provider") + or not data.get("embedding_model") + ): + DatasetService._preserve_existing_embedding_settings(dataset, filtered_data) + return None + else: + return DatasetService._update_embedding_model_settings(dataset, data, filtered_data) + + @staticmethod + def _preserve_existing_embedding_settings(dataset, filtered_data): + """ + Preserve existing embedding model settings when not provided in update. + + Args: + dataset: Current dataset object + filtered_data: Filtered update data to modify + """ + # If the dataset already has embedding model settings, use those + if dataset.embedding_model_provider and dataset.embedding_model: + filtered_data["embedding_model_provider"] = dataset.embedding_model_provider + filtered_data["embedding_model"] = dataset.embedding_model + # If collection_binding_id exists, keep it too + if dataset.collection_binding_id: + filtered_data["collection_binding_id"] = dataset.collection_binding_id + # Otherwise, don't try to update embedding model settings at all + # Remove these fields from filtered_data if they exist but are None/empty + if "embedding_model_provider" in filtered_data and not filtered_data["embedding_model_provider"]: + del filtered_data["embedding_model_provider"] + if "embedding_model" in filtered_data and not filtered_data["embedding_model"]: + del filtered_data["embedding_model"] + + @staticmethod + def _update_embedding_model_settings(dataset, data, filtered_data): + """ + Update embedding model settings with new values. + + Args: + dataset: Current dataset object + data: Update data dictionary + filtered_data: Filtered update data to modify + + Returns: + str: Action to perform ('update' or None) + """ + try: + # Compare current and new model provider settings + current_provider_str = ( + str(ModelProviderID(dataset.embedding_model_provider)) if dataset.embedding_model_provider else None + ) + new_provider_str = ( + str(ModelProviderID(data["embedding_model_provider"])) if data["embedding_model_provider"] else None + ) + + # Only update if values are different + if current_provider_str != new_provider_str or data["embedding_model"] != dataset.embedding_model: + DatasetService._apply_new_embedding_settings(dataset, data, filtered_data) + return "update" + except LLMBadRequestError: + raise ValueError( + "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(ex.description) + return None + + @staticmethod + def _apply_new_embedding_settings(dataset, data, filtered_data): + """ + Apply new embedding model settings to the dataset. + + Args: + dataset: Current dataset object + data: Update data dictionary + filtered_data: Filtered update data to modify + """ + model_manager = ModelManager() + try: + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=data["embedding_model_provider"], + model_type=ModelType.TEXT_EMBEDDING, + model=data["embedding_model"], + ) + except ProviderTokenNotInitError: + # If we can't get the embedding model, preserve existing settings + if dataset.embedding_model_provider and dataset.embedding_model: + filtered_data["embedding_model_provider"] = dataset.embedding_model_provider + filtered_data["embedding_model"] = dataset.embedding_model + if dataset.collection_binding_id: + filtered_data["collection_binding_id"] = dataset.collection_binding_id + # Skip the rest of the embedding model update + return + + # Apply new embedding model settings + filtered_data["embedding_model"] = embedding_model.model + filtered_data["embedding_model_provider"] = embedding_model.provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.provider, embedding_model.model + ) + filtered_data["collection_binding_id"] = dataset_collection_binding.id + @staticmethod def update_rag_pipeline_dataset_settings( session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False @@ -1157,12 +1314,17 @@ class DocumentService: process_rule = knowledge_config.process_rule if process_rule: if process_rule.mode in ("custom", "hierarchical"): - dataset_process_rule = DatasetProcessRule( - dataset_id=dataset.id, - mode=process_rule.mode, - rules=process_rule.rules.model_dump_json() if process_rule.rules else None, - created_by=account.id, - ) + if process_rule.rules: + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=process_rule.mode, + rules=process_rule.rules.model_dump_json() if process_rule.rules else None, + created_by=account.id, + ) + else: + dataset_process_rule = dataset.latest_process_rule + if not dataset_process_rule: + raise ValueError("No process rule found.") elif process_rule.mode == "automatic": dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, @@ -2061,6 +2223,191 @@ class DocumentService: if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int): raise ValueError("Process rule segmentation max_tokens is invalid") + @staticmethod + def batch_update_document_status(dataset: Dataset, document_ids: list[str], action: str, user): + """ + Batch update document status. + + Args: + dataset (Dataset): The dataset object + document_ids (list[str]): List of document IDs to update + action (str): Action to perform (enable, disable, archive, un_archive) + user: Current user performing the action + + Raises: + DocumentIndexingError: If document is being indexed or not in correct state + ValueError: If action is invalid + """ + if not document_ids: + return + + # Early validation of action parameter + valid_actions = ["enable", "disable", "archive", "un_archive"] + if action not in valid_actions: + raise ValueError(f"Invalid action: {action}. Must be one of {valid_actions}") + + documents_to_update = [] + + # First pass: validate all documents and prepare updates + for document_id in document_ids: + document = DocumentService.get_document(dataset.id, document_id) + if not document: + continue + + # Check if document is being indexed + indexing_cache_key = f"document_{document.id}_indexing" + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + raise DocumentIndexingError(f"Document:{document.name} is being indexed, please try again later") + + # Prepare update based on action + update_info = DocumentService._prepare_document_status_update(document, action, user) + if update_info: + documents_to_update.append(update_info) + + # Second pass: apply all updates in a single transaction + if documents_to_update: + try: + for update_info in documents_to_update: + document = update_info["document"] + updates = update_info["updates"] + + # Apply updates to the document + for field, value in updates.items(): + setattr(document, field, value) + + db.session.add(document) + + # Batch commit all changes + db.session.commit() + except Exception as e: + # Rollback on any error + db.session.rollback() + raise e + # Execute async tasks and set Redis cache after successful commit + # propagation_error is used to capture any errors for submitting async task execution + propagation_error = None + for update_info in documents_to_update: + try: + # Execute async tasks after successful commit + if update_info["async_task"]: + task_info = update_info["async_task"] + task_func = task_info["function"] + task_args = task_info["args"] + task_func.delay(*task_args) + except Exception as e: + # Log the error but do not rollback the transaction + logging.exception(f"Error executing async task for document {update_info['document'].id}") + # don't raise the error immediately, but capture it for later + propagation_error = e + try: + # Set Redis cache if needed after successful commit + if update_info["set_cache"]: + document = update_info["document"] + indexing_cache_key = f"document_{document.id}_indexing" + redis_client.setex(indexing_cache_key, 600, 1) + except Exception as e: + # Log the error but do not rollback the transaction + logging.exception(f"Error setting cache for document {update_info['document'].id}") + # Raise any propagation error after all updates + if propagation_error: + raise propagation_error + + @staticmethod + def _prepare_document_status_update(document, action: str, user): + """ + Prepare document status update information. + + Args: + document: Document object to update + action: Action to perform + user: Current user + + Returns: + dict: Update information or None if no update needed + """ + now = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + + if action == "enable": + return DocumentService._prepare_enable_update(document, now) + elif action == "disable": + return DocumentService._prepare_disable_update(document, user, now) + elif action == "archive": + return DocumentService._prepare_archive_update(document, user, now) + elif action == "un_archive": + return DocumentService._prepare_unarchive_update(document, now) + + return None + + @staticmethod + def _prepare_enable_update(document, now): + """Prepare updates for enabling a document.""" + if document.enabled: + return None + + return { + "document": document, + "updates": {"enabled": True, "disabled_at": None, "disabled_by": None, "updated_at": now}, + "async_task": {"function": add_document_to_index_task, "args": [document.id]}, + "set_cache": True, + } + + @staticmethod + def _prepare_disable_update(document, user, now): + """Prepare updates for disabling a document.""" + if not document.completed_at or document.indexing_status != "completed": + raise DocumentIndexingError(f"Document: {document.name} is not completed.") + + if not document.enabled: + return None + + return { + "document": document, + "updates": {"enabled": False, "disabled_at": now, "disabled_by": user.id, "updated_at": now}, + "async_task": {"function": remove_document_from_index_task, "args": [document.id]}, + "set_cache": True, + } + + @staticmethod + def _prepare_archive_update(document, user, now): + """Prepare updates for archiving a document.""" + if document.archived: + return None + + update_info = { + "document": document, + "updates": {"archived": True, "archived_at": now, "archived_by": user.id, "updated_at": now}, + "async_task": None, + "set_cache": False, + } + + # Only set async task and cache if document is currently enabled + if document.enabled: + update_info["async_task"] = {"function": remove_document_from_index_task, "args": [document.id]} + update_info["set_cache"] = True + + return update_info + + @staticmethod + def _prepare_unarchive_update(document, now): + """Prepare updates for unarchiving a document.""" + if not document.archived: + return None + + update_info = { + "document": document, + "updates": {"archived": False, "archived_at": None, "archived_by": None, "updated_at": now}, + "async_task": None, + "set_cache": False, + } + + # Only re-index if the document is currently enabled + if document.enabled: + update_info["async_task"] = {"function": add_document_to_index_task, "args": [document.id]} + update_info["set_cache"] = True + + return update_info + class SegmentService: @classmethod diff --git a/api/services/errors/app.py b/api/services/errors/app.py index 87e9e9247d..5d348c61be 100644 --- a/api/services/errors/app.py +++ b/api/services/errors/app.py @@ -4,3 +4,7 @@ class MoreLikeThisDisabledError(Exception): class WorkflowHashNotEqualError(Exception): pass + + +class IsDraftWorkflowError(Exception): + pass diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index 26d6d4ce18..cfcb121153 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -19,6 +19,10 @@ from services.entities.knowledge_entities.knowledge_entities import ( class MetadataService: @staticmethod def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata: + # check if metadata name is too long + if len(metadata_args.name) > 255: + raise ValueError("Metadata name cannot exceed 255 characters.") + # check if metadata name already exists if ( db.session.query(DatasetMetadata) @@ -42,6 +46,10 @@ class MetadataService: @staticmethod def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore + # check if metadata name is too long + if len(name) > 255: + raise ValueError("Metadata name cannot exceed 255 characters.") + lock_key = f"dataset_metadata_lock_{dataset_id}" # check if metadata name already exists if ( diff --git a/api/services/plugin/data_migration.py b/api/services/plugin/data_migration.py index 02de5a79d7..5324036414 100644 --- a/api/services/plugin/data_migration.py +++ b/api/services/plugin/data_migration.py @@ -22,7 +22,7 @@ class PluginDataMigration: cls.migrate_datasets() cls.migrate_db_records("embeddings", "provider_name", ModelProviderID) # large table cls.migrate_db_records("dataset_collection_bindings", "provider_name", ModelProviderID) - cls.migrate_db_records("tool_builtin_providers", "provider_name", ToolProviderID) + cls.migrate_db_records("tool_builtin_providers", "provider", ToolProviderID) @classmethod def migrate_datasets(cls) -> None: diff --git a/api/services/plugin/oauth_service.py b/api/services/plugin/oauth_service.py index 461247419b..4077ec38d3 100644 --- a/api/services/plugin/oauth_service.py +++ b/api/services/plugin/oauth_service.py @@ -1,7 +1,61 @@ +import json +import uuid + from core.plugin.impl.base import BasePluginClient +from extensions.ext_redis import redis_client -class OAuthService(BasePluginClient): - @classmethod - def get_authorization_url(cls, tenant_id: str, user_id: str, provider_name: str) -> str: - return "1234567890" +class OAuthProxyService(BasePluginClient): + # Default max age for proxy context parameter in seconds + __MAX_AGE__ = 5 * 60 # 5 minutes + + @staticmethod + def create_proxy_context(user_id, tenant_id, plugin_id, provider): + """ + Create a proxy context for an OAuth 2.0 authorization request. + + This parameter is a crucial security measure to prevent Cross-Site Request + Forgery (CSRF) attacks. It works by generating a unique nonce and storing it + in a distributed cache (Redis) along with the user's session context. + + The returned nonce should be included as the 'proxy_context' parameter in the + authorization URL. Upon callback, the `use_proxy_context` method + is used to verify the state, ensuring the request's integrity and authenticity, + and mitigating replay attacks. + """ + seconds, _ = redis_client.time() + context_id = str(uuid.uuid4()) + data = { + "user_id": user_id, + "plugin_id": plugin_id, + "tenant_id": tenant_id, + "provider": provider, + # encode redis time to avoid distribution time skew + "timestamp": seconds, + } + # ignore nonce collision + redis_client.setex( + f"oauth_proxy_context:{context_id}", + OAuthProxyService.__MAX_AGE__, + json.dumps(data), + ) + return context_id + + @staticmethod + def use_proxy_context(context_id, max_age=__MAX_AGE__): + """ + Validate the proxy context parameter. + This checks if the context_id is valid and not expired. + """ + if not context_id: + raise ValueError("context_id is required") + # get data from redis + data = redis_client.getdel(f"oauth_proxy_context:{context_id}") + if not data: + raise ValueError("context_id is invalid") + # check if data is expired + seconds, _ = redis_client.time() + state = json.loads(data) + if state.get("timestamp") < seconds - max_age: + raise ValueError("context_id is expired") + return state diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py new file mode 100644 index 0000000000..164693c2e1 --- /dev/null +++ b/api/services/workflow_draft_variable_service.py @@ -0,0 +1,722 @@ +import dataclasses +import datetime +import logging +from collections.abc import Mapping, Sequence +from enum import StrEnum +from typing import Any, ClassVar + +from sqlalchemy import Engine, orm, select +from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.orm import Session +from sqlalchemy.sql.expression import and_, or_ + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.file.models import File +from core.variables import Segment, StringSegment, Variable +from core.variables.consts import MIN_SELECTORS_LENGTH +from core.variables.segments import ArrayFileSegment, FileSegment +from core.variables.types import SegmentType +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.workflow.enums import SystemVariableKey +from core.workflow.nodes import NodeType +from core.workflow.nodes.variable_assigner.common.helpers import get_updated_variables +from core.workflow.variable_loader import VariableLoader +from factories.file_factory import StorageKeyLoader +from factories.variable_factory import build_segment, segment_to_variable +from models import App, Conversation +from models.enums import DraftVariableType +from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable + +_logger = logging.getLogger(__name__) + + +@dataclasses.dataclass(frozen=True) +class WorkflowDraftVariableList: + variables: list[WorkflowDraftVariable] + total: int | None = None + + +class WorkflowDraftVariableError(Exception): + pass + + +class VariableResetError(WorkflowDraftVariableError): + pass + + +class UpdateNotSupportedError(WorkflowDraftVariableError): + pass + + +class DraftVarLoader(VariableLoader): + # This implements the VariableLoader interface for loading draft variables. + # + # ref: core.workflow.variable_loader.VariableLoader + + # Database engine used for loading variables. + _engine: Engine + # Application ID for which variables are being loaded. + _app_id: str + _tenant_id: str + _fallback_variables: Sequence[Variable] + + def __init__( + self, + engine: Engine, + app_id: str, + tenant_id: str, + fallback_variables: Sequence[Variable] | None = None, + ) -> None: + self._engine = engine + self._app_id = app_id + self._tenant_id = tenant_id + self._fallback_variables = fallback_variables or [] + + def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]: + return (selector[0], selector[1]) + + def load_variables(self, selectors: list[list[str]]) -> list[Variable]: + if not selectors: + return [] + + # Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding Variable instance. + variable_by_selector: dict[tuple[str, str], Variable] = {} + + with Session(bind=self._engine, expire_on_commit=False) as session: + srv = WorkflowDraftVariableService(session) + draft_vars = srv.get_draft_variables_by_selectors(self._app_id, selectors) + + for draft_var in draft_vars: + segment = draft_var.get_value() + variable = segment_to_variable( + segment=segment, + selector=draft_var.get_selector(), + id=draft_var.id, + name=draft_var.name, + description=draft_var.description, + ) + selector_tuple = self._selector_to_tuple(variable.selector) + variable_by_selector[selector_tuple] = variable + + # Important: + files: list[File] = [] + for draft_var in draft_vars: + value = draft_var.get_value() + if isinstance(value, FileSegment): + files.append(value.value) + elif isinstance(value, ArrayFileSegment): + files.extend(value.value) + with Session(bind=self._engine) as session: + storage_key_loader = StorageKeyLoader(session, tenant_id=self._tenant_id) + storage_key_loader.load_storage_keys(files) + + return list(variable_by_selector.values()) + + +class WorkflowDraftVariableService: + _session: Session + + def __init__(self, session: Session) -> None: + self._session = session + + def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None: + return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first() + + def get_draft_variables_by_selectors( + self, + app_id: str, + selectors: Sequence[list[str]], + ) -> list[WorkflowDraftVariable]: + ors = [] + for selector in selectors: + assert len(selector) >= MIN_SELECTORS_LENGTH, f"Invalid selector to get: {selector}" + node_id, name = selector[:2] + ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name)) + + # NOTE(QuantumGhost): Although the number of `or` expressions may be large, as long as + # each expression includes conditions on both `node_id` and `name` (which are covered by the unique index), + # PostgreSQL can efficiently retrieve the results using a bitmap index scan. + # + # Alternatively, a `SELECT` statement could be constructed for each selector and + # combined using `UNION` to fetch all rows. + # Benchmarking indicates that both approaches yield comparable performance. + variables = ( + self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app_id, or_(*ors)).all() + ) + return variables + + def list_variables_without_values(self, app_id: str, page: int, limit: int) -> WorkflowDraftVariableList: + criteria = WorkflowDraftVariable.app_id == app_id + total = None + query = self._session.query(WorkflowDraftVariable).filter(criteria) + if page == 1: + total = query.count() + variables = ( + # Do not load the `value` field. + query.options(orm.defer(WorkflowDraftVariable.value)) + .order_by(WorkflowDraftVariable.id.desc()) + .limit(limit) + .offset((page - 1) * limit) + .all() + ) + + return WorkflowDraftVariableList(variables=variables, total=total) + + def _list_node_variables(self, app_id: str, node_id: str) -> WorkflowDraftVariableList: + criteria = ( + WorkflowDraftVariable.app_id == app_id, + WorkflowDraftVariable.node_id == node_id, + ) + query = self._session.query(WorkflowDraftVariable).filter(*criteria) + variables = query.order_by(WorkflowDraftVariable.id.desc()).all() + return WorkflowDraftVariableList(variables=variables) + + def list_node_variables(self, app_id: str, node_id: str) -> WorkflowDraftVariableList: + return self._list_node_variables(app_id, node_id) + + def list_conversation_variables(self, app_id: str) -> WorkflowDraftVariableList: + return self._list_node_variables(app_id, CONVERSATION_VARIABLE_NODE_ID) + + def list_system_variables(self, app_id: str) -> WorkflowDraftVariableList: + return self._list_node_variables(app_id, SYSTEM_VARIABLE_NODE_ID) + + def get_conversation_variable(self, app_id: str, name: str) -> WorkflowDraftVariable | None: + return self._get_variable(app_id=app_id, node_id=CONVERSATION_VARIABLE_NODE_ID, name=name) + + def get_system_variable(self, app_id: str, name: str) -> WorkflowDraftVariable | None: + return self._get_variable(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name) + + def get_node_variable(self, app_id: str, node_id: str, name: str) -> WorkflowDraftVariable | None: + return self._get_variable(app_id, node_id, name) + + def _get_variable(self, app_id: str, node_id: str, name: str) -> WorkflowDraftVariable | None: + variable = ( + self._session.query(WorkflowDraftVariable) + .where( + WorkflowDraftVariable.app_id == app_id, + WorkflowDraftVariable.node_id == node_id, + WorkflowDraftVariable.name == name, + ) + .first() + ) + return variable + + def update_variable( + self, + variable: WorkflowDraftVariable, + name: str | None = None, + value: Segment | None = None, + ) -> WorkflowDraftVariable: + if not variable.editable: + raise UpdateNotSupportedError(f"variable not support updating, id={variable.id}") + if name is not None: + variable.set_name(name) + if value is not None: + variable.set_value(value) + variable.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + self._session.flush() + return variable + + def _reset_conv_var(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None: + conv_var_by_name = {i.name: i for i in workflow.conversation_variables} + conv_var = conv_var_by_name.get(variable.name) + + if conv_var is None: + self._session.delete(instance=variable) + self._session.flush() + _logger.warning( + "Conversation variable not found for draft variable, id=%s, name=%s", variable.id, variable.name + ) + return None + + variable.set_value(conv_var) + variable.last_edited_at = None + self._session.add(variable) + self._session.flush() + return variable + + def _reset_node_var(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None: + # If a variable does not allow updating, it makes no sence to resetting it. + if not variable.editable: + return variable + # No execution record for this variable, delete the variable instead. + if variable.node_execution_id is None: + self._session.delete(instance=variable) + self._session.flush() + _logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name) + return None + + query = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == variable.node_execution_id) + node_exec = self._session.scalars(query).first() + if node_exec is None: + _logger.warning( + "Node exectution not found for draft variable, id=%s, name=%s, node_execution_id=%s", + variable.id, + variable.name, + variable.node_execution_id, + ) + self._session.delete(instance=variable) + self._session.flush() + return None + + # Get node type for proper value extraction + node_config = workflow.get_node_config_by_id(variable.node_id) + node_type = workflow.get_node_type_from_node_config(node_config) + + outputs_dict = node_exec.outputs_dict or {} + + # Note: Based on the implementation in `_build_from_variable_assigner_mapping`, + # VariableAssignerNode (both v1 and v2) can only create conversation draft variables. + # For consistency, we should simply return when processing VARIABLE_ASSIGNER nodes. + # + # This implementation must remain synchronized with the `_build_from_variable_assigner_mapping` + # and `save` methods. + if node_type == NodeType.VARIABLE_ASSIGNER: + return variable + + if variable.name not in outputs_dict: + # If variable not found in execution data, delete the variable + self._session.delete(instance=variable) + self._session.flush() + return None + value = outputs_dict[variable.name] + value_seg = WorkflowDraftVariable.build_segment_with_type(variable.value_type, value) + # Extract variable value using unified logic + variable.set_value(value_seg) + variable.last_edited_at = None # Reset to indicate this is a reset operation + self._session.flush() + return variable + + def reset_variable(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None: + variable_type = variable.get_variable_type() + if variable_type == DraftVariableType.CONVERSATION: + return self._reset_conv_var(workflow, variable) + elif variable_type == DraftVariableType.NODE: + return self._reset_node_var(workflow, variable) + else: + raise VariableResetError(f"cannot reset system variable, variable_id={variable.id}") + + def delete_variable(self, variable: WorkflowDraftVariable): + self._session.delete(variable) + + def delete_workflow_variables(self, app_id: str): + ( + self._session.query(WorkflowDraftVariable) + .filter(WorkflowDraftVariable.app_id == app_id) + .delete(synchronize_session=False) + ) + + def delete_node_variables(self, app_id: str, node_id: str): + return self._delete_node_variables(app_id, node_id) + + def _delete_node_variables(self, app_id: str, node_id: str): + self._session.query(WorkflowDraftVariable).where( + WorkflowDraftVariable.app_id == app_id, + WorkflowDraftVariable.node_id == node_id, + ).delete() + + def _get_conversation_id_from_draft_variable(self, app_id: str) -> str | None: + draft_var = self._get_variable( + app_id=app_id, + node_id=SYSTEM_VARIABLE_NODE_ID, + name=str(SystemVariableKey.CONVERSATION_ID), + ) + if draft_var is None: + return None + segment = draft_var.get_value() + if not isinstance(segment, StringSegment): + _logger.warning( + "sys.conversation_id variable is not a string: app_id=%s, id=%s", + app_id, + draft_var.id, + ) + return None + return segment.value + + def get_or_create_conversation( + self, + account_id: str, + app: App, + workflow: Workflow, + ) -> str: + """ + get_or_create_conversation creates and returns the ID of a conversation for debugging. + + If a conversation already exists, as determined by the following criteria, its ID is returned: + - The system variable `sys.conversation_id` exists in the draft variable table, and + - A corresponding conversation record is found in the database. + + If no such conversation exists, a new conversation is created and its ID is returned. + """ + conv_id = self._get_conversation_id_from_draft_variable(workflow.app_id) + + if conv_id is not None: + conversation = ( + self._session.query(Conversation) + .filter( + Conversation.id == conv_id, + Conversation.app_id == workflow.app_id, + ) + .first() + ) + # Only return the conversation ID if it exists and is valid (has a correspond conversation record in DB). + if conversation is not None: + return conv_id + conversation = Conversation( + app_id=workflow.app_id, + app_model_config_id=app.app_model_config_id, + model_provider=None, + model_id="", + override_model_configs=None, + mode=app.mode, + name="Draft Debugging Conversation", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + invoke_from=InvokeFrom.DEBUGGER.value, + from_source="console", + from_end_user_id=None, + from_account_id=account_id, + ) + + self._session.add(conversation) + self._session.flush() + return conversation.id + + def prefill_conversation_variable_default_values(self, workflow: Workflow): + """""" + draft_conv_vars: list[WorkflowDraftVariable] = [] + for conv_var in workflow.conversation_variables: + draft_var = WorkflowDraftVariable.new_conversation_variable( + app_id=workflow.app_id, + name=conv_var.name, + value=conv_var, + description=conv_var.description, + ) + draft_conv_vars.append(draft_var) + _batch_upsert_draft_varaible( + self._session, + draft_conv_vars, + policy=_UpsertPolicy.IGNORE, + ) + + +class _UpsertPolicy(StrEnum): + IGNORE = "ignore" + OVERWRITE = "overwrite" + + +def _batch_upsert_draft_varaible( + session: Session, + draft_vars: Sequence[WorkflowDraftVariable], + policy: _UpsertPolicy = _UpsertPolicy.OVERWRITE, +) -> None: + if not draft_vars: + return None + # Although we could use SQLAlchemy ORM operations here, we choose not to for several reasons: + # + # 1. The variable saving process involves writing multiple rows to the + # `workflow_draft_variables` table. Batch insertion significantly improves performance. + # 2. Using the ORM would require either: + # + # a. Checking for the existence of each variable before insertion, + # resulting in 2n SQL statements for n variables and potential concurrency issues. + # b. Attempting insertion first, then updating if a unique index violation occurs, + # which still results in n to 2n SQL statements. + # + # Both approaches are inefficient and suboptimal. + # 3. We do not need to retrieve the results of the SQL execution or populate ORM + # model instances with the returned values. + # 4. Batch insertion with `ON CONFLICT DO UPDATE` allows us to insert or update all + # variables in a single SQL statement, avoiding the issues above. + # + # For these reasons, we use the SQLAlchemy query builder and rely on dialect-specific + # insert operations instead of the ORM layer. + stmt = insert(WorkflowDraftVariable).values([_model_to_insertion_dict(v) for v in draft_vars]) + if policy == _UpsertPolicy.OVERWRITE: + stmt = stmt.on_conflict_do_update( + index_elements=WorkflowDraftVariable.unique_app_id_node_id_name(), + set_={ + "updated_at": stmt.excluded.updated_at, + "last_edited_at": stmt.excluded.last_edited_at, + "description": stmt.excluded.description, + "value_type": stmt.excluded.value_type, + "value": stmt.excluded.value, + "visible": stmt.excluded.visible, + "editable": stmt.excluded.editable, + "node_execution_id": stmt.excluded.node_execution_id, + }, + ) + elif _UpsertPolicy.IGNORE: + stmt = stmt.on_conflict_do_nothing(index_elements=WorkflowDraftVariable.unique_app_id_node_id_name()) + else: + raise Exception("Invalid value for update policy.") + session.execute(stmt) + + +def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]: + d: dict[str, Any] = { + "app_id": model.app_id, + "last_edited_at": None, + "node_id": model.node_id, + "name": model.name, + "selector": model.selector, + "value_type": model.value_type, + "value": model.value, + "node_execution_id": model.node_execution_id, + } + if model.visible is not None: + d["visible"] = model.visible + if model.editable is not None: + d["editable"] = model.editable + if model.created_at is not None: + d["created_at"] = model.created_at + if model.updated_at is not None: + d["updated_at"] = model.updated_at + if model.description is not None: + d["description"] = model.description + return d + + +def _build_segment_for_serialized_values(v: Any) -> Segment: + """ + Reconstructs Segment objects from serialized values, with special handling + for FileSegment and ArrayFileSegment types. + + This function should only be used when: + 1. No explicit type information is available + 2. The input value is in serialized form (dict or list) + + It detects potential file objects in the serialized data and properly rebuilds the + appropriate segment type. + """ + return build_segment(WorkflowDraftVariable.rebuild_file_types(v)) + + +class DraftVariableSaver: + # _DUMMY_OUTPUT_IDENTITY is a placeholder output for workflow nodes. + # Its sole possible value is `None`. + # + # This is used to signal the execution of a workflow node when it has no other outputs. + _DUMMY_OUTPUT_IDENTITY: ClassVar[str] = "__dummy__" + _DUMMY_OUTPUT_VALUE: ClassVar[None] = None + + # _EXCLUDE_VARIABLE_NAMES_MAPPING maps node types and versions to variable names that + # should be excluded when saving draft variables. This prevents certain internal or + # technical variables from being exposed in the draft environment, particularly those + # that aren't meant to be directly edited or viewed by users. + _EXCLUDE_VARIABLE_NAMES_MAPPING: dict[NodeType, frozenset[str]] = { + NodeType.LLM: frozenset(["finish_reason"]), + NodeType.LOOP: frozenset(["loop_round"]), + } + + # Database session used for persisting draft variables. + _session: Session + + # The application ID associated with the draft variables. + # This should match the `Workflow.app_id` of the workflow to which the current node belongs. + _app_id: str + + # The ID of the node for which DraftVariableSaver is saving output variables. + _node_id: str + + # The type of the current node (see NodeType). + _node_type: NodeType + + # Indicates how the workflow execution was triggered (see InvokeFrom). + _invoke_from: InvokeFrom + + # + _node_execution_id: str + + # _enclosing_node_id identifies the container node that the current node belongs to. + # For example, if the current node is an LLM node inside an Iteration node + # or Loop node, then `_enclosing_node_id` refers to the ID of + # the containing Iteration or Loop node. + # + # If the current node is not nested within another node, `_enclosing_node_id` is + # `None`. + _enclosing_node_id: str | None + + def __init__( + self, + session: Session, + app_id: str, + node_id: str, + node_type: NodeType, + invoke_from: InvokeFrom, + node_execution_id: str, + enclosing_node_id: str | None = None, + ): + self._session = session + self._app_id = app_id + self._node_id = node_id + self._node_type = node_type + self._invoke_from = invoke_from + self._node_execution_id = node_execution_id + self._enclosing_node_id = enclosing_node_id + + def _create_dummy_output_variable(self): + return WorkflowDraftVariable.new_node_variable( + app_id=self._app_id, + node_id=self._node_id, + name=self._DUMMY_OUTPUT_IDENTITY, + node_execution_id=self._node_execution_id, + value=build_segment(self._DUMMY_OUTPUT_VALUE), + visible=False, + editable=False, + ) + + def _should_save_output_variables_for_draft(self) -> bool: + # Only save output variables for debugging execution of workflow. + if self._invoke_from != InvokeFrom.DEBUGGER: + return False + if self._enclosing_node_id is not None and self._node_type != NodeType.VARIABLE_ASSIGNER: + # Currently we do not save output variables for nodes inside loop or iteration. + return False + return True + + def _build_from_variable_assigner_mapping(self, process_data: Mapping[str, Any]) -> list[WorkflowDraftVariable]: + draft_vars: list[WorkflowDraftVariable] = [] + updated_variables = get_updated_variables(process_data) or [] + + for item in updated_variables: + selector = item.selector + if len(selector) < MIN_SELECTORS_LENGTH: + raise Exception("selector too short") + # NOTE(QuantumGhost): only the following two kinds of variable could be updated by + # VariableAssigner: ConversationVariable and iteration variable. + # We only save conversation variable here. + if selector[0] != CONVERSATION_VARIABLE_NODE_ID: + continue + segment = WorkflowDraftVariable.build_segment_with_type(segment_type=item.value_type, value=item.new_value) + draft_vars.append( + WorkflowDraftVariable.new_conversation_variable( + app_id=self._app_id, + name=item.name, + value=segment, + ) + ) + # Add a dummy output variable to indicate that this node is executed. + draft_vars.append(self._create_dummy_output_variable()) + return draft_vars + + def _build_variables_from_start_mapping(self, output: Mapping[str, Any]) -> list[WorkflowDraftVariable]: + draft_vars = [] + has_non_sys_variables = False + for name, value in output.items(): + value_seg = _build_segment_for_serialized_values(value) + node_id, name = self._normalize_variable_for_start_node(name) + # If node_id is not `sys`, it means that the variable is a user-defined input field + # in `Start` node. + if node_id != SYSTEM_VARIABLE_NODE_ID: + draft_vars.append( + WorkflowDraftVariable.new_node_variable( + app_id=self._app_id, + node_id=self._node_id, + name=name, + node_execution_id=self._node_execution_id, + value=value_seg, + visible=True, + editable=True, + ) + ) + has_non_sys_variables = True + else: + if name == SystemVariableKey.FILES: + # Here we know the type of variable must be `array[file]`, we + # just build files from the value. + files = [File.model_validate(v) for v in value] + if files: + value_seg = WorkflowDraftVariable.build_segment_with_type(SegmentType.ARRAY_FILE, files) + else: + value_seg = ArrayFileSegment(value=[]) + + draft_vars.append( + WorkflowDraftVariable.new_sys_variable( + app_id=self._app_id, + name=name, + node_execution_id=self._node_execution_id, + value=value_seg, + editable=self._should_variable_be_editable(node_id, name), + ) + ) + if not has_non_sys_variables: + draft_vars.append(self._create_dummy_output_variable()) + return draft_vars + + def _normalize_variable_for_start_node(self, name: str) -> tuple[str, str]: + if not name.startswith(f"{SYSTEM_VARIABLE_NODE_ID}."): + return self._node_id, name + _, name_ = name.split(".", maxsplit=1) + return SYSTEM_VARIABLE_NODE_ID, name_ + + def _build_variables_from_mapping(self, output: Mapping[str, Any]) -> list[WorkflowDraftVariable]: + draft_vars = [] + for name, value in output.items(): + if not self._should_variable_be_saved(name): + _logger.debug( + "Skip saving variable as it has been excluded by its node_type, name=%s, node_type=%s", + name, + self._node_type, + ) + continue + if isinstance(value, Segment): + value_seg = value + else: + value_seg = _build_segment_for_serialized_values(value) + draft_vars.append( + WorkflowDraftVariable.new_node_variable( + app_id=self._app_id, + node_id=self._node_id, + name=name, + node_execution_id=self._node_execution_id, + value=value_seg, + visible=self._should_variable_be_visible(self._node_id, self._node_type, name), + ) + ) + return draft_vars + + def save( + self, + process_data: Mapping[str, Any] | None = None, + outputs: Mapping[str, Any] | None = None, + ): + draft_vars: list[WorkflowDraftVariable] = [] + if outputs is None: + outputs = {} + if process_data is None: + process_data = {} + if not self._should_save_output_variables_for_draft(): + return + if self._node_type == NodeType.VARIABLE_ASSIGNER: + draft_vars = self._build_from_variable_assigner_mapping(process_data=process_data) + elif self._node_type == NodeType.START: + draft_vars = self._build_variables_from_start_mapping(outputs) + else: + draft_vars = self._build_variables_from_mapping(outputs) + _batch_upsert_draft_varaible(self._session, draft_vars) + + @staticmethod + def _should_variable_be_editable(node_id: str, name: str) -> bool: + if node_id in (CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID): + return False + if node_id == SYSTEM_VARIABLE_NODE_ID and not is_system_variable_editable(name): + return False + return True + + @staticmethod + def _should_variable_be_visible(node_id: str, node_type: NodeType, name: str) -> bool: + if node_type in NodeType.IF_ELSE: + return False + if node_id == SYSTEM_VARIABLE_NODE_ID and not is_system_variable_editable(name): + return False + return True + + def _should_variable_be_saved(self, name: str) -> bool: + exclude_var_names = self._EXCLUDE_VARIABLE_NAMES_MAPPING.get(self._node_type) + if exclude_var_names is None: + return True + return name not in exclude_var_names diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index aac6cb27b1..2aee5d8cee 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1,6 +1,7 @@ import json import time -from collections.abc import Callable, Generator, Sequence +import uuid +from collections.abc import Callable, Generator, Mapping, Sequence from datetime import UTC, datetime from typing import Any, Optional from uuid import uuid4 @@ -8,12 +9,17 @@ from uuid import uuid4 from sqlalchemy import select from sqlalchemy.orm import Session +from core.app.app_config.entities import VariableEntityType from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.app.entities.app_invoke_entities import InvokeFrom +from core.file import File from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.variables import Variable from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus +from core.workflow.enums import SystemVariableKey from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes import NodeType @@ -22,9 +28,11 @@ from core.workflow.nodes.enums import ErrorStrategy from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.event.types import NodeEvent from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING +from core.workflow.nodes.start.entities import StartNodeData from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db +from factories.file_factory import build_from_mapping, build_from_mappings from models.account import Account from models.model import App, AppMode from models.tools import WorkflowToolProvider @@ -34,10 +42,15 @@ from models.workflow import ( WorkflowNodeExecutionTriggeredFrom, WorkflowType, ) -from services.errors.app import WorkflowHashNotEqualError +from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError from services.workflow.workflow_converter import WorkflowConverter from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError +from .workflow_draft_variable_service import ( + DraftVariableSaver, + DraftVarLoader, + WorkflowDraftVariableService, +) class WorkflowService: @@ -45,6 +58,33 @@ class WorkflowService: Workflow Service """ + def get_node_last_run(self, app_model: App, workflow: Workflow, node_id: str) -> WorkflowNodeExecutionModel | None: + # TODO(QuantumGhost): This query is not fully covered by index. + criteria = ( + WorkflowNodeExecutionModel.tenant_id == app_model.tenant_id, + WorkflowNodeExecutionModel.app_id == app_model.id, + WorkflowNodeExecutionModel.workflow_id == workflow.id, + WorkflowNodeExecutionModel.node_id == node_id, + ) + node_exec = ( + db.session.query(WorkflowNodeExecutionModel) + .filter(*criteria) + .order_by(WorkflowNodeExecutionModel.created_at.desc()) + .first() + ) + return node_exec + + def is_workflow_exist(self, app_model: App) -> bool: + return ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.version == Workflow.VERSION_DRAFT, + ) + .count() + ) > 0 + def get_draft_workflow(self, app_model: App) -> Optional[Workflow]: """ Get draft workflow @@ -61,6 +101,23 @@ class WorkflowService: # return draft workflow return workflow + def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Optional[Workflow]: + # fetch published workflow by workflow_id + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.id == workflow_id, + ) + .first() + ) + if not workflow: + return None + if workflow.version == Workflow.VERSION_DRAFT: + raise IsDraftWorkflowError(f"Workflow is draft version, id={workflow_id}") + return workflow + def get_published_workflow(self, app_model: App) -> Optional[Workflow]: """ Get published workflow @@ -199,7 +256,7 @@ class WorkflowService: tenant_id=app_model.tenant_id, app_id=app_model.id, type=draft_workflow.type, - version=str(datetime.now(UTC).replace(tzinfo=None)), + version=Workflow.version_from_datetime(datetime.now(UTC).replace(tzinfo=None)), graph=draft_workflow.graph, created_by=account.id, environment_variables=draft_workflow.environment_variables, @@ -252,26 +309,85 @@ class WorkflowService: return default_config def run_draft_workflow_node( - self, app_model: App, node_id: str, user_inputs: dict, account: Account + self, + app_model: App, + draft_workflow: Workflow, + node_id: str, + user_inputs: Mapping[str, Any], + account: Account, + query: str = "", + files: Sequence[File] | None = None, ) -> WorkflowNodeExecutionModel: """ Run draft workflow node """ - # fetch draft workflow by app_model - draft_workflow = self.get_draft_workflow(app_model=app_model) - if not draft_workflow: - raise ValueError("Workflow not initialized") + files = files or [] + + with Session(bind=db.engine, expire_on_commit=False) as session, session.begin(): + draft_var_srv = WorkflowDraftVariableService(session) + draft_var_srv.prefill_conversation_variable_default_values(draft_workflow) + + node_config = draft_workflow.get_node_config_by_id(node_id) + node_type = Workflow.get_node_type_from_node_config(node_config) + node_data = node_config.get("data", {}) + if node_type == NodeType.START: + with Session(bind=db.engine) as session, session.begin(): + draft_var_srv = WorkflowDraftVariableService(session) + conversation_id = draft_var_srv.get_or_create_conversation( + account_id=account.id, + app=app_model, + workflow=draft_workflow, + ) + start_data = StartNodeData.model_validate(node_data) + user_inputs = _rebuild_file_for_user_inputs_in_start_node( + tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs + ) + # init variable pool + variable_pool = _setup_variable_pool( + query=query, + files=files or [], + user_id=account.id, + user_inputs=user_inputs, + workflow=draft_workflow, + # NOTE(QuantumGhost): We rely on `DraftVarLoader` to load conversation variables. + conversation_variables=[], + node_type=node_type, + conversation_id=conversation_id, + ) + + else: + variable_pool = VariablePool( + system_variables={}, + user_inputs=user_inputs, + environment_variables=draft_workflow.environment_variables, + conversation_variables=[], + ) + + variable_loader = DraftVarLoader( + engine=db.engine, + app_id=app_model.id, + tenant_id=app_model.tenant_id, + ) + + eclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config) + if eclosing_node_type_and_id: + _, enclosing_node_id = eclosing_node_type_and_id + else: + enclosing_node_id = None + + run = WorkflowEntry.single_step_run( + workflow=draft_workflow, + node_id=node_id, + user_inputs=user_inputs, + user_id=account.id, + variable_pool=variable_pool, + variable_loader=variable_loader, + ) # run draft workflow node start_at = time.perf_counter() - node_execution = self._handle_node_run_result( - invoke_node_fn=lambda: WorkflowEntry.single_step_run( - workflow=draft_workflow, - node_id=node_id, - user_inputs=user_inputs, - user_id=account.id, - ), + invoke_node_fn=lambda: run, start_at=start_at, node_id=node_id, ) @@ -291,6 +407,18 @@ class WorkflowService: # Convert node_execution to WorkflowNodeExecution after save workflow_node_execution = repository.to_db_model(node_execution) + with Session(bind=db.engine) as session, session.begin(): + draft_var_saver = DraftVariableSaver( + session=session, + app_id=app_model.id, + node_id=workflow_node_execution.node_id, + node_type=NodeType(workflow_node_execution.node_type), + invoke_from=InvokeFrom.DEBUGGER, + enclosing_node_id=enclosing_node_id, + node_execution_id=node_execution.id, + ) + draft_var_saver.save(process_data=node_execution.process_data, outputs=node_execution.outputs) + session.commit() return workflow_node_execution def run_free_workflow_node( @@ -331,7 +459,7 @@ class WorkflowService: node_run_result = event.run_result # sign output files - node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) + # node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) break if not node_run_result: @@ -393,7 +521,7 @@ class WorkflowService: if node_run_result.process_data else None ) - outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None + outputs = node_run_result.outputs node_execution.inputs = inputs node_execution.process_data = process_data @@ -530,3 +658,83 @@ class WorkflowService: session.delete(workflow) return True + + +def _setup_variable_pool( + query: str, + files: Sequence[File], + user_id: str, + user_inputs: Mapping[str, Any], + workflow: Workflow, + node_type: NodeType, + conversation_id: str, + conversation_variables: list[Variable], +): + # Only inject system variables for START node type. + if node_type == NodeType.START: + # Create a variable pool. + system_inputs: dict[SystemVariableKey, Any] = { + # From inputs: + SystemVariableKey.FILES: files, + SystemVariableKey.USER_ID: user_id, + # From workflow model + SystemVariableKey.APP_ID: workflow.app_id, + SystemVariableKey.WORKFLOW_ID: workflow.id, + # Randomly generated. + SystemVariableKey.WORKFLOW_EXECUTION_ID: str(uuid.uuid4()), + } + + # Only add chatflow-specific variables for non-workflow types + if workflow.type != WorkflowType.WORKFLOW.value: + system_inputs.update( + { + SystemVariableKey.QUERY: query, + SystemVariableKey.CONVERSATION_ID: conversation_id, + SystemVariableKey.DIALOGUE_COUNT: 0, + } + ) + else: + system_inputs = {} + + # init variable pool + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=user_inputs, + environment_variables=workflow.environment_variables, + conversation_variables=conversation_variables, + ) + + return variable_pool + + +def _rebuild_file_for_user_inputs_in_start_node( + tenant_id: str, start_node_data: StartNodeData, user_inputs: Mapping[str, Any] +) -> Mapping[str, Any]: + inputs_copy = dict(user_inputs) + + for variable in start_node_data.variables: + if variable.type not in (VariableEntityType.FILE, VariableEntityType.FILE_LIST): + continue + if variable.variable not in user_inputs: + continue + value = user_inputs[variable.variable] + file = _rebuild_single_file(tenant_id=tenant_id, value=value, variable_entity_type=variable.type) + inputs_copy[variable.variable] = file + return inputs_copy + + +def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: VariableEntityType) -> File | Sequence[File]: + if variable_entity_type == VariableEntityType.FILE: + if not isinstance(value, dict): + raise ValueError(f"expected dict for file object, got {type(value)}") + return build_from_mapping(mapping=value, tenant_id=tenant_id) + elif variable_entity_type == VariableEntityType.FILE_LIST: + if not isinstance(value, list): + raise ValueError(f"expected list for file list object, got {type(value)}") + if len(value) == 0: + return [] + if not isinstance(value[0], dict): + raise ValueError(f"expected dict for first element in the file list, got {type(value)}") + return build_from_mappings(mappings=value, tenant_id=tenant_id) + else: + raise Exception("unreachable") diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index 9e40a8494d..4046096c27 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -1,107 +1,217 @@ -# OpenAI API Key -OPENAI_API_KEY= +FLASK_APP=app.py +FLASK_DEBUG=0 +SECRET_KEY='uhySf6a3aZuvRNfAlcr47paOw9TRYBY6j8ZHXpVw1yx5RP27Yj3w2uvI' -# Azure OpenAI API Base Endpoint & API Key -AZURE_OPENAI_API_BASE= -AZURE_OPENAI_API_KEY= +CONSOLE_API_URL=http://127.0.0.1:5001 +CONSOLE_WEB_URL=http://127.0.0.1:3000 -# Anthropic API Key -ANTHROPIC_API_KEY= +# Service API base URL +SERVICE_API_URL=http://127.0.0.1:5001 -# Replicate API Key -REPLICATE_API_KEY= +# Web APP base URL +APP_WEB_URL=http://127.0.0.1:3000 -# Hugging Face API Key -HUGGINGFACE_API_KEY= -HUGGINGFACE_TEXT_GEN_ENDPOINT_URL= -HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL= -HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL= +# Files URL +FILES_URL=http://127.0.0.1:5001 -# Minimax Credentials -MINIMAX_API_KEY= -MINIMAX_GROUP_ID= +# The time in seconds after the signature is rejected +FILES_ACCESS_TIMEOUT=300 -# Spark Credentials -SPARK_APP_ID= -SPARK_API_KEY= -SPARK_API_SECRET= +# Access token expiration time in minutes +ACCESS_TOKEN_EXPIRE_MINUTES=60 -# Tongyi Credentials -TONGYI_DASHSCOPE_API_KEY= +# Refresh token expiration time in days +REFRESH_TOKEN_EXPIRE_DAYS=30 -# Wenxin Credentials -WENXIN_API_KEY= -WENXIN_SECRET_KEY= +# celery configuration +CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1 -# ZhipuAI Credentials -ZHIPUAI_API_KEY= +# redis configuration +REDIS_HOST=localhost +REDIS_PORT=6379 +REDIS_USERNAME= +REDIS_PASSWORD=difyai123456 +REDIS_USE_SSL=false +REDIS_DB=0 -# Baichuan Credentials -BAICHUAN_API_KEY= -BAICHUAN_SECRET_KEY= +# PostgreSQL database configuration +DB_USERNAME=postgres +DB_PASSWORD=difyai123456 +DB_HOST=localhost +DB_PORT=5432 +DB_DATABASE=dify -# ChatGLM Credentials -CHATGLM_API_BASE= +# Storage configuration +# use for store upload files, private keys... +# storage type: opendal, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase +STORAGE_TYPE=opendal -# Xinference Credentials -XINFERENCE_SERVER_URL= -XINFERENCE_GENERATION_MODEL_UID= -XINFERENCE_CHAT_MODEL_UID= -XINFERENCE_EMBEDDINGS_MODEL_UID= -XINFERENCE_RERANK_MODEL_UID= +# Apache OpenDAL storage configuration, refer to https://github.com/apache/opendal +OPENDAL_SCHEME=fs +OPENDAL_FS_ROOT=storage -# OpenLLM Credentials -OPENLLM_SERVER_URL= +# CORS configuration +WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* +CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* -# LocalAI Credentials -LOCALAI_SERVER_URL= +# Vector database configuration +# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase +VECTOR_STORE=weaviate +# Weaviate configuration +WEAVIATE_ENDPOINT=http://localhost:8080 +WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih +WEAVIATE_GRPC_ENABLED=false +WEAVIATE_BATCH_SIZE=100 -# Cohere Credentials -COHERE_API_KEY= -# Jina Credentials -JINA_API_KEY= +# Upload configuration +UPLOAD_FILE_SIZE_LIMIT=15 +UPLOAD_FILE_BATCH_LIMIT=5 +UPLOAD_IMAGE_FILE_SIZE_LIMIT=10 +UPLOAD_VIDEO_FILE_SIZE_LIMIT=100 +UPLOAD_AUDIO_FILE_SIZE_LIMIT=50 -# Ollama Credentials -OLLAMA_BASE_URL= +# Model configuration +MULTIMODAL_SEND_FORMAT=base64 +PROMPT_GENERATION_MAX_TOKENS=4096 +CODE_GENERATION_MAX_TOKENS=1024 -# Together API Key -TOGETHER_API_KEY= +# Mail configuration, support: resend, smtp +MAIL_TYPE= +MAIL_DEFAULT_SEND_FROM=no-reply +RESEND_API_KEY= +RESEND_API_URL=https://api.resend.com +# smtp configuration +SMTP_SERVER=smtp.example.com +SMTP_PORT=465 +SMTP_USERNAME=123 +SMTP_PASSWORD=abc +SMTP_USE_TLS=true +SMTP_OPPORTUNISTIC_TLS=false -# Mock Switch -MOCK_SWITCH=false +# Sentry configuration +SENTRY_DSN= + +# DEBUG +DEBUG=false +SQLALCHEMY_ECHO=false + +# Notion import configuration, support public and internal +NOTION_INTEGRATION_TYPE=public +NOTION_CLIENT_SECRET=you-client-secret +NOTION_CLIENT_ID=you-client-id +NOTION_INTERNAL_SECRET=you-internal-secret + +ETL_TYPE=dify +UNSTRUCTURED_API_URL= +UNSTRUCTURED_API_KEY= +SCARF_NO_ANALYTICS=false + +#ssrf +SSRF_PROXY_HTTP_URL= +SSRF_PROXY_HTTPS_URL= +SSRF_DEFAULT_MAX_RETRIES=3 +SSRF_DEFAULT_TIME_OUT=5 +SSRF_DEFAULT_CONNECT_TIME_OUT=5 +SSRF_DEFAULT_READ_TIME_OUT=5 +SSRF_DEFAULT_WRITE_TIME_OUT=5 + +BATCH_UPLOAD_LIMIT=10 +KEYWORD_DATA_SOURCE_TYPE=database + +# Workflow file upload limit +WORKFLOW_FILE_UPLOAD_LIMIT=10 # CODE EXECUTION CONFIGURATION -CODE_EXECUTION_ENDPOINT= -CODE_EXECUTION_API_KEY= +CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194 +CODE_EXECUTION_API_KEY=dify-sandbox +CODE_MAX_NUMBER=9223372036854775807 +CODE_MIN_NUMBER=-9223372036854775808 +CODE_MAX_STRING_LENGTH=80000 +TEMPLATE_TRANSFORM_MAX_LENGTH=80000 +CODE_MAX_STRING_ARRAY_LENGTH=30 +CODE_MAX_OBJECT_ARRAY_LENGTH=30 +CODE_MAX_NUMBER_ARRAY_LENGTH=1000 -# Volcengine MaaS Credentials -VOLC_API_KEY= -VOLC_SECRET_KEY= -VOLC_MODEL_ENDPOINT_ID= -VOLC_EMBEDDING_ENDPOINT_ID= +# API Tool configuration +API_TOOL_DEFAULT_CONNECT_TIMEOUT=10 +API_TOOL_DEFAULT_READ_TIMEOUT=60 -# 360 AI Credentials -ZHINAO_API_KEY= +# HTTP Node configuration +HTTP_REQUEST_MAX_CONNECT_TIMEOUT=300 +HTTP_REQUEST_MAX_READ_TIMEOUT=600 +HTTP_REQUEST_MAX_WRITE_TIMEOUT=600 +HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 +HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 + +# Respect X-* headers to redirect clients +RESPECT_XFORWARD_HEADERS_ENABLED=false + +# Log file path +LOG_FILE= +# Log file max size, the unit is MB +LOG_FILE_MAX_SIZE=20 +# Log file max backup count +LOG_FILE_BACKUP_COUNT=5 +# Log dateformat +LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S +# Log Timezone +LOG_TZ=UTC +# Log format +LOG_FORMAT=%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s + +# Indexing configuration +INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000 + +# Workflow runtime configuration +WORKFLOW_MAX_EXECUTION_STEPS=500 +WORKFLOW_MAX_EXECUTION_TIME=1200 +WORKFLOW_CALL_MAX_DEPTH=5 +WORKFLOW_PARALLEL_DEPTH_LIMIT=3 +MAX_VARIABLE_SIZE=204800 + +# App configuration +APP_MAX_EXECUTION_TIME=1200 +APP_MAX_ACTIVE_REQUESTS=0 + +# Celery beat configuration +CELERY_BEAT_SCHEDULER_TIME=1 + +# Position configuration +POSITION_TOOL_PINS= +POSITION_TOOL_INCLUDES= +POSITION_TOOL_EXCLUDES= + +POSITION_PROVIDER_PINS= +POSITION_PROVIDER_INCLUDES= +POSITION_PROVIDER_EXCLUDES= # Plugin configuration -PLUGIN_DAEMON_KEY= -PLUGIN_DAEMON_URL= +PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi +PLUGIN_DAEMON_URL=http://127.0.0.1:5002 +PLUGIN_REMOTE_INSTALL_PORT=5003 +PLUGIN_REMOTE_INSTALL_HOST=localhost +PLUGIN_MAX_PACKAGE_SIZE=15728640 +INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1 # Marketplace configuration -MARKETPLACE_API_URL= -# VESSL AI Credentials -VESSL_AI_MODEL_NAME= -VESSL_AI_API_KEY= -VESSL_AI_ENDPOINT_URL= +MARKETPLACE_ENABLED=true +MARKETPLACE_API_URL=https://marketplace.dify.ai -# GPUStack Credentials -GPUSTACK_SERVER_URL= -GPUSTACK_API_KEY= +# Endpoint configuration +ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id} -# Gitee AI Credentials -GITEE_AI_API_KEY= +# Reset password token expiry minutes +RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 -# xAI Credentials -XAI_API_KEY= -XAI_API_BASE= +CREATE_TIDB_SERVICE_JOB_ENABLED=false + +# Maximum number of submitted thread count in a ThreadPool for parallel node execution +MAX_SUBMIT_COUNT=100 +# Lockout duration in seconds +LOGIN_LOCKOUT_DURATION=86400 + +HTTP_PROXY='http://127.0.0.1:1092' +HTTPS_PROXY='http://127.0.0.1:1092' +NO_PROXY='localhost,127.0.0.1' +LOG_LEVEL=INFO diff --git a/api/tests/integration_tests/conftest.py b/api/tests/integration_tests/conftest.py index 6e3ab4b74b..d9f90f992e 100644 --- a/api/tests/integration_tests/conftest.py +++ b/api/tests/integration_tests/conftest.py @@ -1,19 +1,91 @@ -import os +import pathlib +import random +import secrets +from collections.abc import Generator -# Getting the absolute path of the current file's directory -ABS_PATH = os.path.dirname(os.path.abspath(__file__)) +import pytest +from flask import Flask +from flask.testing import FlaskClient +from sqlalchemy.orm import Session -# Getting the absolute path of the project's root directory -PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir)) +from app_factory import create_app +from models import Account, DifySetup, Tenant, TenantAccountJoin, db +from services.account_service import AccountService, RegisterService # Loading the .env file if it exists def _load_env() -> None: - dotenv_path = os.path.join(PROJECT_DIR, "tests", "integration_tests", ".env") - if os.path.exists(dotenv_path): + current_file_path = pathlib.Path(__file__).absolute() + # Items later in the list have higher precedence. + files_to_load = [".env", "vdb.env"] + + env_file_paths = [current_file_path.parent / i for i in files_to_load] + for path in env_file_paths: + if not path.exists(): + continue + from dotenv import load_dotenv - load_dotenv(dotenv_path) + # Set `override=True` to ensure values from `vdb.env` take priority over values from `.env` + load_dotenv(str(path), override=True) _load_env() + +_CACHED_APP = create_app() + + +@pytest.fixture +def flask_app() -> Flask: + return _CACHED_APP + + +@pytest.fixture(scope="session") +def setup_account(request) -> Generator[Account, None, None]: + """`dify_setup` completes the setup process for the Dify application. + + It creates `Account` and `Tenant`, and inserts a `DifySetup` record into the database. + + Most tests in the `controllers` package may require dify has been successfully setup. + """ + with _CACHED_APP.test_request_context(): + rand_suffix = random.randint(int(1e6), int(1e7)) # noqa + name = f"test-user-{rand_suffix}" + email = f"{name}@example.com" + RegisterService.setup( + email=email, + name=name, + password=secrets.token_hex(16), + ip_address="localhost", + ) + + with _CACHED_APP.test_request_context(): + with Session(bind=db.engine, expire_on_commit=False) as session: + account = session.query(Account).filter_by(email=email).one() + + yield account + + with _CACHED_APP.test_request_context(): + db.session.query(DifySetup).delete() + db.session.query(TenantAccountJoin).delete() + db.session.query(Account).delete() + db.session.query(Tenant).delete() + db.session.commit() + + +@pytest.fixture +def flask_req_ctx(): + with _CACHED_APP.test_request_context(): + yield + + +@pytest.fixture +def auth_header(setup_account) -> dict[str, str]: + token = AccountService.get_account_jwt_token(setup_account) + return {"Authorization": f"Bearer {token}"} + + +@pytest.fixture +def test_client() -> Generator[FlaskClient, None, None]: + with _CACHED_APP.test_client() as client: + yield client diff --git a/api/tests/integration_tests/controllers/app_fixture.py b/api/tests/integration_tests/controllers/app_fixture.py deleted file mode 100644 index 32e8c11d19..0000000000 --- a/api/tests/integration_tests/controllers/app_fixture.py +++ /dev/null @@ -1,25 +0,0 @@ -import pytest - -from app_factory import create_app -from configs import dify_config - -mock_user = type( - "MockUser", - (object,), - { - "is_authenticated": True, - "id": "123", - "is_editor": True, - "is_dataset_editor": True, - "status": "active", - "get_id": "123", - "current_tenant_id": "9d2074fc-6f86-45a9-b09d-6ecc63b9056b", - }, -) - - -@pytest.fixture -def app(): - app = create_app() - dify_config.LOGIN_DISABLED = True - return app diff --git a/api/tests/integration_tests/controllers/console/__init__.py b/api/tests/integration_tests/controllers/console/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/controllers/console/app/__init__.py b/api/tests/integration_tests/controllers/console/app/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/integration_tests/controllers/console/app/test_workflow_draft_variable.py new file mode 100644 index 0000000000..038f37af5f --- /dev/null +++ b/api/tests/integration_tests/controllers/console/app/test_workflow_draft_variable.py @@ -0,0 +1,47 @@ +import uuid +from unittest import mock + +from controllers.console.app import workflow_draft_variable as draft_variable_api +from controllers.console.app import wraps +from factories.variable_factory import build_segment +from models import App, AppMode +from models.workflow import WorkflowDraftVariable +from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService + + +def _get_mock_srv_class() -> type[WorkflowDraftVariableService]: + return mock.create_autospec(WorkflowDraftVariableService) + + +class TestWorkflowDraftNodeVariableListApi: + def test_get(self, test_client, auth_header, monkeypatch): + srv_class = _get_mock_srv_class() + mock_app_model: App = App() + mock_app_model.id = str(uuid.uuid4()) + test_node_id = "test_node_id" + mock_app_model.mode = AppMode.ADVANCED_CHAT + mock_load_app_model = mock.Mock(return_value=mock_app_model) + + monkeypatch.setattr(draft_variable_api, "WorkflowDraftVariableService", srv_class) + monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model) + + var1 = WorkflowDraftVariable.new_node_variable( + app_id="test_app_1", + node_id="test_node_1", + name="str_var", + value=build_segment("str_value"), + node_execution_id=str(uuid.uuid4()), + ) + srv_instance = mock.create_autospec(WorkflowDraftVariableService, instance=True) + srv_class.return_value = srv_instance + srv_instance.list_node_variables.return_value = WorkflowDraftVariableList(variables=[var1]) + + response = test_client.get( + f"/console/api/apps/{mock_app_model.id}/workflows/draft/nodes/{test_node_id}/variables", + headers=auth_header, + ) + assert response.status_code == 200 + response_dict = response.json + assert isinstance(response_dict, dict) + assert "items" in response_dict + assert len(response_dict["items"]) == 1 diff --git a/api/tests/integration_tests/controllers/test_controllers.py b/api/tests/integration_tests/controllers/test_controllers.py deleted file mode 100644 index 276ad3a7ed..0000000000 --- a/api/tests/integration_tests/controllers/test_controllers.py +++ /dev/null @@ -1,9 +0,0 @@ -from unittest.mock import patch - -from app_fixture import mock_user # type: ignore - - -def test_post_requires_login(app): - with app.test_client() as client, patch("flask_login.utils._get_user", mock_user): - response = client.get("/console/api/data-source/integrates") - assert response.status_code == 200 diff --git a/api/tests/integration_tests/factories/__init__.py b/api/tests/integration_tests/factories/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/factories/test_storage_key_loader.py b/api/tests/integration_tests/factories/test_storage_key_loader.py new file mode 100644 index 0000000000..fecb3f6d95 --- /dev/null +++ b/api/tests/integration_tests/factories/test_storage_key_loader.py @@ -0,0 +1,371 @@ +import unittest +from datetime import UTC, datetime +from typing import Optional +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session + +from core.file import File, FileTransferMethod, FileType +from extensions.ext_database import db +from factories.file_factory import StorageKeyLoader +from models import ToolFile, UploadFile +from models.enums import CreatorUserRole + + +@pytest.mark.usefixtures("flask_req_ctx") +class TestStorageKeyLoader(unittest.TestCase): + """ + Integration tests for StorageKeyLoader class. + + Tests the batched loading of storage keys from the database for files + with different transfer methods: LOCAL_FILE, REMOTE_URL, and TOOL_FILE. + """ + + def setUp(self): + """Set up test data before each test method.""" + self.session = db.session() + self.tenant_id = str(uuid4()) + self.user_id = str(uuid4()) + self.conversation_id = str(uuid4()) + + # Create test data that will be cleaned up after each test + self.test_upload_files = [] + self.test_tool_files = [] + + # Create StorageKeyLoader instance + self.loader = StorageKeyLoader(self.session, self.tenant_id) + + def tearDown(self): + """Clean up test data after each test method.""" + self.session.rollback() + + def _create_upload_file( + self, file_id: Optional[str] = None, storage_key: Optional[str] = None, tenant_id: Optional[str] = None + ) -> UploadFile: + """Helper method to create an UploadFile record for testing.""" + if file_id is None: + file_id = str(uuid4()) + if storage_key is None: + storage_key = f"test_storage_key_{uuid4()}" + if tenant_id is None: + tenant_id = self.tenant_id + + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type="local", + key=storage_key, + name="test_file.txt", + size=1024, + extension=".txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=self.user_id, + created_at=datetime.now(UTC), + used=False, + ) + upload_file.id = file_id + + self.session.add(upload_file) + self.session.flush() + self.test_upload_files.append(upload_file) + + return upload_file + + def _create_tool_file( + self, file_id: Optional[str] = None, file_key: Optional[str] = None, tenant_id: Optional[str] = None + ) -> ToolFile: + """Helper method to create a ToolFile record for testing.""" + if file_id is None: + file_id = str(uuid4()) + if file_key is None: + file_key = f"test_file_key_{uuid4()}" + if tenant_id is None: + tenant_id = self.tenant_id + + tool_file = ToolFile() + tool_file.id = file_id + tool_file.user_id = self.user_id + tool_file.tenant_id = tenant_id + tool_file.conversation_id = self.conversation_id + tool_file.file_key = file_key + tool_file.mimetype = "text/plain" + tool_file.original_url = "http://example.com/file.txt" + tool_file.name = "test_tool_file.txt" + tool_file.size = 2048 + + self.session.add(tool_file) + self.session.flush() + self.test_tool_files.append(tool_file) + + return tool_file + + def _create_file( + self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None + ) -> File: + """Helper method to create a File object for testing.""" + if tenant_id is None: + tenant_id = self.tenant_id + + # Set related_id for LOCAL_FILE and TOOL_FILE transfer methods + file_related_id = None + remote_url = None + + if transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE): + file_related_id = related_id + elif transfer_method == FileTransferMethod.REMOTE_URL: + remote_url = "https://example.com/test_file.txt" + file_related_id = related_id + + return File( + id=str(uuid4()), # Generate new UUID for File.id + tenant_id=tenant_id, + type=FileType.DOCUMENT, + transfer_method=transfer_method, + related_id=file_related_id, + remote_url=remote_url, + filename="test_file.txt", + extension=".txt", + mime_type="text/plain", + size=1024, + storage_key="initial_key", + ) + + def test_load_storage_keys_local_file(self): + """Test loading storage keys for LOCAL_FILE transfer method.""" + # Create test data + upload_file = self._create_upload_file() + file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + + # Load storage keys + self.loader.load_storage_keys([file]) + + # Verify storage key was loaded correctly + assert file._storage_key == upload_file.key + + def test_load_storage_keys_remote_url(self): + """Test loading storage keys for REMOTE_URL transfer method.""" + # Create test data + upload_file = self._create_upload_file() + file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL) + + # Load storage keys + self.loader.load_storage_keys([file]) + + # Verify storage key was loaded correctly + assert file._storage_key == upload_file.key + + def test_load_storage_keys_tool_file(self): + """Test loading storage keys for TOOL_FILE transfer method.""" + # Create test data + tool_file = self._create_tool_file() + file = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE) + + # Load storage keys + self.loader.load_storage_keys([file]) + + # Verify storage key was loaded correctly + assert file._storage_key == tool_file.file_key + + def test_load_storage_keys_mixed_methods(self): + """Test batch loading with mixed transfer methods.""" + # Create test data for different transfer methods + upload_file1 = self._create_upload_file() + upload_file2 = self._create_upload_file() + tool_file = self._create_tool_file() + + file1 = self._create_file(related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE) + file2 = self._create_file(related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL) + file3 = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE) + + files = [file1, file2, file3] + + # Load storage keys + self.loader.load_storage_keys(files) + + # Verify all storage keys were loaded correctly + assert file1._storage_key == upload_file1.key + assert file2._storage_key == upload_file2.key + assert file3._storage_key == tool_file.file_key + + def test_load_storage_keys_empty_list(self): + """Test with empty file list.""" + # Should not raise any exceptions + self.loader.load_storage_keys([]) + + def test_load_storage_keys_tenant_mismatch(self): + """Test tenant_id validation.""" + # Create file with different tenant_id + upload_file = self._create_upload_file() + file = self._create_file( + related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4()) + ) + + # Should raise ValueError for tenant mismatch + with pytest.raises(ValueError) as context: + self.loader.load_storage_keys([file]) + + assert "invalid file, expected tenant_id" in str(context.value) + + def test_load_storage_keys_missing_file_id(self): + """Test with None file.related_id.""" + # Create a file with valid parameters first, then manually set related_id to None + file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE) + file.related_id = None + + # Should raise ValueError for None file related_id + with pytest.raises(ValueError) as context: + self.loader.load_storage_keys([file]) + + assert str(context.value) == "file id should not be None." + + def test_load_storage_keys_nonexistent_upload_file_records(self): + """Test with missing UploadFile database records.""" + # Create file with non-existent upload file id + non_existent_id = str(uuid4()) + file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.LOCAL_FILE) + + # Should raise ValueError for missing record + with pytest.raises(ValueError): + self.loader.load_storage_keys([file]) + + def test_load_storage_keys_nonexistent_tool_file_records(self): + """Test with missing ToolFile database records.""" + # Create file with non-existent tool file id + non_existent_id = str(uuid4()) + file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.TOOL_FILE) + + # Should raise ValueError for missing record + with pytest.raises(ValueError): + self.loader.load_storage_keys([file]) + + def test_load_storage_keys_invalid_uuid(self): + """Test with invalid UUID format.""" + # Create a file with valid parameters first, then manually set invalid related_id + file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE) + file.related_id = "invalid-uuid-format" + + # Should raise ValueError for invalid UUID + with pytest.raises(ValueError): + self.loader.load_storage_keys([file]) + + def test_load_storage_keys_batch_efficiency(self): + """Test batched operations use efficient queries.""" + # Create multiple files of different types + upload_files = [self._create_upload_file() for _ in range(3)] + tool_files = [self._create_tool_file() for _ in range(2)] + + files = [] + files.extend( + [self._create_file(related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE) for uf in upload_files] + ) + files.extend( + [self._create_file(related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE) for tf in tool_files] + ) + + # Mock the session to count queries + with patch.object(self.session, "scalars", wraps=self.session.scalars) as mock_scalars: + self.loader.load_storage_keys(files) + + # Should make exactly 2 queries (one for upload_files, one for tool_files) + assert mock_scalars.call_count == 2 + + # Verify all storage keys were loaded correctly + for i, file in enumerate(files[:3]): + assert file._storage_key == upload_files[i].key + for i, file in enumerate(files[3:]): + assert file._storage_key == tool_files[i].file_key + + def test_load_storage_keys_tenant_isolation(self): + """Test that tenant isolation works correctly.""" + # Create files for different tenants + other_tenant_id = str(uuid4()) + + # Create upload file for current tenant + upload_file_current = self._create_upload_file() + file_current = self._create_file( + related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE + ) + + # Create upload file for other tenant (but don't add to cleanup list) + upload_file_other = UploadFile( + tenant_id=other_tenant_id, + storage_type="local", + key="other_tenant_key", + name="other_file.txt", + size=1024, + extension=".txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=self.user_id, + created_at=datetime.now(UTC), + used=False, + ) + upload_file_other.id = str(uuid4()) + self.session.add(upload_file_other) + self.session.flush() + + # Create file for other tenant but try to load with current tenant's loader + file_other = self._create_file( + related_id=upload_file_other.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id + ) + + # Should raise ValueError due to tenant mismatch + with pytest.raises(ValueError) as context: + self.loader.load_storage_keys([file_other]) + + assert "invalid file, expected tenant_id" in str(context.value) + + # Current tenant's file should still work + self.loader.load_storage_keys([file_current]) + assert file_current._storage_key == upload_file_current.key + + def test_load_storage_keys_mixed_tenant_batch(self): + """Test batch with mixed tenant files (should fail on first mismatch).""" + # Create files for current tenant + upload_file_current = self._create_upload_file() + file_current = self._create_file( + related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE + ) + + # Create file for different tenant + other_tenant_id = str(uuid4()) + file_other = self._create_file( + related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id + ) + + # Should raise ValueError on tenant mismatch + with pytest.raises(ValueError) as context: + self.loader.load_storage_keys([file_current, file_other]) + + assert "invalid file, expected tenant_id" in str(context.value) + + def test_load_storage_keys_duplicate_file_ids(self): + """Test handling of duplicate file IDs in the batch.""" + # Create upload file + upload_file = self._create_upload_file() + + # Create two File objects with same related_id + file1 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + file2 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + + # Should handle duplicates gracefully + self.loader.load_storage_keys([file1, file2]) + + # Both files should have the same storage key + assert file1._storage_key == upload_file.key + assert file2._storage_key == upload_file.key + + def test_load_storage_keys_session_isolation(self): + """Test that the loader uses the provided session correctly.""" + # Create test data + upload_file = self._create_upload_file() + file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + + # Create loader with different session (same underlying connection) + + with Session(bind=db.engine) as other_session: + other_loader = StorageKeyLoader(other_session, self.tenant_id) + with pytest.raises(ValueError): + other_loader.load_storage_keys([file]) diff --git a/api/tests/integration_tests/services/__init__.py b/api/tests/integration_tests/services/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py new file mode 100644 index 0000000000..30cd2e60cb --- /dev/null +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -0,0 +1,501 @@ +import json +import unittest +import uuid + +import pytest +from sqlalchemy.orm import Session + +from core.variables.variables import StringVariable +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.workflow.nodes import NodeType +from factories.variable_factory import build_segment +from libs import datetime_utils +from models import db +from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel +from services.workflow_draft_variable_service import DraftVarLoader, VariableResetError, WorkflowDraftVariableService + + +@pytest.mark.usefixtures("flask_req_ctx") +class TestWorkflowDraftVariableService(unittest.TestCase): + _test_app_id: str + _session: Session + _node1_id = "test_node_1" + _node2_id = "test_node_2" + _node_exec_id = str(uuid.uuid4()) + + def setUp(self): + self._test_app_id = str(uuid.uuid4()) + self._session: Session = db.session() + sys_var = WorkflowDraftVariable.new_sys_variable( + app_id=self._test_app_id, + name="sys_var", + value=build_segment("sys_value"), + node_execution_id=self._node_exec_id, + ) + conv_var = WorkflowDraftVariable.new_conversation_variable( + app_id=self._test_app_id, + name="conv_var", + value=build_segment("conv_value"), + ) + node2_vars = [ + WorkflowDraftVariable.new_node_variable( + app_id=self._test_app_id, + node_id=self._node2_id, + name="int_var", + value=build_segment(1), + visible=False, + node_execution_id=self._node_exec_id, + ), + WorkflowDraftVariable.new_node_variable( + app_id=self._test_app_id, + node_id=self._node2_id, + name="str_var", + value=build_segment("str_value"), + visible=True, + node_execution_id=self._node_exec_id, + ), + ] + node1_var = WorkflowDraftVariable.new_node_variable( + app_id=self._test_app_id, + node_id=self._node1_id, + name="str_var", + value=build_segment("str_value"), + visible=True, + node_execution_id=self._node_exec_id, + ) + _variables = list(node2_vars) + _variables.extend( + [ + node1_var, + sys_var, + conv_var, + ] + ) + + db.session.add_all(_variables) + db.session.flush() + self._variable_ids = [v.id for v in _variables] + self._node1_str_var_id = node1_var.id + self._sys_var_id = sys_var.id + self._conv_var_id = conv_var.id + self._node2_var_ids = [v.id for v in node2_vars] + + def _get_test_srv(self) -> WorkflowDraftVariableService: + return WorkflowDraftVariableService(session=self._session) + + def tearDown(self): + self._session.rollback() + + def test_list_variables(self): + srv = self._get_test_srv() + var_list = srv.list_variables_without_values(self._test_app_id, page=1, limit=2) + assert var_list.total == 5 + assert len(var_list.variables) == 2 + page1_var_ids = {v.id for v in var_list.variables} + assert page1_var_ids.issubset(self._variable_ids) + + var_list_2 = srv.list_variables_without_values(self._test_app_id, page=2, limit=2) + assert var_list_2.total is None + assert len(var_list_2.variables) == 2 + page2_var_ids = {v.id for v in var_list_2.variables} + assert page2_var_ids.isdisjoint(page1_var_ids) + assert page2_var_ids.issubset(self._variable_ids) + + def test_get_node_variable(self): + srv = self._get_test_srv() + node_var = srv.get_node_variable(self._test_app_id, self._node1_id, "str_var") + assert node_var is not None + assert node_var.id == self._node1_str_var_id + assert node_var.name == "str_var" + assert node_var.get_value() == build_segment("str_value") + + def test_get_system_variable(self): + srv = self._get_test_srv() + sys_var = srv.get_system_variable(self._test_app_id, "sys_var") + assert sys_var is not None + assert sys_var.id == self._sys_var_id + assert sys_var.name == "sys_var" + assert sys_var.get_value() == build_segment("sys_value") + + def test_get_conversation_variable(self): + srv = self._get_test_srv() + conv_var = srv.get_conversation_variable(self._test_app_id, "conv_var") + assert conv_var is not None + assert conv_var.id == self._conv_var_id + assert conv_var.name == "conv_var" + assert conv_var.get_value() == build_segment("conv_value") + + def test_delete_node_variables(self): + srv = self._get_test_srv() + srv.delete_node_variables(self._test_app_id, self._node2_id) + node2_var_count = ( + self._session.query(WorkflowDraftVariable) + .where( + WorkflowDraftVariable.app_id == self._test_app_id, + WorkflowDraftVariable.node_id == self._node2_id, + ) + .count() + ) + assert node2_var_count == 0 + + def test_delete_variable(self): + srv = self._get_test_srv() + node_1_var = ( + self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).one() + ) + srv.delete_variable(node_1_var) + exists = bool( + self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).first() + ) + assert exists is False + + def test__list_node_variables(self): + srv = self._get_test_srv() + node_vars = srv._list_node_variables(self._test_app_id, self._node2_id) + assert len(node_vars.variables) == 2 + assert {v.id for v in node_vars.variables} == set(self._node2_var_ids) + + def test_get_draft_variables_by_selectors(self): + srv = self._get_test_srv() + selectors = [ + [self._node1_id, "str_var"], + [self._node2_id, "str_var"], + [self._node2_id, "int_var"], + ] + variables = srv.get_draft_variables_by_selectors(self._test_app_id, selectors) + assert len(variables) == 3 + assert {v.id for v in variables} == {self._node1_str_var_id} | set(self._node2_var_ids) + + +@pytest.mark.usefixtures("flask_req_ctx") +class TestDraftVariableLoader(unittest.TestCase): + _test_app_id: str + _test_tenant_id: str + + _node1_id = "test_loader_node_1" + _node_exec_id = str(uuid.uuid4()) + + def setUp(self): + self._test_app_id = str(uuid.uuid4()) + self._test_tenant_id = str(uuid.uuid4()) + sys_var = WorkflowDraftVariable.new_sys_variable( + app_id=self._test_app_id, + name="sys_var", + value=build_segment("sys_value"), + node_execution_id=self._node_exec_id, + ) + conv_var = WorkflowDraftVariable.new_conversation_variable( + app_id=self._test_app_id, + name="conv_var", + value=build_segment("conv_value"), + ) + node_var = WorkflowDraftVariable.new_node_variable( + app_id=self._test_app_id, + node_id=self._node1_id, + name="str_var", + value=build_segment("str_value"), + visible=True, + node_execution_id=self._node_exec_id, + ) + _variables = [ + node_var, + sys_var, + conv_var, + ] + + with Session(bind=db.engine, expire_on_commit=False) as session: + session.add_all(_variables) + session.flush() + session.commit() + self._variable_ids = [v.id for v in _variables] + self._node_var_id = node_var.id + self._sys_var_id = sys_var.id + self._conv_var_id = conv_var.id + + def tearDown(self): + with Session(bind=db.engine, expire_on_commit=False) as session: + session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.app_id == self._test_app_id).delete( + synchronize_session=False + ) + session.commit() + + def test_variable_loader_with_empty_selector(self): + var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id) + variables = var_loader.load_variables([]) + assert len(variables) == 0 + + def test_variable_loader_with_non_empty_selector(self): + var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id) + variables = var_loader.load_variables( + [ + [SYSTEM_VARIABLE_NODE_ID, "sys_var"], + [CONVERSATION_VARIABLE_NODE_ID, "conv_var"], + [self._node1_id, "str_var"], + ] + ) + assert len(variables) == 3 + conv_var = next(v for v in variables if v.selector[0] == CONVERSATION_VARIABLE_NODE_ID) + assert conv_var.id == self._conv_var_id + sys_var = next(v for v in variables if v.selector[0] == SYSTEM_VARIABLE_NODE_ID) + assert sys_var.id == self._sys_var_id + node1_var = next(v for v in variables if v.selector[0] == self._node1_id) + assert node1_var.id == self._node_var_id + + +@pytest.mark.usefixtures("flask_req_ctx") +class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase): + """Integration tests for reset_variable functionality using real database""" + + _test_app_id: str + _test_tenant_id: str + _test_workflow_id: str + _session: Session + _node_id = "test_reset_node" + _node_exec_id: str + _workflow_node_exec_id: str + + def setUp(self): + self._test_app_id = str(uuid.uuid4()) + self._test_tenant_id = str(uuid.uuid4()) + self._test_workflow_id = str(uuid.uuid4()) + self._node_exec_id = str(uuid.uuid4()) + self._workflow_node_exec_id = str(uuid.uuid4()) + self._session: Session = db.session() + + # Create a workflow node execution record with outputs + # Note: The WorkflowNodeExecutionModel.id should match the node_execution_id in WorkflowDraftVariable + self._workflow_node_execution = WorkflowNodeExecutionModel( + id=self._node_exec_id, # This should match the node_execution_id in the variable + tenant_id=self._test_tenant_id, + app_id=self._test_app_id, + workflow_id=self._test_workflow_id, + triggered_from="workflow-run", + workflow_run_id=str(uuid.uuid4()), + index=1, + node_execution_id=self._node_exec_id, + node_id=self._node_id, + node_type=NodeType.LLM.value, + title="Test Node", + inputs='{"input": "test input"}', + process_data='{"test_var": "process_value", "other_var": "other_process"}', + outputs='{"test_var": "output_value", "other_var": "other_output"}', + status="succeeded", + elapsed_time=1.5, + created_by_role="account", + created_by=str(uuid.uuid4()), + ) + + # Create conversation variables for the workflow + self._conv_variables = [ + StringVariable( + id=str(uuid.uuid4()), + name="conv_var_1", + description="Test conversation variable 1", + value="default_value_1", + ), + StringVariable( + id=str(uuid.uuid4()), + name="conv_var_2", + description="Test conversation variable 2", + value="default_value_2", + ), + ] + + # Create test variables + self._node_var_with_exec = WorkflowDraftVariable.new_node_variable( + app_id=self._test_app_id, + node_id=self._node_id, + name="test_var", + value=build_segment("old_value"), + node_execution_id=self._node_exec_id, + ) + self._node_var_with_exec.last_edited_at = datetime_utils.naive_utc_now() + + self._node_var_without_exec = WorkflowDraftVariable.new_node_variable( + app_id=self._test_app_id, + node_id=self._node_id, + name="no_exec_var", + value=build_segment("some_value"), + node_execution_id="temp_exec_id", + ) + # Manually set node_execution_id to None after creation + self._node_var_without_exec.node_execution_id = None + + self._node_var_missing_exec = WorkflowDraftVariable.new_node_variable( + app_id=self._test_app_id, + node_id=self._node_id, + name="missing_exec_var", + value=build_segment("some_value"), + node_execution_id=str(uuid.uuid4()), # Use a valid UUID that doesn't exist in database + ) + + self._conv_var = WorkflowDraftVariable.new_conversation_variable( + app_id=self._test_app_id, + name="conv_var_1", + value=build_segment("old_conv_value"), + ) + self._conv_var.last_edited_at = datetime_utils.naive_utc_now() + + # Add all to database + db.session.add_all( + [ + self._workflow_node_execution, + self._node_var_with_exec, + self._node_var_without_exec, + self._node_var_missing_exec, + self._conv_var, + ] + ) + db.session.flush() + + # Store IDs for assertions + self._node_var_with_exec_id = self._node_var_with_exec.id + self._node_var_without_exec_id = self._node_var_without_exec.id + self._node_var_missing_exec_id = self._node_var_missing_exec.id + self._conv_var_id = self._conv_var.id + + def _get_test_srv(self) -> WorkflowDraftVariableService: + return WorkflowDraftVariableService(session=self._session) + + def _create_mock_workflow(self) -> Workflow: + """Create a real workflow with conversation variables and graph""" + conversation_vars = self._conv_variables + + # Create a simple graph with the test node + graph = { + "nodes": [{"id": "test_reset_node", "type": "llm", "title": "Test Node", "data": {"type": "llm"}}], + "edges": [], + } + + workflow = Workflow.new( + tenant_id=str(uuid.uuid4()), + app_id=self._test_app_id, + type="workflow", + version="1.0", + graph=json.dumps(graph), + features="{}", + created_by=str(uuid.uuid4()), + environment_variables=[], + conversation_variables=conversation_vars, + ) + return workflow + + def tearDown(self): + self._session.rollback() + + def test_reset_node_variable_with_valid_execution_record(self): + """Test resetting a node variable with valid execution record - should restore from execution""" + srv = self._get_test_srv() + mock_workflow = self._create_mock_workflow() + + # Get the variable before reset + variable = srv.get_variable(self._node_var_with_exec_id) + assert variable is not None + assert variable.get_value().value == "old_value" + assert variable.last_edited_at is not None + + # Reset the variable + result = srv.reset_variable(mock_workflow, variable) + + # Should return the updated variable + assert result is not None + assert result.id == self._node_var_with_exec_id + assert result.node_execution_id == self._workflow_node_execution.id + assert result.last_edited_at is None # Should be reset to None + + # The returned variable should have the updated value from execution record + assert result.get_value().value == "output_value" + + # Verify the variable was updated in database + updated_variable = srv.get_variable(self._node_var_with_exec_id) + assert updated_variable is not None + # The value should be updated from the execution record's outputs + assert updated_variable.get_value().value == "output_value" + assert updated_variable.last_edited_at is None + assert updated_variable.node_execution_id == self._workflow_node_execution.id + + def test_reset_node_variable_with_no_execution_id(self): + """Test resetting a node variable with no execution ID - should delete variable""" + srv = self._get_test_srv() + mock_workflow = self._create_mock_workflow() + + # Get the variable before reset + variable = srv.get_variable(self._node_var_without_exec_id) + assert variable is not None + + # Reset the variable + result = srv.reset_variable(mock_workflow, variable) + + # Should return None (variable deleted) + assert result is None + + # Verify the variable was deleted + deleted_variable = srv.get_variable(self._node_var_without_exec_id) + assert deleted_variable is None + + def test_reset_node_variable_with_missing_execution_record(self): + """Test resetting a node variable when execution record doesn't exist""" + srv = self._get_test_srv() + mock_workflow = self._create_mock_workflow() + + # Get the variable before reset + variable = srv.get_variable(self._node_var_missing_exec_id) + assert variable is not None + + # Reset the variable + result = srv.reset_variable(mock_workflow, variable) + + # Should return None (variable deleted) + assert result is None + + # Verify the variable was deleted + deleted_variable = srv.get_variable(self._node_var_missing_exec_id) + assert deleted_variable is None + + def test_reset_conversation_variable(self): + """Test resetting a conversation variable""" + srv = self._get_test_srv() + mock_workflow = self._create_mock_workflow() + + # Get the variable before reset + variable = srv.get_variable(self._conv_var_id) + assert variable is not None + assert variable.get_value().value == "old_conv_value" + assert variable.last_edited_at is not None + + # Reset the variable + result = srv.reset_variable(mock_workflow, variable) + + # Should return the updated variable + assert result is not None + assert result.id == self._conv_var_id + assert result.last_edited_at is None # Should be reset to None + + # Verify the variable was updated with default value from workflow + updated_variable = srv.get_variable(self._conv_var_id) + assert updated_variable is not None + # The value should be updated from the workflow's conversation variable default + assert updated_variable.get_value().value == "default_value_1" + assert updated_variable.last_edited_at is None + + def test_reset_system_variable_raises_error(self): + """Test that resetting a system variable raises an error""" + srv = self._get_test_srv() + mock_workflow = self._create_mock_workflow() + + # Create a system variable + sys_var = WorkflowDraftVariable.new_sys_variable( + app_id=self._test_app_id, + name="sys_var", + value=build_segment("sys_value"), + node_execution_id=self._node_exec_id, + ) + db.session.add(sys_var) + db.session.flush() + + # Attempt to reset the system variable + with pytest.raises(VariableResetError) as exc_info: + srv.reset_variable(mock_workflow, sys_var) + + assert "cannot reset system variable" in str(exc_info.value) + assert sys_var.id in str(exc_info.value) diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 6aa48b1cbb..a3b2fdc376 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -8,8 +8,6 @@ from unittest.mock import MagicMock, patch import pytest -from app_factory import create_app -from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.message_entities import AssistantPromptMessage @@ -30,21 +28,6 @@ from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_mod from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock -@pytest.fixture(scope="session") -def app(): - # Set up storage configuration - os.environ["STORAGE_TYPE"] = "opendal" - os.environ["OPENDAL_SCHEME"] = "fs" - os.environ["OPENDAL_FS_ROOT"] = "storage" - - # Ensure storage directory exists - os.makedirs("storage", exist_ok=True) - - app = create_app() - dify_config.LOGIN_DISABLED = True - return app - - def init_llm_node(config: dict) -> LLMNode: graph_config = { "edges": [ @@ -102,197 +85,195 @@ def init_llm_node(config: dict) -> LLMNode: return node -def test_execute_llm(app): - with app.app_context(): - node = init_llm_node( - config={ - "id": "llm", - "data": { - "title": "123", - "type": "llm", - "model": { - "provider": "langgenius/openai/openai", - "name": "gpt-3.5-turbo", - "mode": "chat", - "completion_params": {}, - }, - "prompt_template": [ - { - "role": "system", - "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.", - }, - {"role": "user", "text": "{{#sys.query#}}"}, - ], - "memory": None, - "context": {"enabled": False}, - "vision": {"enabled": False}, +def test_execute_llm(flask_req_ctx): + node = init_llm_node( + config={ + "id": "llm", + "data": { + "title": "123", + "type": "llm", + "model": { + "provider": "langgenius/openai/openai", + "name": "gpt-3.5-turbo", + "mode": "chat", + "completion_params": {}, }, + "prompt_template": [ + { + "role": "system", + "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.", + }, + {"role": "user", "text": "{{#sys.query#}}"}, + ], + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, }, - ) + }, + ) - credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} + credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} - # Create a proper LLM result with real entities - mock_usage = LLMUsage( - prompt_tokens=30, - prompt_unit_price=Decimal("0.001"), - prompt_price_unit=Decimal("1000"), - prompt_price=Decimal("0.00003"), - completion_tokens=20, - completion_unit_price=Decimal("0.002"), - completion_price_unit=Decimal("1000"), - completion_price=Decimal("0.00004"), - total_tokens=50, - total_price=Decimal("0.00007"), - currency="USD", - latency=0.5, - ) + # Create a proper LLM result with real entities + mock_usage = LLMUsage( + prompt_tokens=30, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("1000"), + prompt_price=Decimal("0.00003"), + completion_tokens=20, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("1000"), + completion_price=Decimal("0.00004"), + total_tokens=50, + total_price=Decimal("0.00007"), + currency="USD", + latency=0.5, + ) - mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.") + mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.") - mock_llm_result = LLMResult( - model="gpt-3.5-turbo", - prompt_messages=[], - message=mock_message, - usage=mock_usage, - ) + mock_llm_result = LLMResult( + model="gpt-3.5-turbo", + prompt_messages=[], + message=mock_message, + usage=mock_usage, + ) - # Create a simple mock model instance that doesn't call real providers - mock_model_instance = MagicMock() - mock_model_instance.invoke_llm.return_value = mock_llm_result + # Create a simple mock model instance that doesn't call real providers + mock_model_instance = MagicMock() + mock_model_instance.invoke_llm.return_value = mock_llm_result - # Create a simple mock model config with required attributes - mock_model_config = MagicMock() - mock_model_config.mode = "chat" - mock_model_config.provider = "langgenius/openai/openai" - mock_model_config.model = "gpt-3.5-turbo" - mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" + # Create a simple mock model config with required attributes + mock_model_config = MagicMock() + mock_model_config.mode = "chat" + mock_model_config.provider = "langgenius/openai/openai" + mock_model_config.model = "gpt-3.5-turbo" + mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" - # Mock the _fetch_model_config method - def mock_fetch_model_config_func(_node_data_model): - return mock_model_instance, mock_model_config + # Mock the _fetch_model_config method + def mock_fetch_model_config_func(_node_data_model): + return mock_model_instance, mock_model_config - # Also mock ModelManager.get_model_instance to avoid database calls - def mock_get_model_instance(_self, **kwargs): - return mock_model_instance + # Also mock ModelManager.get_model_instance to avoid database calls + def mock_get_model_instance(_self, **kwargs): + return mock_model_instance - with ( - patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), - patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), - ): - # execute node - result = node._run() - assert isinstance(result, Generator) + with ( + patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), + patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), + ): + # execute node + result = node._run() + assert isinstance(result, Generator) - for item in result: - if isinstance(item, RunCompletedEvent): - assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.process_data is not None - assert item.run_result.outputs is not None - assert item.run_result.outputs.get("text") is not None - assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 + for item in result: + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.process_data is not None + assert item.run_result.outputs is not None + assert item.run_result.outputs.get("text") is not None + assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) -def test_execute_llm_with_jinja2(app, setup_code_executor_mock): +def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock): """ Test execute LLM node with jinja2 """ - with app.app_context(): - node = init_llm_node( - config={ - "id": "llm", - "data": { - "title": "123", - "type": "llm", - "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, - "prompt_config": { - "jinja2_variables": [ - {"variable": "sys_query", "value_selector": ["sys", "query"]}, - {"variable": "output", "value_selector": ["abc", "output"]}, - ] - }, - "prompt_template": [ - { - "role": "system", - "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}", - "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.", - "edition_type": "jinja2", - }, - { - "role": "user", - "text": "{{#sys.query#}}", - "jinja2_text": "{{sys_query}}", - "edition_type": "basic", - }, - ], - "memory": None, - "context": {"enabled": False}, - "vision": {"enabled": False}, + node = init_llm_node( + config={ + "id": "llm", + "data": { + "title": "123", + "type": "llm", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "prompt_config": { + "jinja2_variables": [ + {"variable": "sys_query", "value_selector": ["sys", "query"]}, + {"variable": "output", "value_selector": ["abc", "output"]}, + ] }, + "prompt_template": [ + { + "role": "system", + "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}", + "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.", + "edition_type": "jinja2", + }, + { + "role": "user", + "text": "{{#sys.query#}}", + "jinja2_text": "{{sys_query}}", + "edition_type": "basic", + }, + ], + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, }, - ) + }, + ) - # Mock db.session.close() - db.session.close = MagicMock() + # Mock db.session.close() + db.session.close = MagicMock() - # Create a proper LLM result with real entities - mock_usage = LLMUsage( - prompt_tokens=30, - prompt_unit_price=Decimal("0.001"), - prompt_price_unit=Decimal("1000"), - prompt_price=Decimal("0.00003"), - completion_tokens=20, - completion_unit_price=Decimal("0.002"), - completion_price_unit=Decimal("1000"), - completion_price=Decimal("0.00004"), - total_tokens=50, - total_price=Decimal("0.00007"), - currency="USD", - latency=0.5, - ) + # Create a proper LLM result with real entities + mock_usage = LLMUsage( + prompt_tokens=30, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("1000"), + prompt_price=Decimal("0.00003"), + completion_tokens=20, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("1000"), + completion_price=Decimal("0.00004"), + total_tokens=50, + total_price=Decimal("0.00007"), + currency="USD", + latency=0.5, + ) - mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?") + mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?") - mock_llm_result = LLMResult( - model="gpt-3.5-turbo", - prompt_messages=[], - message=mock_message, - usage=mock_usage, - ) + mock_llm_result = LLMResult( + model="gpt-3.5-turbo", + prompt_messages=[], + message=mock_message, + usage=mock_usage, + ) - # Create a simple mock model instance that doesn't call real providers - mock_model_instance = MagicMock() - mock_model_instance.invoke_llm.return_value = mock_llm_result + # Create a simple mock model instance that doesn't call real providers + mock_model_instance = MagicMock() + mock_model_instance.invoke_llm.return_value = mock_llm_result - # Create a simple mock model config with required attributes - mock_model_config = MagicMock() - mock_model_config.mode = "chat" - mock_model_config.provider = "openai" - mock_model_config.model = "gpt-3.5-turbo" - mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" + # Create a simple mock model config with required attributes + mock_model_config = MagicMock() + mock_model_config.mode = "chat" + mock_model_config.provider = "openai" + mock_model_config.model = "gpt-3.5-turbo" + mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" - # Mock the _fetch_model_config method - def mock_fetch_model_config_func(_node_data_model): - return mock_model_instance, mock_model_config + # Mock the _fetch_model_config method + def mock_fetch_model_config_func(_node_data_model): + return mock_model_instance, mock_model_config - # Also mock ModelManager.get_model_instance to avoid database calls - def mock_get_model_instance(_self, **kwargs): - return mock_model_instance + # Also mock ModelManager.get_model_instance to avoid database calls + def mock_get_model_instance(_self, **kwargs): + return mock_model_instance - with ( - patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), - patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), - ): - # execute node - result = node._run() + with ( + patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), + patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), + ): + # execute node + result = node._run() - for item in result: - if isinstance(item, RunCompletedEvent): - assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.process_data is not None - assert "sunny" in json.dumps(item.run_result.process_data) - assert "what's the weather today?" in json.dumps(item.run_result.process_data) + for item in result: + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.process_data is not None + assert "sunny" in json.dumps(item.run_result.process_data) + assert "what's the weather today?" in json.dumps(item.run_result.process_data) def test_extract_json(): diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index e09acc4c39..077ffe3408 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -1,4 +1,5 @@ import os +from unittest.mock import MagicMock, patch import pytest from flask import Flask @@ -11,6 +12,24 @@ PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir)) CACHED_APP = Flask(__name__) +# set global mock for Redis client +redis_mock = MagicMock() +redis_mock.get = MagicMock(return_value=None) +redis_mock.setex = MagicMock() +redis_mock.setnx = MagicMock() +redis_mock.delete = MagicMock() +redis_mock.lock = MagicMock() +redis_mock.exists = MagicMock(return_value=False) +redis_mock.set = MagicMock() +redis_mock.expire = MagicMock() +redis_mock.hgetall = MagicMock(return_value={}) +redis_mock.hdel = MagicMock() +redis_mock.incr = MagicMock(return_value=1) + +# apply the mock to the Redis client in the Flask app +redis_patcher = patch("extensions.ext_redis.redis_client", redis_mock) +redis_patcher.start() + @pytest.fixture def app() -> Flask: @@ -21,3 +40,19 @@ def app() -> Flask: def _provide_app_context(app: Flask): with app.app_context(): yield + + +@pytest.fixture(autouse=True) +def reset_redis_mock(): + """reset the Redis mock before each test""" + redis_mock.reset_mock() + redis_mock.get.return_value = None + redis_mock.setex.return_value = None + redis_mock.setnx.return_value = None + redis_mock.delete.return_value = None + redis_mock.exists.return_value = False + redis_mock.set.return_value = None + redis_mock.expire.return_value = None + redis_mock.hgetall.return_value = {} + redis_mock.hdel.return_value = None + redis_mock.incr.return_value = 1 diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py new file mode 100644 index 0000000000..f26be6702a --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -0,0 +1,302 @@ +import datetime +import uuid +from collections import OrderedDict +from typing import Any, NamedTuple + +from flask_restful import marshal + +from controllers.console.app.workflow_draft_variable import ( + _WORKFLOW_DRAFT_VARIABLE_FIELDS, + _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS, + _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS, + _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, +) +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from factories.variable_factory import build_segment +from models.workflow import WorkflowDraftVariable +from services.workflow_draft_variable_service import WorkflowDraftVariableList + +_TEST_APP_ID = "test_app_id" +_TEST_NODE_EXEC_ID = str(uuid.uuid4()) + + +class TestWorkflowDraftVariableFields: + def test_conversation_variable(self): + conv_var = WorkflowDraftVariable.new_conversation_variable( + app_id=_TEST_APP_ID, name="conv_var", value=build_segment(1) + ) + + conv_var.id = str(uuid.uuid4()) + conv_var.visible = True + + expected_without_value: OrderedDict[str, Any] = OrderedDict( + { + "id": str(conv_var.id), + "type": conv_var.get_variable_type().value, + "name": "conv_var", + "description": "", + "selector": [CONVERSATION_VARIABLE_NODE_ID, "conv_var"], + "value_type": "number", + "edited": False, + "visible": True, + } + ) + + assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value + expected_with_value = expected_without_value.copy() + expected_with_value["value"] = 1 + assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value + + def test_create_sys_variable(self): + sys_var = WorkflowDraftVariable.new_sys_variable( + app_id=_TEST_APP_ID, + name="sys_var", + value=build_segment("a"), + editable=True, + node_execution_id=_TEST_NODE_EXEC_ID, + ) + + sys_var.id = str(uuid.uuid4()) + sys_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + sys_var.visible = True + + expected_without_value = OrderedDict( + { + "id": str(sys_var.id), + "type": sys_var.get_variable_type().value, + "name": "sys_var", + "description": "", + "selector": [SYSTEM_VARIABLE_NODE_ID, "sys_var"], + "value_type": "string", + "edited": True, + "visible": True, + } + ) + assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value + expected_with_value = expected_without_value.copy() + expected_with_value["value"] = "a" + assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value + + def test_node_variable(self): + node_var = WorkflowDraftVariable.new_node_variable( + app_id=_TEST_APP_ID, + node_id="test_node", + name="node_var", + value=build_segment([1, "a"]), + visible=False, + node_execution_id=_TEST_NODE_EXEC_ID, + ) + + node_var.id = str(uuid.uuid4()) + node_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + + expected_without_value: OrderedDict[str, Any] = OrderedDict( + { + "id": str(node_var.id), + "type": node_var.get_variable_type().value, + "name": "node_var", + "description": "", + "selector": ["test_node", "node_var"], + "value_type": "array[any]", + "edited": True, + "visible": False, + } + ) + + assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value + expected_with_value = expected_without_value.copy() + expected_with_value["value"] = [1, "a"] + assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value + + +class TestWorkflowDraftVariableList: + def test_workflow_draft_variable_list(self): + class TestCase(NamedTuple): + name: str + var_list: WorkflowDraftVariableList + expected: dict + + node_var = WorkflowDraftVariable.new_node_variable( + app_id=_TEST_APP_ID, + node_id="test_node", + name="test_var", + value=build_segment("a"), + visible=True, + node_execution_id=_TEST_NODE_EXEC_ID, + ) + node_var.id = str(uuid.uuid4()) + node_var_dict = OrderedDict( + { + "id": str(node_var.id), + "type": node_var.get_variable_type().value, + "name": "test_var", + "description": "", + "selector": ["test_node", "test_var"], + "value_type": "string", + "edited": False, + "visible": True, + } + ) + + cases = [ + TestCase( + name="empty variable list", + var_list=WorkflowDraftVariableList(variables=[]), + expected=OrderedDict( + { + "items": [], + "total": None, + } + ), + ), + TestCase( + name="empty variable list with total", + var_list=WorkflowDraftVariableList(variables=[], total=10), + expected=OrderedDict( + { + "items": [], + "total": 10, + } + ), + ), + TestCase( + name="non-empty variable list", + var_list=WorkflowDraftVariableList(variables=[node_var], total=None), + expected=OrderedDict( + { + "items": [node_var_dict], + "total": None, + } + ), + ), + TestCase( + name="non-empty variable list with total", + var_list=WorkflowDraftVariableList(variables=[node_var], total=10), + expected=OrderedDict( + { + "items": [node_var_dict], + "total": 10, + } + ), + ), + ] + + for idx, case in enumerate(cases, 1): + assert marshal(case.var_list, _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) == case.expected, ( + f"Test case {idx} failed, {case.name=}" + ) + + +def test_workflow_node_variables_fields(): + conv_var = WorkflowDraftVariable.new_conversation_variable( + app_id=_TEST_APP_ID, name="conv_var", value=build_segment(1) + ) + resp = marshal(WorkflowDraftVariableList(variables=[conv_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) + assert isinstance(resp, dict) + assert len(resp["items"]) == 1 + item_dict = resp["items"][0] + assert item_dict["name"] == "conv_var" + assert item_dict["value"] == 1 + + +def test_workflow_file_variable_with_signed_url(): + """Test that File type variables include signed URLs in API responses.""" + from core.file.enums import FileTransferMethod, FileType + from core.file.models import File + + # Create a File object with LOCAL_FILE transfer method (which generates signed URLs) + test_file = File( + id="test_file_id", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="test_upload_file_id", + filename="test.jpg", + extension=".jpg", + mime_type="image/jpeg", + size=12345, + ) + + # Create a WorkflowDraftVariable with the File + file_var = WorkflowDraftVariable.new_node_variable( + app_id=_TEST_APP_ID, + node_id="test_node", + name="file_var", + value=build_segment(test_file), + node_execution_id=_TEST_NODE_EXEC_ID, + ) + + # Marshal the variable using the API fields + resp = marshal(WorkflowDraftVariableList(variables=[file_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) + + # Verify the response structure + assert isinstance(resp, dict) + assert len(resp["items"]) == 1 + item_dict = resp["items"][0] + assert item_dict["name"] == "file_var" + + # Verify the value is a dict (File.to_dict() result) and contains expected fields + value = item_dict["value"] + assert isinstance(value, dict) + + # Verify the File fields are preserved + assert value["id"] == test_file.id + assert value["filename"] == test_file.filename + assert value["type"] == test_file.type.value + assert value["transfer_method"] == test_file.transfer_method.value + assert value["size"] == test_file.size + + # Verify the URL is present (it should be a signed URL for LOCAL_FILE transfer method) + remote_url = value["remote_url"] + assert remote_url is not None + + assert isinstance(remote_url, str) + # For LOCAL_FILE, the URL should contain signature parameters + assert "timestamp=" in remote_url + assert "nonce=" in remote_url + assert "sign=" in remote_url + + +def test_workflow_file_variable_remote_url(): + """Test that File type variables with REMOTE_URL transfer method return the remote URL.""" + from core.file.enums import FileTransferMethod, FileType + from core.file.models import File + + # Create a File object with REMOTE_URL transfer method + test_file = File( + id="test_file_id", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/test.jpg", + filename="test.jpg", + extension=".jpg", + mime_type="image/jpeg", + size=12345, + ) + + # Create a WorkflowDraftVariable with the File + file_var = WorkflowDraftVariable.new_node_variable( + app_id=_TEST_APP_ID, + node_id="test_node", + name="file_var", + value=build_segment(test_file), + node_execution_id=_TEST_NODE_EXEC_ID, + ) + + # Marshal the variable using the API fields + resp = marshal(WorkflowDraftVariableList(variables=[file_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) + + # Verify the response structure + assert isinstance(resp, dict) + assert len(resp["items"]) == 1 + item_dict = resp["items"][0] + assert item_dict["name"] == "file_var" + + # Verify the value is a dict (File.to_dict() result) and contains expected fields + value = item_dict["value"] + assert isinstance(value, dict) + remote_url = value["remote_url"] + + # For REMOTE_URL, the URL should be the original remote URL + assert remote_url == test_file.remote_url diff --git a/api/tests/unit_tests/core/app/segments/test_factory.py b/api/tests/unit_tests/core/app/segments/test_factory.py deleted file mode 100644 index e6e289c12a..0000000000 --- a/api/tests/unit_tests/core/app/segments/test_factory.py +++ /dev/null @@ -1,165 +0,0 @@ -from uuid import uuid4 - -import pytest - -from core.variables import ( - ArrayNumberVariable, - ArrayObjectVariable, - ArrayStringVariable, - FloatVariable, - IntegerVariable, - ObjectSegment, - SecretVariable, - StringVariable, -) -from core.variables.exc import VariableError -from core.variables.segments import ArrayAnySegment -from factories import variable_factory - - -def test_string_variable(): - test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"} - result = variable_factory.build_conversation_variable_from_mapping(test_data) - assert isinstance(result, StringVariable) - - -def test_integer_variable(): - test_data = {"value_type": "number", "name": "test_int", "value": 42} - result = variable_factory.build_conversation_variable_from_mapping(test_data) - assert isinstance(result, IntegerVariable) - - -def test_float_variable(): - test_data = {"value_type": "number", "name": "test_float", "value": 3.14} - result = variable_factory.build_conversation_variable_from_mapping(test_data) - assert isinstance(result, FloatVariable) - - -def test_secret_variable(): - test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"} - result = variable_factory.build_conversation_variable_from_mapping(test_data) - assert isinstance(result, SecretVariable) - - -def test_invalid_value_type(): - test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"} - with pytest.raises(VariableError): - variable_factory.build_conversation_variable_from_mapping(test_data) - - -def test_build_a_blank_string(): - result = variable_factory.build_conversation_variable_from_mapping( - { - "value_type": "string", - "name": "blank", - "value": "", - } - ) - assert isinstance(result, StringVariable) - assert result.value == "" - - -def test_build_a_object_variable_with_none_value(): - var = variable_factory.build_segment( - { - "key1": None, - } - ) - assert isinstance(var, ObjectSegment) - assert var.value["key1"] is None - - -def test_object_variable(): - mapping = { - "id": str(uuid4()), - "value_type": "object", - "name": "test_object", - "description": "Description of the variable.", - "value": { - "key1": "text", - "key2": 2, - }, - } - variable = variable_factory.build_conversation_variable_from_mapping(mapping) - assert isinstance(variable, ObjectSegment) - assert isinstance(variable.value["key1"], str) - assert isinstance(variable.value["key2"], int) - - -def test_array_string_variable(): - mapping = { - "id": str(uuid4()), - "value_type": "array[string]", - "name": "test_array", - "description": "Description of the variable.", - "value": [ - "text", - "text", - ], - } - variable = variable_factory.build_conversation_variable_from_mapping(mapping) - assert isinstance(variable, ArrayStringVariable) - assert isinstance(variable.value[0], str) - assert isinstance(variable.value[1], str) - - -def test_array_number_variable(): - mapping = { - "id": str(uuid4()), - "value_type": "array[number]", - "name": "test_array", - "description": "Description of the variable.", - "value": [ - 1, - 2.0, - ], - } - variable = variable_factory.build_conversation_variable_from_mapping(mapping) - assert isinstance(variable, ArrayNumberVariable) - assert isinstance(variable.value[0], int) - assert isinstance(variable.value[1], float) - - -def test_array_object_variable(): - mapping = { - "id": str(uuid4()), - "value_type": "array[object]", - "name": "test_array", - "description": "Description of the variable.", - "value": [ - { - "key1": "text", - "key2": 1, - }, - { - "key1": "text", - "key2": 1, - }, - ], - } - variable = variable_factory.build_conversation_variable_from_mapping(mapping) - assert isinstance(variable, ArrayObjectVariable) - assert isinstance(variable.value[0], dict) - assert isinstance(variable.value[1], dict) - assert isinstance(variable.value[0]["key1"], str) - assert isinstance(variable.value[0]["key2"], int) - assert isinstance(variable.value[1]["key1"], str) - assert isinstance(variable.value[1]["key2"], int) - - -def test_variable_cannot_large_than_200_kb(): - with pytest.raises(VariableError): - variable_factory.build_conversation_variable_from_mapping( - { - "id": str(uuid4()), - "value_type": "string", - "name": "test_text", - "value": "a" * 1024 * 201, - } - ) - - -def test_array_none_variable(): - var = variable_factory.build_segment([None, None, None, None]) - assert isinstance(var, ArrayAnySegment) - assert var.value == [None, None, None, None] diff --git a/api/tests/unit_tests/core/file/test_models.py b/api/tests/unit_tests/core/file/test_models.py new file mode 100644 index 0000000000..3ada2087c6 --- /dev/null +++ b/api/tests/unit_tests/core/file/test_models.py @@ -0,0 +1,25 @@ +from core.file import File, FileTransferMethod, FileType + + +def test_file(): + file = File( + id="test-file", + tenant_id="test-tenant-id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id="test-related-id", + filename="image.png", + extension=".png", + mime_type="image/png", + size=67, + storage_key="test-storage-key", + url="https://example.com/image.png", + ) + assert file.tenant_id == "test-tenant-id" + assert file.type == FileType.IMAGE + assert file.transfer_method == FileTransferMethod.TOOL_FILE + assert file.related_id == "test-related-id" + assert file.filename == "image.png" + assert file.extension == ".png" + assert file.mime_type == "image/png" + assert file.size == 67 diff --git a/api/tests/unit_tests/core/rag/extractor/test_markdown_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_markdown_extractor.py new file mode 100644 index 0000000000..d4cf534c56 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_markdown_extractor.py @@ -0,0 +1,22 @@ +from core.rag.extractor.markdown_extractor import MarkdownExtractor + + +def test_markdown_to_tups(): + markdown = """ +this is some text without header + +# title 1 +this is balabala text + +## title 2 +this is more specific text. + """ + extractor = MarkdownExtractor(file_path="dummy_path") + updated_output = extractor.markdown_to_tups(markdown) + assert len(updated_output) == 3 + key, header_value = updated_output[0] + assert key == None + assert header_value.strip() == "this is some text without header" + title_1, value = updated_output[1] + assert title_1.strip() == "title 1" + assert value.strip() == "this is balabala text" diff --git a/api/tests/unit_tests/core/app/segments/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py similarity index 100% rename from api/tests/unit_tests/core/app/segments/test_segment.py rename to api/tests/unit_tests/core/variables/test_segment.py diff --git a/api/tests/unit_tests/core/app/segments/test_variables.py b/api/tests/unit_tests/core/variables/test_variables.py similarity index 100% rename from api/tests/unit_tests/core/app/segments/test_variables.py rename to api/tests/unit_tests/core/variables/test_variables.py diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py new file mode 100644 index 0000000000..8712b61a23 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -0,0 +1,36 @@ +from core.workflow.nodes.base.node import BaseNode +from core.workflow.nodes.enums import NodeType + +# Ensures that all node classes are imported. +from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING + +_ = NODE_TYPE_CLASSES_MAPPING + + +def _get_all_subclasses(root: type[BaseNode]) -> list[type[BaseNode]]: + subclasses = [] + queue = [root] + while queue: + cls = queue.pop() + + subclasses.extend(cls.__subclasses__()) + queue.extend(cls.__subclasses__()) + + return subclasses + + +def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined(): + classes = _get_all_subclasses(BaseNode) # type: ignore + type_version_set: set[tuple[NodeType, str]] = set() + + for cls in classes: + # Validate that 'version' is directly defined in the class (not inherited) by checking the class's __dict__ + assert "version" in cls.__dict__, f"class {cls} should have version method defined (NOT INHERITED.)" + node_type = cls._node_type + node_version = cls.version() + + assert isinstance(cls._node_type, NodeType) + assert isinstance(node_version, str) + node_type_and_version = (node_type, node_version) + assert node_type_and_version not in type_version_set + type_version_set.add(node_type_and_version) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py index 6d854c950d..362072a3db 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py @@ -3,6 +3,7 @@ import uuid from unittest.mock import patch from core.app.entities.app_invoke_entities import InvokeFrom +from core.variables.segments import ArrayAnySegment, ArrayStringSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -197,7 +198,7 @@ def test_run(): count += 1 if isinstance(item, RunCompletedEvent): assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])} assert count == 20 @@ -413,7 +414,7 @@ def test_run_parallel(): count += 1 if isinstance(item, RunCompletedEvent): assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])} assert count == 32 @@ -654,7 +655,7 @@ def test_iteration_run_in_parallel_mode(): parallel_arr.append(item) if isinstance(item, RunCompletedEvent): assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])} assert count == 32 for item in sequential_result: @@ -662,7 +663,7 @@ def test_iteration_run_in_parallel_mode(): count += 1 if isinstance(item, RunCompletedEvent): assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])} assert count == 64 @@ -846,7 +847,7 @@ def test_iteration_run_error_handle(): count += 1 if isinstance(item, RunCompletedEvent): assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.outputs == {"output": [None, None]} + assert item.run_result.outputs == {"output": ArrayAnySegment(value=[None, None])} assert count == 14 # execute remove abnormal output @@ -857,5 +858,5 @@ def test_iteration_run_error_handle(): count += 1 if isinstance(item, RunCompletedEvent): assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.outputs == {"output": []} + assert item.run_result.outputs == {"output": ArrayAnySegment(value=[])} assert count == 14 diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 4cb1aa93f9..66c7818adf 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -7,6 +7,7 @@ from docx.oxml.text.paragraph import CT_P from core.file import File, FileTransferMethod from core.variables import ArrayFileSegment +from core.variables.segments import ArrayStringSegment from core.variables.variables import StringVariable from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -69,7 +70,13 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s @pytest.mark.parametrize( ("mime_type", "file_content", "expected_text", "transfer_method", "extension"), [ - ("text/plain", b"Hello, world!", ["Hello, world!"], FileTransferMethod.LOCAL_FILE, ".txt"), + ( + "text/plain", + b"Hello, world!", + ["Hello, world!"], + FileTransferMethod.LOCAL_FILE, + ".txt", + ), ( "application/pdf", b"%PDF-1.5\n%Test PDF content", @@ -84,7 +91,13 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s FileTransferMethod.REMOTE_URL, "", ), - ("text/plain", b"Remote content", ["Remote content"], FileTransferMethod.REMOTE_URL, None), + ( + "text/plain", + b"Remote content", + ["Remote content"], + FileTransferMethod.REMOTE_URL, + None, + ), ], ) def test_run_extract_text( @@ -131,7 +144,7 @@ def test_run_extract_text( assert isinstance(result, NodeRunResult) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED, result.error assert result.outputs is not None - assert result.outputs["text"] == expected_text + assert result.outputs["text"] == ArrayStringSegment(value=expected_text) if transfer_method == FileTransferMethod.REMOTE_URL: mock_ssrf_proxy_get.assert_called_once_with("https://example.com/file.txt") @@ -329,3 +342,26 @@ def test_extract_text_from_excel_all_sheets_fail(mock_excel_file): assert result == "" assert mock_excel_instance.parse.call_count == 2 + + +@patch("pandas.ExcelFile") +def test_extract_text_from_excel_numeric_type_column(mock_excel_file): + """Test extracting text from Excel file with numeric column names.""" + + # Test numeric type column + data = {1: ["Test"], 1.1: ["Test"]} + + df = pd.DataFrame(data) + + # Mock ExcelFile + mock_excel_instance = Mock() + mock_excel_instance.sheet_names = ["Sheet1"] + mock_excel_instance.parse.return_value = df + mock_excel_file.return_value = mock_excel_instance + + file_content = b"fake_excel_content" + result = _extract_text_from_excel(file_content) + + expected_manual = "| 1.0 | 1.1 |\n| --- | --- |\n| Test | Test |\n\n" + + assert expected_manual == result diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index 77d42e2692..7d3a1d6a2d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -115,7 +115,7 @@ def test_filter_files_by_type(list_operator_node): }, ] assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - for expected_file, result_file in zip(expected_files, result.outputs["result"]): + for expected_file, result_file in zip(expected_files, result.outputs["result"].value): assert expected_file["filename"] == result_file.filename assert expected_file["type"] == result_file.type assert expected_file["tenant_id"] == result_file.tenant_id diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index 9793da129d..deb3e29b86 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -5,6 +5,7 @@ from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import ArrayStringVariable, StringVariable +from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph @@ -63,10 +64,11 @@ def test_overwrite_string_variable(): name="test_string_variable", value="the second value", ) + conversation_id = str(uuid.uuid4()) # construct variable pool variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id}, user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -77,6 +79,9 @@ def test_overwrite_string_variable(): input_variable, ) + mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) + mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) + node = VariableAssignerNode( id=str(uuid.uuid4()), graph_init_params=init_params, @@ -91,11 +96,20 @@ def test_overwrite_string_variable(): "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], }, }, + conv_var_updater_factory=mock_conv_var_updater_factory, ) - with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run: - list(node.run()) - mock_run.assert_called_once() + list(node.run()) + expected_var = StringVariable( + id=conversation_variable.id, + name=conversation_variable.name, + description=conversation_variable.description, + selector=conversation_variable.selector, + value_type=conversation_variable.value_type, + value=input_variable.value, + ) + mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var) + mock_conv_var_updater.flush.assert_called_once() got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None @@ -148,9 +162,10 @@ def test_append_variable_to_array(): name="test_string_variable", value="the second value", ) + conversation_id = str(uuid.uuid4()) variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id}, user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -160,6 +175,9 @@ def test_append_variable_to_array(): input_variable, ) + mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) + mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) + node = VariableAssignerNode( id=str(uuid.uuid4()), graph_init_params=init_params, @@ -174,11 +192,22 @@ def test_append_variable_to_array(): "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], }, }, + conv_var_updater_factory=mock_conv_var_updater_factory, ) - with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run: - list(node.run()) - mock_run.assert_called_once() + list(node.run()) + expected_value = list(conversation_variable.value) + expected_value.append(input_variable.value) + expected_var = ArrayStringVariable( + id=conversation_variable.id, + name=conversation_variable.name, + description=conversation_variable.description, + selector=conversation_variable.selector, + value_type=conversation_variable.value_type, + value=expected_value, + ) + mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var) + mock_conv_var_updater.flush.assert_called_once() got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None @@ -225,13 +254,17 @@ def test_clear_array(): value=["the first value"], ) + conversation_id = str(uuid.uuid4()) variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id}, user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], ) + mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) + mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) + node = VariableAssignerNode( id=str(uuid.uuid4()), graph_init_params=init_params, @@ -246,11 +279,20 @@ def test_clear_array(): "input_variable_selector": [], }, }, + conv_var_updater_factory=mock_conv_var_updater_factory, ) - with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run: - list(node.run()) - mock_run.assert_called_once() + list(node.run()) + expected_var = ArrayStringVariable( + id=conversation_variable.id, + name=conversation_variable.name, + description=conversation_variable.description, + selector=conversation_variable.selector, + value_type=conversation_variable.value_type, + value=[], + ) + mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var) + mock_conv_var_updater.flush.assert_called_once() got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index efbcdc760c..bb8d34fad5 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -1,8 +1,12 @@ import pytest +from pydantic import ValidationError from core.file import File, FileTransferMethod, FileType from core.variables import FileSegment, StringSegment +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from factories.variable_factory import build_segment, segment_to_variable @pytest.fixture @@ -44,3 +48,38 @@ def test_use_long_selector(pool): result = pool.get(("node_1", "part_1", "part_2")) assert result is not None assert result.value == "test_value" + + +class TestVariablePool: + def test_constructor(self): + pool = VariablePool() + pool = VariablePool( + variable_dictionary={}, + user_inputs={}, + system_variables={}, + environment_variables=[], + conversation_variables=[], + ) + + pool = VariablePool( + user_inputs={"key": "value"}, + system_variables={SystemVariableKey.WORKFLOW_ID: "test_workflow_id"}, + environment_variables=[ + segment_to_variable( + segment=build_segment(1), + selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_var_1"], + name="env_var_1", + ) + ], + conversation_variables=[ + segment_to_variable( + segment=build_segment("1"), + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var_1"], + name="conv_var_1", + ) + ], + ) + + def test_constructor_with_invalid_system_variable_key(self): + with pytest.raises(ValidationError): + VariablePool(system_variables={"invalid_key": "value"}) # type: ignore diff --git a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py index fddc182594..646de8bf3a 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py @@ -163,7 +163,6 @@ def real_workflow_run(): workflow_run.tenant_id = "test-tenant-id" workflow_run.app_id = "test-app-id" workflow_run.workflow_id = "test-workflow-id" - workflow_run.sequence_number = 1 workflow_run.type = "chat" workflow_run.triggered_from = "app-run" workflow_run.version = "1.0" diff --git a/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py b/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py index 2f90afcf89..28ef05edde 100644 --- a/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py +++ b/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py @@ -1,22 +1,10 @@ -from core.variables import SecretVariable +import dataclasses + from core.workflow.entities.variable_entities import VariableSelector -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey from core.workflow.utils import variable_template_parser def test_extract_selectors_from_template(): - variable_pool = VariablePool( - system_variables={ - SystemVariableKey("user_id"): "fake-user-id", - }, - user_inputs={}, - environment_variables=[ - SecretVariable(name="secret_key", value="fake-secret-key"), - ], - conversation_variables=[], - ) - variable_pool.add(("node_id", "custom_query"), "fake-user-query") template = ( "Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}." ) @@ -26,3 +14,35 @@ def test_extract_selectors_from_template(): VariableSelector(variable="#node_id.custom_query#", value_selector=["node_id", "custom_query"]), VariableSelector(variable="#env.secret_key#", value_selector=["env", "secret_key"]), ] + + +def test_invalid_references(): + @dataclasses.dataclass + class TestCase: + name: str + template: str + + cases = [ + TestCase( + name="lack of closing brace", + template="Hello, {{#sys.user_id#", + ), + TestCase( + name="lack of opening brace", + template="Hello, #sys.user_id#}}", + ), + TestCase( + name="lack selector name", + template="Hello, {{#sys#}}", + ), + TestCase( + name="empty node name part", + template="Hello, {{#.user_id#}}", + ), + ] + for idx, c in enumerate(cases, 1): + fail_msg = f"Test case {c.name} failed, index={idx}" + selectors = variable_template_parser.extract_selectors_from_template(c.template) + assert selectors == [], fail_msg + parser = variable_template_parser.VariableTemplateParser(c.template) + assert parser.extract_variable_selectors() == [], fail_msg diff --git a/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py b/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py new file mode 100644 index 0000000000..f1cb937bb3 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py @@ -0,0 +1,148 @@ +from typing import Any + +from core.variables.segments import ObjectSegment, StringSegment +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.utils.variable_utils import append_variables_recursively + + +class TestAppendVariablesRecursively: + """Test cases for append_variables_recursively function""" + + def test_append_simple_dict_value(self): + """Test appending a simple dictionary value""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["output"] + variable_value = {"name": "John", "age": 30} + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check that the main variable is added + main_var = pool.get([node_id] + variable_key_list) + assert main_var is not None + assert main_var.value == variable_value + + # Check that nested variables are added recursively + name_var = pool.get([node_id] + variable_key_list + ["name"]) + assert name_var is not None + assert name_var.value == "John" + + age_var = pool.get([node_id] + variable_key_list + ["age"]) + assert age_var is not None + assert age_var.value == 30 + + def test_append_object_segment_value(self): + """Test appending an ObjectSegment value""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["result"] + + # Create an ObjectSegment + obj_data = {"status": "success", "code": 200} + variable_value = ObjectSegment(value=obj_data) + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check that the main variable is added + main_var = pool.get([node_id] + variable_key_list) + assert main_var is not None + assert isinstance(main_var, ObjectSegment) + assert main_var.value == obj_data + + # Check that nested variables are added recursively + status_var = pool.get([node_id] + variable_key_list + ["status"]) + assert status_var is not None + assert status_var.value == "success" + + code_var = pool.get([node_id] + variable_key_list + ["code"]) + assert code_var is not None + assert code_var.value == 200 + + def test_append_nested_dict_value(self): + """Test appending a nested dictionary value""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["data"] + + variable_value = { + "user": { + "profile": {"name": "Alice", "email": "alice@example.com"}, + "settings": {"theme": "dark", "notifications": True}, + }, + "metadata": {"version": "1.0", "timestamp": 1234567890}, + } + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check deeply nested variables + name_var = pool.get([node_id] + variable_key_list + ["user", "profile", "name"]) + assert name_var is not None + assert name_var.value == "Alice" + + email_var = pool.get([node_id] + variable_key_list + ["user", "profile", "email"]) + assert email_var is not None + assert email_var.value == "alice@example.com" + + theme_var = pool.get([node_id] + variable_key_list + ["user", "settings", "theme"]) + assert theme_var is not None + assert theme_var.value == "dark" + + notifications_var = pool.get([node_id] + variable_key_list + ["user", "settings", "notifications"]) + assert notifications_var is not None + assert notifications_var.value == 1 # Boolean True is converted to integer 1 + + version_var = pool.get([node_id] + variable_key_list + ["metadata", "version"]) + assert version_var is not None + assert version_var.value == "1.0" + + def test_append_non_dict_value(self): + """Test appending a non-dictionary value (should not recurse)""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["simple"] + variable_value = "simple_string" + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check that only the main variable is added + main_var = pool.get([node_id] + variable_key_list) + assert main_var is not None + assert main_var.value == variable_value + + # Ensure no additional variables are created + assert len(pool.variable_dictionary[node_id]) == 1 + + def test_append_segment_non_object_value(self): + """Test appending a Segment that is not ObjectSegment (should not recurse)""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["text"] + variable_value = StringSegment(value="Hello World") + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check that only the main variable is added + main_var = pool.get([node_id] + variable_key_list) + assert main_var is not None + assert isinstance(main_var, StringSegment) + assert main_var.value == "Hello World" + + # Ensure no additional variables are created + assert len(pool.variable_dictionary[node_id]) == 1 + + def test_append_empty_dict_value(self): + """Test appending an empty dictionary value""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["empty"] + variable_value: dict[str, Any] = {} + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check that the main variable is added + main_var = pool.get([node_id] + variable_key_list) + assert main_var is not None + assert main_var.value == {} + + # Ensure only the main variable is created (no recursion for empty dict) + assert len(pool.variable_dictionary[node_id]) == 1 diff --git a/api/tests/unit_tests/factories/test_variable_factory.py b/api/tests/unit_tests/factories/test_variable_factory.py new file mode 100644 index 0000000000..481fbdc91a --- /dev/null +++ b/api/tests/unit_tests/factories/test_variable_factory.py @@ -0,0 +1,865 @@ +import math +from dataclasses import dataclass +from typing import Any +from uuid import uuid4 + +import pytest +from hypothesis import given +from hypothesis import strategies as st + +from core.file import File, FileTransferMethod, FileType +from core.variables import ( + ArrayNumberVariable, + ArrayObjectVariable, + ArrayStringVariable, + FloatVariable, + IntegerVariable, + ObjectSegment, + SecretVariable, + SegmentType, + StringVariable, +) +from core.variables.exc import VariableError +from core.variables.segments import ( + ArrayAnySegment, + ArrayFileSegment, + ArrayNumberSegment, + ArrayObjectSegment, + ArrayStringSegment, + FileSegment, + FloatSegment, + IntegerSegment, + NoneSegment, + ObjectSegment, + StringSegment, +) +from core.variables.types import SegmentType +from factories import variable_factory +from factories.variable_factory import TypeMismatchError, build_segment_with_type + + +def test_string_variable(): + test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"} + result = variable_factory.build_conversation_variable_from_mapping(test_data) + assert isinstance(result, StringVariable) + + +def test_integer_variable(): + test_data = {"value_type": "number", "name": "test_int", "value": 42} + result = variable_factory.build_conversation_variable_from_mapping(test_data) + assert isinstance(result, IntegerVariable) + + +def test_float_variable(): + test_data = {"value_type": "number", "name": "test_float", "value": 3.14} + result = variable_factory.build_conversation_variable_from_mapping(test_data) + assert isinstance(result, FloatVariable) + + +def test_secret_variable(): + test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"} + result = variable_factory.build_conversation_variable_from_mapping(test_data) + assert isinstance(result, SecretVariable) + + +def test_invalid_value_type(): + test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"} + with pytest.raises(VariableError): + variable_factory.build_conversation_variable_from_mapping(test_data) + + +def test_build_a_blank_string(): + result = variable_factory.build_conversation_variable_from_mapping( + { + "value_type": "string", + "name": "blank", + "value": "", + } + ) + assert isinstance(result, StringVariable) + assert result.value == "" + + +def test_build_a_object_variable_with_none_value(): + var = variable_factory.build_segment( + { + "key1": None, + } + ) + assert isinstance(var, ObjectSegment) + assert var.value["key1"] is None + + +def test_object_variable(): + mapping = { + "id": str(uuid4()), + "value_type": "object", + "name": "test_object", + "description": "Description of the variable.", + "value": { + "key1": "text", + "key2": 2, + }, + } + variable = variable_factory.build_conversation_variable_from_mapping(mapping) + assert isinstance(variable, ObjectSegment) + assert isinstance(variable.value["key1"], str) + assert isinstance(variable.value["key2"], int) + + +def test_array_string_variable(): + mapping = { + "id": str(uuid4()), + "value_type": "array[string]", + "name": "test_array", + "description": "Description of the variable.", + "value": [ + "text", + "text", + ], + } + variable = variable_factory.build_conversation_variable_from_mapping(mapping) + assert isinstance(variable, ArrayStringVariable) + assert isinstance(variable.value[0], str) + assert isinstance(variable.value[1], str) + + +def test_array_number_variable(): + mapping = { + "id": str(uuid4()), + "value_type": "array[number]", + "name": "test_array", + "description": "Description of the variable.", + "value": [ + 1, + 2.0, + ], + } + variable = variable_factory.build_conversation_variable_from_mapping(mapping) + assert isinstance(variable, ArrayNumberVariable) + assert isinstance(variable.value[0], int) + assert isinstance(variable.value[1], float) + + +def test_array_object_variable(): + mapping = { + "id": str(uuid4()), + "value_type": "array[object]", + "name": "test_array", + "description": "Description of the variable.", + "value": [ + { + "key1": "text", + "key2": 1, + }, + { + "key1": "text", + "key2": 1, + }, + ], + } + variable = variable_factory.build_conversation_variable_from_mapping(mapping) + assert isinstance(variable, ArrayObjectVariable) + assert isinstance(variable.value[0], dict) + assert isinstance(variable.value[1], dict) + assert isinstance(variable.value[0]["key1"], str) + assert isinstance(variable.value[0]["key2"], int) + assert isinstance(variable.value[1]["key1"], str) + assert isinstance(variable.value[1]["key2"], int) + + +def test_variable_cannot_large_than_200_kb(): + with pytest.raises(VariableError): + variable_factory.build_conversation_variable_from_mapping( + { + "id": str(uuid4()), + "value_type": "string", + "name": "test_text", + "value": "a" * 1024 * 201, + } + ) + + +def test_array_none_variable(): + var = variable_factory.build_segment([None, None, None, None]) + assert isinstance(var, ArrayAnySegment) + assert var.value == [None, None, None, None] + + +def test_build_segment_none_type(): + """Test building NoneSegment from None value.""" + segment = variable_factory.build_segment(None) + assert isinstance(segment, NoneSegment) + assert segment.value is None + assert segment.value_type == SegmentType.NONE + + +def test_build_segment_none_type_properties(): + """Test NoneSegment properties and methods.""" + segment = variable_factory.build_segment(None) + assert segment.text == "" + assert segment.log == "" + assert segment.markdown == "" + assert segment.to_object() is None + + +def test_build_segment_array_file_single_file(): + """Test building ArrayFileSegment from list with single file.""" + file = File( + id="test_file_id", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://test.example.com/test-file.png", + filename="test-file", + extension=".png", + mime_type="image/png", + size=1000, + ) + segment = variable_factory.build_segment([file]) + assert isinstance(segment, ArrayFileSegment) + assert len(segment.value) == 1 + assert segment.value[0] == file + assert segment.value_type == SegmentType.ARRAY_FILE + + +def test_build_segment_array_file_multiple_files(): + """Test building ArrayFileSegment from list with multiple files.""" + file1 = File( + id="test_file_id_1", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://test.example.com/test-file1.png", + filename="test-file1", + extension=".png", + mime_type="image/png", + size=1000, + ) + file2 = File( + id="test_file_id_2", + tenant_id="test_tenant_id", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="test_relation_id", + filename="test-file2", + extension=".txt", + mime_type="text/plain", + size=500, + ) + segment = variable_factory.build_segment([file1, file2]) + assert isinstance(segment, ArrayFileSegment) + assert len(segment.value) == 2 + assert segment.value[0] == file1 + assert segment.value[1] == file2 + assert segment.value_type == SegmentType.ARRAY_FILE + + +def test_build_segment_array_file_empty_list(): + """Test building ArrayFileSegment from empty list should create ArrayAnySegment.""" + segment = variable_factory.build_segment([]) + assert isinstance(segment, ArrayAnySegment) + assert segment.value == [] + assert segment.value_type == SegmentType.ARRAY_ANY + + +def test_build_segment_array_any_mixed_types(): + """Test building ArrayAnySegment from list with mixed types.""" + mixed_values = ["string", 42, 3.14, {"key": "value"}, None] + segment = variable_factory.build_segment(mixed_values) + assert isinstance(segment, ArrayAnySegment) + assert segment.value == mixed_values + assert segment.value_type == SegmentType.ARRAY_ANY + + +def test_build_segment_array_any_with_nested_arrays(): + """Test building ArrayAnySegment from list containing arrays.""" + nested_values = [["nested", "array"], [1, 2, 3], "string"] + segment = variable_factory.build_segment(nested_values) + assert isinstance(segment, ArrayAnySegment) + assert segment.value == nested_values + assert segment.value_type == SegmentType.ARRAY_ANY + + +def test_build_segment_array_any_mixed_with_files(): + """Test building ArrayAnySegment from list with files and other types.""" + file = File( + id="test_file_id", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://test.example.com/test-file.png", + filename="test-file", + extension=".png", + mime_type="image/png", + size=1000, + ) + mixed_values = [file, "string", 42] + segment = variable_factory.build_segment(mixed_values) + assert isinstance(segment, ArrayAnySegment) + assert segment.value == mixed_values + assert segment.value_type == SegmentType.ARRAY_ANY + + +def test_build_segment_array_any_all_none_values(): + """Test building ArrayAnySegment from list with all None values.""" + none_values = [None, None, None] + segment = variable_factory.build_segment(none_values) + assert isinstance(segment, ArrayAnySegment) + assert segment.value == none_values + assert segment.value_type == SegmentType.ARRAY_ANY + + +def test_build_segment_array_file_properties(): + """Test ArrayFileSegment properties and methods.""" + file1 = File( + id="test_file_id_1", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://test.example.com/test-file1.png", + filename="test-file1", + extension=".png", + mime_type="image/png", + size=1000, + ) + file2 = File( + id="test_file_id_2", + tenant_id="test_tenant_id", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://test.example.com/test-file2.txt", + filename="test-file2", + extension=".txt", + mime_type="text/plain", + size=500, + ) + segment = variable_factory.build_segment([file1, file2]) + + # Test properties + assert segment.text == "" # ArrayFileSegment text property returns empty string + assert segment.log == "" # ArrayFileSegment log property returns empty string + assert segment.markdown == f"{file1.markdown}\n{file2.markdown}" + assert segment.to_object() == [file1, file2] + + +def test_build_segment_array_any_properties(): + """Test ArrayAnySegment properties and methods.""" + mixed_values = ["string", 42, None] + segment = variable_factory.build_segment(mixed_values) + + # Test properties + assert segment.text == str(mixed_values) + assert segment.log == str(mixed_values) + assert segment.markdown == "string\n42\nNone" + assert segment.to_object() == mixed_values + + +def test_build_segment_edge_cases(): + """Test edge cases for build_segment function.""" + # Test with complex nested structures + complex_structure = [{"nested": {"deep": [1, 2, 3]}}, [{"inner": "value"}], "mixed"] + segment = variable_factory.build_segment(complex_structure) + assert isinstance(segment, ArrayAnySegment) + assert segment.value == complex_structure + + # Test with single None in list + single_none = [None] + segment = variable_factory.build_segment(single_none) + assert isinstance(segment, ArrayAnySegment) + assert segment.value == single_none + + +def test_build_segment_file_array_with_different_file_types(): + """Test ArrayFileSegment with different file types.""" + image_file = File( + id="image_id", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://test.example.com/image.png", + filename="image", + extension=".png", + mime_type="image/png", + size=1000, + ) + + video_file = File( + id="video_id", + tenant_id="test_tenant_id", + type=FileType.VIDEO, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="video_relation_id", + filename="video", + extension=".mp4", + mime_type="video/mp4", + size=5000, + ) + + audio_file = File( + id="audio_id", + tenant_id="test_tenant_id", + type=FileType.AUDIO, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="audio_relation_id", + filename="audio", + extension=".mp3", + mime_type="audio/mpeg", + size=3000, + ) + + segment = variable_factory.build_segment([image_file, video_file, audio_file]) + assert isinstance(segment, ArrayFileSegment) + assert len(segment.value) == 3 + assert segment.value[0].type == FileType.IMAGE + assert segment.value[1].type == FileType.VIDEO + assert segment.value[2].type == FileType.AUDIO + + +@st.composite +def _generate_file(draw) -> File: + file_id = draw(st.text(min_size=1, max_size=10)) + tenant_id = draw(st.text(min_size=1, max_size=10)) + file_type, mime_type, extension = draw( + st.sampled_from( + [ + (FileType.IMAGE, "image/png", ".png"), + (FileType.VIDEO, "video/mp4", ".mp4"), + (FileType.DOCUMENT, "text/plain", ".txt"), + (FileType.AUDIO, "audio/mpeg", ".mp3"), + ] + ) + ) + filename = "test-file" + size = draw(st.integers(min_value=0, max_value=1024 * 1024)) + + transfer_method = draw(st.sampled_from(list(FileTransferMethod))) + if transfer_method == FileTransferMethod.REMOTE_URL: + url = "https://test.example.com/test-file" + file = File( + id="test_file_id", + tenant_id="test_tenant_id", + type=file_type, + transfer_method=transfer_method, + remote_url=url, + related_id=None, + filename=filename, + extension=extension, + mime_type=mime_type, + size=size, + ) + else: + relation_id = draw(st.uuids(version=4)) + + file = File( + id="test_file_id", + tenant_id="test_tenant_id", + type=file_type, + transfer_method=transfer_method, + related_id=str(relation_id), + filename=filename, + extension=extension, + mime_type=mime_type, + size=size, + ) + return file + + +def _scalar_value() -> st.SearchStrategy[int | float | str | File | None]: + return st.one_of( + st.none(), + st.integers(), + st.floats(), + st.text(), + _generate_file(), + ) + + +@given(_scalar_value()) +def test_build_segment_and_extract_values_for_scalar_types(value): + seg = variable_factory.build_segment(value) + # nan == nan yields false, so we need to use `math.isnan` to check `seg.value` here. + if isinstance(value, float) and math.isnan(value): + assert math.isnan(seg.value) + else: + assert seg.value == value + + +@given(st.lists(_scalar_value())) +def test_build_segment_and_extract_values_for_array_types(values): + seg = variable_factory.build_segment(values) + assert seg.value == values + + +def test_build_segment_type_for_scalar(): + @dataclass(frozen=True) + class TestCase: + value: int | float | str | File + expected_type: SegmentType + + file = File( + id="test_file_id", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://test.example.com/test-file.png", + filename="test-file", + extension=".png", + mime_type="image/png", + size=1000, + ) + cases = [ + TestCase(0, SegmentType.NUMBER), + TestCase(0.0, SegmentType.NUMBER), + TestCase("", SegmentType.STRING), + TestCase(file, SegmentType.FILE), + ] + + for idx, c in enumerate(cases, 1): + segment = variable_factory.build_segment(c.value) + assert segment.value_type == c.expected_type, f"test case {idx} failed." + + +class TestBuildSegmentWithType: + """Test cases for build_segment_with_type function.""" + + def test_string_type(self): + """Test building a string segment with correct type.""" + result = build_segment_with_type(SegmentType.STRING, "hello") + assert isinstance(result, StringSegment) + assert result.value == "hello" + assert result.value_type == SegmentType.STRING + + def test_number_type_integer(self): + """Test building a number segment with integer value.""" + result = build_segment_with_type(SegmentType.NUMBER, 42) + assert isinstance(result, IntegerSegment) + assert result.value == 42 + assert result.value_type == SegmentType.NUMBER + + def test_number_type_float(self): + """Test building a number segment with float value.""" + result = build_segment_with_type(SegmentType.NUMBER, 3.14) + assert isinstance(result, FloatSegment) + assert result.value == 3.14 + assert result.value_type == SegmentType.NUMBER + + def test_object_type(self): + """Test building an object segment with correct type.""" + test_obj = {"key": "value", "nested": {"inner": 123}} + result = build_segment_with_type(SegmentType.OBJECT, test_obj) + assert isinstance(result, ObjectSegment) + assert result.value == test_obj + assert result.value_type == SegmentType.OBJECT + + def test_file_type(self): + """Test building a file segment with correct type.""" + test_file = File( + id="test_file_id", + tenant_id="test_tenant_id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://test.example.com/test-file.png", + filename="test-file", + extension=".png", + mime_type="image/png", + size=1000, + storage_key="test_storage_key", + ) + result = build_segment_with_type(SegmentType.FILE, test_file) + assert isinstance(result, FileSegment) + assert result.value == test_file + assert result.value_type == SegmentType.FILE + + def test_none_type(self): + """Test building a none segment with None value.""" + result = build_segment_with_type(SegmentType.NONE, None) + assert isinstance(result, NoneSegment) + assert result.value is None + assert result.value_type == SegmentType.NONE + + def test_empty_array_string(self): + """Test building an empty array[string] segment.""" + result = build_segment_with_type(SegmentType.ARRAY_STRING, []) + assert isinstance(result, ArrayStringSegment) + assert result.value == [] + assert result.value_type == SegmentType.ARRAY_STRING + + def test_empty_array_number(self): + """Test building an empty array[number] segment.""" + result = build_segment_with_type(SegmentType.ARRAY_NUMBER, []) + assert isinstance(result, ArrayNumberSegment) + assert result.value == [] + assert result.value_type == SegmentType.ARRAY_NUMBER + + def test_empty_array_object(self): + """Test building an empty array[object] segment.""" + result = build_segment_with_type(SegmentType.ARRAY_OBJECT, []) + assert isinstance(result, ArrayObjectSegment) + assert result.value == [] + assert result.value_type == SegmentType.ARRAY_OBJECT + + def test_empty_array_file(self): + """Test building an empty array[file] segment.""" + result = build_segment_with_type(SegmentType.ARRAY_FILE, []) + assert isinstance(result, ArrayFileSegment) + assert result.value == [] + assert result.value_type == SegmentType.ARRAY_FILE + + def test_empty_array_any(self): + """Test building an empty array[any] segment.""" + result = build_segment_with_type(SegmentType.ARRAY_ANY, []) + assert isinstance(result, ArrayAnySegment) + assert result.value == [] + assert result.value_type == SegmentType.ARRAY_ANY + + def test_array_with_values(self): + """Test building array segments with actual values.""" + # Array of strings + result = build_segment_with_type(SegmentType.ARRAY_STRING, ["hello", "world"]) + assert isinstance(result, ArrayStringSegment) + assert result.value == ["hello", "world"] + assert result.value_type == SegmentType.ARRAY_STRING + + # Array of numbers + result = build_segment_with_type(SegmentType.ARRAY_NUMBER, [1, 2, 3.14]) + assert isinstance(result, ArrayNumberSegment) + assert result.value == [1, 2, 3.14] + assert result.value_type == SegmentType.ARRAY_NUMBER + + # Array of objects + result = build_segment_with_type(SegmentType.ARRAY_OBJECT, [{"a": 1}, {"b": 2}]) + assert isinstance(result, ArrayObjectSegment) + assert result.value == [{"a": 1}, {"b": 2}] + assert result.value_type == SegmentType.ARRAY_OBJECT + + def test_type_mismatch_string_to_number(self): + """Test type mismatch when expecting number but getting string.""" + with pytest.raises(TypeMismatchError) as exc_info: + build_segment_with_type(SegmentType.NUMBER, "not_a_number") + + assert "Type mismatch" in str(exc_info.value) + assert "expected number" in str(exc_info.value) + assert "str" in str(exc_info.value) + + def test_type_mismatch_number_to_string(self): + """Test type mismatch when expecting string but getting number.""" + with pytest.raises(TypeMismatchError) as exc_info: + build_segment_with_type(SegmentType.STRING, 123) + + assert "Type mismatch" in str(exc_info.value) + assert "expected string" in str(exc_info.value) + assert "int" in str(exc_info.value) + + def test_type_mismatch_none_to_string(self): + """Test type mismatch when expecting string but getting None.""" + with pytest.raises(TypeMismatchError) as exc_info: + build_segment_with_type(SegmentType.STRING, None) + + assert "Expected string, but got None" in str(exc_info.value) + + def test_type_mismatch_empty_list_to_non_array(self): + """Test type mismatch when expecting non-array type but getting empty list.""" + with pytest.raises(TypeMismatchError) as exc_info: + build_segment_with_type(SegmentType.STRING, []) + + assert "Expected string, but got empty list" in str(exc_info.value) + + def test_type_mismatch_object_to_array(self): + """Test type mismatch when expecting array but getting object.""" + with pytest.raises(TypeMismatchError) as exc_info: + build_segment_with_type(SegmentType.ARRAY_STRING, {"key": "value"}) + + assert "Type mismatch" in str(exc_info.value) + assert "expected array[string]" in str(exc_info.value) + + def test_compatible_number_types(self): + """Test that int and float are both compatible with NUMBER type.""" + # Integer should work + result_int = build_segment_with_type(SegmentType.NUMBER, 42) + assert isinstance(result_int, IntegerSegment) + assert result_int.value_type == SegmentType.NUMBER + + # Float should work + result_float = build_segment_with_type(SegmentType.NUMBER, 3.14) + assert isinstance(result_float, FloatSegment) + assert result_float.value_type == SegmentType.NUMBER + + @pytest.mark.parametrize( + ("segment_type", "value", "expected_class"), + [ + (SegmentType.STRING, "test", StringSegment), + (SegmentType.NUMBER, 42, IntegerSegment), + (SegmentType.NUMBER, 3.14, FloatSegment), + (SegmentType.OBJECT, {}, ObjectSegment), + (SegmentType.NONE, None, NoneSegment), + (SegmentType.ARRAY_STRING, [], ArrayStringSegment), + (SegmentType.ARRAY_NUMBER, [], ArrayNumberSegment), + (SegmentType.ARRAY_OBJECT, [], ArrayObjectSegment), + (SegmentType.ARRAY_ANY, [], ArrayAnySegment), + ], + ) + def test_parametrized_valid_types(self, segment_type, value, expected_class): + """Parametrized test for valid type combinations.""" + result = build_segment_with_type(segment_type, value) + assert isinstance(result, expected_class) + assert result.value == value + assert result.value_type == segment_type + + @pytest.mark.parametrize( + ("segment_type", "value"), + [ + (SegmentType.STRING, 123), + (SegmentType.NUMBER, "not_a_number"), + (SegmentType.OBJECT, "not_an_object"), + (SegmentType.ARRAY_STRING, "not_an_array"), + (SegmentType.STRING, None), + (SegmentType.NUMBER, None), + ], + ) + def test_parametrized_type_mismatches(self, segment_type, value): + """Parametrized test for type mismatches that should raise TypeMismatchError.""" + with pytest.raises(TypeMismatchError): + build_segment_with_type(segment_type, value) + + +# Test cases for ValueError scenarios in build_segment function +class TestBuildSegmentValueErrors: + """Test cases for ValueError scenarios in the build_segment function.""" + + @dataclass(frozen=True) + class ValueErrorTestCase: + """Test case data for ValueError scenarios.""" + + name: str + description: str + test_value: Any + + def _get_test_cases(self): + """Get all test cases for ValueError scenarios.""" + + # Define inline classes for complex test cases + class CustomType: + pass + + def unsupported_function(): + return "test" + + def gen(): + yield 1 + yield 2 + + return [ + self.ValueErrorTestCase( + name="unsupported_custom_type", + description="custom class that doesn't match any supported type", + test_value=CustomType(), + ), + self.ValueErrorTestCase( + name="unsupported_set_type", + description="set (unsupported collection type)", + test_value={1, 2, 3}, + ), + self.ValueErrorTestCase( + name="unsupported_tuple_type", description="tuple (unsupported type)", test_value=(1, 2, 3) + ), + self.ValueErrorTestCase( + name="unsupported_bytes_type", + description="bytes (unsupported type)", + test_value=b"hello world", + ), + self.ValueErrorTestCase( + name="unsupported_function_type", + description="function (unsupported type)", + test_value=unsupported_function, + ), + self.ValueErrorTestCase( + name="unsupported_module_type", description="module (unsupported type)", test_value=math + ), + self.ValueErrorTestCase( + name="array_with_unsupported_element_types", + description="array with unsupported element types", + test_value=[CustomType()], + ), + self.ValueErrorTestCase( + name="mixed_array_with_unsupported_types", + description="array with mix of supported and unsupported types", + test_value=["valid_string", 42, CustomType()], + ), + self.ValueErrorTestCase( + name="nested_unsupported_types", + description="nested structures containing unsupported types", + test_value=[{"valid": "data"}, CustomType()], + ), + self.ValueErrorTestCase( + name="complex_number_type", + description="complex number (unsupported type)", + test_value=3 + 4j, + ), + self.ValueErrorTestCase( + name="range_type", description="range object (unsupported type)", test_value=range(10) + ), + self.ValueErrorTestCase( + name="generator_type", + description="generator (unsupported type)", + test_value=gen(), + ), + self.ValueErrorTestCase( + name="exception_message_contains_value", + description="set to verify error message contains the actual unsupported value", + test_value={1, 2, 3}, + ), + self.ValueErrorTestCase( + name="array_with_mixed_unsupported_segment_types", + description="array processing with unsupported segment types in match", + test_value=[CustomType()], + ), + self.ValueErrorTestCase( + name="frozenset_type", + description="frozenset (unsupported type)", + test_value=frozenset([1, 2, 3]), + ), + self.ValueErrorTestCase( + name="memoryview_type", + description="memoryview (unsupported type)", + test_value=memoryview(b"hello"), + ), + self.ValueErrorTestCase( + name="slice_type", description="slice object (unsupported type)", test_value=slice(1, 10, 2) + ), + self.ValueErrorTestCase(name="type_object", description="type object (unsupported type)", test_value=type), + self.ValueErrorTestCase( + name="generic_object", description="generic object (unsupported type)", test_value=object() + ), + ] + + def test_build_segment_unsupported_types(self): + """Table-driven test for all ValueError scenarios in build_segment function.""" + test_cases = self._get_test_cases() + + for index, test_case in enumerate(test_cases, 1): + # Use test value directly + test_value = test_case.test_value + + with pytest.raises(ValueError) as exc_info: # noqa: PT012 + segment = variable_factory.build_segment(test_value) + pytest.fail(f"Test case {index} ({test_case.name}) should raise ValueError but not, result={segment}") + + error_message = str(exc_info.value) + assert "not supported value" in error_message, ( + f"Test case {index} ({test_case.name}): Expected 'not supported value' in error message, " + f"but got: {error_message}" + ) + + def test_build_segment_boolean_type_note(self): + """Note: Boolean values are actually handled as integers in Python, so they don't raise ValueError.""" + # Boolean values in Python are subclasses of int, so they get processed as integers + # True becomes IntegerSegment(value=1) and False becomes IntegerSegment(value=0) + true_segment = variable_factory.build_segment(True) + false_segment = variable_factory.build_segment(False) + + # Verify they are processed as integers, not as errors + assert true_segment.value == 1, "Test case 1 (boolean_true): Expected True to be processed as integer 1" + assert false_segment.value == 0, "Test case 2 (boolean_false): Expected False to be processed as integer 0" + assert true_segment.value_type == SegmentType.NUMBER + assert false_segment.value_type == SegmentType.NUMBER diff --git a/api/tests/unit_tests/libs/test_datetime_utils.py b/api/tests/unit_tests/libs/test_datetime_utils.py new file mode 100644 index 0000000000..e7781a5821 --- /dev/null +++ b/api/tests/unit_tests/libs/test_datetime_utils.py @@ -0,0 +1,20 @@ +import datetime + +from libs.datetime_utils import naive_utc_now + + +def test_naive_utc_now(monkeypatch): + tz_aware_utc_now = datetime.datetime.now(tz=datetime.UTC) + + def _now_func(tz: datetime.timezone | None) -> datetime.datetime: + return tz_aware_utc_now.astimezone(tz) + + monkeypatch.setattr("libs.datetime_utils._now_func", _now_func) + + naive_datetime = naive_utc_now() + + assert naive_datetime.tzinfo is None + assert naive_datetime.date() == tz_aware_utc_now.date() + naive_time = naive_datetime.time() + utc_time = tz_aware_utc_now.time() + assert naive_time == utc_time diff --git a/api/tests/unit_tests/models/__init__.py b/api/tests/unit_tests/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index b79e95c7ed..69163d48bd 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -1,10 +1,15 @@ +import dataclasses import json from unittest import mock from uuid import uuid4 from constants import HIDDEN_VALUE +from core.file.enums import FileTransferMethod, FileType +from core.file.models import File from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable -from models.workflow import Workflow, WorkflowNodeExecutionModel +from core.variables.segments import IntegerSegment, Segment +from factories.variable_factory import build_segment +from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable def test_environment_variables(): @@ -163,3 +168,147 @@ class TestWorkflowNodeExecution: original = {"a": 1, "b": ["2"]} node_exec.execution_metadata = json.dumps(original) assert node_exec.execution_metadata_dict == original + + +class TestIsSystemVariableEditable: + def test_is_system_variable(self): + cases = [ + ("query", True), + ("files", True), + ("dialogue_count", False), + ("conversation_id", False), + ("user_id", False), + ("app_id", False), + ("workflow_id", False), + ("workflow_run_id", False), + ] + for name, editable in cases: + assert editable == is_system_variable_editable(name) + + assert is_system_variable_editable("invalid_or_new_system_variable") == False + + +class TestWorkflowDraftVariableGetValue: + def test_get_value_by_case(self): + @dataclasses.dataclass + class TestCase: + name: str + value: Segment + + tenant_id = "test_tenant_id" + + test_file = File( + tenant_id=tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/example.jpg", + filename="example.jpg", + extension=".jpg", + mime_type="image/jpeg", + size=100, + ) + cases: list[TestCase] = [ + TestCase( + name="number/int", + value=build_segment(1), + ), + TestCase( + name="number/float", + value=build_segment(1.0), + ), + TestCase( + name="string", + value=build_segment("a"), + ), + TestCase( + name="object", + value=build_segment({}), + ), + TestCase( + name="file", + value=build_segment(test_file), + ), + TestCase( + name="array[any]", + value=build_segment([1, "a"]), + ), + TestCase( + name="array[string]", + value=build_segment(["a", "b"]), + ), + TestCase( + name="array[number]/int", + value=build_segment([1, 2]), + ), + TestCase( + name="array[number]/float", + value=build_segment([1.0, 2.0]), + ), + TestCase( + name="array[number]/mixed", + value=build_segment([1, 2.0]), + ), + TestCase( + name="array[object]", + value=build_segment([{}, {"a": 1}]), + ), + TestCase( + name="none", + value=build_segment(None), + ), + ] + + for idx, c in enumerate(cases, 1): + fail_msg = f"test case {c.name} failed, index={idx}" + draft_var = WorkflowDraftVariable() + draft_var.set_value(c.value) + assert c.value == draft_var.get_value(), fail_msg + + def test_file_variable_preserves_all_fields(self): + """Test that File type variables preserve all fields during encoding/decoding.""" + tenant_id = "test_tenant_id" + + # Create a File with specific field values + test_file = File( + id="test_file_id", + tenant_id=tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/test.jpg", + filename="test.jpg", + extension=".jpg", + mime_type="image/jpeg", + size=12345, # Specific size to test preservation + storage_key="test_storage_key", + ) + + # Create a FileSegment and WorkflowDraftVariable + file_segment = build_segment(test_file) + draft_var = WorkflowDraftVariable() + draft_var.set_value(file_segment) + + # Retrieve the value and verify all fields are preserved + retrieved_segment = draft_var.get_value() + retrieved_file = retrieved_segment.value + + # Verify all important fields are preserved + assert retrieved_file.id == test_file.id + assert retrieved_file.tenant_id == test_file.tenant_id + assert retrieved_file.type == test_file.type + assert retrieved_file.transfer_method == test_file.transfer_method + assert retrieved_file.remote_url == test_file.remote_url + assert retrieved_file.filename == test_file.filename + assert retrieved_file.extension == test_file.extension + assert retrieved_file.mime_type == test_file.mime_type + assert retrieved_file.size == test_file.size # This was the main issue being fixed + # Note: storage_key is not serialized in model_dump() so it won't be preserved + + # Verify the segments have the same type and the important fields match + assert file_segment.value_type == retrieved_segment.value_type + + def test_get_and_set_value(self): + draft_var = WorkflowDraftVariable() + int_var = IntegerSegment(value=1) + draft_var.set_value(int_var) + value = draft_var.get_value() + assert value == int_var diff --git a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py new file mode 100644 index 0000000000..f22500cfe4 --- /dev/null +++ b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py @@ -0,0 +1,1238 @@ +import datetime +import unittest + +# Mock redis_client before importing dataset_service +from unittest.mock import Mock, call, patch + +import pytest + +from models.dataset import Dataset, Document +from services.dataset_service import DocumentService +from services.errors.document import DocumentIndexingError +from tests.unit_tests.conftest import redis_mock + + +class TestDatasetServiceBatchUpdateDocumentStatus(unittest.TestCase): + """ + Comprehensive unit tests for DocumentService.batch_update_document_status method. + + This test suite covers all supported actions (enable, disable, archive, un_archive), + error conditions, edge cases, and validates proper interaction with Redis cache, + database operations, and async task triggers. + """ + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.add_document_to_index_task") + @patch("services.dataset_service.DocumentService.get_document") + @patch("services.dataset_service.datetime") + def test_batch_update_enable_documents_success(self, mock_datetime, mock_get_doc, mock_add_task, mock_db): + """ + Test successful enabling of disabled documents. + + Verifies that: + 1. Only disabled documents are processed (already enabled documents are skipped) + 2. Document attributes are updated correctly (enabled=True, metadata cleared) + 3. Database changes are committed for each document + 4. Redis cache keys are set to prevent concurrent indexing + 5. Async indexing task is triggered for each enabled document + 6. Timestamp fields are properly updated + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock disabled document + mock_disabled_doc_1 = Mock(spec=Document) + mock_disabled_doc_1.id = "doc-1" + mock_disabled_doc_1.name = "disabled_document.pdf" + mock_disabled_doc_1.enabled = False + mock_disabled_doc_1.archived = False + mock_disabled_doc_1.indexing_status = "completed" + mock_disabled_doc_1.completed_at = datetime.datetime.now() + + mock_disabled_doc_2 = Mock(spec=Document) + mock_disabled_doc_2.id = "doc-2" + mock_disabled_doc_2.name = "disabled_document.pdf" + mock_disabled_doc_2.enabled = False + mock_disabled_doc_2.archived = False + mock_disabled_doc_2.indexing_status = "completed" + mock_disabled_doc_2.completed_at = datetime.datetime.now() + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock document retrieval to return disabled documents + mock_get_doc.side_effect = [mock_disabled_doc_1, mock_disabled_doc_2] + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Call the method to enable documents + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-1", "doc-2"], action="enable", user=mock_user + ) + + # Verify document attributes were updated correctly + for mock_doc in [mock_disabled_doc_1, mock_disabled_doc_2]: + # Check that document was enabled + assert mock_doc.enabled == True + # Check that disable metadata was cleared + assert mock_doc.disabled_at is None + assert mock_doc.disabled_by is None + # Check that update timestamp was set + assert mock_doc.updated_at == current_time.replace(tzinfo=None) + + # Verify Redis cache operations + expected_cache_calls = [call("document_doc-1_indexing"), call("document_doc-2_indexing")] + redis_mock.get.assert_has_calls(expected_cache_calls) + + # Verify Redis cache was set to prevent concurrent indexing (600 seconds) + expected_setex_calls = [call("document_doc-1_indexing", 600, 1), call("document_doc-2_indexing", 600, 1)] + redis_mock.setex.assert_has_calls(expected_setex_calls) + + # Verify async tasks were triggered for indexing + expected_task_calls = [call("doc-1"), call("doc-2")] + mock_add_task.delay.assert_has_calls(expected_task_calls) + + # Verify database add counts (one add for one document) + assert mock_db.add.call_count == 2 + # Verify database commits (one commit for the batch operation) + assert mock_db.commit.call_count == 1 + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.remove_document_from_index_task") + @patch("services.dataset_service.DocumentService.get_document") + @patch("services.dataset_service.datetime") + def test_batch_update_disable_documents_success(self, mock_datetime, mock_get_doc, mock_remove_task, mock_db): + """ + Test successful disabling of enabled and completed documents. + + Verifies that: + 1. Only completed and enabled documents can be disabled + 2. Document attributes are updated correctly (enabled=False, disable metadata set) + 3. User ID is recorded in disabled_by field + 4. Database changes are committed for each document + 5. Redis cache keys are set to prevent concurrent indexing + 6. Async task is triggered to remove documents from index + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock enabled document + mock_enabled_doc_1 = Mock(spec=Document) + mock_enabled_doc_1.id = "doc-1" + mock_enabled_doc_1.name = "enabled_document.pdf" + mock_enabled_doc_1.enabled = True + mock_enabled_doc_1.archived = False + mock_enabled_doc_1.indexing_status = "completed" + mock_enabled_doc_1.completed_at = datetime.datetime.now() + + mock_enabled_doc_2 = Mock(spec=Document) + mock_enabled_doc_2.id = "doc-2" + mock_enabled_doc_2.name = "enabled_document.pdf" + mock_enabled_doc_2.enabled = True + mock_enabled_doc_2.archived = False + mock_enabled_doc_2.indexing_status = "completed" + mock_enabled_doc_2.completed_at = datetime.datetime.now() + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock document retrieval to return enabled, completed documents + mock_get_doc.side_effect = [mock_enabled_doc_1, mock_enabled_doc_2] + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Call the method to disable documents + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-1", "doc-2"], action="disable", user=mock_user + ) + + # Verify document attributes were updated correctly + for mock_doc in [mock_enabled_doc_1, mock_enabled_doc_2]: + # Check that document was disabled + assert mock_doc.enabled == False + # Check that disable metadata was set correctly + assert mock_doc.disabled_at == current_time.replace(tzinfo=None) + assert mock_doc.disabled_by == mock_user.id + # Check that update timestamp was set + assert mock_doc.updated_at == current_time.replace(tzinfo=None) + + # Verify Redis cache operations for indexing prevention + expected_setex_calls = [call("document_doc-1_indexing", 600, 1), call("document_doc-2_indexing", 600, 1)] + redis_mock.setex.assert_has_calls(expected_setex_calls) + + # Verify async tasks were triggered to remove from index + expected_task_calls = [call("doc-1"), call("doc-2")] + mock_remove_task.delay.assert_has_calls(expected_task_calls) + + # Verify database add counts (one add for one document) + assert mock_db.add.call_count == 2 + # Verify database commits (totally 1 for any batch operation) + assert mock_db.commit.call_count == 1 + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.remove_document_from_index_task") + @patch("services.dataset_service.DocumentService.get_document") + @patch("services.dataset_service.datetime") + def test_batch_update_archive_documents_success(self, mock_datetime, mock_get_doc, mock_remove_task, mock_db): + """ + Test successful archiving of unarchived documents. + + Verifies that: + 1. Only unarchived documents are processed (already archived are skipped) + 2. Document attributes are updated correctly (archived=True, archive metadata set) + 3. User ID is recorded in archived_by field + 4. If documents are enabled, they are removed from the index + 5. Redis cache keys are set only for enabled documents being archived + 6. Database changes are committed for each document + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create unarchived enabled document + unarchived_doc = Mock(spec=Document) + # Manually set attributes to ensure they can be modified + unarchived_doc.id = "doc-1" + unarchived_doc.name = "unarchived_document.pdf" + unarchived_doc.enabled = True + unarchived_doc.archived = False + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + mock_get_doc.return_value = unarchived_doc + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Call the method to archive documents + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-1"], action="archive", user=mock_user + ) + + # Verify document attributes were updated correctly + assert unarchived_doc.archived == True + assert unarchived_doc.archived_at == current_time.replace(tzinfo=None) + assert unarchived_doc.archived_by == mock_user.id + assert unarchived_doc.updated_at == current_time.replace(tzinfo=None) + + # Verify Redis cache was set (because document was enabled) + redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) + + # Verify async task was triggered to remove from index (because enabled) + mock_remove_task.delay.assert_called_once_with("doc-1") + + # Verify database add + mock_db.add.assert_called_once() + # Verify database commit + mock_db.commit.assert_called_once() + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.add_document_to_index_task") + @patch("services.dataset_service.DocumentService.get_document") + @patch("services.dataset_service.datetime") + def test_batch_update_unarchive_documents_success(self, mock_datetime, mock_get_doc, mock_add_task, mock_db): + """ + Test successful unarchiving of archived documents. + + Verifies that: + 1. Only archived documents are processed (already unarchived are skipped) + 2. Document attributes are updated correctly (archived=False, archive metadata cleared) + 3. If documents are enabled, they are added back to the index + 4. Redis cache keys are set only for enabled documents being unarchived + 5. Database changes are committed for each document + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock archived document + mock_archived_doc = Mock(spec=Document) + mock_archived_doc.id = "doc-3" + mock_archived_doc.name = "archived_document.pdf" + mock_archived_doc.enabled = True + mock_archived_doc.archived = True + mock_archived_doc.indexing_status = "completed" + mock_archived_doc.completed_at = datetime.datetime.now() + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + mock_get_doc.return_value = mock_archived_doc + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Call the method to unarchive documents + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-3"], action="un_archive", user=mock_user + ) + + # Verify document attributes were updated correctly + assert mock_archived_doc.archived == False + assert mock_archived_doc.archived_at is None + assert mock_archived_doc.archived_by is None + assert mock_archived_doc.updated_at == current_time.replace(tzinfo=None) + + # Verify Redis cache was set (because document is enabled) + redis_mock.setex.assert_called_once_with("document_doc-3_indexing", 600, 1) + + # Verify async task was triggered to add back to index (because enabled) + mock_add_task.delay.assert_called_once_with("doc-3") + + # Verify database add + mock_db.add.assert_called_once() + # Verify database commit + mock_db.commit.assert_called_once() + + @patch("services.dataset_service.DocumentService.get_document") + def test_batch_update_document_indexing_error_redis_cache_hit(self, mock_get_doc): + """ + Test that DocumentIndexingError is raised when documents are currently being indexed. + + Verifies that: + 1. The method checks Redis cache for active indexing operations + 2. DocumentIndexingError is raised if any document is being indexed + 3. Error message includes the document name for user feedback + 4. No further processing occurs when indexing is detected + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock enabled document + mock_enabled_doc = Mock(spec=Document) + mock_enabled_doc.id = "doc-1" + mock_enabled_doc.name = "enabled_document.pdf" + mock_enabled_doc.enabled = True + mock_enabled_doc.archived = False + mock_enabled_doc.indexing_status = "completed" + mock_enabled_doc.completed_at = datetime.datetime.now() + + # Set up mock to indicate document is being indexed + mock_get_doc.return_value = mock_enabled_doc + + # Reset module-level Redis mock, set to indexing status + redis_mock.reset_mock() + redis_mock.get.return_value = "indexing" + + # Verify that DocumentIndexingError is raised + with pytest.raises(DocumentIndexingError) as exc_info: + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-1"], action="enable", user=mock_user + ) + + # Verify error message contains document name + assert "enabled_document.pdf" in str(exc_info.value) + assert "is being indexed" in str(exc_info.value) + + # Verify Redis cache was checked + redis_mock.get.assert_called_once_with("document_doc-1_indexing") + + @patch("services.dataset_service.DocumentService.get_document") + def test_batch_update_disable_non_completed_document_error(self, mock_get_doc): + """ + Test that DocumentIndexingError is raised when trying to disable non-completed documents. + + Verifies that: + 1. Only completed documents can be disabled + 2. DocumentIndexingError is raised for non-completed documents + 3. Error message indicates the document is not completed + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create a document that's not completed + non_completed_doc = Mock(spec=Document) + # Manually set attributes to ensure they can be modified + non_completed_doc.id = "doc-1" + non_completed_doc.name = "indexing_document.pdf" + non_completed_doc.enabled = True + non_completed_doc.indexing_status = "indexing" # Not completed + non_completed_doc.completed_at = None # Not completed + + mock_get_doc.return_value = non_completed_doc + + # Verify that DocumentIndexingError is raised + with pytest.raises(DocumentIndexingError) as exc_info: + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-1"], action="disable", user=mock_user + ) + + # Verify error message indicates document is not completed + assert "is not completed" in str(exc_info.value) + + @patch("services.dataset_service.DocumentService.get_document") + def test_batch_update_empty_document_list(self, mock_get_doc): + """ + Test batch operations with an empty document ID list. + + Verifies that: + 1. The method handles empty input gracefully + 2. No document operations are performed with empty input + 3. No errors are raised with empty input + 4. Method returns early without processing + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Call method with empty document list + result = DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=[], action="enable", user=mock_user + ) + + # Verify no document lookups were performed + mock_get_doc.assert_not_called() + + # Verify method returns None (early return) + assert result is None + + @patch("services.dataset_service.DocumentService.get_document") + def test_batch_update_document_not_found_skipped(self, mock_get_doc): + """ + Test behavior when some documents don't exist in the database. + + Verifies that: + 1. Non-existent documents are gracefully skipped + 2. Processing continues for existing documents + 3. No errors are raised for missing document IDs + 4. Method completes successfully despite missing documents + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Mock document service to return None (document not found) + mock_get_doc.return_value = None + + # Call method with non-existent document ID + # This should not raise an error, just skip the missing document + try: + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["non-existent-doc"], action="enable", user=mock_user + ) + except Exception as e: + pytest.fail(f"Method should not raise exception for missing documents: {e}") + + # Verify document lookup was attempted + mock_get_doc.assert_called_once_with(mock_dataset.id, "non-existent-doc") + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DocumentService.get_document") + def test_batch_update_enable_already_enabled_document_skipped(self, mock_get_doc, mock_db): + """ + Test enabling documents that are already enabled. + + Verifies that: + 1. Already enabled documents are skipped (no unnecessary operations) + 2. No database commits occur for already enabled documents + 3. No Redis cache operations occur for skipped documents + 4. No async tasks are triggered for skipped documents + 5. Method completes successfully + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock enabled document + mock_enabled_doc = Mock(spec=Document) + mock_enabled_doc.id = "doc-1" + mock_enabled_doc.name = "enabled_document.pdf" + mock_enabled_doc.enabled = True + mock_enabled_doc.archived = False + mock_enabled_doc.indexing_status = "completed" + mock_enabled_doc.completed_at = datetime.datetime.now() + + # Mock document that is already enabled + mock_get_doc.return_value = mock_enabled_doc # Already enabled + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Attempt to enable already enabled document + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-1"], action="enable", user=mock_user + ) + + # Verify no database operations occurred (document was skipped) + mock_db.commit.assert_not_called() + + # Verify no Redis setex operations occurred (document was skipped) + redis_mock.setex.assert_not_called() + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DocumentService.get_document") + def test_batch_update_archive_already_archived_document_skipped(self, mock_get_doc, mock_db): + """ + Test archiving documents that are already archived. + + Verifies that: + 1. Already archived documents are skipped (no unnecessary operations) + 2. No database commits occur for already archived documents + 3. No Redis cache operations occur for skipped documents + 4. No async tasks are triggered for skipped documents + 5. Method completes successfully + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock archived document + mock_archived_doc = Mock(spec=Document) + mock_archived_doc.id = "doc-3" + mock_archived_doc.name = "archived_document.pdf" + mock_archived_doc.enabled = True + mock_archived_doc.archived = True + mock_archived_doc.indexing_status = "completed" + mock_archived_doc.completed_at = datetime.datetime.now() + + # Mock document that is already archived + mock_get_doc.return_value = mock_archived_doc # Already archived + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Attempt to archive already archived document + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-3"], action="archive", user=mock_user + ) + + # Verify no database operations occurred (document was skipped) + mock_db.commit.assert_not_called() + + # Verify no Redis setex operations occurred (document was skipped) + redis_mock.setex.assert_not_called() + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.add_document_to_index_task") + @patch("services.dataset_service.remove_document_from_index_task") + @patch("services.dataset_service.DocumentService.get_document") + @patch("services.dataset_service.datetime") + def test_batch_update_mixed_document_states_and_actions( + self, mock_datetime, mock_get_doc, mock_remove_task, mock_add_task, mock_db + ): + """ + Test batch operations on documents with mixed states and various scenarios. + + Verifies that: + 1. Each document is processed according to its current state + 2. Some documents may be skipped while others are processed + 3. Different async tasks are triggered based on document states + 4. Method handles mixed scenarios gracefully + 5. Database commits occur only for documents that were actually modified + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock documents with different states + mock_disabled_doc = Mock(spec=Document) + mock_disabled_doc.id = "doc-1" + mock_disabled_doc.name = "disabled_document.pdf" + mock_disabled_doc.enabled = False + mock_disabled_doc.archived = False + mock_disabled_doc.indexing_status = "completed" + mock_disabled_doc.completed_at = datetime.datetime.now() + + mock_enabled_doc = Mock(spec=Document) + mock_enabled_doc.id = "doc-2" + mock_enabled_doc.name = "enabled_document.pdf" + mock_enabled_doc.enabled = True + mock_enabled_doc.archived = False + mock_enabled_doc.indexing_status = "completed" + mock_enabled_doc.completed_at = datetime.datetime.now() + + mock_archived_doc = Mock(spec=Document) + mock_archived_doc.id = "doc-3" + mock_archived_doc.name = "archived_document.pdf" + mock_archived_doc.enabled = True + mock_archived_doc.archived = True + mock_archived_doc.indexing_status = "completed" + mock_archived_doc.completed_at = datetime.datetime.now() + + # Set up mixed document states + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mix of different document states + documents = [ + mock_disabled_doc, # Will be enabled + mock_enabled_doc, # Already enabled, will be skipped + mock_archived_doc, # Archived but enabled, will be skipped for enable action + ] + + mock_get_doc.side_effect = documents + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Perform enable operation on mixed state documents + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-1", "doc-2", "doc-3"], action="enable", user=mock_user + ) + + # Verify only the disabled document was processed + # (enabled and archived documents should be skipped for enable action) + + # Only one add should occur (for the disabled document that was enabled) + mock_db.add.assert_called_once() + # Only one commit should occur + mock_db.commit.assert_called_once() + + # Only one Redis setex should occur (for the document that was enabled) + redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) + + # Only one async task should be triggered (for the document that was enabled) + mock_add_task.delay.assert_called_once_with("doc-1") + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.remove_document_from_index_task") + @patch("services.dataset_service.DocumentService.get_document") + @patch("services.dataset_service.datetime") + def test_batch_update_archive_disabled_document_no_index_removal( + self, mock_datetime, mock_get_doc, mock_remove_task, mock_db + ): + """ + Test archiving disabled documents (should not trigger index removal). + + Verifies that: + 1. Disabled documents can be archived + 2. Archive metadata is set correctly + 3. No index removal task is triggered (because document is disabled) + 4. No Redis cache key is set (because document is disabled) + 5. Database commit still occurs + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Set up disabled, unarchived document + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + disabled_unarchived_doc = Mock(spec=Document) + # Manually set attributes to ensure they can be modified + disabled_unarchived_doc.id = "doc-1" + disabled_unarchived_doc.name = "disabled_document.pdf" + disabled_unarchived_doc.enabled = False # Disabled + disabled_unarchived_doc.archived = False # Not archived + + mock_get_doc.return_value = disabled_unarchived_doc + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Archive the disabled document + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-1"], action="archive", user=mock_user + ) + + # Verify document was archived + assert disabled_unarchived_doc.archived == True + assert disabled_unarchived_doc.archived_at == current_time.replace(tzinfo=None) + assert disabled_unarchived_doc.archived_by == mock_user.id + + # Verify no Redis cache was set (document is disabled) + redis_mock.setex.assert_not_called() + + # Verify no index removal task was triggered (document is disabled) + mock_remove_task.delay.assert_not_called() + + # Verify database add still occurred + mock_db.add.assert_called_once() + # Verify database commit still occurred + mock_db.commit.assert_called_once() + + @patch("services.dataset_service.DocumentService.get_document") + def test_batch_update_invalid_action_error(self, mock_get_doc): + """ + Test that ValueError is raised when an invalid action is provided. + + Verifies that: + 1. Invalid actions are rejected with ValueError + 2. Error message includes the invalid action name + 3. No document processing occurs with invalid actions + 4. Method fails fast on invalid input + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock document + mock_doc = Mock(spec=Document) + mock_doc.id = "doc-1" + mock_doc.name = "test_document.pdf" + mock_doc.enabled = True + mock_doc.archived = False + + mock_get_doc.return_value = mock_doc + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Test with invalid action + invalid_action = "invalid_action" + with pytest.raises(ValueError) as exc_info: + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-1"], action=invalid_action, user=mock_user + ) + + # Verify error message contains the invalid action + assert invalid_action in str(exc_info.value) + assert "Invalid action" in str(exc_info.value) + + # Verify no Redis operations occurred + redis_mock.setex.assert_not_called() + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.add_document_to_index_task") + @patch("services.dataset_service.DocumentService.get_document") + @patch("services.dataset_service.datetime") + def test_batch_update_disable_already_disabled_document_skipped( + self, mock_datetime, mock_get_doc, mock_add_task, mock_db + ): + """ + Test disabling documents that are already disabled. + + Verifies that: + 1. Already disabled documents are skipped (no unnecessary operations) + 2. No database commits occur for already disabled documents + 3. No Redis cache operations occur for skipped documents + 4. No async tasks are triggered for skipped documents + 5. Method completes successfully + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock disabled document + mock_disabled_doc = Mock(spec=Document) + mock_disabled_doc.id = "doc-1" + mock_disabled_doc.name = "disabled_document.pdf" + mock_disabled_doc.enabled = False # Already disabled + mock_disabled_doc.archived = False + mock_disabled_doc.indexing_status = "completed" + mock_disabled_doc.completed_at = datetime.datetime.now() + + # Mock document that is already disabled + mock_get_doc.return_value = mock_disabled_doc + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Attempt to disable already disabled document + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-1"], action="disable", user=mock_user + ) + + # Verify no database operations occurred (document was skipped) + mock_db.commit.assert_not_called() + + # Verify no Redis setex operations occurred (document was skipped) + redis_mock.setex.assert_not_called() + + # Verify no async tasks were triggered (document was skipped) + mock_add_task.delay.assert_not_called() + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.add_document_to_index_task") + @patch("services.dataset_service.DocumentService.get_document") + @patch("services.dataset_service.datetime") + def test_batch_update_unarchive_already_unarchived_document_skipped( + self, mock_datetime, mock_get_doc, mock_add_task, mock_db + ): + """ + Test unarchiving documents that are already unarchived. + + Verifies that: + 1. Already unarchived documents are skipped (no unnecessary operations) + 2. No database commits occur for already unarchived documents + 3. No Redis cache operations occur for skipped documents + 4. No async tasks are triggered for skipped documents + 5. Method completes successfully + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock unarchived document + mock_unarchived_doc = Mock(spec=Document) + mock_unarchived_doc.id = "doc-1" + mock_unarchived_doc.name = "unarchived_document.pdf" + mock_unarchived_doc.enabled = True + mock_unarchived_doc.archived = False # Already unarchived + mock_unarchived_doc.indexing_status = "completed" + mock_unarchived_doc.completed_at = datetime.datetime.now() + + # Mock document that is already unarchived + mock_get_doc.return_value = mock_unarchived_doc + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Attempt to unarchive already unarchived document + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-1"], action="un_archive", user=mock_user + ) + + # Verify no database operations occurred (document was skipped) + mock_db.commit.assert_not_called() + + # Verify no Redis setex operations occurred (document was skipped) + redis_mock.setex.assert_not_called() + + # Verify no async tasks were triggered (document was skipped) + mock_add_task.delay.assert_not_called() + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.add_document_to_index_task") + @patch("services.dataset_service.DocumentService.get_document") + @patch("services.dataset_service.datetime") + def test_batch_update_unarchive_disabled_document_no_index_addition( + self, mock_datetime, mock_get_doc, mock_add_task, mock_db + ): + """ + Test unarchiving disabled documents (should not trigger index addition). + + Verifies that: + 1. Disabled documents can be unarchived + 2. Unarchive metadata is cleared correctly + 3. No index addition task is triggered (because document is disabled) + 4. No Redis cache key is set (because document is disabled) + 5. Database commit still occurs + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock archived but disabled document + mock_archived_disabled_doc = Mock(spec=Document) + mock_archived_disabled_doc.id = "doc-1" + mock_archived_disabled_doc.name = "archived_disabled_document.pdf" + mock_archived_disabled_doc.enabled = False # Disabled + mock_archived_disabled_doc.archived = True # Archived + mock_archived_disabled_doc.indexing_status = "completed" + mock_archived_disabled_doc.completed_at = datetime.datetime.now() + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + mock_get_doc.return_value = mock_archived_disabled_doc + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Unarchive the disabled document + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-1"], action="un_archive", user=mock_user + ) + + # Verify document was unarchived + assert mock_archived_disabled_doc.archived == False + assert mock_archived_disabled_doc.archived_at is None + assert mock_archived_disabled_doc.archived_by is None + assert mock_archived_disabled_doc.updated_at == current_time.replace(tzinfo=None) + + # Verify no Redis cache was set (document is disabled) + redis_mock.setex.assert_not_called() + + # Verify no index addition task was triggered (document is disabled) + mock_add_task.delay.assert_not_called() + + # Verify database add still occurred + mock_db.add.assert_called_once() + # Verify database commit still occurred + mock_db.commit.assert_called_once() + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.add_document_to_index_task") + @patch("services.dataset_service.DocumentService.get_document") + @patch("services.dataset_service.datetime") + def test_batch_update_async_task_error_handling(self, mock_datetime, mock_get_doc, mock_add_task, mock_db): + """ + Test handling of async task errors during batch operations. + + Verifies that: + 1. Async task errors are properly handled + 2. Database operations complete successfully + 3. Redis cache operations complete successfully + 4. Method continues processing despite async task errors + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock disabled document + mock_disabled_doc = Mock(spec=Document) + mock_disabled_doc.id = "doc-1" + mock_disabled_doc.name = "disabled_document.pdf" + mock_disabled_doc.enabled = False + mock_disabled_doc.archived = False + mock_disabled_doc.indexing_status = "completed" + mock_disabled_doc.completed_at = datetime.datetime.now() + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + mock_get_doc.return_value = mock_disabled_doc + + # Mock async task to raise an exception + mock_add_task.delay.side_effect = Exception("Celery task error") + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Verify that async task error is propagated + with pytest.raises(Exception) as exc_info: + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-1"], action="enable", user=mock_user + ) + + # Verify error message + assert "Celery task error" in str(exc_info.value) + + # Verify database operations completed successfully + mock_db.add.assert_called_once() + mock_db.commit.assert_called_once() + + # Verify Redis cache was set successfully + redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) + + # Verify document was updated + assert mock_disabled_doc.enabled == True + assert mock_disabled_doc.disabled_at is None + assert mock_disabled_doc.disabled_by is None + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.add_document_to_index_task") + @patch("services.dataset_service.DocumentService.get_document") + @patch("services.dataset_service.datetime") + def test_batch_update_large_document_list_performance(self, mock_datetime, mock_get_doc, mock_add_task, mock_db): + """ + Test batch operations with a large number of documents. + + Verifies that: + 1. Method can handle large document lists efficiently + 2. All documents are processed correctly + 3. Database commits occur for each document + 4. Redis cache operations occur for each document + 5. Async tasks are triggered for each document + 6. Performance remains consistent with large inputs + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create large list of document IDs + document_ids = [f"doc-{i}" for i in range(1, 101)] # 100 documents + + # Create mock documents + mock_documents = [] + for i in range(1, 101): + mock_doc = Mock(spec=Document) + mock_doc.id = f"doc-{i}" + mock_doc.name = f"document_{i}.pdf" + mock_doc.enabled = False # All disabled, will be enabled + mock_doc.archived = False + mock_doc.indexing_status = "completed" + mock_doc.completed_at = datetime.datetime.now() + mock_documents.append(mock_doc) + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + mock_get_doc.side_effect = mock_documents + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Perform batch enable operation + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=document_ids, action="enable", user=mock_user + ) + + # Verify all documents were processed + assert mock_get_doc.call_count == 100 + + # Verify all documents were updated + for mock_doc in mock_documents: + assert mock_doc.enabled == True + assert mock_doc.disabled_at is None + assert mock_doc.disabled_by is None + assert mock_doc.updated_at == current_time.replace(tzinfo=None) + + # Verify database commits, one add for one document + assert mock_db.add.call_count == 100 + # Verify database commits, one commit for the batch operation + assert mock_db.commit.call_count == 1 + + # Verify Redis cache operations occurred for each document + assert redis_mock.setex.call_count == 100 + + # Verify async tasks were triggered for each document + assert mock_add_task.delay.call_count == 100 + + # Verify correct Redis cache keys were set + expected_redis_calls = [call(f"document_doc-{i}_indexing", 600, 1) for i in range(1, 101)] + redis_mock.setex.assert_has_calls(expected_redis_calls) + + # Verify correct async task calls + expected_task_calls = [call(f"doc-{i}") for i in range(1, 101)] + mock_add_task.delay.assert_has_calls(expected_task_calls) + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.add_document_to_index_task") + @patch("services.dataset_service.DocumentService.get_document") + @patch("services.dataset_service.datetime") + def test_batch_update_mixed_document_states_complex_scenario( + self, mock_datetime, mock_get_doc, mock_add_task, mock_db + ): + """ + Test complex batch operations with documents in various states. + + Verifies that: + 1. Each document is processed according to its current state + 2. Some documents are skipped while others are processed + 3. Different actions trigger different async tasks + 4. Database commits occur only for modified documents + 5. Redis cache operations occur only for relevant documents + 6. Method handles complex mixed scenarios correctly + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create documents in various states + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Document 1: Disabled, will be enabled + doc1 = Mock(spec=Document) + doc1.id = "doc-1" + doc1.name = "disabled_doc.pdf" + doc1.enabled = False + doc1.archived = False + doc1.indexing_status = "completed" + doc1.completed_at = datetime.datetime.now() + + # Document 2: Already enabled, will be skipped + doc2 = Mock(spec=Document) + doc2.id = "doc-2" + doc2.name = "enabled_doc.pdf" + doc2.enabled = True + doc2.archived = False + doc2.indexing_status = "completed" + doc2.completed_at = datetime.datetime.now() + + # Document 3: Enabled and completed, will be disabled + doc3 = Mock(spec=Document) + doc3.id = "doc-3" + doc3.name = "enabled_completed_doc.pdf" + doc3.enabled = True + doc3.archived = False + doc3.indexing_status = "completed" + doc3.completed_at = datetime.datetime.now() + + # Document 4: Unarchived, will be archived + doc4 = Mock(spec=Document) + doc4.id = "doc-4" + doc4.name = "unarchived_doc.pdf" + doc4.enabled = True + doc4.archived = False + doc4.indexing_status = "completed" + doc4.completed_at = datetime.datetime.now() + + # Document 5: Archived, will be unarchived + doc5 = Mock(spec=Document) + doc5.id = "doc-5" + doc5.name = "archived_doc.pdf" + doc5.enabled = True + doc5.archived = True + doc5.indexing_status = "completed" + doc5.completed_at = datetime.datetime.now() + + # Document 6: Non-existent, will be skipped + doc6 = None + + mock_get_doc.side_effect = [doc1, doc2, doc3, doc4, doc5, doc6] + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Perform mixed batch operations + DocumentService.batch_update_document_status( + dataset=mock_dataset, + document_ids=["doc-1", "doc-2", "doc-3", "doc-4", "doc-5", "doc-6"], + action="enable", # This will only affect doc1 and doc3 (doc3 will be enabled then disabled) + user=mock_user, + ) + + # Verify document 1 was enabled + assert doc1.enabled == True + assert doc1.disabled_at is None + assert doc1.disabled_by is None + + # Verify document 2 was skipped (already enabled) + assert doc2.enabled == True # No change + + # Verify document 3 was skipped (already enabled) + assert doc3.enabled == True + + # Verify document 4 was skipped (not affected by enable action) + assert doc4.enabled == True # No change + + # Verify document 5 was skipped (not affected by enable action) + assert doc5.enabled == True # No change + + # Verify database commits occurred for processed documents + # Only doc1 should be added (doc2, doc3, doc4, doc5 were skipped, doc6 doesn't exist) + assert mock_db.add.call_count == 1 + assert mock_db.commit.call_count == 1 + + # Verify Redis cache operations occurred for processed documents + # Only doc1 should have Redis operations + assert redis_mock.setex.call_count == 1 + + # Verify async tasks were triggered for processed documents + # Only doc1 should trigger tasks + assert mock_add_task.delay.call_count == 1 + + # Verify correct Redis cache keys were set + expected_redis_calls = [call("document_doc-1_indexing", 600, 1)] + redis_mock.setex.assert_has_calls(expected_redis_calls) + + # Verify correct async task calls + expected_task_calls = [call("doc-1")] + mock_add_task.delay.assert_has_calls(expected_task_calls) diff --git a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py new file mode 100644 index 0000000000..15e1b7569f --- /dev/null +++ b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py @@ -0,0 +1,826 @@ +import datetime + +# Mock redis_client before importing dataset_service +from unittest.mock import Mock, patch + +import pytest + +from core.model_runtime.entities.model_entities import ModelType +from models.dataset import Dataset, ExternalKnowledgeBindings +from services.dataset_service import DatasetService +from services.errors.account import NoPermissionError +from tests.unit_tests.conftest import redis_mock + + +class TestDatasetServiceUpdateDataset: + """ + Comprehensive unit tests for DatasetService.update_dataset method. + + This test suite covers all supported scenarios including: + - External dataset updates + - Internal dataset updates with different indexing techniques + - Embedding model updates + - Permission checks + - Error conditions and edge cases + """ + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_external_dataset_success(self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db): + """ + Test successful update of external dataset. + + Verifies that: + 1. External dataset attributes are updated correctly + 2. External knowledge binding is updated when values change + 3. Database changes are committed + 4. Permission check is performed + """ + from unittest.mock import Mock, patch + + from extensions.ext_database import db + + with patch.object(db.__class__, "engine", new_callable=Mock): + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "external" + mock_dataset.name = "old_name" + mock_dataset.description = "old_description" + mock_dataset.retrieval_model = "old_model" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock external knowledge binding + mock_binding = Mock(spec=ExternalKnowledgeBindings) + mock_binding.external_knowledge_id = "old_knowledge_id" + mock_binding.external_knowledge_api_id = "old_api_id" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Mock external knowledge binding query + with patch("services.dataset_service.Session") as mock_session: + mock_session_instance = Mock() + mock_session.return_value.__enter__.return_value = mock_session_instance + mock_session_instance.query.return_value.filter_by.return_value.first.return_value = mock_binding + + # Test data + update_data = { + "name": "new_name", + "description": "new_description", + "external_retrieval_model": "new_model", + "permission": "only_me", + "external_knowledge_id": "new_knowledge_id", + "external_knowledge_api_id": "new_api_id", + } + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify permission check was called + mock_check_permission.assert_called_once_with(mock_dataset, mock_user) + + # Verify dataset attributes were updated + assert mock_dataset.name == "new_name" + assert mock_dataset.description == "new_description" + assert mock_dataset.retrieval_model == "new_model" + + # Verify external knowledge binding was updated + assert mock_binding.external_knowledge_id == "new_knowledge_id" + assert mock_binding.external_knowledge_api_id == "new_api_id" + + # Verify database operations + mock_db.add.assert_any_call(mock_dataset) + mock_db.add.assert_any_call(mock_binding) + mock_db.commit.assert_called_once() + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + def test_update_external_dataset_missing_knowledge_id_error(self, mock_check_permission, mock_get_dataset, mock_db): + """ + Test error when external knowledge id is missing. + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.provider = "external" + + # Create mock user + mock_user = Mock() + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data without external_knowledge_id + update_data = {"name": "new_name", "external_knowledge_api_id": "api_id"} + + # Call the method and expect ValueError + with pytest.raises(ValueError) as context: + DatasetService.update_dataset("dataset-123", update_data, mock_user) + + assert "External knowledge id is required" in str(context.value) + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + def test_update_external_dataset_missing_api_id_error(self, mock_check_permission, mock_get_dataset, mock_db): + """ + Test error when external knowledge api id is missing. + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.provider = "external" + + # Create mock user + mock_user = Mock() + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data without external_knowledge_api_id + update_data = {"name": "new_name", "external_knowledge_id": "knowledge_id"} + + # Call the method and expect ValueError + with pytest.raises(ValueError) as context: + DatasetService.update_dataset("dataset-123", update_data, mock_user) + + assert "External knowledge api id is required" in str(context.value) + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.Session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + def test_update_external_dataset_binding_not_found_error( + self, mock_check_permission, mock_get_dataset, mock_session, mock_db + ): + from unittest.mock import Mock, patch + + from extensions.ext_database import db + + with patch.object(db.__class__, "engine", new_callable=Mock): + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.provider = "external" + + # Create mock user + mock_user = Mock() + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Mock external knowledge binding query returning None + mock_session_instance = Mock() + mock_session.return_value.__enter__.return_value = mock_session_instance + mock_session_instance.query.return_value.filter_by.return_value.first.return_value = None + + # Test data + update_data = { + "name": "new_name", + "external_knowledge_id": "knowledge_id", + "external_knowledge_api_id": "api_id", + } + + # Call the method and expect ValueError + with pytest.raises(ValueError) as context: + DatasetService.update_dataset("dataset-123", update_data, mock_user) + + assert "External knowledge binding not found" in str(context.value) + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_basic_success( + self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db + ): + """ + Test successful update of internal dataset with basic fields. + + Verifies that: + 1. Basic dataset attributes are updated correctly + 2. Filtered data excludes None values except description + 3. Timestamp fields are updated + 4. Database changes are committed + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.name = "old_name" + mock_dataset.description = "old_description" + mock_dataset.indexing_technique = "high_quality" + mock_dataset.retrieval_model = "old_model" + mock_dataset.embedding_model_provider = "openai" + mock_dataset.embedding_model = "text-embedding-ada-002" + mock_dataset.collection_binding_id = "binding-123" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data + update_data = { + "name": "new_name", + "description": "new_description", + "indexing_technique": "high_quality", + "retrieval_model": "new_model", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + } + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify permission check was called + mock_check_permission.assert_called_once_with(mock_dataset, mock_user) + + # Verify database update was called with correct filtered data + expected_filtered_data = { + "name": "new_name", + "description": "new_description", + "indexing_technique": "high_quality", + "retrieval_model": "new_model", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + "updated_by": mock_user.id, + "updated_at": current_time.replace(tzinfo=None), + } + + mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data) + mock_db.commit.assert_called_once() + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.deal_dataset_vector_index_task") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_indexing_technique_to_economy( + self, mock_datetime, mock_check_permission, mock_get_dataset, mock_task, mock_db + ): + """ + Test updating internal dataset indexing technique to economy. + + Verifies that: + 1. Embedding model fields are cleared when switching to economy + 2. Vector index task is triggered with 'remove' action + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "high_quality" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data + update_data = {"indexing_technique": "economy", "retrieval_model": "new_model"} + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify database update was called with embedding model fields cleared + expected_filtered_data = { + "indexing_technique": "economy", + "embedding_model": None, + "embedding_model_provider": None, + "collection_binding_id": None, + "retrieval_model": "new_model", + "updated_by": mock_user.id, + "updated_at": current_time.replace(tzinfo=None), + } + + mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data) + mock_db.commit.assert_called_once() + + # Verify vector index task was triggered + mock_task.delay.assert_called_once_with("dataset-123", "remove") + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + def test_update_dataset_not_found_error(self, mock_check_permission, mock_get_dataset, mock_db): + """ + Test error when dataset is not found. + """ + # Create mock user + mock_user = Mock() + + # Mock dataset retrieval returning None + mock_get_dataset.return_value = None + + # Test data + update_data = {"name": "new_name"} + + # Call the method and expect ValueError + with pytest.raises(ValueError) as context: + DatasetService.update_dataset("dataset-123", update_data, mock_user) + + assert "Dataset not found" in str(context.value) + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + def test_update_dataset_permission_error(self, mock_check_permission, mock_get_dataset, mock_db): + """ + Test error when user doesn't have permission. + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + + # Create mock user + mock_user = Mock() + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Mock permission check to raise error + mock_check_permission.side_effect = NoPermissionError("No permission") + + # Test data + update_data = {"name": "new_name"} + + # Call the method and expect NoPermissionError + with pytest.raises(NoPermissionError): + DatasetService.update_dataset("dataset-123", update_data, mock_user) + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_keep_existing_embedding_model( + self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db + ): + """ + Test updating internal dataset without changing embedding model. + + Verifies that: + 1. Existing embedding model settings are preserved when not provided in update + 2. No vector index task is triggered + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "high_quality" + mock_dataset.embedding_model_provider = "openai" + mock_dataset.embedding_model = "text-embedding-ada-002" + mock_dataset.collection_binding_id = "binding-123" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data without embedding model fields + update_data = {"name": "new_name", "indexing_technique": "high_quality", "retrieval_model": "new_model"} + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify database update was called with existing embedding model preserved + expected_filtered_data = { + "name": "new_name", + "indexing_technique": "high_quality", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + "collection_binding_id": "binding-123", + "retrieval_model": "new_model", + "updated_by": mock_user.id, + "updated_at": current_time.replace(tzinfo=None), + } + + mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data) + mock_db.commit.assert_called_once() + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding") + @patch("services.dataset_service.ModelManager") + @patch("services.dataset_service.deal_dataset_vector_index_task") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_indexing_technique_to_high_quality( + self, + mock_datetime, + mock_check_permission, + mock_get_dataset, + mock_task, + mock_model_manager, + mock_collection_binding, + mock_db, + ): + """ + Test updating internal dataset indexing technique to high_quality. + + Verifies that: + 1. Embedding model is validated and set + 2. Collection binding is retrieved + 3. Vector index task is triggered with 'add' action + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "economy" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Mock embedding model + mock_embedding_model = Mock() + mock_embedding_model.model = "text-embedding-ada-002" + mock_embedding_model.provider = "openai" + + # Mock collection binding + mock_collection_binding_instance = Mock() + mock_collection_binding_instance.id = "binding-456" + + # Mock model manager + mock_model_manager_instance = Mock() + mock_model_manager_instance.get_model_instance.return_value = mock_embedding_model + mock_model_manager.return_value = mock_model_manager_instance + + # Mock collection binding service + mock_collection_binding.return_value = mock_collection_binding_instance + + # Mock current_user + mock_current_user = Mock() + mock_current_user.current_tenant_id = "tenant-123" + + # Test data + update_data = { + "indexing_technique": "high_quality", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + "retrieval_model": "new_model", + } + + # Call the method with current_user mock + with patch("services.dataset_service.current_user", mock_current_user): + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify embedding model was validated + mock_model_manager_instance.get_model_instance.assert_called_once_with( + tenant_id=mock_current_user.current_tenant_id, + provider="openai", + model_type=ModelType.TEXT_EMBEDDING, + model="text-embedding-ada-002", + ) + + # Verify collection binding was retrieved + mock_collection_binding.assert_called_once_with("openai", "text-embedding-ada-002") + + # Verify database update was called with correct data + expected_filtered_data = { + "indexing_technique": "high_quality", + "embedding_model": "text-embedding-ada-002", + "embedding_model_provider": "openai", + "collection_binding_id": "binding-456", + "retrieval_model": "new_model", + "updated_by": mock_user.id, + "updated_at": current_time.replace(tzinfo=None), + } + + mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data) + mock_db.commit.assert_called_once() + + # Verify vector index task was triggered + mock_task.delay.assert_called_once_with("dataset-123", "add") + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + def test_update_internal_dataset_embedding_model_error(self, mock_check_permission, mock_get_dataset, mock_db): + """ + Test error when embedding model is not available. + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "economy" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Mock current_user + mock_current_user = Mock() + mock_current_user.current_tenant_id = "tenant-123" + + # Mock model manager to raise error + with ( + patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.current_user", mock_current_user), + ): + mock_model_manager_instance = Mock() + mock_model_manager_instance.get_model_instance.side_effect = Exception("No Embedding Model available") + mock_model_manager.return_value = mock_model_manager_instance + + # Test data + update_data = { + "indexing_technique": "high_quality", + "embedding_model_provider": "invalid_provider", + "embedding_model": "invalid_model", + "retrieval_model": "new_model", + } + + # Call the method and expect ValueError + with pytest.raises(Exception) as context: + DatasetService.update_dataset("dataset-123", update_data, mock_user) + + assert "No Embedding Model available".lower() in str(context.value).lower() + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_filter_none_values( + self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db + ): + """ + Test that None values are filtered out except for description field. + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "high_quality" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data with None values + update_data = { + "name": "new_name", + "description": None, # Should be included + "indexing_technique": "high_quality", + "retrieval_model": "new_model", + "embedding_model_provider": None, # Should be filtered out + "embedding_model": None, # Should be filtered out + } + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify database update was called with filtered data + expected_filtered_data = { + "name": "new_name", + "description": None, # Description should be included even if None + "indexing_technique": "high_quality", + "retrieval_model": "new_model", + "updated_by": mock_user.id, + "updated_at": mock_db.query.return_value.filter_by.return_value.update.call_args[0][0]["updated_at"], + } + + actual_call_args = mock_db.query.return_value.filter_by.return_value.update.call_args[0][0] + # Remove timestamp for comparison as it's dynamic + del actual_call_args["updated_at"] + del expected_filtered_data["updated_at"] + + del actual_call_args["collection_binding_id"] + del actual_call_args["embedding_model"] + del actual_call_args["embedding_model_provider"] + + assert actual_call_args == expected_filtered_data + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.deal_dataset_vector_index_task") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_embedding_model_update( + self, mock_datetime, mock_check_permission, mock_get_dataset, mock_task, mock_db + ): + """ + Test updating internal dataset with new embedding model. + + Verifies that: + 1. Embedding model is updated when different from current + 2. Vector index task is triggered with 'update' action + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "high_quality" + mock_dataset.embedding_model_provider = "openai" + mock_dataset.embedding_model = "text-embedding-ada-002" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Mock embedding model + mock_embedding_model = Mock() + mock_embedding_model.model = "text-embedding-3-small" + mock_embedding_model.provider = "openai" + + # Mock collection binding + mock_collection_binding_instance = Mock() + mock_collection_binding_instance.id = "binding-789" + + # Mock current_user + mock_current_user = Mock() + mock_current_user.current_tenant_id = "tenant-123" + + # Mock model manager + with patch("services.dataset_service.ModelManager") as mock_model_manager: + mock_model_manager_instance = Mock() + mock_model_manager_instance.get_model_instance.return_value = mock_embedding_model + mock_model_manager.return_value = mock_model_manager_instance + + # Mock collection binding service + with ( + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" + ) as mock_collection_binding, + patch("services.dataset_service.current_user", mock_current_user), + ): + mock_collection_binding.return_value = mock_collection_binding_instance + + # Test data + update_data = { + "indexing_technique": "high_quality", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-3-small", + "retrieval_model": "new_model", + } + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify embedding model was validated + mock_model_manager_instance.get_model_instance.assert_called_once_with( + tenant_id=mock_current_user.current_tenant_id, + provider="openai", + model_type=ModelType.TEXT_EMBEDDING, + model="text-embedding-3-small", + ) + + # Verify collection binding was retrieved + mock_collection_binding.assert_called_once_with("openai", "text-embedding-3-small") + + # Verify database update was called with correct data + expected_filtered_data = { + "indexing_technique": "high_quality", + "embedding_model": "text-embedding-3-small", + "embedding_model_provider": "openai", + "collection_binding_id": "binding-789", + "retrieval_model": "new_model", + "updated_by": mock_user.id, + "updated_at": current_time.replace(tzinfo=None), + } + + mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data) + mock_db.commit.assert_called_once() + + # Verify vector index task was triggered + mock_task.delay.assert_called_once_with("dataset-123", "update") + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_no_indexing_technique_change( + self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db + ): + """ + Test updating internal dataset without changing indexing technique. + + Verifies that: + 1. No vector index task is triggered when indexing technique doesn't change + 2. Database update is performed normally + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "high_quality" + mock_dataset.embedding_model_provider = "openai" + mock_dataset.embedding_model = "text-embedding-ada-002" + mock_dataset.collection_binding_id = "binding-123" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data with same indexing technique + update_data = { + "name": "new_name", + "indexing_technique": "high_quality", # Same as current + "retrieval_model": "new_model", + } + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify database update was called with correct data + expected_filtered_data = { + "name": "new_name", + "indexing_technique": "high_quality", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + "collection_binding_id": "binding-123", + "retrieval_model": "new_model", + "updated_by": mock_user.id, + "updated_at": current_time.replace(tzinfo=None), + } + + mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data) + mock_db.commit.assert_called_once() + + # Verify no vector index task was triggered + mock_db.query.return_value.filter_by.return_value.update.assert_called_once() + + # Verify return value + assert result == mock_dataset diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py new file mode 100644 index 0000000000..8ae69c8d64 --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py @@ -0,0 +1,222 @@ +import dataclasses +import secrets +from unittest import mock +from unittest.mock import Mock, patch + +import pytest +from sqlalchemy.orm import Session + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.variables.types import SegmentType +from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID +from core.workflow.nodes import NodeType +from models.enums import DraftVariableType +from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel +from services.workflow_draft_variable_service import ( + DraftVariableSaver, + VariableResetError, + WorkflowDraftVariableService, +) + + +class TestDraftVariableSaver: + def _get_test_app_id(self): + suffix = secrets.token_hex(6) + return f"test_app_id_{suffix}" + + def test__should_variable_be_visible(self): + mock_session = mock.MagicMock(spec=Session) + test_app_id = self._get_test_app_id() + saver = DraftVariableSaver( + session=mock_session, + app_id=test_app_id, + node_id="test_node_id", + node_type=NodeType.START, + invoke_from=InvokeFrom.DEBUGGER, + node_execution_id="test_execution_id", + ) + assert saver._should_variable_be_visible("123_456", NodeType.IF_ELSE, "output") == False + assert saver._should_variable_be_visible("123", NodeType.START, "output") == True + + def test__normalize_variable_for_start_node(self): + @dataclasses.dataclass(frozen=True) + class TestCase: + name: str + input_node_id: str + input_name: str + expected_node_id: str + expected_name: str + + _NODE_ID = "1747228642872" + cases = [ + TestCase( + name="name with `sys.` prefix should return the system node_id", + input_node_id=_NODE_ID, + input_name="sys.workflow_id", + expected_node_id=SYSTEM_VARIABLE_NODE_ID, + expected_name="workflow_id", + ), + TestCase( + name="name without `sys.` prefix should return the original input node_id", + input_node_id=_NODE_ID, + input_name="start_input", + expected_node_id=_NODE_ID, + expected_name="start_input", + ), + TestCase( + name="dummy_variable should return the original input node_id", + input_node_id=_NODE_ID, + input_name="__dummy__", + expected_node_id=_NODE_ID, + expected_name="__dummy__", + ), + ] + + mock_session = mock.MagicMock(spec=Session) + test_app_id = self._get_test_app_id() + saver = DraftVariableSaver( + session=mock_session, + app_id=test_app_id, + node_id=_NODE_ID, + node_type=NodeType.START, + invoke_from=InvokeFrom.DEBUGGER, + node_execution_id="test_execution_id", + ) + for idx, c in enumerate(cases, 1): + fail_msg = f"Test case {c.name} failed, index={idx}" + node_id, name = saver._normalize_variable_for_start_node(c.input_name) + assert node_id == c.expected_node_id, fail_msg + assert name == c.expected_name, fail_msg + + +class TestWorkflowDraftVariableService: + def _get_test_app_id(self): + suffix = secrets.token_hex(6) + return f"test_app_id_{suffix}" + + def test_reset_conversation_variable(self): + """Test resetting a conversation variable""" + mock_session = Mock(spec=Session) + service = WorkflowDraftVariableService(mock_session) + mock_workflow = Mock(spec=Workflow) + mock_workflow.app_id = self._get_test_app_id() + + # Create mock variable + mock_variable = Mock(spec=WorkflowDraftVariable) + mock_variable.get_variable_type.return_value = DraftVariableType.CONVERSATION + mock_variable.id = "var-id" + mock_variable.name = "test_var" + + # Mock the _reset_conv_var method + expected_result = Mock(spec=WorkflowDraftVariable) + with patch.object(service, "_reset_conv_var", return_value=expected_result) as mock_reset_conv: + result = service.reset_variable(mock_workflow, mock_variable) + + mock_reset_conv.assert_called_once_with(mock_workflow, mock_variable) + assert result == expected_result + + def test_reset_node_variable_with_no_execution_id(self): + """Test resetting a node variable with no execution ID - should delete variable""" + mock_session = Mock(spec=Session) + service = WorkflowDraftVariableService(mock_session) + mock_workflow = Mock(spec=Workflow) + mock_workflow.app_id = self._get_test_app_id() + + # Create mock variable with no execution ID + mock_variable = Mock(spec=WorkflowDraftVariable) + mock_variable.get_variable_type.return_value = DraftVariableType.NODE + mock_variable.node_execution_id = None + mock_variable.id = "var-id" + mock_variable.name = "test_var" + + result = service._reset_node_var(mock_workflow, mock_variable) + + # Should delete the variable and return None + mock_session.delete.assert_called_once_with(instance=mock_variable) + mock_session.flush.assert_called_once() + assert result is None + + def test_reset_node_variable_with_missing_execution_record(self): + """Test resetting a node variable when execution record doesn't exist""" + mock_session = Mock(spec=Session) + service = WorkflowDraftVariableService(mock_session) + mock_workflow = Mock(spec=Workflow) + mock_workflow.app_id = self._get_test_app_id() + + # Create mock variable with execution ID + mock_variable = Mock(spec=WorkflowDraftVariable) + mock_variable.get_variable_type.return_value = DraftVariableType.NODE + mock_variable.node_execution_id = "exec-id" + mock_variable.id = "var-id" + mock_variable.name = "test_var" + + # Mock session.scalars to return None (no execution record found) + mock_scalars = Mock() + mock_scalars.first.return_value = None + mock_session.scalars.return_value = mock_scalars + + result = service._reset_node_var(mock_workflow, mock_variable) + + # Should delete the variable and return None + mock_session.delete.assert_called_once_with(instance=mock_variable) + mock_session.flush.assert_called_once() + assert result is None + + def test_reset_node_variable_with_valid_execution_record(self): + """Test resetting a node variable with valid execution record - should restore from execution""" + mock_session = Mock(spec=Session) + service = WorkflowDraftVariableService(mock_session) + mock_workflow = Mock(spec=Workflow) + mock_workflow.app_id = self._get_test_app_id() + + # Create mock variable with execution ID + mock_variable = Mock(spec=WorkflowDraftVariable) + mock_variable.get_variable_type.return_value = DraftVariableType.NODE + mock_variable.node_execution_id = "exec-id" + mock_variable.id = "var-id" + mock_variable.name = "test_var" + mock_variable.node_id = "node-id" + mock_variable.value_type = SegmentType.STRING + + # Create mock execution record + mock_execution = Mock(spec=WorkflowNodeExecutionModel) + mock_execution.process_data_dict = {"test_var": "process_value"} + mock_execution.outputs_dict = {"test_var": "output_value"} + + # Mock session.scalars to return the execution record + mock_scalars = Mock() + mock_scalars.first.return_value = mock_execution + mock_session.scalars.return_value = mock_scalars + + # Mock workflow methods + mock_node_config = {"type": "test_node"} + mock_workflow.get_node_config_by_id.return_value = mock_node_config + mock_workflow.get_node_type_from_node_config.return_value = NodeType.LLM + + result = service._reset_node_var(mock_workflow, mock_variable) + + # Verify variable.set_value was called with the correct value + mock_variable.set_value.assert_called_once() + # Verify last_edited_at was reset + assert mock_variable.last_edited_at is None + # Verify session.flush was called + mock_session.flush.assert_called() + + # Should return the updated variable + assert result == mock_variable + + def test_reset_system_variable_raises_error(self): + """Test that resetting a system variable raises an error""" + mock_session = Mock(spec=Session) + service = WorkflowDraftVariableService(mock_session) + mock_workflow = Mock(spec=Workflow) + mock_workflow.app_id = self._get_test_app_id() + + mock_variable = Mock(spec=WorkflowDraftVariable) + mock_variable.get_variable_type.return_value = DraftVariableType.SYS # Not a valid enum value for this test + mock_variable.id = "var-id" + + with pytest.raises(VariableResetError) as exc_info: + service.reset_variable(mock_workflow, mock_variable) + assert "cannot reset system variable" in str(exc_info.value) + assert "variable_id=var-id" in str(exc_info.value) diff --git a/api/tests/unit_tests/utils/http_parser/test_oauth_convert_request_to_raw_data.py b/api/tests/unit_tests/utils/http_parser/test_oauth_convert_request_to_raw_data.py index f788a9756b..293ac253f5 100644 --- a/api/tests/unit_tests/utils/http_parser/test_oauth_convert_request_to_raw_data.py +++ b/api/tests/unit_tests/utils/http_parser/test_oauth_convert_request_to_raw_data.py @@ -1,3 +1,5 @@ +import json + from werkzeug import Request from werkzeug.datastructures import Headers from werkzeug.test import EnvironBuilder @@ -15,6 +17,59 @@ def test_oauth_convert_request_to_raw_data(): request = Request(builder.get_environ()) raw_request_bytes = oauth_handler._convert_request_to_raw_data(request) - assert b"GET /test HTTP/1.1" in raw_request_bytes + assert b"GET /test? HTTP/1.1" in raw_request_bytes assert b"Content-Type: application/json" in raw_request_bytes assert b"\r\n\r\n" in raw_request_bytes + + +def test_oauth_convert_request_to_raw_data_with_query_params(): + oauth_handler = OAuthHandler() + builder = EnvironBuilder( + method="GET", + path="/test", + query_string="code=abc123&state=xyz789", + headers=Headers({"Content-Type": "application/json"}), + ) + request = Request(builder.get_environ()) + raw_request_bytes = oauth_handler._convert_request_to_raw_data(request) + + assert b"GET /test?code=abc123&state=xyz789 HTTP/1.1" in raw_request_bytes + assert b"Content-Type: application/json" in raw_request_bytes + assert b"\r\n\r\n" in raw_request_bytes + + +def test_oauth_convert_request_to_raw_data_with_post_body(): + oauth_handler = OAuthHandler() + builder = EnvironBuilder( + method="POST", + path="/test", + data="param1=value1¶m2=value2", + headers=Headers({"Content-Type": "application/x-www-form-urlencoded"}), + ) + request = Request(builder.get_environ()) + raw_request_bytes = oauth_handler._convert_request_to_raw_data(request) + + assert b"POST /test? HTTP/1.1" in raw_request_bytes + assert b"Content-Type: application/x-www-form-urlencoded" in raw_request_bytes + assert b"\r\n\r\n" in raw_request_bytes + assert b"param1=value1¶m2=value2" in raw_request_bytes + + +def test_oauth_convert_request_to_raw_data_with_json_body(): + oauth_handler = OAuthHandler() + json_data = {"code": "abc123", "state": "xyz789", "grant_type": "authorization_code"} + builder = EnvironBuilder( + method="POST", + path="/test", + data=json.dumps(json_data), + headers=Headers({"Content-Type": "application/json"}), + ) + request = Request(builder.get_environ()) + raw_request_bytes = oauth_handler._convert_request_to_raw_data(request) + + assert b"POST /test? HTTP/1.1" in raw_request_bytes + assert b"Content-Type: application/json" in raw_request_bytes + assert b"\r\n\r\n" in raw_request_bytes + assert b'"code": "abc123"' in raw_request_bytes + assert b'"state": "xyz789"' in raw_request_bytes + assert b'"grant_type": "authorization_code"' in raw_request_bytes diff --git a/api/uv.lock b/api/uv.lock index a03929510e..66bfdcef36 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1284,6 +1284,7 @@ dev = [ { name = "coverage" }, { name = "dotenv-linter" }, { name = "faker" }, + { name = "hypothesis" }, { name = "lxml-stubs" }, { name = "mypy" }, { name = "pandas-stubs" }, @@ -1461,6 +1462,7 @@ dev = [ { name = "coverage", specifier = "~=7.2.4" }, { name = "dotenv-linter", specifier = "~=0.5.0" }, { name = "faker", specifier = "~=32.1.0" }, + { name = "hypothesis", specifier = ">=6.131.15" }, { name = "lxml-stubs", specifier = "~=0.5.1" }, { name = "mypy", specifier = "~=1.16.0" }, { name = "pandas-stubs", specifier = "~=2.2.3" }, @@ -2556,6 +2558,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/30/47d0bf6072f7252e6521f3447ccfa40b421b6824517f82854703d0f5a98b/hyperframe-6.1.0-py3-none-any.whl", hash = "sha256:b03380493a519fce58ea5af42e4a42317bf9bd425596f7a0835ffce80f1a42e5", size = 13007 }, ] +[[package]] +name = "hypothesis" +version = "6.131.15" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "sortedcontainers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f1/6f/1e291f80627f3e043b19a86f9f6b172b910e3575577917d3122a6558410d/hypothesis-6.131.15.tar.gz", hash = "sha256:11849998ae5eecc8c586c6c98e47677fcc02d97475065f62768cfffbcc15ef7a", size = 436596, upload_time = "2025-05-07T23:04:25.127Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b6/c7/78597bcec48e1585ea9029deb2bf2341516e90dd615a3db498413d68a4cc/hypothesis-6.131.15-py3-none-any.whl", hash = "sha256:e02e67e9f3cfd4cd4a67ccc03bf7431beccc1a084c5e90029799ddd36ce006d7", size = 501128, upload_time = "2025-05-07T23:04:22.045Z" }, +] + [[package]] name = "idna" version = "3.10" @@ -5241,6 +5256,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/37/c3/6eeb6034408dac0fa653d126c9204ade96b819c936e136c5e8a6897eee9c/socksio-1.0.0-py3-none-any.whl", hash = "sha256:95dc1f15f9b34e8d7b16f06d74b8ccf48f609af32ab33c608d08761c5dcbb1f3", size = 12763 }, ] +[[package]] +name = "sortedcontainers" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e8/c4/ba2f8066cceb6f23394729afe52f3bf7adec04bf9ed2c820b39e19299111/sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88", size = 30594, upload_time = "2021-05-16T22:03:42.897Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575, upload_time = "2021-05-16T22:03:41.177Z" }, +] + [[package]] name = "soupsieve" version = "2.7" diff --git a/docker/.env.example b/docker/.env.example index 5a2a426338..275da8e2e4 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -798,6 +798,9 @@ HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 HTTP_REQUEST_NODE_SSL_VERIFY=True +# Respect X-* headers to redirect clients +RESPECT_XFORWARD_HEADERS_ENABLED=false + # SSRF Proxy server HTTP URL SSRF_PROXY_HTTP_URL=http://ssrf_proxy:3128 # SSRF Proxy server HTTPS URL diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 5f13060658..335bda89f4 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -355,6 +355,7 @@ x-shared-env: &shared-api-worker-env HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True} + RESPECT_XFORWARD_HEADERS_ENABLED: ${RESPECT_XFORWARD_HEADERS_ENABLED:-false} SSRF_PROXY_HTTP_URL: ${SSRF_PROXY_HTTP_URL:-http://ssrf_proxy:3128} SSRF_PROXY_HTTPS_URL: ${SSRF_PROXY_HTTPS_URL:-http://ssrf_proxy:3128} LOOP_NODE_MAX_COUNT: ${LOOP_NODE_MAX_COUNT:-100} diff --git a/web/.env.example b/web/.env.example index 78b4f33e8c..c30064ffed 100644 --- a/web/.env.example +++ b/web/.env.example @@ -56,3 +56,5 @@ NEXT_PUBLIC_ENABLE_WEBSITE_JINAREADER=true NEXT_PUBLIC_ENABLE_WEBSITE_FIRECRAWL=true NEXT_PUBLIC_ENABLE_WEBSITE_WATERCRAWL=true +# The maximum number of tree node depth for workflow +NEXT_PUBLIC_MAX_TREE_DEPTH=50 diff --git a/web/app/(commonLayout)/datasets/Datasets.tsx b/web/app/(commonLayout)/datasets/Datasets.tsx index 28461e8617..2d4848e92e 100644 --- a/web/app/(commonLayout)/datasets/Datasets.tsx +++ b/web/app/(commonLayout)/datasets/Datasets.tsx @@ -81,7 +81,7 @@ const Datasets = ({ currentContainer?.removeEventListener('scroll', onScroll) onScroll.cancel() } - }, [onScroll]) + }, [containerRef, onScroll]) return (