diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh index 93ecac48f2..022f71bfb4 100755 --- a/.devcontainer/post_create_command.sh +++ b/.devcontainer/post_create_command.sh @@ -1,6 +1,6 @@ #!/bin/bash -npm add -g pnpm@10.11.1 +npm add -g pnpm@10.13.1 cd web && pnpm install pipx install uv @@ -12,3 +12,4 @@ echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f do echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc source /home/vscode/.bashrc + diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index a9580a3ba3..d684fe9144 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -8,13 +8,15 @@ body: label: Self Checks description: "To make sure we get to you in time, please check the following :)" options: + - label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542). + required: true - label: This is only for bug report, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general). required: true - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. required: true - - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). + - label: I confirm that I am using English to submit this report, otherwise it will be closed. required: true - - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" + - label: 【中文用户 & Non English User】请使用英语提交,否则会被关闭 :) required: true - label: "Please do not modify this template :) and fill in all the required fields." required: true @@ -42,20 +44,22 @@ body: attributes: label: Steps to reproduce description: We highly suggest including screenshots and a bug report log. Please use the right markdown syntax for code blocks. - placeholder: Having detailed steps helps us reproduce the bug. + placeholder: Having detailed steps helps us reproduce the bug. If you have logs, please use fenced code blocks (triple backticks ```) to format them. validations: required: true - type: textarea attributes: label: ✔️ Expected Behavior - placeholder: What were you expecting? + description: Describe what you expected to happen. + placeholder: What were you expecting? Please do not copy and paste the steps to reproduce here. validations: - required: false + required: true - type: textarea attributes: label: ❌ Actual Behavior - placeholder: What happened instead? + description: Describe what actually happened. + placeholder: What happened instead? Please do not copy and paste the steps to reproduce here. validations: required: false diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index 6877c382c4..c1666d24cf 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,5 +1,11 @@ blank_issues_enabled: false contact_links: + - name: "\U0001F4A1 Model Providers & Plugins" + url: "https://github.com/langgenius/dify-official-plugins/issues/new/choose" + about: Report issues with official plugins or model providers, you will need to provide the plugin version and other relevant details. + - name: "\U0001F4AC Documentation Issues" + url: "https://github.com/langgenius/dify-docs/issues/new" + about: Report issues with the documentation, such as typos, outdated information, or missing content. Please provide the specific section and details of the issue. - name: "\U0001F4E7 Discussions" url: https://github.com/langgenius/dify/discussions/categories/general - about: General discussions and request help from the community + about: General discussions and seek help from the community diff --git a/.github/ISSUE_TEMPLATE/document_issue.yml b/.github/ISSUE_TEMPLATE/document_issue.yml deleted file mode 100644 index 8fdbc0fb9a..0000000000 --- a/.github/ISSUE_TEMPLATE/document_issue.yml +++ /dev/null @@ -1,24 +0,0 @@ -name: "📚 Documentation Issue" -description: Report issues in our documentation -labels: - - documentation -body: - - type: checkboxes - attributes: - label: Self Checks - description: "To make sure we get to you in time, please check the following :)" - options: - - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. - required: true - - label: I confirm that I am using English to submit report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). - required: true - - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" - required: true - - label: "Please do not modify this template :) and fill in all the required fields." - required: true - - type: textarea - attributes: - label: Provide a description of requested docs changes - placeholder: Briefly describe which document needs to be corrected and why. - validations: - required: true diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index b1952c63a9..bd293e2442 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -8,11 +8,11 @@ body: label: Self Checks description: "To make sure we get to you in time, please check the following :)" options: + - label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542). + required: true - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. required: true - - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). - required: true - - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" + - label: I confirm that I am using English to submit this report, otherwise it will be closed. required: true - label: "Please do not modify this template :) and fill in all the required fields." required: true diff --git a/.github/ISSUE_TEMPLATE/translation_issue.yml b/.github/ISSUE_TEMPLATE/translation_issue.yml deleted file mode 100644 index f9c2dfb7d2..0000000000 --- a/.github/ISSUE_TEMPLATE/translation_issue.yml +++ /dev/null @@ -1,55 +0,0 @@ -name: "🌐 Localization/Translation issue" -description: Report incorrect translations. [please use English :)] -labels: - - translation -body: - - type: checkboxes - attributes: - label: Self Checks - description: "To make sure we get to you in time, please check the following :)" - options: - - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. - required: true - - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). - required: true - - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" - required: true - - label: "Please do not modify this template :) and fill in all the required fields." - required: true - - type: input - attributes: - label: Dify version - description: Hover over system tray icon or look at Settings - validations: - required: true - - type: input - attributes: - label: Utility with translation issue - placeholder: Some area - description: Please input here the utility with the translation issue - validations: - required: true - - type: input - attributes: - label: 🌐 Language affected - placeholder: "German" - validations: - required: true - - type: textarea - attributes: - label: ❌ Actual phrase(s) - placeholder: What is there? Please include a screenshot as that is extremely helpful. - validations: - required: true - - type: textarea - attributes: - label: ✔️ Expected phrase(s) - placeholder: What was expected? - validations: - required: true - - type: textarea - attributes: - label: ℹ Why is the current translation wrong - placeholder: Why do you feel this is incorrect? - validations: - required: true diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml new file mode 100644 index 0000000000..5e290c5d02 --- /dev/null +++ b/.github/workflows/autofix.yml @@ -0,0 +1,27 @@ +name: autofix.ci +on: + workflow_call: + pull_request: + push: + branches: [ "main" ] +permissions: + contents: read + +jobs: + autofix: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + # Use uv to ensure we have the same ruff version in CI and locally. + - uses: astral-sh/setup-uv@7edac99f961f18b581bbd960d59d049f04c0002f + - run: | + cd api + uv sync --dev + # Fix lint errors + uv run ruff check --fix-only . + # Format code + uv run ruff format . + + - uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27 + diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index cc735ae67c..b933560a5e 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -6,6 +6,7 @@ on: - "main" - "deploy/dev" - "deploy/enterprise" + - "build/**" tags: - "*" diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index b06ab9653e..a283f8d5ca 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -28,7 +28,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v45 + uses: tj-actions/changed-files@v46 with: files: | api/** @@ -75,7 +75,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v45 + uses: tj-actions/changed-files@v46 with: files: web/** @@ -113,7 +113,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v45 + uses: tj-actions/changed-files@v46 with: files: | docker/generate_docker_compose @@ -144,7 +144,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v45 + uses: tj-actions/changed-files@v46 with: files: | **.sh @@ -152,13 +152,15 @@ jobs: **.yml **Dockerfile dev/** + .editorconfig - name: Super-linter - uses: super-linter/super-linter/slim@v7 + uses: super-linter/super-linter/slim@v8 if: steps.changed-files.outputs.any_changed == 'true' env: BASH_SEVERITY: warning - DEFAULT_BRANCH: main + DEFAULT_BRANCH: origin/main + EDITORCONFIG_FILE_NAME: editorconfig-checker.json FILTER_REGEX_INCLUDE: pnpm-lock.yaml GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} IGNORE_GENERATED_FILES: true @@ -168,16 +170,6 @@ jobs: # FIXME: temporarily disabled until api-docker.yaml's run script is fixed for shellcheck # VALIDATE_GITHUB_ACTIONS: true VALIDATE_DOCKERFILE_HADOLINT: true + VALIDATE_EDITORCONFIG: true VALIDATE_XML: true VALIDATE_YAML: true - - - name: EditorConfig checks - uses: super-linter/super-linter/slim@v7 - env: - DEFAULT_BRANCH: main - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - IGNORE_GENERATED_FILES: true - IGNORE_GITIGNORED_FILES: true - # EditorConfig validation - VALIDATE_EDITORCONFIG: true - EDITORCONFIG_FILE_NAME: editorconfig-checker.json diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index 37cfdc5c1e..c3f8fdbaf6 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -27,7 +27,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v45 + uses: tj-actions/changed-files@v46 with: files: web/** diff --git a/README.md b/README.md index 1dc7e2dd98..2909e0e6cf 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ README in বাংলা

-Dify is an open-source LLM app development platform. Its intuitive interface combines agentic AI workflow, RAG pipeline, agent capabilities, model management, observability features, and more, allowing you to quickly move from prototype to production. +Dify is an open-source platform for developing LLM applications. Its intuitive interface combines agentic AI workflows, RAG pipelines, agent capabilities, model management, observability features, and more—allowing you to quickly move from prototype to production. ## Quick start @@ -65,7 +65,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
-The easiest way to start the Dify server is through [docker compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine: +The easiest way to start the Dify server is through [Docker Compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine: ```bash cd dify @@ -205,6 +205,7 @@ If you'd like to configure a highly-available setup, there are community-contrib - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Using Terraform for Deployment @@ -261,8 +262,8 @@ At the same time, please consider supporting Dify by sharing it on social media ## Security disclosure -To protect your privacy, please avoid posting security issues on GitHub. Instead, send your questions to security@dify.ai and we will provide you with a more detailed answer. +To protect your privacy, please avoid posting security issues on GitHub. Instead, report issues to security@dify.ai, and our team will respond with detailed answer. ## License -This repository is available under the [Dify Open Source License](LICENSE), which is essentially Apache 2.0 with a few additional restrictions. +This repository is licensed under the [Dify Open Source License](LICENSE), based on Apache 2.0 with additional conditions. diff --git a/README_AR.md b/README_AR.md index d93bca8646..e959ca0f78 100644 --- a/README_AR.md +++ b/README_AR.md @@ -188,6 +188,7 @@ docker compose up -d - [رسم بياني Helm من قبل @magicsong](https://github.com/magicsong/ai-charts) - [ملف YAML من قبل @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [ملف YAML من قبل @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 جديد! ملفات YAML (تدعم Dify v1.6.0) بواسطة @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### استخدام Terraform للتوزيع diff --git a/README_BN.md b/README_BN.md index 3efee3684d..29d7374ea5 100644 --- a/README_BN.md +++ b/README_BN.md @@ -204,6 +204,8 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 নতুন! YAML ফাইলসমূহ (Dify v1.6.0 সমর্থিত) তৈরি করেছেন @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) + #### টেরাফর্ম ব্যবহার করে ডিপ্লয় diff --git a/README_CN.md b/README_CN.md index 21e27429ec..486a368c09 100644 --- a/README_CN.md +++ b/README_CN.md @@ -194,9 +194,9 @@ docker compose up -d 如果您需要自定义配置,请参考 [.env.example](docker/.env.example) 文件中的注释,并更新 `.env` 文件中对应的值。此外,您可能需要根据您的具体部署环境和需求对 `docker-compose.yaml` 文件本身进行调整,例如更改镜像版本、端口映射或卷挂载。完成任何更改后,请重新运行 `docker-compose up -d`。您可以在[此处](https://docs.dify.ai/getting-started/install-self-hosted/environments)找到可用环境变量的完整列表。 -#### 使用 Helm Chart 部署 +#### 使用 Helm Chart 或 Kubernetes 资源清单(YAML)部署 -使用 [Helm Chart](https://helm.sh/) 版本或者 YAML 文件,可以在 Kubernetes 上部署 Dify。 +使用 [Helm Chart](https://helm.sh/) 版本或者 Kubernetes 资源清单(YAML),可以在 Kubernetes 上部署 Dify。 - [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) - [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) @@ -204,6 +204,10 @@ docker compose up -d - [YAML 文件 by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML 文件 (支持 Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) + + + #### 使用 Terraform 部署 使用 [terraform](https://www.terraform.io/) 一键将 Dify 部署到云平台 diff --git a/README_DE.md b/README_DE.md index 20c313035e..fce52c34c2 100644 --- a/README_DE.md +++ b/README_DE.md @@ -203,6 +203,7 @@ Falls Sie eine hochverfügbare Konfiguration einrichten möchten, gibt es von de - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Terraform für die Bereitstellung verwenden diff --git a/README_ES.md b/README_ES.md index e4b7df6686..6fd6dfcee8 100644 --- a/README_ES.md +++ b/README_ES.md @@ -203,6 +203,7 @@ Si desea configurar una configuración de alta disponibilidad, la comunidad prop - [Gráfico Helm por @magicsong](https://github.com/magicsong/ai-charts) - [Ficheros YAML por @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [Ficheros YAML por @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 ¡NUEVO! Archivos YAML (compatible con Dify v1.6.0) por @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Uso de Terraform para el despliegue diff --git a/README_FR.md b/README_FR.md index 8fd17fb7c3..b2209fb495 100644 --- a/README_FR.md +++ b/README_FR.md @@ -201,6 +201,7 @@ Si vous souhaitez configurer une configuration haute disponibilité, la communau - [Helm Chart par @magicsong](https://github.com/magicsong/ai-charts) - [Fichier YAML par @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [Fichier YAML par @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NOUVEAU ! Fichiers YAML (compatible avec Dify v1.6.0) par @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Utilisation de Terraform pour le déploiement diff --git a/README_JA.md b/README_JA.md index a3ee81e1f2..c658225f90 100644 --- a/README_JA.md +++ b/README_JA.md @@ -202,6 +202,7 @@ docker compose up -d - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 新着!YAML ファイル(Dify v1.6.0 対応)by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Terraformを使用したデプロイ diff --git a/README_KL.md b/README_KL.md index 3e5ab1a74f..bfafcc7407 100644 --- a/README_KL.md +++ b/README_KL.md @@ -201,6 +201,7 @@ If you'd like to configure a highly-available setup, there are community-contrib - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Terraform atorlugu pilersitsineq diff --git a/README_KR.md b/README_KR.md index 3c504900e1..282117e776 100644 --- a/README_KR.md +++ b/README_KR.md @@ -195,6 +195,7 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Terraform을 사용한 배포 diff --git a/README_PT.md b/README_PT.md index fb5f3662ae..576f6b48f7 100644 --- a/README_PT.md +++ b/README_PT.md @@ -200,6 +200,7 @@ Se deseja configurar uma instalação de alta disponibilidade, há [Helm Charts] - [Helm Chart de @magicsong](https://github.com/magicsong/ai-charts) - [Arquivo YAML por @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [Arquivo YAML por @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NOVO! Arquivos YAML (Compatível com Dify v1.6.0) por @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Usando o Terraform para Implantação diff --git a/README_SI.md b/README_SI.md index 647069a220..7ded001d86 100644 --- a/README_SI.md +++ b/README_SI.md @@ -201,6 +201,7 @@ Star Dify on GitHub and be instantly notified of new releases. - [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Uporaba Terraform za uvajanje diff --git a/README_TR.md b/README_TR.md index f52335646a..6e94e54fa0 100644 --- a/README_TR.md +++ b/README_TR.md @@ -194,6 +194,7 @@ Yüksek kullanılabilirliğe sahip bir kurulum yapılandırmak isterseniz, Dify' - [@BorisPolonsky tarafından Helm Chart](https://github.com/BorisPolonsky/dify-helm) - [@Winson-030 tarafından YAML dosyası](https://github.com/Winson-030/dify-kubernetes) - [@wyy-holding tarafından YAML dosyası](https://github.com/wyy-holding/dify-k8s) +- [🚀 YENİ! YAML dosyaları (Dify v1.6.0 destekli) @Zhoneym tarafından](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Dağıtım için Terraform Kullanımı diff --git a/README_TW.md b/README_TW.md index 71082ff893..6e3e22b5c1 100644 --- a/README_TW.md +++ b/README_TW.md @@ -197,12 +197,13 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify 如果您需要自定義配置,請參考我們的 [.env.example](docker/.env.example) 文件中的註釋,並在您的 `.env` 文件中更新相應的值。此外,根據您特定的部署環境和需求,您可能需要調整 `docker-compose.yaml` 文件本身,例如更改映像版本、端口映射或卷掛載。進行任何更改後,請重新運行 `docker-compose up -d`。您可以在[這裡](https://docs.dify.ai/getting-started/install-self-hosted/environments)找到可用環境變數的完整列表。 -如果您想配置高可用性設置,社區貢獻的 [Helm Charts](https://helm.sh/) 和 YAML 文件允許在 Kubernetes 上部署 Dify。 +如果您想配置高可用性設置,社區貢獻的 [Helm Charts](https://helm.sh/) 和 Kubernetes 資源清單(YAML)允許在 Kubernetes 上部署 Dify。 - [由 @LeoQuote 提供的 Helm Chart](https://github.com/douban/charts/tree/master/charts/dify) - [由 @BorisPolonsky 提供的 Helm Chart](https://github.com/BorisPolonsky/dify-helm) - [由 @Winson-030 提供的 YAML 文件](https://github.com/Winson-030/dify-kubernetes) - [由 @wyy-holding 提供的 YAML 文件](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML 檔案(支援 Dify v1.6.0)by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) ### 使用 Terraform 進行部署 diff --git a/README_VI.md b/README_VI.md index 58d8434fff..51314e6de5 100644 --- a/README_VI.md +++ b/README_VI.md @@ -196,6 +196,7 @@ Nếu bạn muốn cấu hình một cài đặt có độ sẵn sàng cao, có - [Helm Chart bởi @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) - [Tệp YAML bởi @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [Tệp YAML bởi @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 MỚI! Tệp YAML (Hỗ trợ Dify v1.6.0) bởi @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Sử dụng Terraform để Triển khai diff --git a/api/.env.example b/api/.env.example index baa9c382c8..80b1c12cd8 100644 --- a/api/.env.example +++ b/api/.env.example @@ -5,17 +5,22 @@ SECRET_KEY= # Console API base URL -CONSOLE_API_URL=http://127.0.0.1:5001 -CONSOLE_WEB_URL=http://127.0.0.1:3000 +CONSOLE_API_URL=http://localhost:5001 +CONSOLE_WEB_URL=http://localhost:3000 # Service API base URL -SERVICE_API_URL=http://127.0.0.1:5001 +SERVICE_API_URL=http://localhost:5001 # Web APP base URL -APP_WEB_URL=http://127.0.0.1:3000 +APP_WEB_URL=http://localhost:3000 # Files URL -FILES_URL=http://127.0.0.1:5001 +FILES_URL=http://localhost:5001 + +# INTERNAL_FILES_URL is used for plugin daemon communication within Docker network. +# Set this to the internal Docker service URL for proper plugin file access. +# Example: INTERNAL_FILES_URL=http://api:5001 +INTERNAL_FILES_URL=http://127.0.0.1:5001 # The time in seconds after the signature is rejected FILES_ACCESS_TIMEOUT=300 @@ -49,7 +54,7 @@ REDIS_CLUSTERS_PASSWORD= # celery configuration CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1 - +CELERY_BACKEND=redis # PostgreSQL database configuration DB_USERNAME=postgres DB_PASSWORD=difyai123456 @@ -133,12 +138,14 @@ SUPABASE_API_KEY=your-access-key SUPABASE_URL=your-server-url # CORS configuration -WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* -CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* +WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,* +CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,* # Vector database configuration -# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore, matrixone +# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`. VECTOR_STORE=weaviate +# Prefix used to create collection name in vector database +VECTOR_INDEX_NAME_PREFIX=Vector_index # Weaviate configuration WEAVIATE_ENDPOINT=http://localhost:8080 @@ -444,6 +451,19 @@ MAX_VARIABLE_SIZE=204800 # hybrid: Save new data to object storage, read from both object storage and RDBMS WORKFLOW_NODE_EXECUTION_STORAGE=rdbms +# Repository configuration +# Core workflow execution repository implementation +CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository + +# Core workflow node execution repository implementation +CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository + +# API workflow node execution repository implementation +API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository + +# API workflow run repository implementation +API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository + # App configuration APP_MAX_EXECUTION_TIME=1200 APP_MAX_ACTIVE_REQUESTS=0 @@ -451,6 +471,16 @@ APP_MAX_ACTIVE_REQUESTS=0 # Celery beat configuration CELERY_BEAT_SCHEDULER_TIME=1 +# Celery schedule tasks configuration +ENABLE_CLEAN_EMBEDDING_CACHE_TASK=false +ENABLE_CLEAN_UNUSED_DATASETS_TASK=false +ENABLE_CREATE_TIDB_SERVERLESS_TASK=false +ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false +ENABLE_CLEAN_MESSAGES=false +ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false +ENABLE_DATASETS_QUEUE_MONITOR=false +ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true + # Position configuration POSITION_TOOL_PINS= POSITION_TOOL_INCLUDES= @@ -477,6 +507,8 @@ ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id} # Reset password token expiry minutes RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 +CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES=5 +OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5 CREATE_TIDB_SERVICE_JOB_ENABLED=false @@ -487,6 +519,8 @@ LOGIN_LOCKOUT_DURATION=86400 # Enable OpenTelemetry ENABLE_OTEL=false +OTLP_TRACE_ENDPOINT= +OTLP_METRIC_ENDPOINT= OTLP_BASE_ENDPOINT=http://localhost:4318 OTLP_API_KEY= OTEL_EXPORTER_OTLP_PROTOCOL= diff --git a/api/Dockerfile b/api/Dockerfile index 7e4997507f..8c7a1717b9 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -47,6 +47,8 @@ RUN \ curl nodejs libgmp-dev libmpfr-dev libmpc-dev \ # For Security expat libldap-2.5-0 perl libsqlite3-0 zlib1g \ + # install fonts to support the use of tools like pypdfium2 + fonts-noto-cjk \ # install a package to improve the accuracy of guessing mime type and file extension media-types \ # install libmagic to support the use of python-magic guess MIMETYPE diff --git a/api/README.md b/api/README.md index 9308d5dc44..6ab923070e 100644 --- a/api/README.md +++ b/api/README.md @@ -74,7 +74,12 @@ 10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service. ```bash - uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion + uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin + ``` + + Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal: + ```bash + uv run celery -A app.celery beat ``` ## Testing diff --git a/api/commands.py b/api/commands.py index 86769847c1..c2e62ec261 100644 --- a/api/commands.py +++ b/api/commands.py @@ -2,19 +2,22 @@ import base64 import json import logging import secrets -from typing import Optional +from typing import Any, Optional import click from flask import current_app +from pydantic import TypeAdapter from sqlalchemy import select from werkzeug.exceptions import NotFound from configs import dify_config from constants.languages import languages +from core.plugin.entities.plugin import ToolProviderID from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_type import VectorType from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.models.document import Document +from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params from events.app_event import app_was_created from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -27,6 +30,7 @@ from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, D from models.dataset import Document as DatasetDocument from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation from models.provider import Provider, ProviderModel +from models.tools import ToolOAuthSystemClient from services.account_service import AccountService, RegisterService, TenantService from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs from services.plugin.data_migration import PluginDataMigration @@ -46,7 +50,7 @@ def reset_password(email, new_password, password_confirm): click.echo(click.style("Passwords do not match.", fg="red")) return - account = db.session.query(Account).filter(Account.email == email).one_or_none() + account = db.session.query(Account).where(Account.email == email).one_or_none() if not account: click.echo(click.style("Account not found for email: {}".format(email), fg="red")) @@ -85,7 +89,7 @@ def reset_email(email, new_email, email_confirm): click.echo(click.style("New emails do not match.", fg="red")) return - account = db.session.query(Account).filter(Account.email == email).one_or_none() + account = db.session.query(Account).where(Account.email == email).one_or_none() if not account: click.echo(click.style("Account not found for email: {}".format(email), fg="red")) @@ -132,8 +136,8 @@ def reset_encrypt_key_pair(): tenant.encrypt_public_key = generate_key_pair(tenant.id) - db.session.query(Provider).filter(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete() - db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete() + db.session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete() + db.session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete() db.session.commit() click.echo( @@ -168,7 +172,7 @@ def migrate_annotation_vector_database(): per_page = 50 apps = ( db.session.query(App) - .filter(App.status == "normal") + .where(App.status == "normal") .order_by(App.created_at.desc()) .limit(per_page) .offset((page - 1) * per_page) @@ -188,7 +192,7 @@ def migrate_annotation_vector_database(): try: click.echo("Creating app annotation index: {}".format(app.id)) app_annotation_setting = ( - db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first() + db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first() ) if not app_annotation_setting: @@ -198,13 +202,13 @@ def migrate_annotation_vector_database(): # get dataset_collection_binding info dataset_collection_binding = ( db.session.query(DatasetCollectionBinding) - .filter(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id) + .where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id) .first() ) if not dataset_collection_binding: click.echo("App annotation collection binding not found: {}".format(app.id)) continue - annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all() + annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app.id).all() dataset = Dataset( id=app.id, tenant_id=app.tenant_id, @@ -301,7 +305,7 @@ def migrate_knowledge_vector_database(): while True: try: stmt = ( - select(Dataset).filter(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc()) + select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc()) ) datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) @@ -328,7 +332,7 @@ def migrate_knowledge_vector_database(): if dataset.collection_binding_id: dataset_collection_binding = ( db.session.query(DatasetCollectionBinding) - .filter(DatasetCollectionBinding.id == dataset.collection_binding_id) + .where(DatasetCollectionBinding.id == dataset.collection_binding_id) .one_or_none() ) if dataset_collection_binding: @@ -363,7 +367,7 @@ def migrate_knowledge_vector_database(): dataset_documents = ( db.session.query(DatasetDocument) - .filter( + .where( DatasetDocument.dataset_id == dataset.id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, @@ -377,7 +381,7 @@ def migrate_knowledge_vector_database(): for dataset_document in dataset_documents: segments = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.document_id == dataset_document.id, DocumentSegment.status == "completed", DocumentSegment.enabled == True, @@ -464,7 +468,7 @@ def convert_to_agent_apps(): app_id = str(i.id) if app_id not in proceeded_app_ids: proceeded_app_ids.append(app_id) - app = db.session.query(App).filter(App.id == app_id).first() + app = db.session.query(App).where(App.id == app_id).first() if app is not None: apps.append(app) @@ -479,7 +483,7 @@ def convert_to_agent_apps(): db.session.commit() # update conversation mode to agent - db.session.query(Conversation).filter(Conversation.app_id == app.id).update( + db.session.query(Conversation).where(Conversation.app_id == app.id).update( {Conversation.mode: AppMode.AGENT_CHAT.value} ) @@ -556,7 +560,7 @@ def old_metadata_migration(): try: stmt = ( select(DatasetDocument) - .filter(DatasetDocument.doc_metadata.is_not(None)) + .where(DatasetDocument.doc_metadata.is_not(None)) .order_by(DatasetDocument.created_at.desc()) ) documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) @@ -574,7 +578,7 @@ def old_metadata_migration(): else: dataset_metadata = ( db.session.query(DatasetMetadata) - .filter(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key) + .where(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key) .first() ) if not dataset_metadata: @@ -598,7 +602,7 @@ def old_metadata_migration(): else: dataset_metadata_binding = ( db.session.query(DatasetMetadataBinding) # type: ignore - .filter( + .where( DatasetMetadataBinding.dataset_id == document.dataset_id, DatasetMetadataBinding.document_id == document.id, DatasetMetadataBinding.metadata_id == dataset_metadata.id, @@ -713,7 +717,7 @@ where sites.id is null limit 1000""" continue try: - app = db.session.query(App).filter(App.id == app_id).first() + app = db.session.query(App).where(App.id == app_id).first() if not app: print(f"App {app_id} not found") continue @@ -1155,3 +1159,49 @@ def remove_orphaned_files_on_storage(force: bool): click.echo(click.style(f"Removed {removed_files} orphaned files without errors.", fg="green")) else: click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow")) + + +@click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.") +@click.option("--provider", prompt=True, help="Provider name") +@click.option("--client-params", prompt=True, help="Client Params") +def setup_system_tool_oauth_client(provider, client_params): + """ + Setup system tool oauth client + """ + provider_id = ToolProviderID(provider) + provider_name = provider_id.provider_name + plugin_id = provider_id.plugin_id + + try: + # json validate + click.echo(click.style(f"Validating client params: {client_params}", fg="yellow")) + client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params) + click.echo(click.style("Client params validated successfully.", fg="green")) + + click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow")) + click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow")) + oauth_client_params = encrypt_system_oauth_params(client_params_dict) + click.echo(click.style("Client params encrypted successfully.", fg="green")) + except Exception as e: + click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) + return + + deleted_count = ( + db.session.query(ToolOAuthSystemClient) + .filter_by( + provider=provider_name, + plugin_id=plugin_id, + ) + .delete() + ) + if deleted_count > 0: + click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow")) + + oauth_client = ToolOAuthSystemClient( + provider=provider_name, + plugin_id=plugin_id, + encrypted_oauth_params=oauth_client_params, + ) + db.session.add(oauth_client) + db.session.commit() + click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green")) diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index df15b92c35..9f1646ea7d 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -31,6 +31,15 @@ class SecurityConfig(BaseSettings): description="Duration in minutes for which a password reset token remains valid", default=5, ) + CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES: PositiveInt = Field( + description="Duration in minutes for which a change email token remains valid", + default=5, + ) + + OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES: PositiveInt = Field( + description="Duration in minutes for which a owner transfer token remains valid", + default=5, + ) LOGIN_DISABLED: bool = Field( description="Whether to disable login checks", @@ -237,6 +246,13 @@ class FileAccessConfig(BaseSettings): default="", ) + INTERNAL_FILES_URL: str = Field( + description="Internal base URL for file access within Docker network," + " used for plugin daemon and internal service communication." + " Falls back to FILES_URL if not specified.", + default="", + ) + FILES_ACCESS_TIMEOUT: int = Field( description="Expiration time in seconds for file access URLs", default=300, @@ -530,6 +546,33 @@ class WorkflowNodeExecutionConfig(BaseSettings): ) +class RepositoryConfig(BaseSettings): + """ + Configuration for repository implementations + """ + + CORE_WORKFLOW_EXECUTION_REPOSITORY: str = Field( + description="Repository implementation for WorkflowExecution. Specify as a module path", + default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository", + ) + + CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field( + description="Repository implementation for WorkflowNodeExecution. Specify as a module path", + default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository", + ) + + API_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field( + description="Service-layer repository implementation for WorkflowNodeExecutionModel operations. " + "Specify as a module path", + default="repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository", + ) + + API_WORKFLOW_RUN_REPOSITORY: str = Field( + description="Service-layer repository implementation for WorkflowRun operations. Specify as a module path", + default="repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository", + ) + + class AuthConfig(BaseSettings): """ Configuration for authentication and OAuth @@ -580,6 +623,16 @@ class AuthConfig(BaseSettings): default=86400, ) + CHANGE_EMAIL_LOCKOUT_DURATION: PositiveInt = Field( + description="Time (in seconds) a user must wait before retrying change email after exceeding the rate limit.", + default=86400, + ) + + OWNER_TRANSFER_LOCKOUT_DURATION: PositiveInt = Field( + description="Time (in seconds) a user must wait before retrying owner transfer after exceeding the rate limit.", + default=86400, + ) + class ModerationConfig(BaseSettings): """ @@ -779,6 +832,41 @@ class CeleryBeatConfig(BaseSettings): ) +class CeleryScheduleTasksConfig(BaseSettings): + ENABLE_CLEAN_EMBEDDING_CACHE_TASK: bool = Field( + description="Enable clean embedding cache task", + default=False, + ) + ENABLE_CLEAN_UNUSED_DATASETS_TASK: bool = Field( + description="Enable clean unused datasets task", + default=False, + ) + ENABLE_CREATE_TIDB_SERVERLESS_TASK: bool = Field( + description="Enable create tidb service job task", + default=False, + ) + ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK: bool = Field( + description="Enable update tidb service job status task", + default=False, + ) + ENABLE_CLEAN_MESSAGES: bool = Field( + description="Enable clean messages task", + default=False, + ) + ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field( + description="Enable mail clean document notify task", + default=False, + ) + ENABLE_DATASETS_QUEUE_MONITOR: bool = Field( + description="Enable queue monitor task", + default=False, + ) + ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: bool = Field( + description="Enable check upgradable plugin task", + default=True, + ) + + class PositionConfig(BaseSettings): POSITION_PROVIDER_PINS: str = Field( description="Comma-separated list of pinned model providers", @@ -896,6 +984,7 @@ class FeatureConfig( MultiModalTransferConfig, PositionConfig, RagEtlConfig, + RepositoryConfig, SecurityConfig, ToolConfig, UpdateConfig, @@ -907,5 +996,6 @@ class FeatureConfig( # hosted services config HostedServiceConfig, CeleryBeatConfig, + CeleryScheduleTasksConfig, ): pass diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 427602676f..587ea55ca7 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -85,6 +85,11 @@ class VectorStoreConfig(BaseSettings): default=False, ) + VECTOR_INDEX_NAME_PREFIX: Optional[str] = Field( + description="Prefix used to create collection name in vector database", + default="Vector_index", + ) + class KeywordStoreConfig(BaseSettings): KEYWORD_STORE: str = Field( @@ -162,6 +167,11 @@ class DatabaseConfig(BaseSettings): default=3600, ) + SQLALCHEMY_POOL_USE_LIFO: bool = Field( + description="If True, SQLAlchemy will use last-in-first-out way to retrieve connections from pool.", + default=False, + ) + SQLALCHEMY_POOL_PRE_PING: bool = Field( description="If True, enables connection pool pre-ping feature to check connections.", default=False, @@ -199,13 +209,14 @@ class DatabaseConfig(BaseSettings): "pool_recycle": self.SQLALCHEMY_POOL_RECYCLE, "pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING, "connect_args": connect_args, + "pool_use_lifo": self.SQLALCHEMY_POOL_USE_LIFO, } class CeleryConfig(DatabaseConfig): CELERY_BACKEND: str = Field( description="Backend for Celery task results. Options: 'database', 'redis'.", - default="database", + default="redis", ) CELERY_BROKER_URL: Optional[str] = Field( diff --git a/api/configs/observability/otel/otel_config.py b/api/configs/observability/otel/otel_config.py index 1b88ddcfe6..7572a696ce 100644 --- a/api/configs/observability/otel/otel_config.py +++ b/api/configs/observability/otel/otel_config.py @@ -12,6 +12,16 @@ class OTelConfig(BaseSettings): default=False, ) + OTLP_TRACE_ENDPOINT: str = Field( + description="OTLP trace endpoint", + default="", + ) + + OTLP_METRIC_ENDPOINT: str = Field( + description="OTLP metric endpoint", + default="", + ) + OTLP_BASE_ENDPOINT: str = Field( description="OTLP base endpoint", default="http://localhost:4318", diff --git a/api/constants/__init__.py b/api/constants/__init__.py index a84de0a451..9e052320ac 100644 --- a/api/constants/__init__.py +++ b/api/constants/__init__.py @@ -1,6 +1,7 @@ from configs import dify_config HIDDEN_VALUE = "[__HIDDEN__]" +UNKNOWN_VALUE = "[__UNKNOWN__]" UUID_NIL = "00000000-0000-0000-0000-000000000000" DEFAULT_FILE_NUMBER_LIMITS = 3 diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index dbdcdc46ce..e25f92399c 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -56,6 +56,7 @@ from .app import ( conversation, conversation_variables, generator, + mcp_server, message, model_config, ops_trace, diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index f5257fae79..8a55197fb6 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -56,7 +56,7 @@ class InsertExploreAppListApi(Resource): parser.add_argument("position", type=int, required=True, nullable=False, location="json") args = parser.parse_args() - app = db.session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none() + app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none() if not app: raise NotFound(f"App '{args['app_id']}' is not found") @@ -74,7 +74,7 @@ class InsertExploreAppListApi(Resource): with Session(db.engine) as session: recommended_app = session.execute( - select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]) + select(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]) ).scalar_one_or_none() if not recommended_app: @@ -117,21 +117,21 @@ class InsertExploreAppApi(Resource): def delete(self, app_id): with Session(db.engine) as session: recommended_app = session.execute( - select(RecommendedApp).filter(RecommendedApp.app_id == str(app_id)) + select(RecommendedApp).where(RecommendedApp.app_id == str(app_id)) ).scalar_one_or_none() if not recommended_app: return {"result": "success"}, 204 with Session(db.engine) as session: - app = session.execute(select(App).filter(App.id == recommended_app.app_id)).scalar_one_or_none() + app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none() if app: app.is_public = False with Session(db.engine) as session: installed_apps = session.execute( - select(InstalledApp).filter( + select(InstalledApp).where( InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id, ) diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 47c93a15c6..d7500c415c 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -61,7 +61,7 @@ class BaseApiKeyListResource(Resource): _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) keys = ( db.session.query(ApiToken) - .filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) + .where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) .all() ) return {"items": keys} @@ -76,7 +76,7 @@ class BaseApiKeyListResource(Resource): current_key_count = ( db.session.query(ApiToken) - .filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) + .where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) .count() ) @@ -117,7 +117,7 @@ class BaseApiKeyResource(Resource): key = ( db.session.query(ApiToken) - .filter( + .where( getattr(ApiToken, self.resource_id_field) == resource_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id, @@ -128,7 +128,7 @@ class BaseApiKeyResource(Resource): if key is None: flask_restful.abort(404, message="API key not found") - db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() + db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.commit() return {"result": "success"}, 204 diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 860166a61a..9fe32dde6d 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -151,6 +151,7 @@ class AppApi(Resource): parser.add_argument("icon", type=str, location="json") parser.add_argument("icon_background", type=str, location="json") parser.add_argument("use_icon_as_answer_icon", type=bool, location="json") + parser.add_argument("max_active_requests", type=int, location="json") args = parser.parse_args() app_service = AppService() diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 70d6216497..b5b6d1f75b 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -1,4 +1,4 @@ -from datetime import UTC, datetime +from datetime import datetime import pytz # pip install pytz from flask_login import current_user @@ -19,6 +19,7 @@ from fields.conversation_fields import ( conversation_pagination_fields, conversation_with_summary_pagination_fields, ) +from libs.datetime_utils import naive_utc_now from libs.helper import DatetimeString from libs.login import login_required from models import Conversation, EndUser, Message, MessageAnnotation @@ -48,7 +49,7 @@ class CompletionConversationApi(Resource): query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion") if args["keyword"]: - query = query.join(Message, Message.conversation_id == Conversation.id).filter( + query = query.join(Message, Message.conversation_id == Conversation.id).where( or_( Message.query.ilike("%{}%".format(args["keyword"])), Message.answer.ilike("%{}%".format(args["keyword"])), @@ -120,7 +121,7 @@ class CompletionConversationDetailApi(Resource): conversation = ( db.session.query(Conversation) - .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) + .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) .first() ) @@ -180,7 +181,7 @@ class ChatConversationApi(Resource): Message.conversation_id == Conversation.id, ) .join(subquery, subquery.c.conversation_id == Conversation.id) - .filter( + .where( or_( Message.query.ilike(keyword_filter), Message.answer.ilike(keyword_filter), @@ -285,7 +286,7 @@ class ChatConversationDetailApi(Resource): conversation = ( db.session.query(Conversation) - .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) + .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) .first() ) @@ -307,7 +308,7 @@ api.add_resource(ChatConversationDetailApi, "/apps//chat-conversati def _get_conversation(app_model, conversation_id): conversation = ( db.session.query(Conversation) - .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) + .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) .first() ) @@ -315,7 +316,7 @@ def _get_conversation(app_model, conversation_id): raise NotFound("Conversation Not Exists.") if not conversation.read_at: - conversation.read_at = datetime.now(UTC).replace(tzinfo=None) + conversation.read_at = naive_utc_now() conversation.read_account_id = current_user.id db.session.commit() diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py new file mode 100644 index 0000000000..2344fd5acb --- /dev/null +++ b/api/controllers/console/app/mcp_server.py @@ -0,0 +1,119 @@ +import json +from enum import StrEnum + +from flask_login import current_user +from flask_restful import Resource, marshal_with, reqparse +from werkzeug.exceptions import NotFound + +from controllers.console import api +from controllers.console.app.wraps import get_app_model +from controllers.console.wraps import account_initialization_required, setup_required +from extensions.ext_database import db +from fields.app_fields import app_server_fields +from libs.login import login_required +from models.model import AppMCPServer + + +class AppMCPServerStatus(StrEnum): + ACTIVE = "active" + INACTIVE = "inactive" + + +class AppMCPServerController(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model + @marshal_with(app_server_fields) + def get(self, app_model): + server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first() + return server + + @setup_required + @login_required + @account_initialization_required + @get_app_model + @marshal_with(app_server_fields) + def post(self, app_model): + if not current_user.is_editor: + raise NotFound() + parser = reqparse.RequestParser() + parser.add_argument("description", type=str, required=False, location="json") + parser.add_argument("parameters", type=dict, required=True, location="json") + args = parser.parse_args() + + description = args.get("description") + if not description: + description = app_model.description or "" + + server = AppMCPServer( + name=app_model.name, + description=description, + parameters=json.dumps(args["parameters"], ensure_ascii=False), + status=AppMCPServerStatus.ACTIVE, + app_id=app_model.id, + tenant_id=current_user.current_tenant_id, + server_code=AppMCPServer.generate_server_code(16), + ) + db.session.add(server) + db.session.commit() + return server + + @setup_required + @login_required + @account_initialization_required + @get_app_model + @marshal_with(app_server_fields) + def put(self, app_model): + if not current_user.is_editor: + raise NotFound() + parser = reqparse.RequestParser() + parser.add_argument("id", type=str, required=True, location="json") + parser.add_argument("description", type=str, required=False, location="json") + parser.add_argument("parameters", type=dict, required=True, location="json") + parser.add_argument("status", type=str, required=False, location="json") + args = parser.parse_args() + server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first() + if not server: + raise NotFound() + + description = args.get("description") + if description is None: + pass + elif not description: + server.description = app_model.description or "" + else: + server.description = description + + server.parameters = json.dumps(args["parameters"], ensure_ascii=False) + if args["status"]: + if args["status"] not in [status.value for status in AppMCPServerStatus]: + raise ValueError("Invalid status") + server.status = args["status"] + db.session.commit() + return server + + +class AppMCPServerRefreshController(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(app_server_fields) + def get(self, server_id): + if not current_user.is_editor: + raise NotFound() + server = ( + db.session.query(AppMCPServer) + .where(AppMCPServer.id == server_id) + .where(AppMCPServer.tenant_id == current_user.current_tenant_id) + .first() + ) + if not server: + raise NotFound() + server.server_code = AppMCPServer.generate_server_code(16) + db.session.commit() + return server + + +api.add_resource(AppMCPServerController, "/apps//server") +api.add_resource(AppMCPServerRefreshController, "/apps//server/refresh") diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index b7a4c31a15..5e79e8dece 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -5,6 +5,7 @@ from flask_restful import Resource, fields, marshal_with, reqparse from flask_restful.inputs import int_range from werkzeug.exceptions import Forbidden, InternalServerError, NotFound +import services from controllers.console import api from controllers.console.app.error import ( CompletionRequestError, @@ -27,7 +28,7 @@ from fields.conversation_fields import annotation_fields, message_detail_fields from libs.helper import uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.login import login_required -from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback +from models.model import AppMode, Conversation, Message, MessageAnnotation from services.annotation_service import AppAnnotationService from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError @@ -55,7 +56,7 @@ class ChatMessageListApi(Resource): conversation = ( db.session.query(Conversation) - .filter(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id) + .where(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id) .first() ) @@ -65,7 +66,7 @@ class ChatMessageListApi(Resource): if args["first_id"]: first_message = ( db.session.query(Message) - .filter(Message.conversation_id == conversation.id, Message.id == args["first_id"]) + .where(Message.conversation_id == conversation.id, Message.id == args["first_id"]) .first() ) @@ -74,7 +75,7 @@ class ChatMessageListApi(Resource): history_messages = ( db.session.query(Message) - .filter( + .where( Message.conversation_id == conversation.id, Message.created_at < first_message.created_at, Message.id != first_message.id, @@ -86,7 +87,7 @@ class ChatMessageListApi(Resource): else: history_messages = ( db.session.query(Message) - .filter(Message.conversation_id == conversation.id) + .where(Message.conversation_id == conversation.id) .order_by(Message.created_at.desc()) .limit(args["limit"]) .all() @@ -97,7 +98,7 @@ class ChatMessageListApi(Resource): current_page_first_message = history_messages[-1] rest_count = ( db.session.query(Message) - .filter( + .where( Message.conversation_id == conversation.id, Message.created_at < current_page_first_message.created_at, Message.id != current_page_first_message.id, @@ -124,33 +125,16 @@ class MessageFeedbackApi(Resource): parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") args = parser.parse_args() - message_id = str(args["message_id"]) - - message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() - - if not message: - raise NotFound("Message Not Exists.") - - feedback = message.admin_feedback - - if not args["rating"] and feedback: - db.session.delete(feedback) - elif args["rating"] and feedback: - feedback.rating = args["rating"] - elif not args["rating"] and not feedback: - raise ValueError("rating cannot be None when feedback not exists") - else: - feedback = MessageFeedback( - app_id=app_model.id, - conversation_id=message.conversation_id, - message_id=message.id, - rating=args["rating"], - from_source="admin", - from_account_id=current_user.id, + try: + MessageService.create_feedback( + app_model=app_model, + message_id=str(args["message_id"]), + user=current_user, + rating=args.get("rating"), + content=None, ) - db.session.add(feedback) - - db.session.commit() + except services.errors.message.MessageNotExistsError: + raise NotFound("Message Not Exists.") return {"result": "success"} @@ -183,7 +167,7 @@ class MessageAnnotationCountApi(Resource): @account_initialization_required @get_app_model def get(self, app_model): - count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_model.id).count() + count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count() return {"count": count} @@ -230,7 +214,7 @@ class MessageApi(Resource): def get(self, app_model, message_id): message_id = str(message_id) - message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() + message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first() if not message: raise NotFound("Message Not Exists.") diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index f30e3e893c..029138fb6b 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -42,7 +42,7 @@ class ModelConfigResource(Resource): if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: # get original app model config original_app_model_config = ( - db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() + db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first() ) if original_app_model_config is None: raise ValueError("Original app model config not found") diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 3c3a359eeb..03418f1dd2 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,5 +1,3 @@ -from datetime import UTC, datetime - from flask_login import current_user from flask_restful import Resource, marshal_with, reqparse from werkzeug.exceptions import Forbidden, NotFound @@ -10,6 +8,7 @@ from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db from fields.app_fields import app_site_fields +from libs.datetime_utils import naive_utc_now from libs.login import login_required from models import Site @@ -50,7 +49,7 @@ class AppSite(Resource): if not current_user.is_editor: raise Forbidden() - site = db.session.query(Site).filter(Site.app_id == app_model.id).first() + site = db.session.query(Site).where(Site.app_id == app_model.id).first() if not site: raise NotFound @@ -77,7 +76,7 @@ class AppSite(Resource): setattr(site, attr_name, value) site.updated_by = current_user.id - site.updated_at = datetime.now(UTC).replace(tzinfo=None) + site.updated_at = naive_utc_now() db.session.commit() return site @@ -94,14 +93,14 @@ class AppSiteAccessTokenReset(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - site = db.session.query(Site).filter(Site.app_id == app_model.id).first() + site = db.session.query(Site).where(Site.app_id == app_model.id).first() if not site: raise NotFound site.code = Site.generate_code(16) site.updated_by = current_user.id - site.updated_at = datetime.now(UTC).replace(tzinfo=None) + site.updated_at = naive_utc_now() db.session.commit() return site diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 86aed77412..32b64d10c5 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -2,6 +2,7 @@ from datetime import datetime from decimal import Decimal import pytz +import sqlalchemy as sa from flask import jsonify from flask_login import current_user from flask_restful import Resource, reqparse @@ -9,10 +10,11 @@ from flask_restful import Resource, reqparse from controllers.console import api from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required +from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from libs.helper import DatetimeString from libs.login import login_required -from models.model import AppMode +from models import AppMode, Message class DailyMessageStatistic(Resource): @@ -85,46 +87,41 @@ class DailyConversationStatistic(Resource): parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """SELECT - DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, - COUNT(DISTINCT messages.conversation_id) AS conversation_count -FROM - messages -WHERE - app_id = :app_id""" - arg_dict = {"tz": account.timezone, "app_id": app_model.id} - timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc + stmt = ( + sa.select( + sa.func.date( + sa.func.date_trunc("day", sa.text("created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz")) + ).label("date"), + sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"), + ) + .select_from(Message) + .where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER.value) + ) + if args["start"]: start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) - start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - - sql_query += " AND created_at >= :start" - arg_dict["start"] = start_datetime_utc + stmt = stmt.where(Message.created_at >= start_datetime_utc) if args["end"]: end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) - end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) + stmt = stmt.where(Message.created_at < end_datetime_utc) - sql_query += " AND created_at < :end" - arg_dict["end"] = end_datetime_utc - - sql_query += " GROUP BY date ORDER BY date" + stmt = stmt.group_by("date").order_by("date") response_data = [] - with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) - for i in rs: - response_data.append({"date": str(i.date), "conversation_count": i.conversation_count}) + rs = conn.execute(stmt, {"tz": account.timezone}) + for row in rs: + response_data.append({"date": str(row.date), "conversation_count": row.conversation_count}) return jsonify({"data": response_data}) diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 00d6fa3cbf..ba93f82756 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -68,13 +68,18 @@ def _create_pagination_parser(): return parser +def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str: + value_type = workflow_draft_var.value_type + return value_type.exposed_type().value + + _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = { "id": fields.String, "type": fields.String(attribute=lambda model: model.get_variable_type()), "name": fields.String, "description": fields.String, "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()), - "value_type": fields.String, + "value_type": fields.String(attribute=_serialize_variable_type), "edited": fields.Boolean(attribute=lambda model: model.edited), "visible": fields.Boolean, } @@ -90,7 +95,7 @@ _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = { "name": fields.String, "description": fields.String, "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()), - "value_type": fields.String, + "value_type": fields.String(attribute=_serialize_variable_type), "edited": fields.Boolean(attribute=lambda model: model.edited), "visible": fields.Boolean, } @@ -396,7 +401,7 @@ class EnvironmentVariableCollectionApi(Resource): "name": v.name, "description": v.description, "selector": v.selector, - "value_type": v.value_type.value, + "value_type": v.value_type.exposed_type().value, "value": v.value, # Do not track edited for env vars. "edited": False, diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index 03b60610aa..132dc1f96b 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -11,7 +11,7 @@ from models import App, AppMode def _load_app_model(app_id: str) -> Optional[App]: app_model = ( db.session.query(App) - .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) return app_model @@ -35,8 +35,6 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[ raise AppNotFoundError() app_mode = AppMode.value_of(app_model.mode) - if app_mode == AppMode.CHANNEL: - raise AppNotFoundError() if mode is not None: if isinstance(mode, list): diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index 1795563ff7..2562fb5eb8 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -1,5 +1,3 @@ -import datetime - from flask import request from flask_restful import Resource, reqparse @@ -7,6 +5,7 @@ from constants.languages import supported_language from controllers.console import api from controllers.console.error import AlreadyActivateError from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from libs.helper import StrLen, email, extract_remote_ip, timezone from models.account import AccountStatus from services.account_service import AccountService, RegisterService @@ -65,7 +64,7 @@ class ActivateApi(Resource): account.timezone = args["timezone"] account.interface_theme = "light" account.status = AccountStatus.ACTIVE.value - account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + account.initialized_at = naive_utc_now() db.session.commit() token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) diff --git a/api/controllers/console/auth/error.py b/api/controllers/console/auth/error.py index b40934dbf5..8c5e23de58 100644 --- a/api/controllers/console/auth/error.py +++ b/api/controllers/console/auth/error.py @@ -27,7 +27,19 @@ class InvalidTokenError(BaseHTTPException): class PasswordResetRateLimitExceededError(BaseHTTPException): error_code = "password_reset_rate_limit_exceeded" - description = "Too many password reset emails have been sent. Please try again in 1 minutes." + description = "Too many password reset emails have been sent. Please try again in 1 minute." + code = 429 + + +class EmailChangeRateLimitExceededError(BaseHTTPException): + error_code = "email_change_rate_limit_exceeded" + description = "Too many email change emails have been sent. Please try again in 1 minute." + code = 429 + + +class OwnerTransferRateLimitExceededError(BaseHTTPException): + error_code = "owner_transfer_rate_limit_exceeded" + description = "Too many owner transfer emails have been sent. Please try again in 1 minute." code = 429 @@ -65,3 +77,39 @@ class EmailPasswordResetLimitError(BaseHTTPException): error_code = "email_password_reset_limit" description = "Too many failed password reset attempts. Please try again in 24 hours." code = 429 + + +class EmailChangeLimitError(BaseHTTPException): + error_code = "email_change_limit" + description = "Too many failed email change attempts. Please try again in 24 hours." + code = 429 + + +class EmailAlreadyInUseError(BaseHTTPException): + error_code = "email_already_in_use" + description = "A user with this email already exists." + code = 400 + + +class OwnerTransferLimitError(BaseHTTPException): + error_code = "owner_transfer_limit" + description = "Too many failed owner transfer attempts. Please try again in 24 hours." + code = 429 + + +class NotOwnerError(BaseHTTPException): + error_code = "not_owner" + description = "You are not the owner of the workspace." + code = 400 + + +class CannotTransferOwnerToSelfError(BaseHTTPException): + error_code = "cannot_transfer_owner_to_self" + description = "You cannot transfer ownership to yourself." + code = 400 + + +class MemberNotInTenantError(BaseHTTPException): + error_code = "member_not_in_tenant" + description = "The member is not in the workspace." + code = 400 diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 395367c9e2..d0a4f3ff6d 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -1,5 +1,4 @@ import logging -from datetime import UTC, datetime from typing import Optional import requests @@ -13,6 +12,7 @@ from configs import dify_config from constants.languages import languages from events.tenant_event import tenant_was_created from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from libs.helper import extract_remote_ip from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from models import Account @@ -110,7 +110,7 @@ class OAuthCallback(Resource): if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value - account.initialized_at = datetime.now(UTC).replace(tzinfo=None) + account.initialized_at = naive_utc_now() db.session.commit() try: diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 7b0d9373cf..39f8ab5787 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -1,4 +1,3 @@ -import datetime import json from flask import request @@ -15,6 +14,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.notion_extractor import NotionExtractor from extensions.ext_database import db from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields +from libs.datetime_utils import naive_utc_now from libs.login import login_required from models import DataSourceOauthBinding, Document from services.dataset_service import DatasetService, DocumentService @@ -30,7 +30,7 @@ class DataSourceApi(Resource): # get workspace data source integrates data_source_integrates = ( db.session.query(DataSourceOauthBinding) - .filter( + .where( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.disabled == False, ) @@ -88,7 +88,7 @@ class DataSourceApi(Resource): if action == "enable": if data_source_binding.disabled: data_source_binding.disabled = False - data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + data_source_binding.updated_at = naive_utc_now() db.session.add(data_source_binding) db.session.commit() else: @@ -97,7 +97,7 @@ class DataSourceApi(Resource): if action == "disable": if not data_source_binding.disabled: data_source_binding.disabled = True - data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + data_source_binding.updated_at = naive_utc_now() db.session.add(data_source_binding) db.session.commit() else: @@ -171,7 +171,7 @@ class DataSourceNotionApi(Resource): page_id = str(page_id) with Session(db.engine) as session: data_source_binding = session.execute( - select(DataSourceOauthBinding).filter( + select(DataSourceOauthBinding).where( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.provider == "notion", diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 1611214cb3..f551bc2432 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -211,10 +211,6 @@ class DatasetApi(Resource): else: data["embedding_available"] = True - if data.get("permission") == "partial_members": - part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) - data.update({"partial_member_list": part_users_list}) - return data, 200 @setup_required @@ -416,7 +412,7 @@ class DatasetIndexingEstimateApi(Resource): file_ids = args["info_list"]["file_info_list"]["file_ids"] file_details = ( db.session.query(UploadFile) - .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)) + .where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)) .all() ) @@ -521,14 +517,14 @@ class DatasetIndexingStatusApi(Resource): dataset_id = str(dataset_id) documents = ( db.session.query(Document) - .filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id) + .where(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id) .all() ) documents_status = [] for document in documents: completed_segments = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.completed_at.isnot(None), DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment", @@ -537,7 +533,7 @@ class DatasetIndexingStatusApi(Resource): ) total_segments = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") .count() ) # Create a dictionary with document attributes and additional fields @@ -572,7 +568,7 @@ class DatasetApiKeyApi(Resource): def get(self): keys = ( db.session.query(ApiToken) - .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) + .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) .all() ) return {"items": keys} @@ -588,7 +584,7 @@ class DatasetApiKeyApi(Resource): current_key_count = ( db.session.query(ApiToken) - .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) + .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) .count() ) @@ -624,7 +620,7 @@ class DatasetApiDeleteApi(Resource): key = ( db.session.query(ApiToken) - .filter( + .where( ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id, @@ -635,7 +631,7 @@ class DatasetApiDeleteApi(Resource): if key is None: flask_restful.abort(404, message="API key not found") - db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() + db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.commit() return {"result": "success"}, 204 diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index b2fcf3ce7b..d14b208a4b 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,6 +1,5 @@ import logging from argparse import ArgumentTypeError -from datetime import UTC, datetime from typing import cast from flask import request @@ -49,6 +48,7 @@ from fields.document_fields import ( document_status_fields, document_with_segments_fields, ) +from libs.datetime_utils import naive_utc_now from libs.login import login_required from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile from services.dataset_service import DatasetService, DocumentService @@ -124,7 +124,7 @@ class GetProcessRuleApi(Resource): # get the latest process rule dataset_process_rule = ( db.session.query(DatasetProcessRule) - .filter(DatasetProcessRule.dataset_id == document.dataset_id) + .where(DatasetProcessRule.dataset_id == document.dataset_id) .order_by(DatasetProcessRule.created_at.desc()) .limit(1) .one_or_none() @@ -176,7 +176,7 @@ class DatasetDocumentListApi(Resource): if search: search = f"%{search}%" - query = query.filter(Document.name.like(search)) + query = query.where(Document.name.like(search)) if sort.startswith("-"): sort_logic = desc @@ -212,7 +212,7 @@ class DatasetDocumentListApi(Resource): for document in documents: completed_segments = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.completed_at.isnot(None), DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment", @@ -221,7 +221,7 @@ class DatasetDocumentListApi(Resource): ) total_segments = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") .count() ) document.completed_segments = completed_segments @@ -417,7 +417,7 @@ class DocumentIndexingEstimateApi(DocumentResource): file = ( db.session.query(UploadFile) - .filter(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) + .where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) .first() ) @@ -492,7 +492,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): file_id = data_source_info["upload_file_id"] file_detail = ( db.session.query(UploadFile) - .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id) + .where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id) .first() ) @@ -568,7 +568,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource): for document in documents: completed_segments = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.completed_at.isnot(None), DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment", @@ -577,7 +577,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource): ) total_segments = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") .count() ) # Create a dictionary with document attributes and additional fields @@ -611,7 +611,7 @@ class DocumentIndexingStatusApi(DocumentResource): completed_segments = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.completed_at.isnot(None), DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment", @@ -620,7 +620,7 @@ class DocumentIndexingStatusApi(DocumentResource): ) total_segments = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment") + .where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment") .count() ) @@ -750,7 +750,7 @@ class DocumentProcessingApi(DocumentResource): raise InvalidActionError("Document not in indexing state.") document.paused_by = current_user.id - document.paused_at = datetime.now(UTC).replace(tzinfo=None) + document.paused_at = naive_utc_now() document.is_paused = True db.session.commit() @@ -830,7 +830,7 @@ class DocumentMetadataApi(DocumentResource): document.doc_metadata[key] = value document.doc_type = doc_type - document.updated_at = datetime.now(UTC).replace(tzinfo=None) + document.updated_at = naive_utc_now() db.session.commit() return {"result": "success", "message": "Document metadata updated."}, 200 diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 48142dbe73..b3704ce8b1 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -78,7 +78,7 @@ class DatasetDocumentSegmentListApi(Resource): query = ( select(DocumentSegment) - .filter( + .where( DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id, ) @@ -86,19 +86,19 @@ class DatasetDocumentSegmentListApi(Resource): ) if status_list: - query = query.filter(DocumentSegment.status.in_(status_list)) + query = query.where(DocumentSegment.status.in_(status_list)) if hit_count_gte is not None: - query = query.filter(DocumentSegment.hit_count >= hit_count_gte) + query = query.where(DocumentSegment.hit_count >= hit_count_gte) if keyword: query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) if args["enabled"].lower() != "all": if args["enabled"].lower() == "true": - query = query.filter(DocumentSegment.enabled == True) + query = query.where(DocumentSegment.enabled == True) elif args["enabled"].lower() == "false": - query = query.filter(DocumentSegment.enabled == False) + query = query.where(DocumentSegment.enabled == False) segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) @@ -285,7 +285,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .first() ) if not segment: @@ -331,7 +331,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .first() ) if not segment: @@ -436,7 +436,7 @@ class ChildChunkAddApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .first() ) if not segment: @@ -493,7 +493,7 @@ class ChildChunkAddApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .first() ) if not segment: @@ -540,7 +540,7 @@ class ChildChunkAddApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .first() ) if not segment: @@ -586,7 +586,7 @@ class ChildChunkUpdateApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .first() ) if not segment: @@ -595,7 +595,7 @@ class ChildChunkUpdateApi(Resource): child_chunk_id = str(child_chunk_id) child_chunk = ( db.session.query(ChildChunk) - .filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) + .where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) .first() ) if not child_chunk: @@ -635,7 +635,7 @@ class ChildChunkUpdateApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .first() ) if not segment: @@ -644,7 +644,7 @@ class ChildChunkUpdateApi(Resource): child_chunk_id = str(child_chunk_id) child_chunk = ( db.session.query(ChildChunk) - .filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) + .where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) .first() ) if not child_chunk: diff --git a/api/controllers/console/datasets/error.py b/api/controllers/console/datasets/error.py index 2f00a84de6..cb68bb5e81 100644 --- a/api/controllers/console/datasets/error.py +++ b/api/controllers/console/datasets/error.py @@ -25,12 +25,6 @@ class UnsupportedFileTypeError(BaseHTTPException): code = 415 -class HighQualityDatasetOnlyError(BaseHTTPException): - error_code = "high_quality_dataset_only" - description = "Current operation only supports 'high-quality' datasets." - code = 400 - - class DatasetNotInitializedError(BaseHTTPException): error_code = "dataset_not_initialized" description = "The dataset is still being initialized or indexing. Please wait a moment." diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py index 4200a51709..fcdc91ec67 100644 --- a/api/controllers/console/datasets/website.py +++ b/api/controllers/console/datasets/website.py @@ -4,7 +4,7 @@ from controllers.console import api from controllers.console.datasets.error import WebsiteCrawlError from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required -from services.website_service import WebsiteService +from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusApiRequest, WebsiteService class WebsiteCrawlApi(Resource): @@ -24,10 +24,16 @@ class WebsiteCrawlApi(Resource): parser.add_argument("url", type=str, required=True, nullable=True, location="json") parser.add_argument("options", type=dict, required=True, nullable=True, location="json") args = parser.parse_args() - WebsiteService.document_create_args_validate(args) - # crawl url + + # Create typed request and validate try: - result = WebsiteService.crawl_url(args) + api_request = WebsiteCrawlApiRequest.from_args(args) + except ValueError as e: + raise WebsiteCrawlError(str(e)) + + # Crawl URL using typed request + try: + result = WebsiteService.crawl_url(api_request) except Exception as e: raise WebsiteCrawlError(str(e)) return result, 200 @@ -43,9 +49,16 @@ class WebsiteCrawlStatusApi(Resource): "provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args" ) args = parser.parse_args() - # get crawl status + + # Create typed request and validate try: - result = WebsiteService.get_crawl_status(job_id, args["provider"]) + api_request = WebsiteCrawlStatusApiRequest.from_args(args, job_id) + except ValueError as e: + raise WebsiteCrawlError(str(e)) + + # Get crawl status using typed request + try: + result = WebsiteService.get_crawl_status_typed(api_request) except Exception as e: raise WebsiteCrawlError(str(e)) return result, 200 diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 4367da1162..4842fefc57 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -1,5 +1,4 @@ import logging -from datetime import UTC, datetime from flask_login import current_user from flask_restful import reqparse @@ -27,6 +26,7 @@ from core.errors.error import ( from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from libs import helper +from libs.datetime_utils import naive_utc_now from libs.helper import uuid_value from models.model import AppMode from services.app_generate_service import AppGenerateService @@ -51,7 +51,7 @@ class CompletionApi(InstalledAppResource): streaming = args["response_mode"] == "streaming" args["auto_generate_name"] = False - installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None) + installed_app.last_used_at = naive_utc_now() db.session.commit() try: @@ -111,7 +111,7 @@ class ChatApi(InstalledAppResource): args["auto_generate_name"] = False - installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None) + installed_app.last_used_at = naive_utc_now() db.session.commit() try: diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 9d0c08564e..ffdf73c368 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -1,5 +1,4 @@ import logging -from datetime import UTC, datetime from typing import Any from flask import request @@ -13,6 +12,7 @@ from controllers.console.explore.wraps import InstalledAppResource from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from extensions.ext_database import db from fields.installed_app_fields import installed_app_list_fields +from libs.datetime_utils import naive_utc_now from libs.login import login_required from models import App, InstalledApp, RecommendedApp from services.account_service import TenantService @@ -34,11 +34,11 @@ class InstalledAppsListApi(Resource): if app_id: installed_apps = ( db.session.query(InstalledApp) - .filter(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id)) + .where(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id)) .all() ) else: - installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all() + installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all() current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) installed_app_list: list[dict[str, Any]] = [ @@ -94,12 +94,12 @@ class InstalledAppsListApi(Resource): parser.add_argument("app_id", type=str, required=True, help="Invalid app_id") args = parser.parse_args() - recommended_app = db.session.query(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]).first() + recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first() if recommended_app is None: raise NotFound("App not found") current_tenant_id = current_user.current_tenant_id - app = db.session.query(App).filter(App.id == args["app_id"]).first() + app = db.session.query(App).where(App.id == args["app_id"]).first() if app is None: raise NotFound("App not found") @@ -109,7 +109,7 @@ class InstalledAppsListApi(Resource): installed_app = ( db.session.query(InstalledApp) - .filter(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)) + .where(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)) .first() ) @@ -122,7 +122,7 @@ class InstalledAppsListApi(Resource): tenant_id=current_tenant_id, app_owner_tenant_id=app.tenant_id, is_pinned=False, - last_used_at=datetime.now(UTC).replace(tzinfo=None), + last_used_at=naive_utc_now(), ) db.session.add(new_installed_app) db.session.commit() diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index afbd78bd5b..de97fb149e 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -28,7 +28,7 @@ def installed_app_required(view=None): installed_app = ( db.session.query(InstalledApp) - .filter( + .where( InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id ) .first() diff --git a/api/controllers/console/workspace/__init__.py b/api/controllers/console/workspace/__init__.py index 072e904caf..ef814dd738 100644 --- a/api/controllers/console/workspace/__init__.py +++ b/api/controllers/console/workspace/__init__.py @@ -21,7 +21,7 @@ def plugin_permission_required( with Session(db.engine) as session: permission = ( session.query(TenantPluginPermission) - .filter( + .where( TenantPluginPermission.tenant_id == tenant_id, ) .first() diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index a9dbf44456..5cd2e0cd2d 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -1,13 +1,21 @@ -import datetime - import pytz from flask import request from flask_login import current_user from flask_restful import Resource, fields, marshal_with, reqparse +from sqlalchemy import select +from sqlalchemy.orm import Session from configs import dify_config from constants.languages import supported_language from controllers.console import api +from controllers.console.auth.error import ( + EmailAlreadyInUseError, + EmailChangeLimitError, + EmailCodeError, + InvalidEmailError, + InvalidTokenError, +) +from controllers.console.error import AccountNotFound, EmailSendIpLimitError from controllers.console.workspace.error import ( AccountAlreadyInitedError, CurrentPasswordIncorrectError, @@ -18,15 +26,18 @@ from controllers.console.workspace.error import ( from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_enabled, + enable_change_email, enterprise_license_required, only_edition_cloud, setup_required, ) from extensions.ext_database import db from fields.member_fields import account_fields -from libs.helper import TimestampField, timezone +from libs.datetime_utils import naive_utc_now +from libs.helper import TimestampField, email, extract_remote_ip, timezone from libs.login import login_required from models import AccountIntegrate, InvitationCode +from models.account import Account from services.account_service import AccountService from services.billing_service import BillingService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError @@ -57,7 +68,7 @@ class AccountInitApi(Resource): # check invitation code invitation_code = ( db.session.query(InvitationCode) - .filter( + .where( InvitationCode.code == args["invitation_code"], InvitationCode.status == "unused", ) @@ -68,7 +79,7 @@ class AccountInitApi(Resource): raise InvalidInvitationCodeError() invitation_code.status = "used" - invitation_code.used_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + invitation_code.used_at = naive_utc_now() invitation_code.used_by_tenant_id = account.current_tenant_id invitation_code.used_by_account_id = account.id @@ -76,7 +87,7 @@ class AccountInitApi(Resource): account.timezone = args["timezone"] account.interface_theme = "light" account.status = "active" - account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + account.initialized_at = naive_utc_now() db.session.commit() return {"result": "success"} @@ -217,7 +228,7 @@ class AccountIntegrateApi(Resource): def get(self): account = current_user - account_integrates = db.session.query(AccountIntegrate).filter(AccountIntegrate.account_id == account.id).all() + account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all() base_url = request.url_root.rstrip("/") oauth_base_path = "/console/api/oauth/login" @@ -369,6 +380,134 @@ class EducationAutoCompleteApi(Resource): return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"]) +class ChangeEmailSendEmailApi(Resource): + @enable_change_email + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("language", type=str, required=False, location="json") + parser.add_argument("phase", type=str, required=False, location="json") + parser.add_argument("token", type=str, required=False, location="json") + args = parser.parse_args() + + ip_address = extract_remote_ip(request) + if AccountService.is_email_send_ip_limit(ip_address): + raise EmailSendIpLimitError() + + if args["language"] is not None and args["language"] == "zh-Hans": + language = "zh-Hans" + else: + language = "en-US" + account = None + user_email = args["email"] + if args["phase"] is not None and args["phase"] == "new_email": + if args["token"] is None: + raise InvalidTokenError() + + reset_data = AccountService.get_change_email_data(args["token"]) + if reset_data is None: + raise InvalidTokenError() + user_email = reset_data.get("email", "") + + if user_email != current_user.email: + raise InvalidEmailError() + else: + with Session(db.engine) as session: + account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() + if account is None: + raise AccountNotFound() + + token = AccountService.send_change_email_email( + account=account, email=args["email"], old_email=user_email, language=language, phase=args["phase"] + ) + return {"result": "success", "data": token} + + +class ChangeEmailCheckApi(Resource): + @enable_change_email + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("code", type=str, required=True, location="json") + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + + user_email = args["email"] + + is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args["email"]) + if is_change_email_error_rate_limit: + raise EmailChangeLimitError() + + token_data = AccountService.get_change_email_data(args["token"]) + if token_data is None: + raise InvalidTokenError() + + if user_email != token_data.get("email"): + raise InvalidEmailError() + + if args["code"] != token_data.get("code"): + AccountService.add_change_email_error_rate_limit(args["email"]) + raise EmailCodeError() + + # Verified, revoke the first token + AccountService.revoke_change_email_token(args["token"]) + + # Refresh token data by generating a new token + _, new_token = AccountService.generate_change_email_token( + user_email, code=args["code"], old_email=token_data.get("old_email"), additional_data={} + ) + + AccountService.reset_change_email_error_rate_limit(args["email"]) + return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + + +class ChangeEmailResetApi(Resource): + @enable_change_email + @setup_required + @login_required + @account_initialization_required + @marshal_with(account_fields) + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("new_email", type=email, required=True, location="json") + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + + reset_data = AccountService.get_change_email_data(args["token"]) + if not reset_data: + raise InvalidTokenError() + + AccountService.revoke_change_email_token(args["token"]) + + if not AccountService.check_email_unique(args["new_email"]): + raise EmailAlreadyInUseError() + + old_email = reset_data.get("old_email", "") + if current_user.email != old_email: + raise AccountNotFound() + + updated_account = AccountService.update_account(current_user, email=args["new_email"]) + + return updated_account + + +class CheckEmailUnique(Resource): + @setup_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + args = parser.parse_args() + if not AccountService.check_email_unique(args["email"]): + raise EmailAlreadyInUseError() + return {"result": "success"} + + # Register API resources api.add_resource(AccountInitApi, "/account/init") api.add_resource(AccountProfileApi, "/account/profile") @@ -385,5 +524,10 @@ api.add_resource(AccountDeleteUpdateFeedbackApi, "/account/delete/feedback") api.add_resource(EducationVerifyApi, "/account/education/verify") api.add_resource(EducationApi, "/account/education") api.add_resource(EducationAutoCompleteApi, "/account/education/autocomplete") +# Change email +api.add_resource(ChangeEmailSendEmailApi, "/account/change-email") +api.add_resource(ChangeEmailCheckApi, "/account/change-email/validity") +api.add_resource(ChangeEmailResetApi, "/account/change-email/reset") +api.add_resource(CheckEmailUnique, "/account/change-email/check-email-unique") # api.add_resource(AccountEmailApi, '/account/email') # api.add_resource(AccountEmailVerifyApi, '/account/email-verify') diff --git a/api/controllers/console/workspace/error.py b/api/controllers/console/workspace/error.py index 8b70ca62b9..4427d1ff72 100644 --- a/api/controllers/console/workspace/error.py +++ b/api/controllers/console/workspace/error.py @@ -13,12 +13,6 @@ class CurrentPasswordIncorrectError(BaseHTTPException): code = 400 -class ProviderRequestFailedError(BaseHTTPException): - error_code = "provider_request_failed" - description = None - code = 400 - - class InvalidInvitationCodeError(BaseHTTPException): error_code = "invalid_invitation_code" description = "Invalid invitation code." diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 48225ac90d..f7424923b9 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,22 +1,34 @@ from urllib import parse +from flask import request from flask_login import current_user from flask_restful import Resource, abort, marshal_with, reqparse import services from configs import dify_config from controllers.console import api -from controllers.console.error import WorkspaceMembersLimitExceeded +from controllers.console.auth.error import ( + CannotTransferOwnerToSelfError, + EmailCodeError, + InvalidEmailError, + InvalidTokenError, + MemberNotInTenantError, + NotOwnerError, + OwnerTransferLimitError, +) +from controllers.console.error import EmailSendIpLimitError, WorkspaceMembersLimitExceeded from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_resource_check, + is_allow_transfer_owner, setup_required, ) from extensions.ext_database import db from fields.member_fields import account_with_role_list_fields +from libs.helper import extract_remote_ip from libs.login import login_required from models.account import Account, TenantAccountRole -from services.account_service import RegisterService, TenantService +from services.account_service import AccountService, RegisterService, TenantService from services.errors.account import AccountAlreadyInTenantError from services.feature_service import FeatureService @@ -96,7 +108,7 @@ class MemberCancelInviteApi(Resource): @login_required @account_initialization_required def delete(self, member_id): - member = db.session.query(Account).filter(Account.id == str(member_id)).first() + member = db.session.query(Account).where(Account.id == str(member_id)).first() if member is None: abort(404) else: @@ -156,8 +168,146 @@ class DatasetOperatorMemberListApi(Resource): return {"result": "success", "accounts": members}, 200 +class SendOwnerTransferEmailApi(Resource): + """Send owner transfer email.""" + + @setup_required + @login_required + @account_initialization_required + @is_allow_transfer_owner + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("language", type=str, required=False, location="json") + args = parser.parse_args() + ip_address = extract_remote_ip(request) + if AccountService.is_email_send_ip_limit(ip_address): + raise EmailSendIpLimitError() + + # check if the current user is the owner of the workspace + if not TenantService.is_owner(current_user, current_user.current_tenant): + raise NotOwnerError() + + if args["language"] is not None and args["language"] == "zh-Hans": + language = "zh-Hans" + else: + language = "en-US" + + email = current_user.email + + token = AccountService.send_owner_transfer_email( + account=current_user, + email=email, + language=language, + workspace_name=current_user.current_tenant.name, + ) + + return {"result": "success", "data": token} + + +class OwnerTransferCheckApi(Resource): + @setup_required + @login_required + @account_initialization_required + @is_allow_transfer_owner + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("code", type=str, required=True, location="json") + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + # check if the current user is the owner of the workspace + if not TenantService.is_owner(current_user, current_user.current_tenant): + raise NotOwnerError() + + user_email = current_user.email + + is_owner_transfer_error_rate_limit = AccountService.is_owner_transfer_error_rate_limit(user_email) + if is_owner_transfer_error_rate_limit: + raise OwnerTransferLimitError() + + token_data = AccountService.get_owner_transfer_data(args["token"]) + if token_data is None: + raise InvalidTokenError() + + if user_email != token_data.get("email"): + raise InvalidEmailError() + + if args["code"] != token_data.get("code"): + AccountService.add_owner_transfer_error_rate_limit(user_email) + raise EmailCodeError() + + # Verified, revoke the first token + AccountService.revoke_owner_transfer_token(args["token"]) + + # Refresh token data by generating a new token + _, new_token = AccountService.generate_owner_transfer_token(user_email, code=args["code"], additional_data={}) + + AccountService.reset_owner_transfer_error_rate_limit(user_email) + return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + + +class OwnerTransfer(Resource): + @setup_required + @login_required + @account_initialization_required + @is_allow_transfer_owner + def post(self, member_id): + parser = reqparse.RequestParser() + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + + # check if the current user is the owner of the workspace + if not TenantService.is_owner(current_user, current_user.current_tenant): + raise NotOwnerError() + + if current_user.id == str(member_id): + raise CannotTransferOwnerToSelfError() + + transfer_token_data = AccountService.get_owner_transfer_data(args["token"]) + if not transfer_token_data: + raise InvalidTokenError() + + if transfer_token_data.get("email") != current_user.email: + raise InvalidEmailError() + + AccountService.revoke_owner_transfer_token(args["token"]) + + member = db.session.get(Account, str(member_id)) + if not member: + abort(404) + else: + member_account = member + if not TenantService.is_member(member_account, current_user.current_tenant): + raise MemberNotInTenantError() + + try: + assert member is not None, "Member not found" + TenantService.update_member_role(current_user.current_tenant, member, "owner", current_user) + + AccountService.send_new_owner_transfer_notify_email( + account=member, + email=member.email, + workspace_name=current_user.current_tenant.name, + ) + + AccountService.send_old_owner_transfer_notify_email( + account=current_user, + email=current_user.email, + workspace_name=current_user.current_tenant.name, + new_owner_email=member.email, + ) + + except Exception as e: + raise ValueError(str(e)) + + return {"result": "success"} + + api.add_resource(MemberListApi, "/workspaces/current/members") api.add_resource(MemberInviteEmailApi, "/workspaces/current/members/invite-email") api.add_resource(MemberCancelInviteApi, "/workspaces/current/members/") api.add_resource(MemberUpdateRoleApi, "/workspaces/current/members//update-role") api.add_resource(DatasetOperatorMemberListApi, "/workspaces/current/dataset-operators") +# owner transfer +api.add_resource(SendOwnerTransferEmailApi, "/workspaces/current/members/send-owner-transfer-confirm-email") +api.add_resource(OwnerTransferCheckApi, "/workspaces/current/members/owner-transfer-check") +api.add_resource(OwnerTransfer, "/workspaces/current/members//owner-transfer") diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index c0a4734828..09846d5c94 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -12,7 +12,8 @@ from controllers.console.wraps import account_initialization_required, setup_req from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.exc import PluginDaemonClientSideError from libs.login import login_required -from models.account import TenantPluginPermission +from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission +from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService from services.plugin.plugin_parameter_service import PluginParameterService from services.plugin.plugin_permission_service import PluginPermissionService from services.plugin.plugin_service import PluginService @@ -534,6 +535,114 @@ class PluginFetchDynamicSelectOptionsApi(Resource): return jsonable_encoder({"options": options}) +class PluginChangePreferencesApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + user = current_user + if not user.is_admin_or_owner: + raise Forbidden() + + req = reqparse.RequestParser() + req.add_argument("permission", type=dict, required=True, location="json") + req.add_argument("auto_upgrade", type=dict, required=True, location="json") + args = req.parse_args() + + tenant_id = user.current_tenant_id + + permission = args["permission"] + + install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone")) + debug_permission = TenantPluginPermission.DebugPermission(permission.get("debug_permission", "everyone")) + + auto_upgrade = args["auto_upgrade"] + + strategy_setting = TenantPluginAutoUpgradeStrategy.StrategySetting( + auto_upgrade.get("strategy_setting", "fix_only") + ) + upgrade_time_of_day = auto_upgrade.get("upgrade_time_of_day", 0) + upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode(auto_upgrade.get("upgrade_mode", "exclude")) + exclude_plugins = auto_upgrade.get("exclude_plugins", []) + include_plugins = auto_upgrade.get("include_plugins", []) + + # set permission + set_permission_result = PluginPermissionService.change_permission( + tenant_id, + install_permission, + debug_permission, + ) + if not set_permission_result: + return jsonable_encoder({"success": False, "message": "Failed to set permission"}) + + # set auto upgrade strategy + set_auto_upgrade_strategy_result = PluginAutoUpgradeService.change_strategy( + tenant_id, + strategy_setting, + upgrade_time_of_day, + upgrade_mode, + exclude_plugins, + include_plugins, + ) + if not set_auto_upgrade_strategy_result: + return jsonable_encoder({"success": False, "message": "Failed to set auto upgrade strategy"}) + + return jsonable_encoder({"success": True}) + + +class PluginFetchPreferencesApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + tenant_id = current_user.current_tenant_id + + permission = PluginPermissionService.get_permission(tenant_id) + permission_dict = { + "install_permission": TenantPluginPermission.InstallPermission.EVERYONE, + "debug_permission": TenantPluginPermission.DebugPermission.EVERYONE, + } + + if permission: + permission_dict["install_permission"] = permission.install_permission + permission_dict["debug_permission"] = permission.debug_permission + + auto_upgrade = PluginAutoUpgradeService.get_strategy(tenant_id) + auto_upgrade_dict = { + "strategy_setting": TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED, + "upgrade_time_of_day": 0, + "upgrade_mode": TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE, + "exclude_plugins": [], + "include_plugins": [], + } + + if auto_upgrade: + auto_upgrade_dict = { + "strategy_setting": auto_upgrade.strategy_setting, + "upgrade_time_of_day": auto_upgrade.upgrade_time_of_day, + "upgrade_mode": auto_upgrade.upgrade_mode, + "exclude_plugins": auto_upgrade.exclude_plugins, + "include_plugins": auto_upgrade.include_plugins, + } + + return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict}) + + +class PluginAutoUpgradeExcludePluginApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + # exclude one single plugin + tenant_id = current_user.current_tenant_id + + req = reqparse.RequestParser() + req.add_argument("plugin_id", type=str, required=True, location="json") + args = req.parse_args() + + return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])}) + + api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key") api.add_resource(PluginListApi, "/workspaces/current/plugin/list") api.add_resource(PluginListLatestVersionsApi, "/workspaces/current/plugin/list/latest-versions") @@ -560,3 +669,7 @@ api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permissi api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch") api.add_resource(PluginFetchDynamicSelectOptionsApi, "/workspaces/current/plugin/parameters/dynamic-options") + +api.add_resource(PluginFetchPreferencesApi, "/workspaces/current/plugin/preferences/fetch") +api.add_resource(PluginChangePreferencesApi, "/workspaces/current/plugin/preferences/change") +api.add_resource(PluginAutoUpgradeExcludePluginApi, "/workspaces/current/plugin/preferences/autoupgrade/exclude") diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 2b1379bfb2..c4d1ef70d8 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,25 +1,52 @@ import io +from urllib.parse import urlparse -from flask import send_file +from flask import make_response, redirect, request, send_file from flask_login import current_user -from flask_restful import Resource, reqparse -from sqlalchemy.orm import Session +from flask_restful import ( + Resource, + reqparse, +) from werkzeug.exceptions import Forbidden from configs import dify_config from controllers.console import api -from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + enterprise_license_required, + setup_required, +) +from core.mcp.auth.auth_flow import auth, handle_callback +from core.mcp.auth.auth_provider import OAuthClientProvider +from core.mcp.error import MCPAuthError, MCPError +from core.mcp.mcp_client import MCPClient from core.model_runtime.utils.encoders import jsonable_encoder -from extensions.ext_database import db -from libs.helper import alphanumeric, uuid_value +from core.plugin.entities.plugin import ToolProviderID +from core.plugin.impl.oauth import OAuthHandler +from core.tools.entities.tool_entities import CredentialType +from libs.helper import StrLen, alphanumeric, uuid_value from libs.login import login_required +from services.plugin.oauth_service import OAuthProxyService from services.tools.api_tools_manage_service import ApiToolManageService from services.tools.builtin_tools_manage_service import BuiltinToolManageService +from services.tools.mcp_tools_manage_service import MCPToolManageService from services.tools.tool_labels_service import ToolLabelsService from services.tools.tools_manage_service import ToolCommonService +from services.tools.tools_transform_service import ToolTransformService from services.tools.workflow_tools_manage_service import WorkflowToolManageService +def is_valid_url(url: str) -> bool: + if not url: + return False + + try: + parsed = urlparse(url) + return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"] + except Exception: + return False + + class ToolProviderListApi(Resource): @setup_required @login_required @@ -34,7 +61,7 @@ class ToolProviderListApi(Resource): req.add_argument( "type", type=str, - choices=["builtin", "model", "api", "workflow"], + choices=["builtin", "model", "api", "workflow", "mcp"], required=False, nullable=True, location="args", @@ -71,7 +98,7 @@ class ToolBuiltinProviderInfoApi(Resource): user_id = user.id tenant_id = user.current_tenant_id - return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(user_id, tenant_id, provider)) + return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider)) class ToolBuiltinProviderDeleteApi(Resource): @@ -80,17 +107,47 @@ class ToolBuiltinProviderDeleteApi(Resource): @account_initialization_required def post(self, provider): user = current_user - if not user.is_admin_or_owner: raise Forbidden() + tenant_id = user.current_tenant_id + req = reqparse.RequestParser() + req.add_argument("credential_id", type=str, required=True, nullable=False, location="json") + args = req.parse_args() + + return BuiltinToolManageService.delete_builtin_tool_provider( + tenant_id, + provider, + args["credential_id"], + ) + + +class ToolBuiltinProviderAddApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider): + user = current_user + user_id = user.id tenant_id = user.current_tenant_id - return BuiltinToolManageService.delete_builtin_tool_provider( - user_id, - tenant_id, - provider, + parser = reqparse.RequestParser() + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("name", type=StrLen(30), required=False, nullable=False, location="json") + parser.add_argument("type", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + + if args["type"] not in CredentialType.values(): + raise ValueError(f"Invalid credential type: {args['type']}") + + return BuiltinToolManageService.add_builtin_tool_provider( + user_id=user_id, + tenant_id=tenant_id, + provider=provider, + credentials=args["credentials"], + name=args["name"], + api_type=CredentialType.of(args["type"]), ) @@ -108,19 +165,20 @@ class ToolBuiltinProviderUpdateApi(Resource): tenant_id = user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") + parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") + parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") args = parser.parse_args() - with Session(db.engine) as session: - result = BuiltinToolManageService.update_builtin_tool_provider( - session=session, - user_id=user_id, - tenant_id=tenant_id, - provider_name=provider, - credentials=args["credentials"], - ) - session.commit() + result = BuiltinToolManageService.update_builtin_tool_provider( + user_id=user_id, + tenant_id=tenant_id, + provider=provider, + credential_id=args["credential_id"], + credentials=args.get("credentials", None), + name=args.get("name", ""), + ) return result @@ -131,9 +189,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): def get(self, provider): tenant_id = current_user.current_tenant_id - return BuiltinToolManageService.get_builtin_tool_provider_credentials( - tenant_id=tenant_id, - provider_name=provider, + return jsonable_encoder( + BuiltinToolManageService.get_builtin_tool_provider_credentials( + tenant_id=tenant_id, + provider_name=provider, + ) ) @@ -326,12 +386,15 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider): + def get(self, provider, credential_type): user = current_user - tenant_id = user.current_tenant_id - return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id) + return jsonable_encoder( + BuiltinToolManageService.list_builtin_provider_credentials_schema( + provider, CredentialType.of(credential_type), tenant_id + ) + ) class ToolApiProviderSchemaApi(Resource): @@ -568,15 +631,12 @@ class ToolApiListApi(Resource): @account_initialization_required def get(self): user = current_user - - user_id = user.id tenant_id = user.current_tenant_id return jsonable_encoder( [ provider.to_dict() for provider in ApiToolManageService.list_api_tools( - user_id, tenant_id, ) ] @@ -613,20 +673,373 @@ class ToolLabelsApi(Resource): return jsonable_encoder(ToolLabelsService.list_tool_labels()) +class ToolPluginOAuthApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider): + tool_provider = ToolProviderID(provider) + plugin_id = tool_provider.plugin_id + provider_name = tool_provider.provider_name + + # todo check permission + user = current_user + + if not user.is_admin_or_owner: + raise Forbidden() + + tenant_id = user.current_tenant_id + oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider) + if oauth_client_params is None: + raise Forbidden("no oauth available client config found for this tool provider") + + oauth_handler = OAuthHandler() + context_id = OAuthProxyService.create_proxy_context( + user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name + ) + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" + authorization_url_response = oauth_handler.get_authorization_url( + tenant_id=tenant_id, + user_id=user.id, + plugin_id=plugin_id, + provider=provider_name, + redirect_uri=redirect_uri, + system_credentials=oauth_client_params, + ) + response = make_response(jsonable_encoder(authorization_url_response)) + response.set_cookie( + "context_id", + context_id, + httponly=True, + samesite="Lax", + max_age=OAuthProxyService.__MAX_AGE__, + ) + return response + + +class ToolOAuthCallback(Resource): + @setup_required + def get(self, provider): + context_id = request.cookies.get("context_id") + if not context_id: + raise Forbidden("context_id not found") + + context = OAuthProxyService.use_proxy_context(context_id) + if context is None: + raise Forbidden("Invalid context_id") + + tool_provider = ToolProviderID(provider) + plugin_id = tool_provider.plugin_id + provider_name = tool_provider.provider_name + user_id, tenant_id = context.get("user_id"), context.get("tenant_id") + + oauth_handler = OAuthHandler() + oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id, provider) + if oauth_client_params is None: + raise Forbidden("no oauth available client config found for this tool provider") + + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" + credentials_response = oauth_handler.get_credentials( + tenant_id=tenant_id, + user_id=user_id, + plugin_id=plugin_id, + provider=provider_name, + redirect_uri=redirect_uri, + system_credentials=oauth_client_params, + request=request, + ) + + credentials = credentials_response.credentials + expires_at = credentials_response.expires_at + + if not credentials: + raise Exception("the plugin credentials failed") + + # add credentials to database + BuiltinToolManageService.add_builtin_tool_provider( + user_id=user_id, + tenant_id=tenant_id, + provider=provider, + credentials=dict(credentials), + expires_at=expires_at, + api_type=CredentialType.OAUTH2, + ) + return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") + + +class ToolBuiltinProviderSetDefaultApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider): + parser = reqparse.RequestParser() + parser.add_argument("id", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + return BuiltinToolManageService.set_default_provider( + tenant_id=current_user.current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"] + ) + + +class ToolOAuthCustomClient(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider): + parser = reqparse.RequestParser() + parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json") + parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") + args = parser.parse_args() + + user = current_user + + if not user.is_admin_or_owner: + raise Forbidden() + + return BuiltinToolManageService.save_custom_oauth_client_params( + tenant_id=user.current_tenant_id, + provider=provider, + client_params=args.get("client_params", {}), + enable_oauth_custom_client=args.get("enable_oauth_custom_client", True), + ) + + @setup_required + @login_required + @account_initialization_required + def get(self, provider): + return jsonable_encoder( + BuiltinToolManageService.get_custom_oauth_client_params( + tenant_id=current_user.current_tenant_id, provider=provider + ) + ) + + @setup_required + @login_required + @account_initialization_required + def delete(self, provider): + return jsonable_encoder( + BuiltinToolManageService.delete_custom_oauth_client_params( + tenant_id=current_user.current_tenant_id, provider=provider + ) + ) + + +class ToolBuiltinProviderGetOauthClientSchemaApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider): + return jsonable_encoder( + BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema( + tenant_id=current_user.current_tenant_id, provider_name=provider + ) + ) + + +class ToolBuiltinProviderGetCredentialInfoApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider): + tenant_id = current_user.current_tenant_id + + return jsonable_encoder( + BuiltinToolManageService.get_builtin_tool_provider_credential_info( + tenant_id=tenant_id, + provider=provider, + ) + ) + + +class ToolProviderMCPApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("server_url", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, required=True, nullable=False, location="json") + parser.add_argument("icon", type=str, required=True, nullable=False, location="json") + parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json") + parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="") + parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + user = current_user + if not is_valid_url(args["server_url"]): + raise ValueError("Server URL is not valid.") + return jsonable_encoder( + MCPToolManageService.create_mcp_provider( + tenant_id=user.current_tenant_id, + server_url=args["server_url"], + name=args["name"], + icon=args["icon"], + icon_type=args["icon_type"], + icon_background=args["icon_background"], + user_id=user.id, + server_identifier=args["server_identifier"], + ) + ) + + @setup_required + @login_required + @account_initialization_required + def put(self): + parser = reqparse.RequestParser() + parser.add_argument("server_url", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, required=True, nullable=False, location="json") + parser.add_argument("icon", type=str, required=True, nullable=False, location="json") + parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json") + parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") + parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") + parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + if not is_valid_url(args["server_url"]): + if "[__HIDDEN__]" in args["server_url"]: + pass + else: + raise ValueError("Server URL is not valid.") + MCPToolManageService.update_mcp_provider( + tenant_id=current_user.current_tenant_id, + provider_id=args["provider_id"], + server_url=args["server_url"], + name=args["name"], + icon=args["icon"], + icon_type=args["icon_type"], + icon_background=args["icon_background"], + server_identifier=args["server_identifier"], + ) + return {"result": "success"} + + @setup_required + @login_required + @account_initialization_required + def delete(self): + parser = reqparse.RequestParser() + parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + MCPToolManageService.delete_mcp_tool(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"]) + return {"result": "success"} + + +class ToolMCPAuthApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") + parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="json") + args = parser.parse_args() + provider_id = args["provider_id"] + tenant_id = current_user.current_tenant_id + provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id) + if not provider: + raise ValueError("provider not found") + try: + with MCPClient( + provider.decrypted_server_url, + provider_id, + tenant_id, + authed=False, + authorization_code=args["authorization_code"], + for_list=True, + ): + MCPToolManageService.update_mcp_provider_credentials( + mcp_provider=provider, + credentials=provider.decrypted_credentials, + authed=True, + ) + return {"result": "success"} + + except MCPAuthError: + auth_provider = OAuthClientProvider(provider_id, tenant_id, for_list=True) + return auth(auth_provider, provider.decrypted_server_url, args["authorization_code"]) + except MCPError as e: + MCPToolManageService.update_mcp_provider_credentials( + mcp_provider=provider, + credentials={}, + authed=False, + ) + raise ValueError(f"Failed to connect to MCP server: {e}") from e + + +class ToolMCPDetailApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider_id): + user = current_user + provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, user.current_tenant_id) + return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True)) + + +class ToolMCPListAllApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + user = current_user + tenant_id = user.current_tenant_id + + tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id) + + return [tool.to_dict() for tool in tools] + + +class ToolMCPUpdateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider_id): + tenant_id = current_user.current_tenant_id + tools = MCPToolManageService.list_mcp_tool_from_remote_server( + tenant_id=tenant_id, + provider_id=provider_id, + ) + return jsonable_encoder(tools) + + +class ToolMCPCallbackApi(Resource): + def get(self): + parser = reqparse.RequestParser() + parser.add_argument("code", type=str, required=True, nullable=False, location="args") + parser.add_argument("state", type=str, required=True, nullable=False, location="args") + args = parser.parse_args() + state_key = args["state"] + authorization_code = args["code"] + handle_callback(state_key, authorization_code) + return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") + + # tool provider api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") +# tool oauth +api.add_resource(ToolPluginOAuthApi, "/oauth/plugin//tool/authorization-url") +api.add_resource(ToolOAuthCallback, "/oauth/plugin//tool/callback") +api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin//oauth/custom-client") + # builtin tool provider api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin//tools") api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin//info") +api.add_resource(ToolBuiltinProviderAddApi, "/workspaces/current/tool-provider/builtin//add") api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin//delete") api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin//update") +api.add_resource( + ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin//default-credential" +) +api.add_resource( + ToolBuiltinProviderGetCredentialInfoApi, "/workspaces/current/tool-provider/builtin//credential/info" +) api.add_resource( ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin//credentials" ) api.add_resource( ToolBuiltinProviderCredentialsSchemaApi, - "/workspaces/current/tool-provider/builtin//credentials_schema", + "/workspaces/current/tool-provider/builtin//credential/schema/", +) +api.add_resource( + ToolBuiltinProviderGetOauthClientSchemaApi, + "/workspaces/current/tool-provider/builtin//oauth/client-schema", ) api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin//icon") @@ -647,8 +1060,15 @@ api.add_resource(ToolWorkflowProviderDeleteApi, "/workspaces/current/tool-provid api.add_resource(ToolWorkflowProviderGetApi, "/workspaces/current/tool-provider/workflow/get") api.add_resource(ToolWorkflowProviderListToolApi, "/workspaces/current/tool-provider/workflow/tools") +# mcp tool provider +api.add_resource(ToolMCPDetailApi, "/workspaces/current/tool-provider/mcp/tools/") +api.add_resource(ToolProviderMCPApi, "/workspaces/current/tool-provider/mcp") +api.add_resource(ToolMCPUpdateApi, "/workspaces/current/tool-provider/mcp/update/") +api.add_resource(ToolMCPAuthApi, "/workspaces/current/tool-provider/mcp/auth") +api.add_resource(ToolMCPCallbackApi, "/mcp/oauth/callback") + api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin") api.add_resource(ToolApiListApi, "/workspaces/current/tools/api") +api.add_resource(ToolMCPListAllApi, "/workspaces/current/tools/mcp") api.add_resource(ToolWorkflowListApi, "/workspaces/current/tools/workflow") - api.add_resource(ToolLabelsApi, "/workspaces/current/tool-labels") diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index ca122772de..d862dac373 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -235,3 +235,29 @@ def email_password_login_enabled(view): abort(403) return decorated + + +def enable_change_email(view): + @wraps(view) + def decorated(*args, **kwargs): + features = FeatureService.get_system_features() + if features.enable_change_email: + return view(*args, **kwargs) + + # otherwise, return 403 + abort(403) + + return decorated + + +def is_allow_transfer_owner(view): + @wraps(view) + def decorated(*args, **kwargs): + features = FeatureService.get_features(current_user.current_tenant_id) + if features.is_allow_transfer_workspace: + return view(*args, **kwargs) + + # otherwise, return 403 + abort(403) + + return decorated diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index f1a15793c7..15f93d2774 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -87,7 +87,5 @@ class PluginUploadFileApi(Resource): except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - return tool_file, 201 - api.add_resource(PluginUploadFileApi, "/files/upload/for-plugin") diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 327e9ce834..5dfe41eb6b 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -175,6 +175,7 @@ class PluginInvokeToolApi(Resource): provider=payload.provider, tool_name=payload.tool, tool_parameters=payload.tool_parameters, + credential_id=payload.credential_id, ), ) diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index 50408e0929..b533614d4d 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -22,7 +22,7 @@ def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser: user_id = "DEFAULT-USER" if user_id == "DEFAULT-USER": - user_model = session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first() + user_model = session.query(EndUser).where(EndUser.session_id == "DEFAULT-USER").first() if not user_model: user_model = EndUser( tenant_id=tenant_id, @@ -36,7 +36,7 @@ def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser: else: user_model = AccountService.load_user(user_id) if not user_model: - user_model = session.query(EndUser).filter(EndUser.id == user_id).first() + user_model = session.query(EndUser).where(EndUser.id == user_id).first() if not user_model: raise ValueError("user not found") except Exception: @@ -71,7 +71,7 @@ def get_user_tenant(view: Optional[Callable] = None): try: tenant_model = ( db.session.query(Tenant) - .filter( + .where( Tenant.id == tenant_id, ) .first() diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index f3a9312dd0..9e7b3d4f29 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -55,7 +55,7 @@ def enterprise_inner_api_user_auth(view): if signature_base64 != token: return view(*args, **kwargs) - kwargs["user"] = db.session.query(EndUser).filter(EndUser.id == user_id).first() + kwargs["user"] = db.session.query(EndUser).where(EndUser.id == user_id).first() return view(*args, **kwargs) diff --git a/api/controllers/mcp/__init__.py b/api/controllers/mcp/__init__.py new file mode 100644 index 0000000000..1b3e0a5621 --- /dev/null +++ b/api/controllers/mcp/__init__.py @@ -0,0 +1,8 @@ +from flask import Blueprint + +from libs.external_api import ExternalApi + +bp = Blueprint("mcp", __name__, url_prefix="/mcp") +api = ExternalApi(bp) + +from . import mcp diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py new file mode 100644 index 0000000000..87d678796f --- /dev/null +++ b/api/controllers/mcp/mcp.py @@ -0,0 +1,104 @@ +from flask_restful import Resource, reqparse +from pydantic import ValidationError + +from controllers.console.app.mcp_server import AppMCPServerStatus +from controllers.mcp import api +from core.app.app_config.entities import VariableEntity +from core.mcp import types +from core.mcp.server.streamable_http import MCPServerStreamableHTTPRequestHandler +from core.mcp.types import ClientNotification, ClientRequest +from core.mcp.utils import create_mcp_error_response +from extensions.ext_database import db +from libs import helper +from models.model import App, AppMCPServer, AppMode + + +class MCPAppApi(Resource): + def post(self, server_code): + def int_or_str(value): + if isinstance(value, (int, str)): + return value + else: + return None + + parser = reqparse.RequestParser() + parser.add_argument("jsonrpc", type=str, required=True, location="json") + parser.add_argument("method", type=str, required=True, location="json") + parser.add_argument("params", type=dict, required=False, location="json") + parser.add_argument("id", type=int_or_str, required=False, location="json") + args = parser.parse_args() + + request_id = args.get("id") + + server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first() + if not server: + return helper.compact_generate_response( + create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found") + ) + + if server.status != AppMCPServerStatus.ACTIVE: + return helper.compact_generate_response( + create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active") + ) + + app = db.session.query(App).where(App.id == server.app_id).first() + if not app: + return helper.compact_generate_response( + create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found") + ) + + if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + workflow = app.workflow + if workflow is None: + return helper.compact_generate_response( + create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable") + ) + + user_input_form = workflow.user_input_form(to_old_structure=True) + else: + app_model_config = app.app_model_config + if app_model_config is None: + return helper.compact_generate_response( + create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable") + ) + + features_dict = app_model_config.to_dict() + user_input_form = features_dict.get("user_input_form", []) + converted_user_input_form: list[VariableEntity] = [] + try: + for item in user_input_form: + variable_type = item.get("type", "") or list(item.keys())[0] + variable = item[variable_type] + converted_user_input_form.append( + VariableEntity( + type=variable_type, + variable=variable.get("variable"), + description=variable.get("description") or "", + label=variable.get("label"), + required=variable.get("required", False), + max_length=variable.get("max_length"), + options=variable.get("options") or [], + ) + ) + except ValidationError as e: + return helper.compact_generate_response( + create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}") + ) + + try: + request: ClientRequest | ClientNotification = ClientRequest.model_validate(args) + except ValidationError as e: + try: + notification = ClientNotification.model_validate(args) + request = notification + except ValidationError as e: + return helper.compact_generate_response( + create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}") + ) + + mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form) + response = mcp_server_handler.handle() + return helper.compact_generate_response(response) + + +api.add_resource(MCPAppApi, "/server//mcp") diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 1d9890199d..7762672494 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -1,5 +1,6 @@ import logging +from flask import request from flask_restful import Resource, reqparse from werkzeug.exceptions import InternalServerError, NotFound @@ -23,6 +24,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from core.helper.trace_id_helper import get_external_trace_id from core.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value @@ -111,6 +113,10 @@ class ChatApi(Resource): args = parser.parse_args() + external_trace_id = get_external_trace_id(request) + if external_trace_id: + args["external_trace_id"] = external_trace_id + streaming = args["response_mode"] == "streaming" try: diff --git a/api/controllers/service_api/app/site.py b/api/controllers/service_api/app/site.py index e752dfee30..c157b39f6b 100644 --- a/api/controllers/service_api/app/site.py +++ b/api/controllers/service_api/app/site.py @@ -16,7 +16,7 @@ class AppSiteApi(Resource): @marshal_with(fields.site_fields) def get(self, app_model: App): """Retrieve app site info.""" - site = db.session.query(Site).filter(Site.app_id == app_model.id).first() + site = db.session.query(Site).where(Site.app_id == app_model.id).first() if not site: raise Forbidden() diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index efb4acc5fb..370ff911b4 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -1,9 +1,10 @@ import logging from dateutil.parser import isoparse +from flask import request from flask_restful import Resource, fields, marshal_with, reqparse from flask_restful.inputs import int_range -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import InternalServerError from controllers.service_api import api @@ -23,6 +24,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from core.helper.trace_id_helper import get_external_trace_id from core.model_runtime.errors.invoke import InvokeError from core.workflow.entities.workflow_execution import WorkflowExecutionStatus from extensions.ext_database import db @@ -30,7 +32,7 @@ from fields.workflow_app_log_fields import workflow_app_log_pagination_fields from libs import helper from libs.helper import TimestampField from models.model import App, AppMode, EndUser -from models.workflow import WorkflowRun +from repositories.factory import DifyAPIRepositoryFactory from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError from services.workflow_app_service import WorkflowAppService @@ -63,7 +65,15 @@ class WorkflowRunDetailApi(Resource): if app_mode not in [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]: raise NotWorkflowAppError() - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() + # Use repository to get workflow run + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + + workflow_run = workflow_run_repo.get_workflow_run_by_id( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + run_id=workflow_run_id, + ) return workflow_run @@ -82,7 +92,9 @@ class WorkflowRunApi(Resource): parser.add_argument("files", type=list, required=False, location="json") parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") args = parser.parse_args() - + external_trace_id = get_external_trace_id(request) + if external_trace_id: + args["external_trace_id"] = external_trace_id streaming = args.get("response_mode") == "streaming" try: diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index d571b21a0a..ac85c0b38d 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -63,7 +63,7 @@ class DocumentAddByTextApi(DatasetApiResource): dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise ValueError("Dataset does not exist.") @@ -136,7 +136,7 @@ class DocumentUpdateByTextApi(DatasetApiResource): args = parser.parse_args() dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise ValueError("Dataset does not exist.") @@ -206,7 +206,7 @@ class DocumentAddByFileApi(DatasetApiResource): # get dataset info dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise ValueError("Dataset does not exist.") @@ -299,7 +299,7 @@ class DocumentUpdateByFileApi(DatasetApiResource): # get dataset info dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise ValueError("Dataset does not exist.") @@ -367,7 +367,7 @@ class DocumentDeleteApi(DatasetApiResource): tenant_id = str(tenant_id) # get dataset info - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise ValueError("Dataset does not exist.") @@ -398,7 +398,7 @@ class DocumentListApi(DatasetApiResource): page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) search = request.args.get("keyword", default=None, type=str) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") @@ -406,7 +406,7 @@ class DocumentListApi(DatasetApiResource): if search: search = f"%{search}%" - query = query.filter(Document.name.like(search)) + query = query.where(Document.name.like(search)) query = query.order_by(desc(Document.created_at), desc(Document.position)) @@ -430,7 +430,7 @@ class DocumentIndexingStatusApi(DatasetApiResource): batch = str(batch) tenant_id = str(tenant_id) # get dataset - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # get documents @@ -441,7 +441,7 @@ class DocumentIndexingStatusApi(DatasetApiResource): for document in documents: completed_segments = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.completed_at.isnot(None), DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment", @@ -450,7 +450,7 @@ class DocumentIndexingStatusApi(DatasetApiResource): ) total_segments = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") .count() ) # Create a dictionary with document attributes and additional fields diff --git a/api/controllers/service_api/dataset/error.py b/api/controllers/service_api/dataset/error.py index 5ff5e08c72..ecc47b40a1 100644 --- a/api/controllers/service_api/dataset/error.py +++ b/api/controllers/service_api/dataset/error.py @@ -25,12 +25,6 @@ class UnsupportedFileTypeError(BaseHTTPException): code = 415 -class HighQualityDatasetOnlyError(BaseHTTPException): - error_code = "high_quality_dataset_only" - description = "Current operation only supports 'high-quality' datasets." - code = 400 - - class DatasetNotInitializedError(BaseHTTPException): error_code = "dataset_not_initialized" description = "The dataset is still being initialized or indexing. Please wait a moment." diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 403b7f0a0c..31f862dc8f 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -42,7 +42,7 @@ class SegmentApi(DatasetApiResource): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check document @@ -89,7 +89,7 @@ class SegmentApi(DatasetApiResource): tenant_id = str(tenant_id) page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check document @@ -146,7 +146,7 @@ class DatasetSegmentApi(DatasetApiResource): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check user's model setting @@ -170,7 +170,7 @@ class DatasetSegmentApi(DatasetApiResource): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check user's model setting @@ -216,7 +216,7 @@ class DatasetSegmentApi(DatasetApiResource): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check user's model setting @@ -246,7 +246,7 @@ class ChildChunkApi(DatasetApiResource): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") @@ -296,7 +296,7 @@ class ChildChunkApi(DatasetApiResource): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") @@ -343,7 +343,7 @@ class DatasetChildChunkApi(DatasetApiResource): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") @@ -382,7 +382,7 @@ class DatasetChildChunkApi(DatasetApiResource): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") diff --git a/api/controllers/service_api/dataset/upload_file.py b/api/controllers/service_api/dataset/upload_file.py index 6382b63ea9..3b4721b5b0 100644 --- a/api/controllers/service_api/dataset/upload_file.py +++ b/api/controllers/service_api/dataset/upload_file.py @@ -17,7 +17,7 @@ class UploadFileApi(DatasetApiResource): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check document @@ -31,7 +31,7 @@ class UploadFileApi(DatasetApiResource): data_source_info = document.data_source_info_dict if data_source_info and "upload_file_id" in data_source_info: file_id = data_source_info["upload_file_id"] - upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() + upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() if not upload_file: raise NotFound("UploadFile not found.") else: diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 5b919a68d4..da81cc8bc3 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -1,6 +1,6 @@ import time from collections.abc import Callable -from datetime import UTC, datetime, timedelta +from datetime import timedelta from enum import Enum from functools import wraps from typing import Optional @@ -15,6 +15,7 @@ from werkzeug.exceptions import Forbidden, NotFound, Unauthorized from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.datetime_utils import naive_utc_now from libs.login import _get_user from models.account import Account, Tenant, TenantAccountJoin, TenantStatus from models.dataset import Dataset, RateLimitLog @@ -43,7 +44,7 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio def decorated_view(*args, **kwargs): api_token = validate_and_get_api_token("app") - app_model = db.session.query(App).filter(App.id == api_token.app_id).first() + app_model = db.session.query(App).where(App.id == api_token.app_id).first() if not app_model: raise Forbidden("The app no longer exists.") @@ -53,7 +54,7 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio if not app_model.enable_api: raise Forbidden("The app's API service has been disabled.") - tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first() + tenant = db.session.query(Tenant).where(Tenant.id == app_model.tenant_id).first() if tenant is None: raise ValueError("Tenant does not exist.") if tenant.status == TenantStatus.ARCHIVE: @@ -61,15 +62,15 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio tenant_account_join = ( db.session.query(Tenant, TenantAccountJoin) - .filter(Tenant.id == api_token.tenant_id) - .filter(TenantAccountJoin.tenant_id == Tenant.id) - .filter(TenantAccountJoin.role.in_(["owner"])) - .filter(Tenant.status == TenantStatus.NORMAL) + .where(Tenant.id == api_token.tenant_id) + .where(TenantAccountJoin.tenant_id == Tenant.id) + .where(TenantAccountJoin.role.in_(["owner"])) + .where(Tenant.status == TenantStatus.NORMAL) .one_or_none() ) # TODO: only owner information is required, so only one is returned. if tenant_account_join: tenant, ta = tenant_account_join - account = db.session.query(Account).filter(Account.id == ta.account_id).first() + account = db.session.query(Account).where(Account.id == ta.account_id).first() # Login admin if account: account.current_tenant = tenant @@ -212,15 +213,15 @@ def validate_dataset_token(view=None): api_token = validate_and_get_api_token("dataset") tenant_account_join = ( db.session.query(Tenant, TenantAccountJoin) - .filter(Tenant.id == api_token.tenant_id) - .filter(TenantAccountJoin.tenant_id == Tenant.id) - .filter(TenantAccountJoin.role.in_(["owner"])) - .filter(Tenant.status == TenantStatus.NORMAL) + .where(Tenant.id == api_token.tenant_id) + .where(TenantAccountJoin.tenant_id == Tenant.id) + .where(TenantAccountJoin.role.in_(["owner"])) + .where(Tenant.status == TenantStatus.NORMAL) .one_or_none() ) # TODO: only owner information is required, so only one is returned. if tenant_account_join: tenant, ta = tenant_account_join - account = db.session.query(Account).filter(Account.id == ta.account_id).first() + account = db.session.query(Account).where(Account.id == ta.account_id).first() # Login admin if account: account.current_tenant = tenant @@ -256,7 +257,7 @@ def validate_and_get_api_token(scope: str | None = None): if auth_scheme != "bearer": raise Unauthorized("Authorization scheme must be 'Bearer'") - current_time = datetime.now(UTC).replace(tzinfo=None) + current_time = naive_utc_now() cutoff_time = current_time - timedelta(minutes=1) with Session(db.engine, expire_on_commit=False) as session: update_stmt = ( @@ -292,7 +293,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] end_user = ( db.session.query(EndUser) - .filter( + .where( EndUser.tenant_id == app_model.tenant_id, EndUser.app_id == app_model.id, EndUser.session_id == user_id, @@ -319,7 +320,7 @@ class DatasetApiResource(Resource): method_decorators = [validate_dataset_token] def get_dataset(self, dataset_id: str, tenant_id: str) -> Dataset: - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first() if not dataset: raise NotFound("Dataset not found.") diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 10c3cdcf0e..acd3a8b539 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -3,6 +3,7 @@ from datetime import UTC, datetime, timedelta from flask import request from flask_restful import Resource +from sqlalchemy import func, select from werkzeug.exceptions import NotFound, Unauthorized from configs import dify_config @@ -42,17 +43,17 @@ class PassportResource(Resource): raise WebAppAuthRequiredError() # get site from db and check if it is normal - site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() + site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal")) if not site: raise NotFound() # get app from db and check if it is normal and enable_site - app_model = db.session.query(App).filter(App.id == site.app_id).first() + app_model = db.session.scalar(select(App).where(App.id == site.app_id)) if not app_model or app_model.status != "normal" or not app_model.enable_site: raise NotFound() if user_id: - end_user = ( - db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first() + end_user = db.session.scalar( + select(EndUser).where(EndUser.app_id == app_model.id, EndUser.session_id == user_id) ) if end_user: @@ -121,11 +122,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: if not user_auth_type: raise Unauthorized("Missing auth_type in the token.") - site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() + site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal")) if not site: raise NotFound() - app_model = db.session.query(App).filter(App.id == site.app_id).first() + app_model = db.session.scalar(select(App).where(App.id == site.app_id)) if not app_model or app_model.status != "normal" or not app_model.enable_site: raise NotFound() @@ -140,16 +141,14 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: end_user = None if end_user_id: - end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first() + end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id)) if session_id: - end_user = ( - db.session.query(EndUser) - .filter( + end_user = db.session.scalar( + select(EndUser).where( EndUser.session_id == session_id, EndUser.tenant_id == app_model.tenant_id, EndUser.app_id == app_model.id, ) - .first() ) if not end_user: if not session_id: @@ -187,8 +186,8 @@ def _exchange_for_public_app_token(app_model, site, token_decoded): user_id = token_decoded.get("user_id") end_user = None if user_id: - end_user = ( - db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first() + end_user = db.session.scalar( + select(EndUser).where(EndUser.app_id == app_model.id, EndUser.session_id == user_id) ) if not end_user: @@ -224,6 +223,8 @@ def generate_session_id(): """ while True: session_id = str(uuid.uuid4()) - existing_count = db.session.query(EndUser).filter(EndUser.session_id == session_id).count() + existing_count = db.session.scalar( + select(func.count()).select_from(EndUser).where(EndUser.session_id == session_id) + ) if existing_count == 0: return session_id diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 0564b15ea3..3c133499b7 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -57,7 +57,7 @@ class AppSiteApi(WebApiResource): def get(self, app_model, end_user): """Retrieve app site info.""" # get site - site = db.session.query(Site).filter(Site.app_id == app_model.id).first() + site = db.session.query(Site).where(Site.app_id == app_model.id).first() if not site: raise Forbidden() diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index 154bddfc5c..ae6f14a689 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -3,6 +3,7 @@ from functools import wraps from flask import request from flask_restful import Resource +from sqlalchemy import select from werkzeug.exceptions import BadRequest, NotFound, Unauthorized from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError @@ -48,8 +49,8 @@ def decode_jwt_token(): decoded = PassportService().verify(tk) app_code = decoded.get("app_code") app_id = decoded.get("app_id") - app_model = db.session.query(App).filter(App.id == app_id).first() - site = db.session.query(Site).filter(Site.code == app_code).first() + app_model = db.session.scalar(select(App).where(App.id == app_id)) + site = db.session.scalar(select(Site).where(Site.code == app_code)) if not app_model: raise NotFound() if not app_code or not site: @@ -57,7 +58,7 @@ def decode_jwt_token(): if app_model.enable_site is False: raise BadRequest("Site is disabled.") end_user_id = decoded.get("end_user_id") - end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first() + end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id)) if not end_user: raise NotFound() diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 6998e4d29a..1f3c218d59 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -3,6 +3,8 @@ import logging import uuid from typing import Optional, Union, cast +from sqlalchemy import select + from core.agent.entities import AgentEntity, AgentToolEntity from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig @@ -97,7 +99,7 @@ class BaseAgentRunner(AppRunner): # get how many agent thoughts have been created self.agent_thought_count = ( db.session.query(MessageAgentThought) - .filter( + .where( MessageAgentThought.message_id == self.message.id, ) .count() @@ -161,10 +163,14 @@ class BaseAgentRunner(AppRunner): if parameter.type == ToolParameter.ToolParameterType.SELECT: enum = [option.value for option in parameter.options] if parameter.options else [] - message_tool.parameters["properties"][parameter.name] = { - "type": parameter_type, - "description": parameter.llm_description or "", - } + message_tool.parameters["properties"][parameter.name] = ( + { + "type": parameter_type, + "description": parameter.llm_description or "", + } + if parameter.input_schema is None + else parameter.input_schema + ) if len(enum) > 0: message_tool.parameters["properties"][parameter.name]["enum"] = enum @@ -254,10 +260,14 @@ class BaseAgentRunner(AppRunner): if parameter.type == ToolParameter.ToolParameterType.SELECT: enum = [option.value for option in parameter.options] if parameter.options else [] - prompt_tool.parameters["properties"][parameter.name] = { - "type": parameter_type, - "description": parameter.llm_description or "", - } + prompt_tool.parameters["properties"][parameter.name] = ( + { + "type": parameter_type, + "description": parameter.llm_description or "", + } + if parameter.input_schema is None + else parameter.input_schema + ) if len(enum) > 0: prompt_tool.parameters["properties"][parameter.name]["enum"] = enum @@ -326,7 +336,7 @@ class BaseAgentRunner(AppRunner): Save agent thought """ updated_agent_thought = ( - db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() + db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought.id).first() ) if not updated_agent_thought: raise ValueError("agent thought not found") @@ -409,12 +419,15 @@ class BaseAgentRunner(AppRunner): if isinstance(prompt_message, SystemPromptMessage): result.append(prompt_message) - messages: list[Message] = ( - db.session.query(Message) - .filter( - Message.conversation_id == self.message.conversation_id, + messages = ( + ( + db.session.execute( + select(Message) + .where(Message.conversation_id == self.message.conversation_id) + .order_by(Message.created_at.desc()) + ) ) - .order_by(Message.created_at.desc()) + .scalars() .all() ) @@ -483,7 +496,7 @@ class BaseAgentRunner(AppRunner): return result def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: - files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() + files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all() if not files: return UserPromptMessage(content=message.query) if message.app_model_config: diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 143a3a51aa..a31c1050bd 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -16,6 +16,7 @@ class AgentToolEntity(BaseModel): tool_name: str tool_parameters: dict[str, Any] = Field(default_factory=dict) plugin_unique_identifier: str | None = None + credential_id: str | None = None class AgentPromptEntity(BaseModel): diff --git a/api/core/agent/plugin_entities.py b/api/core/agent/plugin_entities.py index 9c722baa23..a3438fc2c7 100644 --- a/api/core/agent/plugin_entities.py +++ b/api/core/agent/plugin_entities.py @@ -41,6 +41,7 @@ class AgentStrategyParameter(PluginParameter): APP_SELECTOR = CommonParameterType.APP_SELECTOR.value MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value + ANY = CommonParameterType.ANY.value # deprecated, should not use. SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value @@ -85,7 +86,7 @@ class AgentStrategyEntity(BaseModel): description: I18nObject = Field(..., description="The description of the agent strategy") output_schema: Optional[dict] = None features: Optional[list[AgentFeature]] = None - + meta_version: Optional[str] = None # pydantic configs model_config = ConfigDict(protected_namespaces=()) diff --git a/api/core/agent/strategy/base.py b/api/core/agent/strategy/base.py index ead81a7a0e..a52a1dfd7a 100644 --- a/api/core/agent/strategy/base.py +++ b/api/core/agent/strategy/base.py @@ -4,6 +4,7 @@ from typing import Any, Optional from core.agent.entities import AgentInvokeMessage from core.agent.plugin_entities import AgentStrategyParameter +from core.plugin.entities.request import InvokeCredentials class BaseAgentStrategy(ABC): @@ -18,11 +19,12 @@ class BaseAgentStrategy(ABC): conversation_id: Optional[str] = None, app_id: Optional[str] = None, message_id: Optional[str] = None, + credentials: Optional[InvokeCredentials] = None, ) -> Generator[AgentInvokeMessage, None, None]: """ Invoke the agent strategy. """ - yield from self._invoke(params, user_id, conversation_id, app_id, message_id) + yield from self._invoke(params, user_id, conversation_id, app_id, message_id, credentials) def get_parameters(self) -> Sequence[AgentStrategyParameter]: """ @@ -38,5 +40,6 @@ class BaseAgentStrategy(ABC): conversation_id: Optional[str] = None, app_id: Optional[str] = None, message_id: Optional[str] = None, + credentials: Optional[InvokeCredentials] = None, ) -> Generator[AgentInvokeMessage, None, None]: pass diff --git a/api/core/agent/strategy/plugin.py b/api/core/agent/strategy/plugin.py index 79b074cf95..04661581a7 100644 --- a/api/core/agent/strategy/plugin.py +++ b/api/core/agent/strategy/plugin.py @@ -4,6 +4,7 @@ from typing import Any, Optional from core.agent.entities import AgentInvokeMessage from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter from core.agent.strategy.base import BaseAgentStrategy +from core.plugin.entities.request import InvokeCredentials, PluginInvokeContext from core.plugin.impl.agent import PluginAgentClient from core.plugin.utils.converter import convert_parameters_to_plugin_format @@ -15,10 +16,12 @@ class PluginAgentStrategy(BaseAgentStrategy): tenant_id: str declaration: AgentStrategyEntity + meta_version: str | None = None - def __init__(self, tenant_id: str, declaration: AgentStrategyEntity): + def __init__(self, tenant_id: str, declaration: AgentStrategyEntity, meta_version: str | None): self.tenant_id = tenant_id self.declaration = declaration + self.meta_version = meta_version def get_parameters(self) -> Sequence[AgentStrategyParameter]: return self.declaration.parameters @@ -38,6 +41,7 @@ class PluginAgentStrategy(BaseAgentStrategy): conversation_id: Optional[str] = None, app_id: Optional[str] = None, message_id: Optional[str] = None, + credentials: Optional[InvokeCredentials] = None, ) -> Generator[AgentInvokeMessage, None, None]: """ Invoke the agent strategy. @@ -56,4 +60,5 @@ class PluginAgentStrategy(BaseAgentStrategy): conversation_id=conversation_id, app_id=app_id, message_id=message_id, + context=PluginInvokeContext(credentials=credentials or InvokeCredentials()), ) diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index 590b944c0d..8887d2500c 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -39,6 +39,7 @@ class AgentConfigManager: "provider_id": tool["provider_id"], "tool_name": tool["tool_name"], "tool_parameters": tool.get("tool_parameters", {}), + "credential_id": tool.get("credential_id", None), } agent_tools.append(AgentToolEntity(**agent_tool_properties)) diff --git a/api/core/app/apps/README.md b/api/core/app/apps/README.md deleted file mode 100644 index 7a57bb3658..0000000000 --- a/api/core/app/apps/README.md +++ /dev/null @@ -1,48 +0,0 @@ -## Guidelines for Database Connection Management in App Runner and Task Pipeline - -Due to the presence of tasks in App Runner that require long execution times, such as LLM generation and external requests, Flask-Sqlalchemy's strategy for database connection pooling is to allocate one connection (transaction) per request. This approach keeps a connection occupied even during non-DB tasks, leading to the inability to acquire new connections during high concurrency requests due to multiple long-running tasks. - -Therefore, the database operations in App Runner and Task Pipeline must ensure connections are closed immediately after use, and it's better to pass IDs rather than Model objects to avoid detach errors. - -Examples: - -1. Creating a new record: - - ```python - app = App(id=1) - db.session.add(app) - db.session.commit() - db.session.refresh(app) # Retrieve table default values, like created_at, cached in the app object, won't affect after close - - # Handle non-long-running tasks or store the content of the App instance in memory (via variable assignment). - - db.session.close() - - return app.id - ``` - -2. Fetching a record from the table: - - ```python - app = db.session.query(App).filter(App.id == app_id).first() - - created_at = app.created_at - - db.session.close() - - # Handle tasks (include long-running). - - ``` - -3. Updating a table field: - - ```python - app = db.session.query(App).filter(App.id == app_id).first() - - app.updated_at = time.utcnow() - db.session.commit() - db.session.close() - - return app_id - ``` - diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 7877408cef..610a5bb278 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -7,7 +7,8 @@ from typing import Any, Literal, Optional, Union, overload from flask import Flask, current_app from pydantic import ValidationError -from sqlalchemy.orm import sessionmaker +from sqlalchemy import select +from sqlalchemy.orm import Session, sessionmaker import contexts from configs import dify_config @@ -17,16 +18,17 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse +from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.get_thread_messages_length import get_thread_messages_length -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.repositories.draft_variable_repository import ( DraftVariableSaverFactory, ) @@ -112,7 +114,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): query = query.replace("\x00", "") inputs = args["inputs"] - extras = {"auto_generate_conversation_name": args.get("auto_generate_name", False)} + extras = { + "auto_generate_conversation_name": args.get("auto_generate_name", False), + **extract_external_trace_id_from_args(args), + } # get conversation conversation = None @@ -183,14 +188,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING else: workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=workflow_triggered_from, ) # Create workflow node execution repository - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -260,14 +265,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -343,14 +348,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -482,21 +487,52 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): """ with preserve_flask_contexts(flask_app, context_vars=context): - try: - # get conversation and message - conversation = self._get_conversation(conversation_id) - message = self._get_message(message_id) + # get conversation and message + conversation = self._get_conversation(conversation_id) + message = self._get_message(message_id) - # chatbot app - runner = AdvancedChatAppRunner( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - dialogue_count=self._dialogue_count, - variable_loader=variable_loader, + with Session(db.engine, expire_on_commit=False) as session: + workflow = session.scalar( + select(Workflow).where( + Workflow.tenant_id == application_generate_entity.app_config.tenant_id, + Workflow.app_id == application_generate_entity.app_config.app_id, + Workflow.id == application_generate_entity.app_config.workflow_id, + ) ) + if workflow is None: + raise ValueError("Workflow not found") + # Determine system_user_id based on invocation source + is_external_api_call = application_generate_entity.invoke_from in { + InvokeFrom.WEB_APP, + InvokeFrom.SERVICE_API, + } + + if is_external_api_call: + # For external API calls, use end user's session ID + end_user = session.scalar(select(EndUser).where(EndUser.id == application_generate_entity.user_id)) + system_user_id = end_user.session_id if end_user else "" + else: + # For internal calls, use the original user ID + system_user_id = application_generate_entity.user_id + + app = session.scalar(select(App).where(App.id == application_generate_entity.app_config.app_id)) + if app is None: + raise ValueError("App not found") + + runner = AdvancedChatAppRunner( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + dialogue_count=self._dialogue_count, + variable_loader=variable_loader, + workflow=workflow, + system_user_id=system_user_id, + app=app, + ) + + try: runner.run() except GenerateTaskStoppedError: pass diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 840a3c9d3b..a75e17af64 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -1,6 +1,6 @@ import logging from collections.abc import Mapping -from typing import Any, cast +from typing import Any, Optional, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -9,21 +9,29 @@ from configs import dify_config from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.app.entities.app_invoke_entities import ( + AdvancedChatAppGenerateEntity, + AppGenerateEntity, + InvokeFrom, +) from core.app.entities.queue_entities import ( QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent, ) +from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.moderation.base import ModerationError +from core.moderation.input_moderation import InputModeration +from core.variables.variables import VariableUnion from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey +from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db +from models import Workflow from models.enums import UserFrom -from models.model import App, Conversation, EndUser, Message +from models.model import App, Conversation, Message, MessageAnnotation from models.workflow import ConversationVariable, WorkflowType logger = logging.getLogger(__name__) @@ -36,42 +44,38 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): def __init__( self, + *, application_generate_entity: AdvancedChatAppGenerateEntity, queue_manager: AppQueueManager, conversation: Conversation, message: Message, dialogue_count: int, variable_loader: VariableLoader, + workflow: Workflow, + system_user_id: str, + app: App, ) -> None: - super().__init__(queue_manager, variable_loader) + super().__init__( + queue_manager=queue_manager, + variable_loader=variable_loader, + app_id=application_generate_entity.app_config.app_id, + ) self.application_generate_entity = application_generate_entity self.conversation = conversation self.message = message self._dialogue_count = dialogue_count - - def _get_app_id(self) -> str: - return self.application_generate_entity.app_config.app_id + self._workflow = workflow + self.system_user_id = system_user_id + self._app = app def run(self) -> None: app_config = self.application_generate_entity.app_config app_config = cast(AdvancedChatAppConfig, app_config) - app_record = db.session.query(App).filter(App.id == app_config.app_id).first() + app_record = db.session.query(App).where(App.id == app_config.app_id).first() if not app_record: raise ValueError("App not found") - workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) - if not workflow: - raise ValueError("Workflow not initialized") - - user_id = None - if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: - end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() - if end_user: - user_id = end_user.session_id - else: - user_id = self.application_generate_entity.user_id - workflow_callbacks: list[WorkflowCallback] = [] if dify_config.DEBUG: workflow_callbacks.append(WorkflowLoggingCallback()) @@ -79,14 +83,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): if self.application_generate_entity.single_iteration_run: # if only single iteration run is requested graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( - workflow=workflow, + workflow=self._workflow, node_id=self.application_generate_entity.single_iteration_run.node_id, user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs), ) elif self.application_generate_entity.single_loop_run: # if only single loop run is requested graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( - workflow=workflow, + workflow=self._workflow, node_id=self.application_generate_entity.single_loop_run.node_id, user_inputs=dict(self.application_generate_entity.single_loop_run.inputs), ) @@ -97,7 +101,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): # moderation if self.handle_input_moderation( - app_record=app_record, + app_record=self._app, app_generate_entity=self.application_generate_entity, inputs=inputs, query=query, @@ -107,7 +111,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): # annotation reply if self.handle_annotation_reply( - app_record=app_record, + app_record=self._app, message=self.message, query=query, app_generate_entity=self.application_generate_entity, @@ -127,7 +131,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): ConversationVariable.from_variable( app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable ) - for variable in workflow.conversation_variables + for variable in self._workflow.conversation_variables ] session.add_all(db_conversation_variables) # Convert database entities to variables. @@ -136,38 +140,40 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): session.commit() # Create a variable pool. - system_inputs = { - SystemVariableKey.QUERY: query, - SystemVariableKey.FILES: files, - SystemVariableKey.CONVERSATION_ID: self.conversation.id, - SystemVariableKey.USER_ID: user_id, - SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count, - SystemVariableKey.APP_ID: app_config.app_id, - SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, - SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_run_id, - } + system_inputs = SystemVariable( + query=query, + files=files, + conversation_id=self.conversation.id, + user_id=self.system_user_id, + dialogue_count=self._dialogue_count, + app_id=app_config.app_id, + workflow_id=app_config.workflow_id, + workflow_execution_id=self.application_generate_entity.workflow_run_id, + ) # init variable pool variable_pool = VariablePool( system_variables=system_inputs, user_inputs=inputs, - environment_variables=workflow.environment_variables, - conversation_variables=conversation_variables, + environment_variables=self._workflow.environment_variables, + # Based on the definition of `VariableUnion`, + # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. + conversation_variables=cast(list[VariableUnion], conversation_variables), ) # init graph - graph = self._init_graph(graph_config=workflow.graph_dict) + graph = self._init_graph(graph_config=self._workflow.graph_dict) db.session.close() # RUN WORKFLOW workflow_entry = WorkflowEntry( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - workflow_id=workflow.id, - workflow_type=WorkflowType.value_of(workflow.type), + tenant_id=self._workflow.tenant_id, + app_id=self._workflow.app_id, + workflow_id=self._workflow.id, + workflow_type=WorkflowType.value_of(self._workflow.type), graph=graph, - graph_config=workflow.graph_dict, + graph_config=self._workflow.graph_dict, user_id=self.application_generate_entity.user_id, user_from=( UserFrom.ACCOUNT @@ -238,3 +244,51 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): self._publish_event(QueueTextChunkEvent(text=text)) self._publish_event(QueueStopEvent(stopped_by=stopped_by)) + + def query_app_annotations_to_reply( + self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom + ) -> Optional[MessageAnnotation]: + """ + Query app annotations to reply + :param app_record: app record + :param message: message + :param query: query + :param user_id: user id + :param invoke_from: invoke from + :return: + """ + annotation_reply_feature = AnnotationReplyFeature() + return annotation_reply_feature.query( + app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from + ) + + def moderation_for_inputs( + self, + *, + app_id: str, + tenant_id: str, + app_generate_entity: AppGenerateEntity, + inputs: Mapping[str, Any], + query: str | None = None, + message_id: str, + ) -> tuple[bool, Mapping[str, Any], str]: + """ + Process sensitive_word_avoidance. + :param app_id: app id + :param tenant_id: tenant id + :param app_generate_entity: app generate entity + :param inputs: inputs + :param query: query + :param message_id: message id + :return: + """ + moderation_feature = InputModeration() + return moderation_feature.check( + app_id=app_id, + tenant_id=tenant_id, + app_config=app_generate_entity.app_config, + inputs=dict(inputs), + query=query or "", + message_id=message_id, + trace_manager=app_generate_entity.trace_manager, + ) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 4c52fc3e83..dc27076a4d 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -1,6 +1,7 @@ import logging import time -from collections.abc import Generator, Mapping +from collections.abc import Callable, Generator, Mapping +from contextlib import contextmanager from threading import Thread from typing import Any, Optional, Union @@ -15,6 +16,7 @@ from core.app.entities.app_invoke_entities import ( InvokeFrom, ) from core.app.entities.queue_entities import ( + MessageQueueMessage, QueueAdvancedChatMessageEndEvent, QueueAgentLogEvent, QueueAnnotationReplyEvent, @@ -44,6 +46,7 @@ from core.app.entities.queue_entities import ( QueueWorkflowPartialSuccessEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, + WorkflowQueueMessage, ) from core.app.entities.task_entities import ( ChatbotAppBlockingResponse, @@ -52,6 +55,7 @@ from core.app.entities.task_entities import ( MessageAudioEndStreamResponse, MessageAudioStreamResponse, MessageEndStreamResponse, + PingStreamResponse, StreamResponse, WorkflowTaskState, ) @@ -61,12 +65,12 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.model_runtime.entities.llm_entities import LLMUsage from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes import NodeType from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.system_variable import SystemVariable from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager from events.message_event import message_was_created from extensions.ext_database import db @@ -116,16 +120,16 @@ class AdvancedChatAppGenerateTaskPipeline: self._workflow_cycle_manager = WorkflowCycleManager( application_generate_entity=application_generate_entity, - workflow_system_variables={ - SystemVariableKey.QUERY: message.query, - SystemVariableKey.FILES: application_generate_entity.files, - SystemVariableKey.CONVERSATION_ID: conversation.id, - SystemVariableKey.USER_ID: user_session_id, - SystemVariableKey.DIALOGUE_COUNT: dialogue_count, - SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, - SystemVariableKey.WORKFLOW_ID: workflow.id, - SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_run_id, - }, + workflow_system_variables=SystemVariable( + query=message.query, + files=application_generate_entity.files, + conversation_id=conversation.id, + user_id=user_session_id, + dialogue_count=dialogue_count, + app_id=application_generate_entity.app_config.app_id, + workflow_id=workflow.id, + workflow_execution_id=application_generate_entity.workflow_run_id, + ), workflow_info=CycleManagerWorkflowInfo( workflow_id=workflow.id, workflow_type=WorkflowType(workflow.type), @@ -162,7 +166,6 @@ class AdvancedChatAppGenerateTaskPipeline: Process generate task pipeline. :return: """ - # start generate conversation name thread self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name( conversation_id=self._conversation_id, query=self._application_generate_entity.query ) @@ -254,15 +257,12 @@ class AdvancedChatAppGenerateTaskPipeline: yield response start_listener_time = time.time() - # timeout while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: try: if not tts_publisher: break audio_trunk = tts_publisher.check_and_get_audio() if audio_trunk is None: - # release cpu - # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME) continue if audio_trunk.status == "finish": @@ -276,403 +276,617 @@ class AdvancedChatAppGenerateTaskPipeline: if tts_publisher: yield MessageAudioEndStreamResponse(audio="", task_id=task_id) + @contextmanager + def _database_session(self): + """Context manager for database sessions.""" + with Session(db.engine, expire_on_commit=False) as session: + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + + def _ensure_workflow_initialized(self) -> None: + """Fluent validation for workflow state.""" + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + def _ensure_graph_runtime_initialized(self, graph_runtime_state: Optional[GraphRuntimeState]) -> GraphRuntimeState: + """Fluent validation for graph runtime state.""" + if not graph_runtime_state: + raise ValueError("graph runtime state not initialized.") + return graph_runtime_state + + def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]: + """Handle ping events.""" + yield self._base_task_pipeline._ping_stream_response() + + def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]: + """Handle error events.""" + with self._database_session() as session: + err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id) + yield self._base_task_pipeline._error_to_stream_response(err) + + def _handle_workflow_started_event( + self, event: QueueWorkflowStartedEvent, *, graph_runtime_state: Optional[GraphRuntimeState] = None, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle workflow started events.""" + # Override graph runtime state - this is a side effect but necessary + graph_runtime_state = event.graph_runtime_state + + with self._database_session() as session: + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start() + self._workflow_run_id = workflow_execution.id_ + + message = self._get_message(session=session) + if not message: + raise ValueError(f"Message not found: {self._message_id}") + + message.workflow_run_id = workflow_execution.id_ + workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, + ) + + yield workflow_start_resp + + def _handle_node_retry_event(self, event: QueueNodeRetryEvent, **kwargs) -> Generator[StreamResponse, None, None]: + """Handle node retry events.""" + self._ensure_workflow_initialized() + + with self._database_session() as session: + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried( + workflow_execution_id=self._workflow_run_id, event=event + ) + node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + if node_retry_resp: + yield node_retry_resp + + def _handle_node_started_event( + self, event: QueueNodeStartedEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle node started events.""" + self._ensure_workflow_initialized() + + workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start( + workflow_execution_id=self._workflow_run_id, event=event + ) + + node_start_resp = self._workflow_response_converter.workflow_node_start_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + if node_start_resp: + yield node_start_resp + + def _handle_node_succeeded_event( + self, event: QueueNodeSucceededEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle node succeeded events.""" + # Record files if it's an answer node or end node + if event.node_type in [NodeType.ANSWER, NodeType.END]: + self._recorded_files.extend( + self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {}) + ) + + with self._database_session() as session: + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event) + node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + self._save_output_for_event(event, workflow_node_execution.id) + + if node_finish_resp: + yield node_finish_resp + + def _handle_node_failed_events( + self, + event: Union[ + QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent + ], + **kwargs, + ) -> Generator[StreamResponse, None, None]: + """Handle various node failure events.""" + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(event=event) + + node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + if isinstance(event, QueueNodeExceptionEvent): + self._save_output_for_event(event, workflow_node_execution.id) + + if node_finish_resp: + yield node_finish_resp + + def _handle_text_chunk_event( + self, + event: QueueTextChunkEvent, + *, + tts_publisher: Optional[AppGeneratorTTSPublisher] = None, + queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None, + **kwargs, + ) -> Generator[StreamResponse, None, None]: + """Handle text chunk events.""" + delta_text = event.text + if delta_text is None: + return + + # Handle output moderation chunk + should_direct_answer = self._handle_output_moderation_chunk(delta_text) + if should_direct_answer: + return + + # Only publish tts message at text chunk streaming + if tts_publisher and queue_message: + tts_publisher.publish(queue_message) + + self._task_state.answer += delta_text + yield self._message_cycle_manager.message_to_stream_response( + answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector + ) + + def _handle_parallel_branch_started_event( + self, event: QueueParallelBranchRunStartedEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle parallel branch started events.""" + self._ensure_workflow_initialized() + + parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield parallel_start_resp + + def _handle_parallel_branch_finished_events( + self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle parallel branch finished events.""" + self._ensure_workflow_initialized() + + parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield parallel_finish_resp + + def _handle_iteration_start_event( + self, event: QueueIterationStartEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle iteration start events.""" + self._ensure_workflow_initialized() + + iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield iter_start_resp + + def _handle_iteration_next_event( + self, event: QueueIterationNextEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle iteration next events.""" + self._ensure_workflow_initialized() + + iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield iter_next_resp + + def _handle_iteration_completed_event( + self, event: QueueIterationCompletedEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle iteration completed events.""" + self._ensure_workflow_initialized() + + iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield iter_finish_resp + + def _handle_loop_start_event(self, event: QueueLoopStartEvent, **kwargs) -> Generator[StreamResponse, None, None]: + """Handle loop start events.""" + self._ensure_workflow_initialized() + + loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield loop_start_resp + + def _handle_loop_next_event(self, event: QueueLoopNextEvent, **kwargs) -> Generator[StreamResponse, None, None]: + """Handle loop next events.""" + self._ensure_workflow_initialized() + + loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield loop_next_resp + + def _handle_loop_completed_event( + self, event: QueueLoopCompletedEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle loop completed events.""" + self._ensure_workflow_initialized() + + loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield loop_finish_resp + + def _handle_workflow_succeeded_event( + self, + event: QueueWorkflowSucceededEvent, + *, + graph_runtime_state: Optional[GraphRuntimeState] = None, + trace_manager: Optional[TraceQueueManager] = None, + **kwargs, + ) -> Generator[StreamResponse, None, None]: + """Handle workflow succeeded events.""" + self._ensure_workflow_initialized() + validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state) + + with self._database_session() as session: + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success( + workflow_run_id=self._workflow_run_id, + total_tokens=validated_state.total_tokens, + total_steps=validated_state.node_run_steps, + outputs=event.outputs, + conversation_id=self._conversation_id, + trace_manager=trace_manager, + external_trace_id=self._application_generate_entity.extras.get("external_trace_id"), + ) + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, + ) + + yield workflow_finish_resp + self._base_task_pipeline._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) + + def _handle_workflow_partial_success_event( + self, + event: QueueWorkflowPartialSuccessEvent, + *, + graph_runtime_state: Optional[GraphRuntimeState] = None, + trace_manager: Optional[TraceQueueManager] = None, + **kwargs, + ) -> Generator[StreamResponse, None, None]: + """Handle workflow partial success events.""" + self._ensure_workflow_initialized() + validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state) + + with self._database_session() as session: + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success( + workflow_run_id=self._workflow_run_id, + total_tokens=validated_state.total_tokens, + total_steps=validated_state.node_run_steps, + outputs=event.outputs, + exceptions_count=event.exceptions_count, + conversation_id=None, + trace_manager=trace_manager, + external_trace_id=self._application_generate_entity.extras.get("external_trace_id"), + ) + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, + ) + + yield workflow_finish_resp + self._base_task_pipeline._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) + + def _handle_workflow_failed_event( + self, + event: QueueWorkflowFailedEvent, + *, + graph_runtime_state: Optional[GraphRuntimeState] = None, + trace_manager: Optional[TraceQueueManager] = None, + **kwargs, + ) -> Generator[StreamResponse, None, None]: + """Handle workflow failed events.""" + self._ensure_workflow_initialized() + validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state) + + with self._database_session() as session: + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed( + workflow_run_id=self._workflow_run_id, + total_tokens=validated_state.total_tokens, + total_steps=validated_state.node_run_steps, + status=WorkflowExecutionStatus.FAILED, + error_message=event.error, + conversation_id=self._conversation_id, + trace_manager=trace_manager, + exceptions_count=event.exceptions_count, + external_trace_id=self._application_generate_entity.extras.get("external_trace_id"), + ) + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, + ) + err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}")) + err = self._base_task_pipeline._handle_error(event=err_event, session=session, message_id=self._message_id) + + yield workflow_finish_resp + yield self._base_task_pipeline._error_to_stream_response(err) + + def _handle_stop_event( + self, + event: QueueStopEvent, + *, + graph_runtime_state: Optional[GraphRuntimeState] = None, + trace_manager: Optional[TraceQueueManager] = None, + **kwargs, + ) -> Generator[StreamResponse, None, None]: + """Handle stop events.""" + if self._workflow_run_id and graph_runtime_state: + with self._database_session() as session: + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed( + workflow_run_id=self._workflow_run_id, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowExecutionStatus.STOPPED, + error_message=event.get_stop_reason(), + conversation_id=self._conversation_id, + trace_manager=trace_manager, + external_trace_id=self._application_generate_entity.extras.get("external_trace_id"), + ) + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, + ) + # Save message + self._save_message(session=session, graph_runtime_state=graph_runtime_state) + + yield workflow_finish_resp + elif event.stopped_by in ( + QueueStopEvent.StopBy.INPUT_MODERATION, + QueueStopEvent.StopBy.ANNOTATION_REPLY, + ): + # When hitting input-moderation or annotation-reply, the workflow will not start + with self._database_session() as session: + # Save message + self._save_message(session=session) + + yield self._message_end_to_stream_response() + + def _handle_advanced_chat_message_end_event( + self, + event: QueueAdvancedChatMessageEndEvent, + *, + graph_runtime_state: Optional[GraphRuntimeState] = None, + **kwargs, + ) -> Generator[StreamResponse, None, None]: + """Handle advanced chat message end events.""" + self._ensure_graph_runtime_initialized(graph_runtime_state) + + output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished( + self._task_state.answer + ) + if output_moderation_answer: + self._task_state.answer = output_moderation_answer + yield self._message_cycle_manager.message_replace_to_stream_response( + answer=output_moderation_answer, + reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION, + ) + + # Save message + with self._database_session() as session: + self._save_message(session=session, graph_runtime_state=graph_runtime_state) + + yield self._message_end_to_stream_response() + + def _handle_retriever_resources_event( + self, event: QueueRetrieverResourcesEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle retriever resources events.""" + self._message_cycle_manager.handle_retriever_resources(event) + + with self._database_session() as session: + message = self._get_message(session=session) + message.message_metadata = self._task_state.metadata.model_dump_json() + return + yield # Make this a generator + + def _handle_annotation_reply_event( + self, event: QueueAnnotationReplyEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle annotation reply events.""" + self._message_cycle_manager.handle_annotation_reply(event) + + with self._database_session() as session: + message = self._get_message(session=session) + message.message_metadata = self._task_state.metadata.model_dump_json() + return + yield # Make this a generator + + def _handle_message_replace_event( + self, event: QueueMessageReplaceEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle message replace events.""" + yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text, reason=event.reason) + + def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]: + """Handle agent log events.""" + yield self._workflow_response_converter.handle_agent_log( + task_id=self._application_generate_entity.task_id, event=event + ) + + def _get_event_handlers(self) -> dict[type, Callable]: + """Get mapping of event types to their handlers using fluent pattern.""" + return { + # Basic events + QueuePingEvent: self._handle_ping_event, + QueueErrorEvent: self._handle_error_event, + QueueTextChunkEvent: self._handle_text_chunk_event, + # Workflow events + QueueWorkflowStartedEvent: self._handle_workflow_started_event, + QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event, + QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event, + QueueWorkflowFailedEvent: self._handle_workflow_failed_event, + # Node events + QueueNodeRetryEvent: self._handle_node_retry_event, + QueueNodeStartedEvent: self._handle_node_started_event, + QueueNodeSucceededEvent: self._handle_node_succeeded_event, + # Parallel branch events + QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event, + # Iteration events + QueueIterationStartEvent: self._handle_iteration_start_event, + QueueIterationNextEvent: self._handle_iteration_next_event, + QueueIterationCompletedEvent: self._handle_iteration_completed_event, + # Loop events + QueueLoopStartEvent: self._handle_loop_start_event, + QueueLoopNextEvent: self._handle_loop_next_event, + QueueLoopCompletedEvent: self._handle_loop_completed_event, + # Control events + QueueStopEvent: self._handle_stop_event, + # Message events + QueueRetrieverResourcesEvent: self._handle_retriever_resources_event, + QueueAnnotationReplyEvent: self._handle_annotation_reply_event, + QueueMessageReplaceEvent: self._handle_message_replace_event, + QueueAdvancedChatMessageEndEvent: self._handle_advanced_chat_message_end_event, + QueueAgentLogEvent: self._handle_agent_log_event, + } + + def _dispatch_event( + self, + event: Any, + *, + graph_runtime_state: Optional[GraphRuntimeState] = None, + tts_publisher: Optional[AppGeneratorTTSPublisher] = None, + trace_manager: Optional[TraceQueueManager] = None, + queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None, + ) -> Generator[StreamResponse, None, None]: + """Dispatch events using elegant pattern matching.""" + handlers = self._get_event_handlers() + event_type = type(event) + + # Direct handler lookup + if handler := handlers.get(event_type): + yield from handler( + event, + graph_runtime_state=graph_runtime_state, + tts_publisher=tts_publisher, + trace_manager=trace_manager, + queue_message=queue_message, + ) + return + + # Handle node failure events with isinstance check + if isinstance( + event, + ( + QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, + QueueNodeInLoopFailedEvent, + QueueNodeExceptionEvent, + ), + ): + yield from self._handle_node_failed_events( + event, + graph_runtime_state=graph_runtime_state, + tts_publisher=tts_publisher, + trace_manager=trace_manager, + queue_message=queue_message, + ) + return + + # Handle parallel branch finished events with isinstance check + if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)): + yield from self._handle_parallel_branch_finished_events( + event, + graph_runtime_state=graph_runtime_state, + tts_publisher=tts_publisher, + trace_manager=trace_manager, + queue_message=queue_message, + ) + return + + # For unhandled events, we continue (original behavior) + return + def _process_stream_response( self, tts_publisher: Optional[AppGeneratorTTSPublisher] = None, trace_manager: Optional[TraceQueueManager] = None, ) -> Generator[StreamResponse, None, None]: """ - Process stream response. - :return: + Process stream response using elegant Fluent Python patterns. + Maintains exact same functionality as original 57-if-statement version. """ - # init fake graph runtime state + # Initialize graph runtime state graph_runtime_state: Optional[GraphRuntimeState] = None for queue_message in self._base_task_pipeline._queue_manager.listen(): event = queue_message.event - if isinstance(event, QueuePingEvent): - yield self._base_task_pipeline._ping_stream_response() - elif isinstance(event, QueueErrorEvent): - with Session(db.engine, expire_on_commit=False) as session: - err = self._base_task_pipeline._handle_error( - event=event, session=session, message_id=self._message_id - ) - session.commit() - yield self._base_task_pipeline._error_to_stream_response(err) - break - elif isinstance(event, QueueWorkflowStartedEvent): - # override graph runtime state - graph_runtime_state = event.graph_runtime_state + match event: + case QueueWorkflowStartedEvent(): + graph_runtime_state = event.graph_runtime_state + yield from self._handle_workflow_started_event(event) - with Session(db.engine, expire_on_commit=False) as session: - # init workflow run - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start() - self._workflow_run_id = workflow_execution.id_ - message = self._get_message(session=session) - if not message: - raise ValueError(f"Message not found: {self._message_id}") - message.workflow_run_id = workflow_execution.id_ - workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, - ) - session.commit() - - yield workflow_start_resp - elif isinstance( - event, - QueueNodeRetryEvent, - ): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - - with Session(db.engine, expire_on_commit=False) as session: - workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried( - workflow_execution_id=self._workflow_run_id, event=event - ) - node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - session.commit() - - if node_retry_resp: - yield node_retry_resp - elif isinstance(event, QueueNodeStartedEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - - workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start( - workflow_execution_id=self._workflow_run_id, event=event - ) - - node_start_resp = self._workflow_response_converter.workflow_node_start_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - - if node_start_resp: - yield node_start_resp - elif isinstance(event, QueueNodeSucceededEvent): - # Record files if it's an answer node or end node - if event.node_type in [NodeType.ANSWER, NodeType.END]: - self._recorded_files.extend( - self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {}) + case QueueTextChunkEvent(): + yield from self._handle_text_chunk_event( + event, tts_publisher=tts_publisher, queue_message=queue_message ) - with Session(db.engine, expire_on_commit=False) as session: - workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success( - event=event + case QueueErrorEvent(): + yield from self._handle_error_event(event) + break + + case QueueWorkflowFailedEvent(): + yield from self._handle_workflow_failed_event( + event, graph_runtime_state=graph_runtime_state, trace_manager=trace_manager ) + break - node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, + case QueueStopEvent(): + yield from self._handle_stop_event( + event, graph_runtime_state=graph_runtime_state, trace_manager=trace_manager ) - session.commit() - self._save_output_for_event(event, workflow_node_execution.id) + break - if node_finish_resp: - yield node_finish_resp - elif isinstance( - event, - QueueNodeFailedEvent - | QueueNodeInIterationFailedEvent - | QueueNodeInLoopFailedEvent - | QueueNodeExceptionEvent, - ): - workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed( - event=event - ) - - node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - if isinstance(event, QueueNodeExceptionEvent): - self._save_output_for_event(event, workflow_node_execution.id) - - if node_finish_resp: - yield node_finish_resp - elif isinstance(event, QueueParallelBranchRunStartedEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - - parallel_start_resp = ( - self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) - ) - - yield parallel_start_resp - elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - - parallel_finish_resp = ( - self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) - ) - - yield parallel_finish_resp - elif isinstance(event, QueueIterationStartEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - - iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) - - yield iter_start_resp - elif isinstance(event, QueueIterationNextEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - - iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) - - yield iter_next_resp - elif isinstance(event, QueueIterationCompletedEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - - iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) - - yield iter_finish_resp - elif isinstance(event, QueueLoopStartEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - - loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) - - yield loop_start_resp - elif isinstance(event, QueueLoopNextEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - - loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) - - yield loop_next_resp - elif isinstance(event, QueueLoopCompletedEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - - loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) - - yield loop_finish_resp - elif isinstance(event, QueueWorkflowSucceededEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - - if not graph_runtime_state: - raise ValueError("workflow run not initialized.") - - with Session(db.engine, expire_on_commit=False) as session: - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success( - workflow_run_id=self._workflow_run_id, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - conversation_id=self._conversation_id, - trace_manager=trace_manager, - ) - - workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, - ) - - yield workflow_finish_resp - self._base_task_pipeline._queue_manager.publish( - QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE - ) - elif isinstance(event, QueueWorkflowPartialSuccessEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - if not graph_runtime_state: - raise ValueError("graph runtime state not initialized.") - - with Session(db.engine, expire_on_commit=False) as session: - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success( - workflow_run_id=self._workflow_run_id, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - exceptions_count=event.exceptions_count, - conversation_id=None, - trace_manager=trace_manager, - ) - workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, - ) - - yield workflow_finish_resp - self._base_task_pipeline._queue_manager.publish( - QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE - ) - elif isinstance(event, QueueWorkflowFailedEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - if not graph_runtime_state: - raise ValueError("graph runtime state not initialized.") - - with Session(db.engine, expire_on_commit=False) as session: - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed( - workflow_run_id=self._workflow_run_id, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - status=WorkflowExecutionStatus.FAILED, - error_message=event.error, - conversation_id=self._conversation_id, - trace_manager=trace_manager, - exceptions_count=event.exceptions_count, - ) - workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, - ) - err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}")) - err = self._base_task_pipeline._handle_error( - event=err_event, session=session, message_id=self._message_id - ) - - yield workflow_finish_resp - yield self._base_task_pipeline._error_to_stream_response(err) - break - elif isinstance(event, QueueStopEvent): - if self._workflow_run_id and graph_runtime_state: - with Session(db.engine, expire_on_commit=False) as session: - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed( - workflow_run_id=self._workflow_run_id, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - status=WorkflowExecutionStatus.STOPPED, - error_message=event.get_stop_reason(), - conversation_id=self._conversation_id, + # Handle all other events through elegant dispatch + case _: + if responses := list( + self._dispatch_event( + event, + graph_runtime_state=graph_runtime_state, + tts_publisher=tts_publisher, trace_manager=trace_manager, + queue_message=queue_message, ) - workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, - ) - # Save message - self._save_message(session=session, graph_runtime_state=graph_runtime_state) - session.commit() + ): + yield from responses - yield workflow_finish_resp - elif event.stopped_by in ( - QueueStopEvent.StopBy.INPUT_MODERATION, - QueueStopEvent.StopBy.ANNOTATION_REPLY, - ): - # When hitting input-moderation or annotation-reply, the workflow will not start - with Session(db.engine, expire_on_commit=False) as session: - # Save message - self._save_message(session=session) - session.commit() - - yield self._message_end_to_stream_response() - break - elif isinstance(event, QueueRetrieverResourcesEvent): - self._message_cycle_manager.handle_retriever_resources(event) - - with Session(db.engine, expire_on_commit=False) as session: - message = self._get_message(session=session) - message.message_metadata = self._task_state.metadata.model_dump_json() - session.commit() - elif isinstance(event, QueueAnnotationReplyEvent): - self._message_cycle_manager.handle_annotation_reply(event) - - with Session(db.engine, expire_on_commit=False) as session: - message = self._get_message(session=session) - message.message_metadata = self._task_state.metadata.model_dump_json() - session.commit() - elif isinstance(event, QueueTextChunkEvent): - delta_text = event.text - if delta_text is None: - continue - - # handle output moderation chunk - should_direct_answer = self._handle_output_moderation_chunk(delta_text) - if should_direct_answer: - continue - - # only publish tts message at text chunk streaming - if tts_publisher: - tts_publisher.publish(queue_message) - - self._task_state.answer += delta_text - yield self._message_cycle_manager.message_to_stream_response( - answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector - ) - elif isinstance(event, QueueMessageReplaceEvent): - # published by moderation - yield self._message_cycle_manager.message_replace_to_stream_response( - answer=event.text, reason=event.reason - ) - elif isinstance(event, QueueAdvancedChatMessageEndEvent): - if not graph_runtime_state: - raise ValueError("graph runtime state not initialized.") - - output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished( - self._task_state.answer - ) - if output_moderation_answer: - self._task_state.answer = output_moderation_answer - yield self._message_cycle_manager.message_replace_to_stream_response( - answer=output_moderation_answer, - reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION, - ) - - # Save message - with Session(db.engine, expire_on_commit=False) as session: - self._save_message(session=session, graph_runtime_state=graph_runtime_state) - session.commit() - - yield self._message_end_to_stream_response() - elif isinstance(event, QueueAgentLogEvent): - yield self._workflow_response_converter.handle_agent_log( - task_id=self._application_generate_entity.task_id, event=event - ) - else: - continue - - # publish None when task finished if tts_publisher: tts_publisher.publish(None) @@ -744,7 +958,6 @@ class AdvancedChatAppGenerateTaskPipeline: """ if self._base_task_pipeline._output_moderation_handler: if self._base_task_pipeline._output_moderation_handler.should_direct_output(): - # stop subscribe new token when output moderation should direct output self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output() self._base_task_pipeline._queue_manager.publish( QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index edea6199d3..8665bc9d11 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -15,7 +15,8 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 71328f6d1b..39d6ba39f5 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -45,7 +45,7 @@ class AgentChatAppRunner(AppRunner): app_config = application_generate_entity.app_config app_config = cast(AgentChatAppConfig, app_config) - app_record = db.session.query(App).filter(App.id == app_config.app_id).first() + app_record = db.session.query(App).where(App.id == app_config.app_id).first() if not app_record: raise ValueError("App not found") @@ -183,10 +183,10 @@ class AgentChatAppRunner(AppRunner): if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []): agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING - conversation_result = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() + conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first() if conversation_result is None: raise ValueError("Conversation not found") - message_result = db.session.query(Message).filter(Message.id == message.id).first() + message_result = db.session.query(Message).where(Message.id == message.id).first() if message_result is None: raise ValueError("Message not found") db.session.close() diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 0ba33fbe0d..9da0bae56a 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -169,7 +169,3 @@ class AppQueueManager: raise TypeError( "Critical Error: Passing SQLAlchemy Model instances that cause thread safety issues is not allowed." ) - - -class GenerateTaskStoppedError(Exception): - pass diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index a3f0cf7f9f..6e8c261a6a 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -38,69 +38,6 @@ _logger = logging.getLogger(__name__) class AppRunner: - def get_pre_calculate_rest_tokens( - self, - app_record: App, - model_config: ModelConfigWithCredentialsEntity, - prompt_template_entity: PromptTemplateEntity, - inputs: Mapping[str, str], - files: Sequence["File"], - query: Optional[str] = None, - ) -> int: - """ - Get pre calculate rest tokens - :param app_record: app record - :param model_config: model config entity - :param prompt_template_entity: prompt template entity - :param inputs: inputs - :param files: files - :param query: query - :return: - """ - # Invoke model - model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, model=model_config.model - ) - - model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - - max_tokens = 0 - for parameter_rule in model_config.model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template or "") - ) or 0 - - if model_context_tokens is None: - return -1 - - if max_tokens is None: - max_tokens = 0 - - # get prompt messages without memory and context - prompt_messages, stop = self.organize_prompt_messages( - app_record=app_record, - model_config=model_config, - prompt_template_entity=prompt_template_entity, - inputs=inputs, - files=files, - query=query, - ) - - prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) - - rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens - if rest_tokens < 0: - raise InvokeBadRequestError( - "Query or prefix prompt is too long, you can reduce the prefix prompt, " - "or shrink the max token, or switch to a llm with a larger token limit size." - ) - - return rest_tokens - def recalc_llm_max_tokens( self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage] ): @@ -181,7 +118,7 @@ class AppRunner: else: memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) - model_mode = ModelMode.value_of(model_config.mode) + model_mode = ModelMode(model_config.mode) prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]] if model_mode == ModelMode.COMPLETION: advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index a28c106ce9..0c76cc39ae 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -11,10 +11,11 @@ from configs import dify_config from constants import UUID_NIL from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.chat.app_runner import ChatAppRunner from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter +from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 39597fc036..894d7906d5 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -43,7 +43,7 @@ class ChatAppRunner(AppRunner): app_config = application_generate_entity.app_config app_config = cast(ChatAppConfig, app_config) - app_record = db.session.query(App).filter(App.id == app_config.app_id).first() + app_record = db.session.query(App).where(App.id == app_config.app_id).first() if not app_record: raise ValueError("App not found") diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 966a6f1d66..9356bd1cea 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -10,10 +10,11 @@ from pydantic import ValidationError from configs import dify_config from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from core.app.apps.completion.app_runner import CompletionAppRunner from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter +from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom @@ -247,7 +248,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): """ message = ( db.session.query(Message) - .filter( + .where( Message.id == message_id, Message.app_id == app_model.id, Message.from_source == ("api" if isinstance(user, EndUser) else "console"), diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 80fdd0b80e..50d2a0036c 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -36,7 +36,7 @@ class CompletionAppRunner(AppRunner): app_config = application_generate_entity.app_config app_config = cast(CompletionAppConfig, app_config) - app_record = db.session.query(App).filter(App.id == app_config.app_id).first() + app_record = db.session.query(App).where(App.id == app_config.app_id).first() if not app_record: raise ValueError("App not found") diff --git a/api/core/app/apps/exc.py b/api/core/app/apps/exc.py new file mode 100644 index 0000000000..4187118b9b --- /dev/null +++ b/api/core/app/apps/exc.py @@ -0,0 +1,2 @@ +class GenerateTaskStoppedError(Exception): + pass diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index e84d59209d..7dd9904eeb 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -1,12 +1,12 @@ import json import logging from collections.abc import Generator -from datetime import UTC, datetime from typing import Optional, Union, cast from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom from core.app.apps.base_app_generator import BaseAppGenerator -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, AgentChatAppGenerateEntity, @@ -24,6 +24,7 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models import Account from models.enums import CreatorUserRole from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile @@ -84,7 +85,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): if conversation: app_model_config = ( db.session.query(AppModelConfig) - .filter(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id) + .where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id) .first() ) @@ -150,13 +151,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): introduction = self._get_conversation_introduction(application_generate_entity) # get conversation name - if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity): - query = application_generate_entity.query or "New conversation" - else: - query = next(iter(application_generate_entity.inputs.values()), "New conversation") - if isinstance(query, int): - query = str(query) - query = query or "New conversation" + query = application_generate_entity.query or "New conversation" conversation_name = (query[:20] + "…") if len(query) > 20 else query if not conversation: @@ -183,7 +178,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): db.session.commit() db.session.refresh(conversation) else: - conversation.updated_at = datetime.now(UTC).replace(tzinfo=None) + conversation.updated_at = naive_utc_now() db.session.commit() message = Message( @@ -258,7 +253,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): :param conversation_id: conversation id :return: conversation """ - conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() + conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first() if not conversation: raise ConversationNotExistsError("Conversation not exists") @@ -271,7 +266,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): :param message_id: message id :return: message """ - message = db.session.query(Message).filter(Message.id == message_id).first() + message = db.session.query(Message).where(Message.id == message_id).first() if message is None: raise MessageNotExistsError("Message not exists") diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py index 363c3c82bb..8507f23f17 100644 --- a/api/core/app/apps/message_based_app_queue_manager.py +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -1,4 +1,5 @@ -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 40a1e272a7..4c36f63c71 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -7,13 +7,15 @@ from typing import Any, Literal, Optional, Union, overload from flask import Flask, current_app from pydantic import ValidationError -from sqlalchemy.orm import sessionmaker +from sqlalchemy import select +from sqlalchemy.orm import Session, sessionmaker import contexts from configs import dify_config from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.base_app_generator import BaseAppGenerator -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner @@ -21,10 +23,10 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse +from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository @@ -123,6 +125,10 @@ class WorkflowAppGenerator(BaseAppGenerator): ) inputs: Mapping[str, Any] = args["inputs"] + + extras = { + **extract_external_trace_id_from_args(args), + } workflow_run_id = str(uuid.uuid4()) # init application generate entity application_generate_entity = WorkflowAppGenerateEntity( @@ -142,6 +148,7 @@ class WorkflowAppGenerator(BaseAppGenerator): call_depth=call_depth, trace_manager=trace_manager, workflow_execution_id=workflow_run_id, + extras=extras, ) contexts.plugin_tool_providers.set({}) @@ -156,14 +163,14 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING else: workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=workflow_triggered_from, ) # Create workflow node execution repository - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -306,16 +313,14 @@ class WorkflowAppGenerator(BaseAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -390,16 +395,14 @@ class WorkflowAppGenerator(BaseAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -443,17 +446,44 @@ class WorkflowAppGenerator(BaseAppGenerator): """ with preserve_flask_contexts(flask_app, context_vars=context): - try: - # workflow app - runner = WorkflowAppRunner( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - workflow_thread_pool_id=workflow_thread_pool_id, - variable_loader=variable_loader, + with Session(db.engine, expire_on_commit=False) as session: + workflow = session.scalar( + select(Workflow).where( + Workflow.tenant_id == application_generate_entity.app_config.tenant_id, + Workflow.app_id == application_generate_entity.app_config.app_id, + Workflow.id == application_generate_entity.app_config.workflow_id, + ) ) + if workflow is None: + raise ValueError("Workflow not found") + # Determine system_user_id based on invocation source + is_external_api_call = application_generate_entity.invoke_from in { + InvokeFrom.WEB_APP, + InvokeFrom.SERVICE_API, + } + + if is_external_api_call: + # For external API calls, use end user's session ID + end_user = session.scalar(select(EndUser).where(EndUser.id == application_generate_entity.user_id)) + system_user_id = end_user.session_id if end_user else "" + else: + # For internal calls, use the original user ID + system_user_id = application_generate_entity.user_id + + runner = WorkflowAppRunner( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + workflow_thread_pool_id=workflow_thread_pool_id, + variable_loader=variable_loader, + workflow=workflow, + system_user_id=system_user_id, + ) + + try: runner.run() - except GenerateTaskStoppedError: + except GenerateTaskStoppedError as e: + logger.warning(f"Task stopped: {str(e)}") pass except InvokeAuthorizationError: queue_manager.publish_error( @@ -469,8 +499,6 @@ class WorkflowAppGenerator(BaseAppGenerator): except Exception as e: logger.exception("Unknown Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) - finally: - db.session.close() def _handle_response( self, diff --git a/api/core/app/apps/workflow/app_queue_manager.py b/api/core/app/apps/workflow/app_queue_manager.py index 349b8eb51b..40fc03afb7 100644 --- a/api/core/app/apps/workflow/app_queue_manager.py +++ b/api/core/app/apps/workflow/app_queue_manager.py @@ -1,4 +1,5 @@ -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 07aeb57fa3..4f4c1460ae 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -11,13 +11,11 @@ from core.app.entities.app_invoke_entities import ( ) from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey +from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry -from extensions.ext_database import db from models.enums import UserFrom -from models.model import App, EndUser -from models.workflow import WorkflowType +from models.workflow import Workflow, WorkflowType logger = logging.getLogger(__name__) @@ -29,22 +27,23 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): def __init__( self, + *, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager, variable_loader: VariableLoader, workflow_thread_pool_id: Optional[str] = None, + workflow: Workflow, + system_user_id: str, ) -> None: - """ - :param application_generate_entity: application generate entity - :param queue_manager: application queue manager - :param workflow_thread_pool_id: workflow thread pool id - """ - super().__init__(queue_manager, variable_loader) + super().__init__( + queue_manager=queue_manager, + variable_loader=variable_loader, + app_id=application_generate_entity.app_config.app_id, + ) self.application_generate_entity = application_generate_entity self.workflow_thread_pool_id = workflow_thread_pool_id - - def _get_app_id(self) -> str: - return self.application_generate_entity.app_config.app_id + self._workflow = workflow + self._sys_user_id = system_user_id def run(self) -> None: """ @@ -53,24 +52,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): app_config = self.application_generate_entity.app_config app_config = cast(WorkflowAppConfig, app_config) - user_id = None - if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: - end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() - if end_user: - user_id = end_user.session_id - else: - user_id = self.application_generate_entity.user_id - - app_record = db.session.query(App).filter(App.id == app_config.app_id).first() - if not app_record: - raise ValueError("App not found") - - workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) - if not workflow: - raise ValueError("Workflow not initialized") - - db.session.close() - workflow_callbacks: list[WorkflowCallback] = [] if dify_config.DEBUG: workflow_callbacks.append(WorkflowLoggingCallback()) @@ -79,14 +60,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): if self.application_generate_entity.single_iteration_run: # if only single iteration run is requested graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( - workflow=workflow, + workflow=self._workflow, node_id=self.application_generate_entity.single_iteration_run.node_id, user_inputs=self.application_generate_entity.single_iteration_run.inputs, ) elif self.application_generate_entity.single_loop_run: # if only single loop run is requested graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( - workflow=workflow, + workflow=self._workflow, node_id=self.application_generate_entity.single_loop_run.node_id, user_inputs=self.application_generate_entity.single_loop_run.inputs, ) @@ -95,32 +76,33 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): files = self.application_generate_entity.files # Create a variable pool. - system_inputs = { - SystemVariableKey.FILES: files, - SystemVariableKey.USER_ID: user_id, - SystemVariableKey.APP_ID: app_config.app_id, - SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, - SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_execution_id, - } + + system_inputs = SystemVariable( + files=files, + user_id=self._sys_user_id, + app_id=app_config.app_id, + workflow_id=app_config.workflow_id, + workflow_execution_id=self.application_generate_entity.workflow_execution_id, + ) variable_pool = VariablePool( system_variables=system_inputs, user_inputs=inputs, - environment_variables=workflow.environment_variables, + environment_variables=self._workflow.environment_variables, conversation_variables=[], ) # init graph - graph = self._init_graph(graph_config=workflow.graph_dict) + graph = self._init_graph(graph_config=self._workflow.graph_dict) # RUN WORKFLOW workflow_entry = WorkflowEntry( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - workflow_id=workflow.id, - workflow_type=WorkflowType.value_of(workflow.type), + tenant_id=self._workflow.tenant_id, + app_id=self._workflow.app_id, + workflow_id=self._workflow.id, + workflow_type=WorkflowType.value_of(self._workflow.type), graph=graph, - graph_config=workflow.graph_dict, + graph_config=self._workflow.graph_dict, user_id=self.application_generate_entity.user_id, user_from=( UserFrom.ACCOUNT diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 2a85cd5e3d..e31a316c56 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -1,9 +1,9 @@ import logging import time -from collections.abc import Generator -from typing import Optional, Union +from collections.abc import Callable, Generator +from contextlib import contextmanager +from typing import Any, Optional, Union -from sqlalchemy import select from sqlalchemy.orm import Session from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME @@ -14,6 +14,7 @@ from core.app.entities.app_invoke_entities import ( WorkflowAppGenerateEntity, ) from core.app.entities.queue_entities import ( + MessageQueueMessage, QueueAgentLogEvent, QueueErrorEvent, QueueIterationCompletedEvent, @@ -39,11 +40,13 @@ from core.app.entities.queue_entities import ( QueueWorkflowPartialSuccessEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, + WorkflowQueueMessage, ) from core.app.entities.task_entities import ( ErrorStreamResponse, MessageAudioEndStreamResponse, MessageAudioStreamResponse, + PingStreamResponse, StreamResponse, TextChunkStreamResponse, WorkflowAppBlockingResponse, @@ -55,10 +58,11 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType -from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.system_variable import SystemVariable from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager from extensions.ext_database import db from models.account import Account @@ -68,7 +72,6 @@ from models.workflow import ( Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, - WorkflowRun, ) logger = logging.getLogger(__name__) @@ -109,13 +112,13 @@ class WorkflowAppGenerateTaskPipeline: self._workflow_cycle_manager = WorkflowCycleManager( application_generate_entity=application_generate_entity, - workflow_system_variables={ - SystemVariableKey.FILES: application_generate_entity.files, - SystemVariableKey.USER_ID: user_session_id, - SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, - SystemVariableKey.WORKFLOW_ID: workflow.id, - SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_execution_id, - }, + workflow_system_variables=SystemVariable( + files=application_generate_entity.files, + user_id=user_session_id, + app_id=application_generate_entity.app_config.app_id, + workflow_id=workflow.id, + workflow_execution_id=application_generate_entity.workflow_execution_id, + ), workflow_info=CycleManagerWorkflowInfo( workflow_id=workflow.id, workflow_type=WorkflowType(workflow.type), @@ -248,322 +251,500 @@ class WorkflowAppGenerateTaskPipeline: if tts_publisher: yield MessageAudioEndStreamResponse(audio="", task_id=task_id) + @contextmanager + def _database_session(self): + """Context manager for database sessions.""" + with Session(db.engine, expire_on_commit=False) as session: + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + + def _ensure_workflow_initialized(self) -> None: + """Fluent validation for workflow state.""" + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + def _ensure_graph_runtime_initialized(self, graph_runtime_state: Optional[GraphRuntimeState]) -> GraphRuntimeState: + """Fluent validation for graph runtime state.""" + if not graph_runtime_state: + raise ValueError("graph runtime state not initialized.") + return graph_runtime_state + + def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]: + """Handle ping events.""" + yield self._base_task_pipeline._ping_stream_response() + + def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]: + """Handle error events.""" + err = self._base_task_pipeline._handle_error(event=event) + yield self._base_task_pipeline._error_to_stream_response(err) + + def _handle_workflow_started_event( + self, event: QueueWorkflowStartedEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle workflow started events.""" + # init workflow run + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start() + self._workflow_run_id = workflow_execution.id_ + start_resp = self._workflow_response_converter.workflow_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, + ) + yield start_resp + + def _handle_node_retry_event(self, event: QueueNodeRetryEvent, **kwargs) -> Generator[StreamResponse, None, None]: + """Handle node retry events.""" + self._ensure_workflow_initialized() + + with self._database_session() as session: + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried( + workflow_execution_id=self._workflow_run_id, + event=event, + ) + response = self._workflow_response_converter.workflow_node_retry_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + if response: + yield response + + def _handle_node_started_event( + self, event: QueueNodeStartedEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle node started events.""" + self._ensure_workflow_initialized() + + workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start( + workflow_execution_id=self._workflow_run_id, event=event + ) + node_start_response = self._workflow_response_converter.workflow_node_start_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + if node_start_response: + yield node_start_response + + def _handle_node_succeeded_event( + self, event: QueueNodeSucceededEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle node succeeded events.""" + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event) + node_success_response = self._workflow_response_converter.workflow_node_finish_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + self._save_output_for_event(event, workflow_node_execution.id) + + if node_success_response: + yield node_success_response + + def _handle_node_failed_events( + self, + event: Union[ + QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent + ], + **kwargs, + ) -> Generator[StreamResponse, None, None]: + """Handle various node failure events.""" + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed( + event=event, + ) + node_failed_response = self._workflow_response_converter.workflow_node_finish_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + if isinstance(event, QueueNodeExceptionEvent): + self._save_output_for_event(event, workflow_node_execution.id) + + if node_failed_response: + yield node_failed_response + + def _handle_parallel_branch_started_event( + self, event: QueueParallelBranchRunStartedEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle parallel branch started events.""" + self._ensure_workflow_initialized() + + parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield parallel_start_resp + + def _handle_parallel_branch_finished_events( + self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle parallel branch finished events.""" + self._ensure_workflow_initialized() + + parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield parallel_finish_resp + + def _handle_iteration_start_event( + self, event: QueueIterationStartEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle iteration start events.""" + self._ensure_workflow_initialized() + + iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield iter_start_resp + + def _handle_iteration_next_event( + self, event: QueueIterationNextEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle iteration next events.""" + self._ensure_workflow_initialized() + + iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield iter_next_resp + + def _handle_iteration_completed_event( + self, event: QueueIterationCompletedEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle iteration completed events.""" + self._ensure_workflow_initialized() + + iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield iter_finish_resp + + def _handle_loop_start_event(self, event: QueueLoopStartEvent, **kwargs) -> Generator[StreamResponse, None, None]: + """Handle loop start events.""" + self._ensure_workflow_initialized() + + loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield loop_start_resp + + def _handle_loop_next_event(self, event: QueueLoopNextEvent, **kwargs) -> Generator[StreamResponse, None, None]: + """Handle loop next events.""" + self._ensure_workflow_initialized() + + loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield loop_next_resp + + def _handle_loop_completed_event( + self, event: QueueLoopCompletedEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle loop completed events.""" + self._ensure_workflow_initialized() + + loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield loop_finish_resp + + def _handle_workflow_succeeded_event( + self, + event: QueueWorkflowSucceededEvent, + *, + graph_runtime_state: Optional[GraphRuntimeState] = None, + trace_manager: Optional[TraceQueueManager] = None, + **kwargs, + ) -> Generator[StreamResponse, None, None]: + """Handle workflow succeeded events.""" + self._ensure_workflow_initialized() + validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state) + + with self._database_session() as session: + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success( + workflow_run_id=self._workflow_run_id, + total_tokens=validated_state.total_tokens, + total_steps=validated_state.node_run_steps, + outputs=event.outputs, + conversation_id=None, + trace_manager=trace_manager, + external_trace_id=self._application_generate_entity.extras.get("external_trace_id"), + ) + + # save workflow app log + self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) + + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, + ) + + yield workflow_finish_resp + + def _handle_workflow_partial_success_event( + self, + event: QueueWorkflowPartialSuccessEvent, + *, + graph_runtime_state: Optional[GraphRuntimeState] = None, + trace_manager: Optional[TraceQueueManager] = None, + **kwargs, + ) -> Generator[StreamResponse, None, None]: + """Handle workflow partial success events.""" + self._ensure_workflow_initialized() + validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state) + + with self._database_session() as session: + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success( + workflow_run_id=self._workflow_run_id, + total_tokens=validated_state.total_tokens, + total_steps=validated_state.node_run_steps, + outputs=event.outputs, + exceptions_count=event.exceptions_count, + conversation_id=None, + trace_manager=trace_manager, + external_trace_id=self._application_generate_entity.extras.get("external_trace_id"), + ) + + # save workflow app log + self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) + + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, + ) + + yield workflow_finish_resp + + def _handle_workflow_failed_and_stop_events( + self, + event: Union[QueueWorkflowFailedEvent, QueueStopEvent], + *, + graph_runtime_state: Optional[GraphRuntimeState] = None, + trace_manager: Optional[TraceQueueManager] = None, + **kwargs, + ) -> Generator[StreamResponse, None, None]: + """Handle workflow failed and stop events.""" + self._ensure_workflow_initialized() + validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state) + + with self._database_session() as session: + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed( + workflow_run_id=self._workflow_run_id, + total_tokens=validated_state.total_tokens, + total_steps=validated_state.node_run_steps, + status=WorkflowExecutionStatus.FAILED + if isinstance(event, QueueWorkflowFailedEvent) + else WorkflowExecutionStatus.STOPPED, + error_message=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), + conversation_id=None, + trace_manager=trace_manager, + exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, + external_trace_id=self._application_generate_entity.extras.get("external_trace_id"), + ) + + # save workflow app log + self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) + + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, + ) + + yield workflow_finish_resp + + def _handle_text_chunk_event( + self, + event: QueueTextChunkEvent, + *, + tts_publisher: Optional[AppGeneratorTTSPublisher] = None, + queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None, + **kwargs, + ) -> Generator[StreamResponse, None, None]: + """Handle text chunk events.""" + delta_text = event.text + if delta_text is None: + return + + # only publish tts message at text chunk streaming + if tts_publisher and queue_message: + tts_publisher.publish(queue_message) + + yield self._text_chunk_to_stream_response(delta_text, from_variable_selector=event.from_variable_selector) + + def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]: + """Handle agent log events.""" + yield self._workflow_response_converter.handle_agent_log( + task_id=self._application_generate_entity.task_id, event=event + ) + + def _get_event_handlers(self) -> dict[type, Callable]: + """Get mapping of event types to their handlers using fluent pattern.""" + return { + # Basic events + QueuePingEvent: self._handle_ping_event, + QueueErrorEvent: self._handle_error_event, + QueueTextChunkEvent: self._handle_text_chunk_event, + # Workflow events + QueueWorkflowStartedEvent: self._handle_workflow_started_event, + QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event, + QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event, + # Node events + QueueNodeRetryEvent: self._handle_node_retry_event, + QueueNodeStartedEvent: self._handle_node_started_event, + QueueNodeSucceededEvent: self._handle_node_succeeded_event, + # Parallel branch events + QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event, + # Iteration events + QueueIterationStartEvent: self._handle_iteration_start_event, + QueueIterationNextEvent: self._handle_iteration_next_event, + QueueIterationCompletedEvent: self._handle_iteration_completed_event, + # Loop events + QueueLoopStartEvent: self._handle_loop_start_event, + QueueLoopNextEvent: self._handle_loop_next_event, + QueueLoopCompletedEvent: self._handle_loop_completed_event, + # Agent events + QueueAgentLogEvent: self._handle_agent_log_event, + } + + def _dispatch_event( + self, + event: Any, + *, + graph_runtime_state: Optional[GraphRuntimeState] = None, + tts_publisher: Optional[AppGeneratorTTSPublisher] = None, + trace_manager: Optional[TraceQueueManager] = None, + queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None, + ) -> Generator[StreamResponse, None, None]: + """Dispatch events using elegant pattern matching.""" + handlers = self._get_event_handlers() + event_type = type(event) + + # Direct handler lookup + if handler := handlers.get(event_type): + yield from handler( + event, + graph_runtime_state=graph_runtime_state, + tts_publisher=tts_publisher, + trace_manager=trace_manager, + queue_message=queue_message, + ) + return + + # Handle node failure events with isinstance check + if isinstance( + event, + ( + QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, + QueueNodeInLoopFailedEvent, + QueueNodeExceptionEvent, + ), + ): + yield from self._handle_node_failed_events( + event, + graph_runtime_state=graph_runtime_state, + tts_publisher=tts_publisher, + trace_manager=trace_manager, + queue_message=queue_message, + ) + return + + # Handle parallel branch finished events with isinstance check + if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)): + yield from self._handle_parallel_branch_finished_events( + event, + graph_runtime_state=graph_runtime_state, + tts_publisher=tts_publisher, + trace_manager=trace_manager, + queue_message=queue_message, + ) + return + + # Handle workflow failed and stop events with isinstance check + if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)): + yield from self._handle_workflow_failed_and_stop_events( + event, + graph_runtime_state=graph_runtime_state, + tts_publisher=tts_publisher, + trace_manager=trace_manager, + queue_message=queue_message, + ) + return + + # For unhandled events, we continue (original behavior) + return + def _process_stream_response( self, tts_publisher: Optional[AppGeneratorTTSPublisher] = None, trace_manager: Optional[TraceQueueManager] = None, ) -> Generator[StreamResponse, None, None]: """ - Process stream response. - :return: + Process stream response using elegant Fluent Python patterns. + Maintains exact same functionality as original 44-if-statement version. """ + # Initialize graph runtime state graph_runtime_state = None for queue_message in self._base_task_pipeline._queue_manager.listen(): event = queue_message.event - if isinstance(event, QueuePingEvent): - yield self._base_task_pipeline._ping_stream_response() - elif isinstance(event, QueueErrorEvent): - err = self._base_task_pipeline._handle_error(event=event) - yield self._base_task_pipeline._error_to_stream_response(err) - break - elif isinstance(event, QueueWorkflowStartedEvent): - # override graph runtime state - graph_runtime_state = event.graph_runtime_state + match event: + case QueueWorkflowStartedEvent(): + graph_runtime_state = event.graph_runtime_state + yield from self._handle_workflow_started_event(event) - # init workflow run - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start() - self._workflow_run_id = workflow_execution.id_ - start_resp = self._workflow_response_converter.workflow_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, - ) - - yield start_resp - elif isinstance( - event, - QueueNodeRetryEvent, - ): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried( - workflow_execution_id=self._workflow_run_id, - event=event, - ) - response = self._workflow_response_converter.workflow_node_retry_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - session.commit() - - if response: - yield response - elif isinstance(event, QueueNodeStartedEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - - workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start( - workflow_execution_id=self._workflow_run_id, event=event - ) - node_start_response = self._workflow_response_converter.workflow_node_start_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - - if node_start_response: - yield node_start_response - elif isinstance(event, QueueNodeSucceededEvent): - workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success( - event=event - ) - node_success_response = self._workflow_response_converter.workflow_node_finish_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - - self._save_output_for_event(event, workflow_node_execution.id) - - if node_success_response: - yield node_success_response - elif isinstance( - event, - QueueNodeFailedEvent - | QueueNodeInIterationFailedEvent - | QueueNodeInLoopFailedEvent - | QueueNodeExceptionEvent, - ): - workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed( - event=event, - ) - node_failed_response = self._workflow_response_converter.workflow_node_finish_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - if isinstance(event, QueueNodeExceptionEvent): - self._save_output_for_event(event, workflow_node_execution.id) - - if node_failed_response: - yield node_failed_response - - elif isinstance(event, QueueParallelBranchRunStartedEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - - parallel_start_resp = ( - self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) - ) - - yield parallel_start_resp - - elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - - parallel_finish_resp = ( - self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) - ) - - yield parallel_finish_resp - - elif isinstance(event, QueueIterationStartEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - - iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) - - yield iter_start_resp - - elif isinstance(event, QueueIterationNextEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - - iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) - - yield iter_next_resp - - elif isinstance(event, QueueIterationCompletedEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - - iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) - - yield iter_finish_resp - - elif isinstance(event, QueueLoopStartEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - - loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) - - yield loop_start_resp - - elif isinstance(event, QueueLoopNextEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - - loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) - - yield loop_next_resp - - elif isinstance(event, QueueLoopCompletedEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - - loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) - - yield loop_finish_resp - - elif isinstance(event, QueueWorkflowSucceededEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - if not graph_runtime_state: - raise ValueError("graph runtime state not initialized.") - - with Session(db.engine, expire_on_commit=False) as session: - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success( - workflow_run_id=self._workflow_run_id, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - conversation_id=None, - trace_manager=trace_manager, + case QueueTextChunkEvent(): + yield from self._handle_text_chunk_event( + event, tts_publisher=tts_publisher, queue_message=queue_message ) - # save workflow app log - self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) + case QueueErrorEvent(): + yield from self._handle_error_event(event) + break - workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, - ) - session.commit() - - yield workflow_finish_resp - elif isinstance(event, QueueWorkflowPartialSuccessEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - if not graph_runtime_state: - raise ValueError("graph runtime state not initialized.") - - with Session(db.engine, expire_on_commit=False) as session: - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success( - workflow_run_id=self._workflow_run_id, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - exceptions_count=event.exceptions_count, - conversation_id=None, - trace_manager=trace_manager, - ) - - # save workflow app log - self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) - - workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, - ) - session.commit() - - yield workflow_finish_resp - elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - if not graph_runtime_state: - raise ValueError("graph runtime state not initialized.") - - with Session(db.engine, expire_on_commit=False) as session: - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed( - workflow_run_id=self._workflow_run_id, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - status=WorkflowExecutionStatus.FAILED - if isinstance(event, QueueWorkflowFailedEvent) - else WorkflowExecutionStatus.STOPPED, - error_message=event.error - if isinstance(event, QueueWorkflowFailedEvent) - else event.get_stop_reason(), - conversation_id=None, - trace_manager=trace_manager, - exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, - ) - - # save workflow app log - self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) - - workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, - ) - session.commit() - - yield workflow_finish_resp - elif isinstance(event, QueueTextChunkEvent): - delta_text = event.text - if delta_text is None: - continue - - # only publish tts message at text chunk streaming - if tts_publisher: - tts_publisher.publish(queue_message) - - yield self._text_chunk_to_stream_response( - delta_text, from_variable_selector=event.from_variable_selector - ) - elif isinstance(event, QueueAgentLogEvent): - yield self._workflow_response_converter.handle_agent_log( - task_id=self._application_generate_entity.task_id, event=event - ) - else: - continue + # Handle all other events through elegant dispatch + case _: + if responses := list( + self._dispatch_event( + event, + graph_runtime_state=graph_runtime_state, + tts_publisher=tts_publisher, + trace_manager=trace_manager, + queue_message=queue_message, + ) + ): + yield from responses if tts_publisher: tts_publisher.publish(None) def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None: - workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_)) - assert workflow_run is not None invoke_from = self._application_generate_entity.invoke_from if invoke_from == InvokeFrom.SERVICE_API: created_from = WorkflowAppLogCreatedFrom.SERVICE_API @@ -576,10 +757,10 @@ class WorkflowAppGenerateTaskPipeline: return workflow_app_log = WorkflowAppLog() - workflow_app_log.tenant_id = workflow_run.tenant_id - workflow_app_log.app_id = workflow_run.app_id - workflow_app_log.workflow_id = workflow_run.workflow_id - workflow_app_log.workflow_run_id = workflow_run.id + workflow_app_log.tenant_id = self._application_generate_entity.app_config.tenant_id + workflow_app_log.app_id = self._application_generate_entity.app_config.app_id + workflow_app_log.workflow_id = workflow_execution.workflow_id + workflow_app_log.workflow_run_id = workflow_execution.id_ workflow_app_log.created_from = created_from.value workflow_app_log.created_by_role = self._created_by_role workflow_app_log.created_by = self._user_id diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 17b9ac5827..948ea95e63 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -1,8 +1,7 @@ from collections.abc import Mapping -from typing import Any, Optional, cast +from typing import Any, cast from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.apps.base_app_runner import AppRunner from core.app.entities.queue_entities import ( AppQueueEvent, QueueAgentLogEvent, @@ -62,20 +61,23 @@ from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.graph import Graph from core.workflow.nodes import NodeType from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from core.workflow.workflow_entry import WorkflowEntry -from extensions.ext_database import db -from models.model import App from models.workflow import Workflow -class WorkflowBasedAppRunner(AppRunner): - def __init__(self, queue_manager: AppQueueManager, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER) -> None: - self.queue_manager = queue_manager +class WorkflowBasedAppRunner: + def __init__( + self, + *, + queue_manager: AppQueueManager, + variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, + app_id: str, + ) -> None: + self._queue_manager = queue_manager self._variable_loader = variable_loader - - def _get_app_id(self) -> str: - raise NotImplementedError("not implemented") + self._app_id = app_id def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph: """ @@ -166,7 +168,7 @@ class WorkflowBasedAppRunner(AppRunner): # init variable pool variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, environment_variables=workflow.environment_variables, ) @@ -263,7 +265,7 @@ class WorkflowBasedAppRunner(AppRunner): # init variable pool variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, environment_variables=workflow.environment_variables, ) @@ -692,21 +694,5 @@ class WorkflowBasedAppRunner(AppRunner): ) ) - def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: - """ - Get workflow - """ - # fetch workflow by workflow_id - workflow = ( - db.session.query(Workflow) - .filter( - Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id - ) - .first() - ) - - # return workflow - return workflow - def _publish_event(self, event: AppQueueEvent) -> None: - self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) + self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 83fd3debad..54dc69302a 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -26,7 +26,7 @@ class AnnotationReplyFeature: :return: """ annotation_setting = ( - db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_record.id).first() + db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first() ) if not annotation_setting: diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 3c8c7bb5a2..888434798a 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -471,7 +471,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): :return: """ agent_thought: Optional[MessageAgentThought] = ( - db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first() + db.session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first() ) if agent_thought: diff --git a/api/core/app/task_pipeline/exc.py b/api/core/app/task_pipeline/exc.py index e4b4168d08..df62776977 100644 --- a/api/core/app/task_pipeline/exc.py +++ b/api/core/app/task_pipeline/exc.py @@ -10,8 +10,3 @@ class RecordNotFoundError(TaskPipilineError): class WorkflowRunNotFoundError(RecordNotFoundError): def __init__(self, workflow_run_id: str): super().__init__("WorkflowRun", workflow_run_id) - - -class WorkflowNodeExecutionNotFoundError(RecordNotFoundError): - def __init__(self, workflow_node_execution_id: str): - super().__init__("WorkflowNodeExecution", workflow_node_execution_id) diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index 2343081eaf..824da0b934 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -81,7 +81,7 @@ class MessageCycleManager: def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str): with flask_app.app_context(): # get conversation and message - conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() + conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first() if not conversation: return @@ -140,7 +140,7 @@ class MessageCycleManager: :param event: event :return: """ - message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first() + message_file = db.session.query(MessageFile).where(MessageFile.id == event.message_file_id).first() if message_file and message_file.url is not None: # get tool file id diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index a3a7b4b812..c55ba5e0fe 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -49,7 +49,7 @@ class DatasetIndexToolCallbackHandler: for document in documents: if document.metadata is not None: document_id = document.metadata["document_id"] - dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() + dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() if not dataset_document: _logger.warning( "Expected DatasetDocument record to exist, but none was found, document_id=%s", @@ -59,7 +59,7 @@ class DatasetIndexToolCallbackHandler: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: child_chunk = ( db.session.query(ChildChunk) - .filter( + .where( ChildChunk.index_node_id == document.metadata["doc_id"], ChildChunk.dataset_id == dataset_document.dataset_id, ChildChunk.document_id == dataset_document.id, @@ -69,18 +69,18 @@ class DatasetIndexToolCallbackHandler: if child_chunk: segment = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.id == child_chunk.segment_id) + .where(DocumentSegment.id == child_chunk.segment_id) .update( {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False ) ) else: - query = db.session.query(DocumentSegment).filter( + query = db.session.query(DocumentSegment).where( DocumentSegment.index_node_id == document.metadata["doc_id"] ) if "dataset_id" in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) + query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"]) # add hit count to document segment query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) diff --git a/api/core/entities/parameter_entities.py b/api/core/entities/parameter_entities.py index b071bfa5b1..fbd62437e6 100644 --- a/api/core/entities/parameter_entities.py +++ b/api/core/entities/parameter_entities.py @@ -14,6 +14,7 @@ class CommonParameterType(StrEnum): APP_SELECTOR = "app-selector" MODEL_SELECTOR = "model-selector" TOOLS_SELECTOR = "array[tools]" + ANY = "any" # Dynamic select parameter # Once you are not sure about the available options until authorization is done @@ -21,6 +22,9 @@ class CommonParameterType(StrEnum): DYNAMIC_SELECT = "dynamic-select" # TOOL_SELECTOR = "tool-selector" + # MCP object and array type parameters + ARRAY = "array" + OBJECT = "object" class AppSelectorScope(StrEnum): diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 66d8d0f414..af5c18e267 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -191,7 +191,7 @@ class ProviderConfiguration(BaseModel): provider_record = ( db.session.query(Provider) - .filter( + .where( Provider.tenant_id == self.tenant_id, Provider.provider_type == ProviderType.CUSTOM.value, Provider.provider_name.in_(provider_names), @@ -351,7 +351,7 @@ class ProviderConfiguration(BaseModel): provider_model_record = ( db.session.query(ProviderModel) - .filter( + .where( ProviderModel.tenant_id == self.tenant_id, ProviderModel.provider_name.in_(provider_names), ProviderModel.model_name == model, @@ -481,7 +481,7 @@ class ProviderConfiguration(BaseModel): return ( db.session.query(ProviderModelSetting) - .filter( + .where( ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.provider_name.in_(provider_names), ProviderModelSetting.model_type == model_type.to_origin_model_type(), @@ -560,7 +560,7 @@ class ProviderConfiguration(BaseModel): return ( db.session.query(LoadBalancingModelConfig) - .filter( + .where( LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name.in_(provider_names), LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), @@ -583,7 +583,7 @@ class ProviderConfiguration(BaseModel): load_balancing_config_count = ( db.session.query(LoadBalancingModelConfig) - .filter( + .where( LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name.in_(provider_names), LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), @@ -627,7 +627,7 @@ class ProviderConfiguration(BaseModel): model_setting = ( db.session.query(ProviderModelSetting) - .filter( + .where( ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.provider_name.in_(provider_names), ProviderModelSetting.model_type == model_type.to_origin_model_type(), @@ -693,7 +693,7 @@ class ProviderConfiguration(BaseModel): preferred_model_provider = ( db.session.query(TenantPreferredModelProvider) - .filter( + .where( TenantPreferredModelProvider.tenant_id == self.tenant_id, TenantPreferredModelProvider.provider_name.in_(provider_names), ) diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index 53acdf075f..2099a9e34c 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -32,7 +32,7 @@ class ApiExternalDataTool(ExternalDataTool): # get api_based_extension api_based_extension = ( db.session.query(APIBasedExtension) - .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) + .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) .first() ) @@ -56,7 +56,7 @@ class ApiExternalDataTool(ExternalDataTool): # get api_based_extension api_based_extension = ( db.session.query(APIBasedExtension) - .filter(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id) + .where(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id) .first() ) diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index ada19ef8ce..f8c050c2ac 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -7,6 +7,7 @@ from core.model_runtime.entities import ( AudioPromptMessageContent, DocumentPromptMessageContent, ImagePromptMessageContent, + TextPromptMessageContent, VideoPromptMessageContent, ) from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes @@ -44,11 +45,44 @@ def to_prompt_message_content( *, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, ) -> PromptMessageContentUnionTypes: + """ + Convert a file to prompt message content. + + This function converts files to their appropriate prompt message content types. + For supported file types (IMAGE, AUDIO, VIDEO, DOCUMENT), it creates the + corresponding message content with proper encoding/URL. + + For unsupported file types, instead of raising an error, it returns a + TextPromptMessageContent with a descriptive message about the file. + + Args: + f: The file to convert + image_detail_config: Optional detail configuration for image files + + Returns: + PromptMessageContentUnionTypes: The appropriate message content type + + Raises: + ValueError: If file extension or mime_type is missing + """ if f.extension is None: raise ValueError("Missing file extension") if f.mime_type is None: raise ValueError("Missing file mime_type") + prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = { + FileType.IMAGE: ImagePromptMessageContent, + FileType.AUDIO: AudioPromptMessageContent, + FileType.VIDEO: VideoPromptMessageContent, + FileType.DOCUMENT: DocumentPromptMessageContent, + } + + # Check if file type is supported + if f.type not in prompt_class_map: + # For unsupported file types, return a text description + return TextPromptMessageContent(data=f"[Unsupported file type: {f.filename} ({f.type.value})]") + + # Process supported file types params = { "base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "", "url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "", @@ -58,17 +92,7 @@ def to_prompt_message_content( if f.type == FileType.IMAGE: params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW - prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = { - FileType.IMAGE: ImagePromptMessageContent, - FileType.AUDIO: AudioPromptMessageContent, - FileType.VIDEO: VideoPromptMessageContent, - FileType.DOCUMENT: DocumentPromptMessageContent, - } - - try: - return prompt_class_map[f.type].model_validate(params) - except KeyError: - raise ValueError(f"file type {f.type} is not supported") + return prompt_class_map[f.type].model_validate(params) def download(f: File, /): diff --git a/api/core/file/helpers.py b/api/core/file/helpers.py index 73fabdb11b..335ad2266a 100644 --- a/api/core/file/helpers.py +++ b/api/core/file/helpers.py @@ -21,7 +21,9 @@ def get_signed_file_url(upload_file_id: str) -> str: def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str: - url = f"{dify_config.FILES_URL}/files/upload/for-plugin" + # Plugin access should use internal URL for Docker network communication + base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + url = f"{base_url}/files/upload/for-plugin" if user_id is None: user_id = "DEFAULT-USER" diff --git a/api/core/file/models.py b/api/core/file/models.py index aa3b5f629c..f61334e7bc 100644 --- a/api/core/file/models.py +++ b/api/core/file/models.py @@ -51,7 +51,7 @@ class File(BaseModel): # It should be set to `ToolFile.id` when `transfer_method` is `tool_file`. related_id: Optional[str] = None filename: Optional[str] = None - extension: Optional[str] = Field(default=None, description="File extension, should contains dot") + extension: Optional[str] = Field(default=None, description="File extension, should contain dot") mime_type: Optional[str] = None size: int = -1 diff --git a/api/core/file/tool_file_parser.py b/api/core/file/tool_file_parser.py index 656c9d48ed..fac68beb0f 100644 --- a/api/core/file/tool_file_parser.py +++ b/api/core/file/tool_file_parser.py @@ -7,13 +7,6 @@ if TYPE_CHECKING: _tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None -class ToolFileParser: - @staticmethod - def get_tool_file_manager() -> "ToolFileManager": - assert _tool_file_manager_factory is not None - return _tool_file_manager_factory() - - def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]) -> None: global _tool_file_manager_factory _tool_file_manager_factory = factory diff --git a/api/core/file/upload_file_parser.py b/api/core/file/upload_file_parser.py deleted file mode 100644 index 96b2884811..0000000000 --- a/api/core/file/upload_file_parser.py +++ /dev/null @@ -1,67 +0,0 @@ -import base64 -import logging -import time -from typing import Optional - -from configs import dify_config -from constants import IMAGE_EXTENSIONS -from core.helper.url_signer import UrlSigner -from extensions.ext_storage import storage - - -class UploadFileParser: - @classmethod - def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]: - if not upload_file: - return None - - if upload_file.extension not in IMAGE_EXTENSIONS: - return None - - if dify_config.MULTIMODAL_SEND_FORMAT == "url" or force_url: - return cls.get_signed_temp_image_url(upload_file.id) - else: - # get image file base64 - try: - data = storage.load(upload_file.key) - except FileNotFoundError: - logging.exception(f"File not found: {upload_file.key}") - return None - - encoded_string = base64.b64encode(data).decode("utf-8") - return f"data:{upload_file.mime_type};base64,{encoded_string}" - - @classmethod - def get_signed_temp_image_url(cls, upload_file_id) -> str: - """ - get signed url from upload file - - :param upload_file_id: the id of UploadFile object - :return: - """ - base_url = dify_config.FILES_URL - image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview" - - return UrlSigner.get_signed_url(url=image_preview_url, sign_key=upload_file_id, prefix="image-preview") - - @classmethod - def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - """ - verify signature - - :param upload_file_id: file id - :param timestamp: timestamp - :param nonce: nonce - :param sign: signature - :return: - """ - result = UrlSigner.verify( - sign_key=upload_file_id, timestamp=timestamp, nonce=nonce, sign=sign, prefix="image-preview" - ) - - # verify signature - if not result: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index baa792b5bc..b416e48ce4 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -5,6 +5,8 @@ from base64 import b64encode from collections.abc import Mapping from typing import Any +from core.variables.utils import SegmentJSONEncoder + class TemplateTransformer(ABC): _code_placeholder: str = "{{code}}" @@ -28,7 +30,7 @@ class TemplateTransformer(ABC): def extract_result_str_from_response(cls, response: str): result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL) if not result: - raise ValueError("Failed to parse result") + raise ValueError(f"Failed to parse result: no result tag found in response. Response: {response[:200]}...") return result.group(1) @classmethod @@ -38,16 +40,49 @@ class TemplateTransformer(ABC): :param response: response :return: """ + try: - result = json.loads(cls.extract_result_str_from_response(response)) - except json.JSONDecodeError: - raise ValueError("failed to parse response") + result_str = cls.extract_result_str_from_response(response) + result = json.loads(result_str) + except json.JSONDecodeError as e: + raise ValueError(f"Failed to parse JSON response: {str(e)}.") + except ValueError as e: + # Re-raise ValueError from extract_result_str_from_response + raise e + except Exception as e: + raise ValueError(f"Unexpected error during response transformation: {str(e)}") + if not isinstance(result, dict): - raise ValueError("result must be a dict") + raise ValueError(f"Result must be a dict, got {type(result).__name__}") if not all(isinstance(k, str) for k in result): - raise ValueError("result keys must be strings") + raise ValueError("Result keys must be strings") + + # Post-process the result to convert scientific notation strings back to numbers + result = cls._post_process_result(result) return result + @classmethod + def _post_process_result(cls, result: dict[Any, Any]) -> dict[Any, Any]: + """ + Post-process the result to convert scientific notation strings back to numbers + """ + + def convert_scientific_notation(value): + if isinstance(value, str): + # Check if the string looks like scientific notation + if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE): + try: + return float(value) + except ValueError: + pass + elif isinstance(value, dict): + return {k: convert_scientific_notation(v) for k, v in value.items()} + elif isinstance(value, list): + return [convert_scientific_notation(v) for v in value] + return value + + return convert_scientific_notation(result) # type: ignore[no-any-return] + @classmethod @abstractmethod def get_runner_script(cls) -> str: @@ -58,7 +93,7 @@ class TemplateTransformer(ABC): @classmethod def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str: - inputs_json_str = json.dumps(inputs, ensure_ascii=False).encode() + inputs_json_str = json.dumps(inputs, ensure_ascii=False, cls=SegmentJSONEncoder).encode() input_base64_encoded = b64encode(inputs_json_str).decode("utf-8") return input_base64_encoded diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py index 744fce1cf9..f761d20374 100644 --- a/api/core/helper/encrypter.py +++ b/api/core/helper/encrypter.py @@ -15,13 +15,13 @@ def encrypt_token(tenant_id: str, token: str): from models.account import Tenant from models.engine import db - if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()): + if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()): raise ValueError(f"Tenant with id {tenant_id} not found") encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) return base64.b64encode(encrypted_token).decode() -def decrypt_token(tenant_id: str, token: str): +def decrypt_token(tenant_id: str, token: str) -> str: return rsa.decrypt(base64.b64decode(token), tenant_id) diff --git a/api/core/helper/lru_cache.py b/api/core/helper/lru_cache.py deleted file mode 100644 index 81501d2e4e..0000000000 --- a/api/core/helper/lru_cache.py +++ /dev/null @@ -1,22 +0,0 @@ -from collections import OrderedDict -from typing import Any - - -class LRUCache: - def __init__(self, capacity: int): - self.cache: OrderedDict[Any, Any] = OrderedDict() - self.capacity = capacity - - def get(self, key: Any) -> Any: - if key not in self.cache: - return None - else: - self.cache.move_to_end(key) # move the key to the end of the OrderedDict - return self.cache[key] - - def put(self, key: Any, value: Any) -> None: - if key in self.cache: - self.cache.move_to_end(key) - self.cache[key] = value - if len(self.cache) > self.capacity: - self.cache.popitem(last=False) # pop the first item diff --git a/api/core/helper/marketplace.py b/api/core/helper/marketplace.py index 65bf4fc1db..fe3078923d 100644 --- a/api/core/helper/marketplace.py +++ b/api/core/helper/marketplace.py @@ -25,9 +25,29 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP url = str(marketplace_api_url / "api/v1/plugins/batch") response = requests.post(url, json={"plugin_ids": plugin_ids}) response.raise_for_status() + return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]] +def batch_fetch_plugin_manifests_ignore_deserialization_error( + plugin_ids: list[str], +) -> Sequence[MarketplacePluginDeclaration]: + if len(plugin_ids) == 0: + return [] + + url = str(marketplace_api_url / "api/v1/plugins/batch") + response = requests.post(url, json={"plugin_ids": plugin_ids}) + response.raise_for_status() + result: list[MarketplacePluginDeclaration] = [] + for plugin in response.json()["data"]["plugins"]: + try: + result.append(MarketplacePluginDeclaration(**plugin)) + except Exception as e: + pass + + return result + + def record_install_plugin_event(plugin_unique_identifier: str): url = str(marketplace_api_url / "api/v1/stats/plugins/install_count") response = requests.post(url, json={"unique_identifier": plugin_unique_identifier}) diff --git a/api/core/helper/provider_cache.py b/api/core/helper/provider_cache.py new file mode 100644 index 0000000000..48ec3be5c8 --- /dev/null +++ b/api/core/helper/provider_cache.py @@ -0,0 +1,84 @@ +import json +from abc import ABC, abstractmethod +from json import JSONDecodeError +from typing import Any, Optional + +from extensions.ext_redis import redis_client + + +class ProviderCredentialsCache(ABC): + """Base class for provider credentials cache""" + + def __init__(self, **kwargs): + self.cache_key = self._generate_cache_key(**kwargs) + + @abstractmethod + def _generate_cache_key(self, **kwargs) -> str: + """Generate cache key based on subclass implementation""" + pass + + def get(self) -> Optional[dict]: + """Get cached provider credentials""" + cached_credentials = redis_client.get(self.cache_key) + if cached_credentials: + try: + cached_credentials = cached_credentials.decode("utf-8") + return dict(json.loads(cached_credentials)) + except JSONDecodeError: + return None + return None + + def set(self, config: dict[str, Any]) -> None: + """Cache provider credentials""" + redis_client.setex(self.cache_key, 86400, json.dumps(config)) + + def delete(self) -> None: + """Delete cached provider credentials""" + redis_client.delete(self.cache_key) + + +class SingletonProviderCredentialsCache(ProviderCredentialsCache): + """Cache for tool single provider credentials""" + + def __init__(self, tenant_id: str, provider_type: str, provider_identity: str): + super().__init__( + tenant_id=tenant_id, + provider_type=provider_type, + provider_identity=provider_identity, + ) + + def _generate_cache_key(self, **kwargs) -> str: + tenant_id = kwargs["tenant_id"] + provider_type = kwargs["provider_type"] + identity_name = kwargs["provider_identity"] + identity_id = f"{provider_type}.{identity_name}" + return f"{provider_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}" + + +class ToolProviderCredentialsCache(ProviderCredentialsCache): + """Cache for tool provider credentials""" + + def __init__(self, tenant_id: str, provider: str, credential_id: str): + super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id) + + def _generate_cache_key(self, **kwargs) -> str: + tenant_id = kwargs["tenant_id"] + provider = kwargs["provider"] + credential_id = kwargs["credential_id"] + return f"tool_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}" + + +class NoOpProviderCredentialCache: + """No-op provider credential cache""" + + def get(self) -> Optional[dict]: + """Get cached provider credentials""" + return None + + def set(self, config: dict[str, Any]) -> None: + """Cache provider credentials""" + pass + + def delete(self) -> None: + """Delete cached provider credentials""" + pass diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py deleted file mode 100644 index 2e4a04c579..0000000000 --- a/api/core/helper/tool_provider_cache.py +++ /dev/null @@ -1,51 +0,0 @@ -import json -from enum import Enum -from json import JSONDecodeError -from typing import Optional - -from extensions.ext_redis import redis_client - - -class ToolProviderCredentialsCacheType(Enum): - PROVIDER = "tool_provider" - ENDPOINT = "endpoint" - - -class ToolProviderCredentialsCache: - def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType): - self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}" - - def get(self) -> Optional[dict]: - """ - Get cached model provider credentials. - - :return: - """ - cached_provider_credentials = redis_client.get(self.cache_key) - if cached_provider_credentials: - try: - cached_provider_credentials = cached_provider_credentials.decode("utf-8") - cached_provider_credentials = json.loads(cached_provider_credentials) - except JSONDecodeError: - return None - - return dict(cached_provider_credentials) - else: - return None - - def set(self, credentials: dict) -> None: - """ - Cache model provider credentials. - - :param credentials: provider credentials - :return: - """ - redis_client.setex(self.cache_key, 86400, json.dumps(credentials)) - - def delete(self) -> None: - """ - Delete cached model provider credentials. - - :return: - """ - redis_client.delete(self.cache_key) diff --git a/api/core/helper/trace_id_helper.py b/api/core/helper/trace_id_helper.py new file mode 100644 index 0000000000..e90c3194f2 --- /dev/null +++ b/api/core/helper/trace_id_helper.py @@ -0,0 +1,42 @@ +import re +from collections.abc import Mapping +from typing import Any, Optional + + +def is_valid_trace_id(trace_id: str) -> bool: + """ + Check if the trace_id is valid. + + Requirements: 1-128 characters, only letters, numbers, '-', and '_'. + """ + return bool(re.match(r"^[a-zA-Z0-9\-_]{1,128}$", trace_id)) + + +def get_external_trace_id(request: Any) -> Optional[str]: + """ + Retrieve the trace_id from the request. + + Priority: header ('X-Trace-Id'), then parameters, then JSON body. Returns None if not provided or invalid. + """ + trace_id = request.headers.get("X-Trace-Id") + if not trace_id: + trace_id = request.args.get("trace_id") + if not trace_id and getattr(request, "is_json", False): + json_data = getattr(request, "json", None) + if json_data: + trace_id = json_data.get("trace_id") + if isinstance(trace_id, str) and is_valid_trace_id(trace_id): + return trace_id + return None + + +def extract_external_trace_id_from_args(args: Mapping[str, Any]) -> dict: + """ + Extract 'external_trace_id' from args. + + Returns a dict suitable for use in extras. Returns an empty dict if not found. + """ + trace_id = args.get("external_trace_id") + if trace_id: + return {"external_trace_id": trace_id} + return {} diff --git a/api/core/helper/url_signer.py b/api/core/helper/url_signer.py deleted file mode 100644 index dfb143f4c4..0000000000 --- a/api/core/helper/url_signer.py +++ /dev/null @@ -1,52 +0,0 @@ -import base64 -import hashlib -import hmac -import os -import time - -from pydantic import BaseModel, Field - -from configs import dify_config - - -class SignedUrlParams(BaseModel): - sign_key: str = Field(..., description="The sign key") - timestamp: str = Field(..., description="Timestamp") - nonce: str = Field(..., description="Nonce") - sign: str = Field(..., description="Signature") - - -class UrlSigner: - @classmethod - def get_signed_url(cls, url: str, sign_key: str, prefix: str) -> str: - signed_url_params = cls.get_signed_url_params(sign_key, prefix) - return ( - f"{url}?timestamp={signed_url_params.timestamp}" - f"&nonce={signed_url_params.nonce}&sign={signed_url_params.sign}" - ) - - @classmethod - def get_signed_url_params(cls, sign_key: str, prefix: str) -> SignedUrlParams: - timestamp = str(int(time.time())) - nonce = os.urandom(16).hex() - sign = cls._sign(sign_key, timestamp, nonce, prefix) - - return SignedUrlParams(sign_key=sign_key, timestamp=timestamp, nonce=nonce, sign=sign) - - @classmethod - def verify(cls, sign_key: str, timestamp: str, nonce: str, sign: str, prefix: str) -> bool: - recalculated_sign = cls._sign(sign_key, timestamp, nonce, prefix) - - return sign == recalculated_sign - - @classmethod - def _sign(cls, sign_key: str, timestamp: str, nonce: str, prefix: str) -> str: - if not dify_config.SECRET_KEY: - raise Exception("SECRET_KEY is not set") - - data_to_sign = f"{prefix}|{sign_key}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() - sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - encoded_sign = base64.urlsafe_b64encode(sign).decode() - - return encoded_sign diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index f2fe306179..fc5d0547fc 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -59,7 +59,7 @@ class IndexingRunner: # get the process rule processing_rule = ( db.session.query(DatasetProcessRule) - .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .first() ) if not processing_rule: @@ -119,12 +119,12 @@ class IndexingRunner: db.session.delete(document_segment) if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: # delete child chunks - db.session.query(ChildChunk).filter(ChildChunk.segment_id == document_segment.id).delete() + db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete() db.session.commit() # get the process rule processing_rule = ( db.session.query(DatasetProcessRule) - .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .first() ) if not processing_rule: @@ -212,7 +212,7 @@ class IndexingRunner: # get the process rule processing_rule = ( db.session.query(DatasetProcessRule) - .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .first() ) @@ -316,10 +316,11 @@ class IndexingRunner: # delete image files and related db records image_upload_file_ids = get_image_upload_file_ids(document.page_content) for upload_file_id in image_upload_file_ids: - image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() + if image_file is None: + continue try: - if image_file: - storage.delete(image_file.key) + storage.delete(image_file.key) except Exception: logging.exception( "Delete image_files failed while indexing_estimate, \ @@ -345,7 +346,7 @@ class IndexingRunner: raise ValueError("no upload file found") file_detail = ( - db.session.query(UploadFile).filter(UploadFile.id == data_source_info["upload_file_id"]).one_or_none() + db.session.query(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]).one_or_none() ) if file_detail: @@ -598,7 +599,7 @@ class IndexingRunner: keyword.create(documents) if dataset.indexing_technique != "high_quality": document_ids = [document.metadata["doc_id"] for document in documents] - db.session.query(DocumentSegment).filter( + db.session.query(DocumentSegment).where( DocumentSegment.document_id == document_id, DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id.in_(document_ids), @@ -629,7 +630,7 @@ class IndexingRunner: index_processor.load(dataset, chunk_documents, with_keywords=False) document_ids = [document.metadata["doc_id"] for document in chunk_documents] - db.session.query(DocumentSegment).filter( + db.session.query(DocumentSegment).where( DocumentSegment.document_id == dataset_document.id, DocumentSegment.dataset_id == dataset.id, DocumentSegment.index_node_id.in_(document_ids), @@ -671,8 +672,7 @@ class IndexingRunner: if extra_update_params: update_params.update(extra_update_params) - - db.session.query(DatasetDocument).filter_by(id=document_id).update(update_params) + db.session.query(DatasetDocument).filter_by(id=document_id).update(update_params) # type: ignore db.session.commit() @staticmethod diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index e01896a491..331ac933c8 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -114,7 +114,8 @@ class LLMGenerator: ), ) - questions = output_parser.parse(cast(str, response.message.content)) + text_content = response.message.get_text_content() + questions = output_parser.parse(text_content) if text_content else [] except InvokeError: questions = [] except Exception: @@ -148,9 +149,11 @@ class LLMGenerator: model_manager = ModelManager() - model_instance = model_manager.get_default_model_instance( + model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), ) try: diff --git a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py index c451bf514c..98cdc4c8b7 100644 --- a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py +++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py @@ -15,5 +15,4 @@ class SuggestedQuestionsAfterAnswerOutputParser: json_obj = json.loads(action_match.group(0).strip()) else: json_obj = [] - return json_obj diff --git a/api/core/mcp/__init__.py b/api/core/mcp/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py new file mode 100644 index 0000000000..bcb31a816f --- /dev/null +++ b/api/core/mcp/auth/auth_flow.py @@ -0,0 +1,342 @@ +import base64 +import hashlib +import json +import os +import secrets +import urllib.parse +from typing import Optional +from urllib.parse import urljoin + +import requests +from pydantic import BaseModel, ValidationError + +from core.mcp.auth.auth_provider import OAuthClientProvider +from core.mcp.types import ( + OAuthClientInformation, + OAuthClientInformationFull, + OAuthClientMetadata, + OAuthMetadata, + OAuthTokens, +) +from extensions.ext_redis import redis_client + +LATEST_PROTOCOL_VERSION = "1.0" +OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry +OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:" + + +class OAuthCallbackState(BaseModel): + provider_id: str + tenant_id: str + server_url: str + metadata: OAuthMetadata | None = None + client_information: OAuthClientInformation + code_verifier: str + redirect_uri: str + + +def generate_pkce_challenge() -> tuple[str, str]: + """Generate PKCE challenge and verifier.""" + code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8") + code_verifier = code_verifier.replace("=", "").replace("+", "-").replace("/", "_") + + code_challenge_hash = hashlib.sha256(code_verifier.encode("utf-8")).digest() + code_challenge = base64.urlsafe_b64encode(code_challenge_hash).decode("utf-8") + code_challenge = code_challenge.replace("=", "").replace("+", "-").replace("/", "_") + + return code_verifier, code_challenge + + +def _create_secure_redis_state(state_data: OAuthCallbackState) -> str: + """Create a secure state parameter by storing state data in Redis and returning a random state key.""" + # Generate a secure random state key + state_key = secrets.token_urlsafe(32) + + # Store the state data in Redis with expiration + redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}" + redis_client.setex(redis_key, OAUTH_STATE_EXPIRY_SECONDS, state_data.model_dump_json()) + + return state_key + + +def _retrieve_redis_state(state_key: str) -> OAuthCallbackState: + """Retrieve and decode OAuth state data from Redis using the state key, then delete it.""" + redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}" + + # Get state data from Redis + state_data = redis_client.get(redis_key) + + if not state_data: + raise ValueError("State parameter has expired or does not exist") + + # Delete the state data from Redis immediately after retrieval to prevent reuse + redis_client.delete(redis_key) + + try: + # Parse and validate the state data + oauth_state = OAuthCallbackState.model_validate_json(state_data) + + return oauth_state + except ValidationError as e: + raise ValueError(f"Invalid state parameter: {str(e)}") + + +def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState: + """Handle the callback from the OAuth provider.""" + # Retrieve state data from Redis (state is automatically deleted after retrieval) + full_state_data = _retrieve_redis_state(state_key) + + tokens = exchange_authorization( + full_state_data.server_url, + full_state_data.metadata, + full_state_data.client_information, + authorization_code, + full_state_data.code_verifier, + full_state_data.redirect_uri, + ) + provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True) + provider.save_tokens(tokens) + return full_state_data + + +def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]: + """Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata.""" + url = urljoin(server_url, "/.well-known/oauth-authorization-server") + + try: + headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION} + response = requests.get(url, headers=headers) + if response.status_code == 404: + return None + if not response.ok: + raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata") + return OAuthMetadata.model_validate(response.json()) + except requests.RequestException as e: + if isinstance(e, requests.ConnectionError): + response = requests.get(url) + if response.status_code == 404: + return None + if not response.ok: + raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata") + return OAuthMetadata.model_validate(response.json()) + raise + + +def start_authorization( + server_url: str, + metadata: Optional[OAuthMetadata], + client_information: OAuthClientInformation, + redirect_url: str, + provider_id: str, + tenant_id: str, +) -> tuple[str, str]: + """Begins the authorization flow with secure Redis state storage.""" + response_type = "code" + code_challenge_method = "S256" + + if metadata: + authorization_url = metadata.authorization_endpoint + if response_type not in metadata.response_types_supported: + raise ValueError(f"Incompatible auth server: does not support response type {response_type}") + if ( + not metadata.code_challenge_methods_supported + or code_challenge_method not in metadata.code_challenge_methods_supported + ): + raise ValueError( + f"Incompatible auth server: does not support code challenge method {code_challenge_method}" + ) + else: + authorization_url = urljoin(server_url, "/authorize") + + code_verifier, code_challenge = generate_pkce_challenge() + + # Prepare state data with all necessary information + state_data = OAuthCallbackState( + provider_id=provider_id, + tenant_id=tenant_id, + server_url=server_url, + metadata=metadata, + client_information=client_information, + code_verifier=code_verifier, + redirect_uri=redirect_url, + ) + + # Store state data in Redis and generate secure state key + state_key = _create_secure_redis_state(state_data) + + params = { + "response_type": response_type, + "client_id": client_information.client_id, + "code_challenge": code_challenge, + "code_challenge_method": code_challenge_method, + "redirect_uri": redirect_url, + "state": state_key, + } + + authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}" + return authorization_url, code_verifier + + +def exchange_authorization( + server_url: str, + metadata: Optional[OAuthMetadata], + client_information: OAuthClientInformation, + authorization_code: str, + code_verifier: str, + redirect_uri: str, +) -> OAuthTokens: + """Exchanges an authorization code for an access token.""" + grant_type = "authorization_code" + + if metadata: + token_url = metadata.token_endpoint + if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported: + raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}") + else: + token_url = urljoin(server_url, "/token") + + params = { + "grant_type": grant_type, + "client_id": client_information.client_id, + "code": authorization_code, + "code_verifier": code_verifier, + "redirect_uri": redirect_uri, + } + + if client_information.client_secret: + params["client_secret"] = client_information.client_secret + + response = requests.post(token_url, data=params) + if not response.ok: + raise ValueError(f"Token exchange failed: HTTP {response.status_code}") + return OAuthTokens.model_validate(response.json()) + + +def refresh_authorization( + server_url: str, + metadata: Optional[OAuthMetadata], + client_information: OAuthClientInformation, + refresh_token: str, +) -> OAuthTokens: + """Exchange a refresh token for an updated access token.""" + grant_type = "refresh_token" + + if metadata: + token_url = metadata.token_endpoint + if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported: + raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}") + else: + token_url = urljoin(server_url, "/token") + + params = { + "grant_type": grant_type, + "client_id": client_information.client_id, + "refresh_token": refresh_token, + } + + if client_information.client_secret: + params["client_secret"] = client_information.client_secret + + response = requests.post(token_url, data=params) + if not response.ok: + raise ValueError(f"Token refresh failed: HTTP {response.status_code}") + return OAuthTokens.model_validate(response.json()) + + +def register_client( + server_url: str, + metadata: Optional[OAuthMetadata], + client_metadata: OAuthClientMetadata, +) -> OAuthClientInformationFull: + """Performs OAuth 2.0 Dynamic Client Registration.""" + if metadata: + if not metadata.registration_endpoint: + raise ValueError("Incompatible auth server: does not support dynamic client registration") + registration_url = metadata.registration_endpoint + else: + registration_url = urljoin(server_url, "/register") + + response = requests.post( + registration_url, + json=client_metadata.model_dump(), + headers={"Content-Type": "application/json"}, + ) + if not response.ok: + response.raise_for_status() + return OAuthClientInformationFull.model_validate(response.json()) + + +def auth( + provider: OAuthClientProvider, + server_url: str, + authorization_code: Optional[str] = None, + state_param: Optional[str] = None, + for_list: bool = False, +) -> dict[str, str]: + """Orchestrates the full auth flow with a server using secure Redis state storage.""" + metadata = discover_oauth_metadata(server_url) + + # Handle client registration if needed + client_information = provider.client_information() + if not client_information: + if authorization_code is not None: + raise ValueError("Existing OAuth client information is required when exchanging an authorization code") + try: + full_information = register_client(server_url, metadata, provider.client_metadata) + except requests.RequestException as e: + raise ValueError(f"Could not register OAuth client: {e}") + provider.save_client_information(full_information) + client_information = full_information + + # Exchange authorization code for tokens + if authorization_code is not None: + if not state_param: + raise ValueError("State parameter is required when exchanging authorization code") + + try: + # Retrieve state data from Redis using state key + full_state_data = _retrieve_redis_state(state_param) + + code_verifier = full_state_data.code_verifier + redirect_uri = full_state_data.redirect_uri + + if not code_verifier or not redirect_uri: + raise ValueError("Missing code_verifier or redirect_uri in state data") + + except (json.JSONDecodeError, ValueError) as e: + raise ValueError(f"Invalid state parameter: {e}") + + tokens = exchange_authorization( + server_url, + metadata, + client_information, + authorization_code, + code_verifier, + redirect_uri, + ) + provider.save_tokens(tokens) + return {"result": "success"} + + provider_tokens = provider.tokens() + + # Handle token refresh or new authorization + if provider_tokens and provider_tokens.refresh_token: + try: + new_tokens = refresh_authorization(server_url, metadata, client_information, provider_tokens.refresh_token) + provider.save_tokens(new_tokens) + return {"result": "success"} + except Exception as e: + raise ValueError(f"Could not refresh OAuth tokens: {e}") + + # Start new authorization flow + authorization_url, code_verifier = start_authorization( + server_url, + metadata, + client_information, + provider.redirect_url, + provider.mcp_provider.id, + provider.mcp_provider.tenant_id, + ) + + provider.save_code_verifier(code_verifier) + return {"authorization_url": authorization_url} diff --git a/api/core/mcp/auth/auth_provider.py b/api/core/mcp/auth/auth_provider.py new file mode 100644 index 0000000000..00d5a25956 --- /dev/null +++ b/api/core/mcp/auth/auth_provider.py @@ -0,0 +1,81 @@ +from typing import Optional + +from configs import dify_config +from core.mcp.types import ( + OAuthClientInformation, + OAuthClientInformationFull, + OAuthClientMetadata, + OAuthTokens, +) +from models.tools import MCPToolProvider +from services.tools.mcp_tools_manage_service import MCPToolManageService + +LATEST_PROTOCOL_VERSION = "1.0" + + +class OAuthClientProvider: + mcp_provider: MCPToolProvider + + def __init__(self, provider_id: str, tenant_id: str, for_list: bool = False): + if for_list: + self.mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id) + else: + self.mcp_provider = MCPToolManageService.get_mcp_provider_by_server_identifier(provider_id, tenant_id) + + @property + def redirect_url(self) -> str: + """The URL to redirect the user agent to after authorization.""" + return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback" + + @property + def client_metadata(self) -> OAuthClientMetadata: + """Metadata about this OAuth client.""" + return OAuthClientMetadata( + redirect_uris=[self.redirect_url], + token_endpoint_auth_method="none", + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + client_name="Dify", + client_uri="https://github.com/langgenius/dify", + ) + + def client_information(self) -> Optional[OAuthClientInformation]: + """Loads information about this OAuth client.""" + client_information = self.mcp_provider.decrypted_credentials.get("client_information", {}) + if not client_information: + return None + return OAuthClientInformation.model_validate(client_information) + + def save_client_information(self, client_information: OAuthClientInformationFull) -> None: + """Saves client information after dynamic registration.""" + MCPToolManageService.update_mcp_provider_credentials( + self.mcp_provider, + {"client_information": client_information.model_dump()}, + ) + + def tokens(self) -> Optional[OAuthTokens]: + """Loads any existing OAuth tokens for the current session.""" + credentials = self.mcp_provider.decrypted_credentials + if not credentials: + return None + return OAuthTokens( + access_token=credentials.get("access_token", ""), + token_type=credentials.get("token_type", "Bearer"), + expires_in=int(credentials.get("expires_in", "3600") or 3600), + refresh_token=credentials.get("refresh_token", ""), + ) + + def save_tokens(self, tokens: OAuthTokens) -> None: + """Stores new OAuth tokens for the current session.""" + # update mcp provider credentials + token_dict = tokens.model_dump() + MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True) + + def save_code_verifier(self, code_verifier: str) -> None: + """Saves a PKCE code verifier for the current session.""" + MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier}) + + def code_verifier(self) -> str: + """Loads the PKCE code verifier for the current session.""" + # get code verifier from mcp provider credentials + return str(self.mcp_provider.decrypted_credentials.get("code_verifier", "")) diff --git a/api/core/mcp/client/sse_client.py b/api/core/mcp/client/sse_client.py new file mode 100644 index 0000000000..91debcc8f9 --- /dev/null +++ b/api/core/mcp/client/sse_client.py @@ -0,0 +1,361 @@ +import logging +import queue +from collections.abc import Generator +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager +from typing import Any, TypeAlias, final +from urllib.parse import urljoin, urlparse + +import httpx +from sseclient import SSEClient + +from core.mcp import types +from core.mcp.error import MCPAuthError, MCPConnectionError +from core.mcp.types import SessionMessage +from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect + +logger = logging.getLogger(__name__) + +DEFAULT_QUEUE_READ_TIMEOUT = 3 + + +@final +class _StatusReady: + def __init__(self, endpoint_url: str): + self._endpoint_url = endpoint_url + + +@final +class _StatusError: + def __init__(self, exc: Exception): + self._exc = exc + + +# Type aliases for better readability +ReadQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None] +WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None] +StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError] + + +def remove_request_params(url: str) -> str: + """Remove request parameters from URL, keeping only the path.""" + return urljoin(url, urlparse(url).path) + + +class SSETransport: + """SSE client transport implementation.""" + + def __init__( + self, + url: str, + headers: dict[str, Any] | None = None, + timeout: float = 5.0, + sse_read_timeout: float = 5 * 60, + ) -> None: + """Initialize the SSE transport. + + Args: + url: The SSE endpoint URL. + headers: Optional headers to include in requests. + timeout: HTTP timeout for regular operations. + sse_read_timeout: Timeout for SSE read operations. + """ + self.url = url + self.headers = headers or {} + self.timeout = timeout + self.sse_read_timeout = sse_read_timeout + self.endpoint_url: str | None = None + + def _validate_endpoint_url(self, endpoint_url: str) -> bool: + """Validate that the endpoint URL matches the connection origin. + + Args: + endpoint_url: The endpoint URL to validate. + + Returns: + True if valid, False otherwise. + """ + url_parsed = urlparse(self.url) + endpoint_parsed = urlparse(endpoint_url) + + return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme + + def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue) -> None: + """Handle an 'endpoint' SSE event. + + Args: + sse_data: The SSE event data. + status_queue: Queue to put status updates. + """ + endpoint_url = urljoin(self.url, sse_data) + logger.info(f"Received endpoint URL: {endpoint_url}") + + if not self._validate_endpoint_url(endpoint_url): + error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}" + logger.error(error_msg) + status_queue.put(_StatusError(ValueError(error_msg))) + return + + status_queue.put(_StatusReady(endpoint_url)) + + def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None: + """Handle a 'message' SSE event. + + Args: + sse_data: The SSE event data. + read_queue: Queue to put parsed messages. + """ + try: + message = types.JSONRPCMessage.model_validate_json(sse_data) + logger.debug(f"Received server message: {message}") + session_message = SessionMessage(message) + read_queue.put(session_message) + except Exception as exc: + logger.exception("Error parsing server message") + read_queue.put(exc) + + def _handle_sse_event(self, sse, read_queue: ReadQueue, status_queue: StatusQueue) -> None: + """Handle a single SSE event. + + Args: + sse: The SSE event object. + read_queue: Queue for message events. + status_queue: Queue for status events. + """ + match sse.event: + case "endpoint": + self._handle_endpoint_event(sse.data, status_queue) + case "message": + self._handle_message_event(sse.data, read_queue) + case _: + logger.warning(f"Unknown SSE event: {sse.event}") + + def sse_reader(self, event_source, read_queue: ReadQueue, status_queue: StatusQueue) -> None: + """Read and process SSE events. + + Args: + event_source: The SSE event source. + read_queue: Queue to put received messages. + status_queue: Queue to put status updates. + """ + try: + for sse in event_source.iter_sse(): + self._handle_sse_event(sse, read_queue, status_queue) + except httpx.ReadError as exc: + logger.debug(f"SSE reader shutting down normally: {exc}") + except Exception as exc: + read_queue.put(exc) + finally: + read_queue.put(None) + + def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage) -> None: + """Send a single message to the server. + + Args: + client: HTTP client to use. + endpoint_url: The endpoint URL to send to. + message: The message to send. + """ + response = client.post( + endpoint_url, + json=message.message.model_dump( + by_alias=True, + mode="json", + exclude_none=True, + ), + ) + response.raise_for_status() + logger.debug(f"Client message sent successfully: {response.status_code}") + + def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None: + """Handle writing messages to the server. + + Args: + client: HTTP client to use. + endpoint_url: The endpoint URL to send messages to. + write_queue: Queue to read messages from. + """ + try: + while True: + try: + message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT) + if message is None: + break + if isinstance(message, Exception): + write_queue.put(message) + continue + + self._send_message(client, endpoint_url, message) + + except queue.Empty: + continue + except httpx.ReadError as exc: + logger.debug(f"Post writer shutting down normally: {exc}") + except Exception as exc: + logger.exception("Error writing messages") + write_queue.put(exc) + finally: + write_queue.put(None) + + def _wait_for_endpoint(self, status_queue: StatusQueue) -> str: + """Wait for the endpoint URL from the status queue. + + Args: + status_queue: Queue to read status from. + + Returns: + The endpoint URL. + + Raises: + ValueError: If endpoint URL is not received or there's an error. + """ + try: + status = status_queue.get(timeout=1) + except queue.Empty: + raise ValueError("failed to get endpoint URL") + + if isinstance(status, _StatusReady): + return status._endpoint_url + elif isinstance(status, _StatusError): + raise status._exc + else: + raise ValueError("failed to get endpoint URL") + + def connect( + self, + executor: ThreadPoolExecutor, + client: httpx.Client, + event_source, + ) -> tuple[ReadQueue, WriteQueue]: + """Establish connection and start worker threads. + + Args: + executor: Thread pool executor. + client: HTTP client. + event_source: SSE event source. + + Returns: + Tuple of (read_queue, write_queue). + """ + read_queue: ReadQueue = queue.Queue() + write_queue: WriteQueue = queue.Queue() + status_queue: StatusQueue = queue.Queue() + + # Start SSE reader thread + executor.submit(self.sse_reader, event_source, read_queue, status_queue) + + # Wait for endpoint URL + endpoint_url = self._wait_for_endpoint(status_queue) + self.endpoint_url = endpoint_url + + # Start post writer thread + executor.submit(self.post_writer, client, endpoint_url, write_queue) + + return read_queue, write_queue + + +@contextmanager +def sse_client( + url: str, + headers: dict[str, Any] | None = None, + timeout: float = 5.0, + sse_read_timeout: float = 5 * 60, +) -> Generator[tuple[ReadQueue, WriteQueue], None, None]: + """ + Client transport for SSE. + `sse_read_timeout` determines how long (in seconds) the client will wait for a new + event before disconnecting. All other HTTP operations are controlled by `timeout`. + + Args: + url: The SSE endpoint URL. + headers: Optional headers to include in requests. + timeout: HTTP timeout for regular operations. + sse_read_timeout: Timeout for SSE read operations. + + Yields: + Tuple of (read_queue, write_queue) for message communication. + """ + transport = SSETransport(url, headers, timeout, sse_read_timeout) + + read_queue: ReadQueue | None = None + write_queue: WriteQueue | None = None + + with ThreadPoolExecutor() as executor: + try: + with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client: + with ssrf_proxy_sse_connect( + url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client + ) as event_source: + event_source.response.raise_for_status() + + read_queue, write_queue = transport.connect(executor, client, event_source) + + yield read_queue, write_queue + + except httpx.HTTPStatusError as exc: + if exc.response.status_code == 401: + raise MCPAuthError() + raise MCPConnectionError() + except Exception: + logger.exception("Error connecting to SSE endpoint") + raise + finally: + # Clean up queues + if read_queue: + read_queue.put(None) + if write_queue: + write_queue.put(None) + + +def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage) -> None: + """ + Send a message to the server using the provided HTTP client. + + Args: + http_client: The HTTP client to use for sending + endpoint_url: The endpoint URL to send the message to + session_message: The message to send + """ + try: + response = http_client.post( + endpoint_url, + json=session_message.message.model_dump( + by_alias=True, + mode="json", + exclude_none=True, + ), + ) + response.raise_for_status() + logger.debug(f"Client message sent successfully: {response.status_code}") + except Exception as exc: + logger.exception("Error sending message") + raise + + +def read_messages( + sse_client: SSEClient, +) -> Generator[SessionMessage | Exception, None, None]: + """ + Read messages from the SSE client. + + Args: + sse_client: The SSE client to read from + + Yields: + SessionMessage or Exception for each event received + """ + try: + for sse in sse_client.events(): + if sse.event == "message": + try: + message = types.JSONRPCMessage.model_validate_json(sse.data) + logger.debug(f"Received server message: {message}") + yield SessionMessage(message) + except Exception as exc: + logger.exception("Error parsing server message") + yield exc + else: + logger.warning(f"Unknown SSE event: {sse.event}") + except Exception as exc: + logger.exception("Error reading SSE messages") + yield exc diff --git a/api/core/mcp/client/streamable_client.py b/api/core/mcp/client/streamable_client.py new file mode 100644 index 0000000000..fbd8d05f9e --- /dev/null +++ b/api/core/mcp/client/streamable_client.py @@ -0,0 +1,476 @@ +""" +StreamableHTTP Client Transport Module + +This module implements the StreamableHTTP transport for MCP clients, +providing support for HTTP POST requests with optional SSE streaming responses +and session management. +""" + +import logging +import queue +from collections.abc import Callable, Generator +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager +from dataclasses import dataclass +from datetime import timedelta +from typing import Any, cast + +import httpx +from httpx_sse import EventSource, ServerSentEvent + +from core.mcp.types import ( + ClientMessageMetadata, + ErrorData, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + RequestId, + SessionMessage, +) +from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect + +logger = logging.getLogger(__name__) + + +SessionMessageOrError = SessionMessage | Exception | None +# Queue types with clearer names for their roles +ServerToClientQueue = queue.Queue[SessionMessageOrError] # Server to client messages +ClientToServerQueue = queue.Queue[SessionMessage | None] # Client to server messages +GetSessionIdCallback = Callable[[], str | None] + +MCP_SESSION_ID = "mcp-session-id" +LAST_EVENT_ID = "last-event-id" +CONTENT_TYPE = "content-type" +ACCEPT = "Accept" + + +JSON = "application/json" +SSE = "text/event-stream" + +DEFAULT_QUEUE_READ_TIMEOUT = 3 + + +class StreamableHTTPError(Exception): + """Base exception for StreamableHTTP transport errors.""" + + pass + + +class ResumptionError(StreamableHTTPError): + """Raised when resumption request is invalid.""" + + pass + + +@dataclass +class RequestContext: + """Context for a request operation.""" + + client: httpx.Client + headers: dict[str, str] + session_id: str | None + session_message: SessionMessage + metadata: ClientMessageMetadata | None + server_to_client_queue: ServerToClientQueue # Renamed for clarity + sse_read_timeout: timedelta + + +class StreamableHTTPTransport: + """StreamableHTTP client transport implementation.""" + + def __init__( + self, + url: str, + headers: dict[str, Any] | None = None, + timeout: timedelta = timedelta(seconds=30), + sse_read_timeout: timedelta = timedelta(seconds=60 * 5), + ) -> None: + """Initialize the StreamableHTTP transport. + + Args: + url: The endpoint URL. + headers: Optional headers to include in requests. + timeout: HTTP timeout for regular operations. + sse_read_timeout: Timeout for SSE read operations. + """ + self.url = url + self.headers = headers or {} + self.timeout = timeout + self.sse_read_timeout = sse_read_timeout + self.session_id: str | None = None + self.request_headers = { + ACCEPT: f"{JSON}, {SSE}", + CONTENT_TYPE: JSON, + **self.headers, + } + + def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]: + """Update headers with session ID if available.""" + headers = base_headers.copy() + if self.session_id: + headers[MCP_SESSION_ID] = self.session_id + return headers + + def _is_initialization_request(self, message: JSONRPCMessage) -> bool: + """Check if the message is an initialization request.""" + return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize" + + def _is_initialized_notification(self, message: JSONRPCMessage) -> bool: + """Check if the message is an initialized notification.""" + return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized" + + def _maybe_extract_session_id_from_response( + self, + response: httpx.Response, + ) -> None: + """Extract and store session ID from response headers.""" + new_session_id = response.headers.get(MCP_SESSION_ID) + if new_session_id: + self.session_id = new_session_id + logger.info(f"Received session ID: {self.session_id}") + + def _handle_sse_event( + self, + sse: ServerSentEvent, + server_to_client_queue: ServerToClientQueue, + original_request_id: RequestId | None = None, + resumption_callback: Callable[[str], None] | None = None, + ) -> bool: + """Handle an SSE event, returning True if the response is complete.""" + if sse.event == "message": + try: + message = JSONRPCMessage.model_validate_json(sse.data) + logger.debug(f"SSE message: {message}") + + # If this is a response and we have original_request_id, replace it + if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError): + message.root.id = original_request_id + + session_message = SessionMessage(message) + # Put message in queue that goes to client + server_to_client_queue.put(session_message) + + # Call resumption token callback if we have an ID + if sse.id and resumption_callback: + resumption_callback(sse.id) + + # If this is a response or error return True indicating completion + # Otherwise, return False to continue listening + return isinstance(message.root, JSONRPCResponse | JSONRPCError) + + except Exception as exc: + # Put exception in queue that goes to client + server_to_client_queue.put(exc) + return False + elif sse.event == "ping": + logger.debug("Received ping event") + return False + else: + logger.warning(f"Unknown SSE event: {sse.event}") + return False + + def handle_get_stream( + self, + client: httpx.Client, + server_to_client_queue: ServerToClientQueue, + ) -> None: + """Handle GET stream for server-initiated messages.""" + try: + if not self.session_id: + return + + headers = self._update_headers_with_session(self.request_headers) + + with ssrf_proxy_sse_connect( + self.url, + headers=headers, + timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds), + client=client, + method="GET", + ) as event_source: + event_source.response.raise_for_status() + logger.debug("GET SSE connection established") + + for sse in event_source.iter_sse(): + self._handle_sse_event(sse, server_to_client_queue) + + except Exception as exc: + logger.debug(f"GET stream error (non-fatal): {exc}") + + def _handle_resumption_request(self, ctx: RequestContext) -> None: + """Handle a resumption request using GET with SSE.""" + headers = self._update_headers_with_session(ctx.headers) + if ctx.metadata and ctx.metadata.resumption_token: + headers[LAST_EVENT_ID] = ctx.metadata.resumption_token + else: + raise ResumptionError("Resumption request requires a resumption token") + + # Extract original request ID to map responses + original_request_id = None + if isinstance(ctx.session_message.message.root, JSONRPCRequest): + original_request_id = ctx.session_message.message.root.id + + with ssrf_proxy_sse_connect( + self.url, + headers=headers, + timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds), + client=ctx.client, + method="GET", + ) as event_source: + event_source.response.raise_for_status() + logger.debug("Resumption GET SSE connection established") + + for sse in event_source.iter_sse(): + is_complete = self._handle_sse_event( + sse, + ctx.server_to_client_queue, + original_request_id, + ctx.metadata.on_resumption_token_update if ctx.metadata else None, + ) + if is_complete: + break + + def _handle_post_request(self, ctx: RequestContext) -> None: + """Handle a POST request with response processing.""" + headers = self._update_headers_with_session(ctx.headers) + message = ctx.session_message.message + is_initialization = self._is_initialization_request(message) + + with ctx.client.stream( + "POST", + self.url, + json=message.model_dump(by_alias=True, mode="json", exclude_none=True), + headers=headers, + ) as response: + if response.status_code == 202: + logger.debug("Received 202 Accepted") + return + + if response.status_code == 404: + if isinstance(message.root, JSONRPCRequest): + self._send_session_terminated_error( + ctx.server_to_client_queue, + message.root.id, + ) + return + + response.raise_for_status() + if is_initialization: + self._maybe_extract_session_id_from_response(response) + + content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower()) + + if content_type.startswith(JSON): + self._handle_json_response(response, ctx.server_to_client_queue) + elif content_type.startswith(SSE): + self._handle_sse_response(response, ctx) + else: + self._handle_unexpected_content_type( + content_type, + ctx.server_to_client_queue, + ) + + def _handle_json_response( + self, + response: httpx.Response, + server_to_client_queue: ServerToClientQueue, + ) -> None: + """Handle JSON response from the server.""" + try: + content = response.read() + message = JSONRPCMessage.model_validate_json(content) + session_message = SessionMessage(message) + server_to_client_queue.put(session_message) + except Exception as exc: + server_to_client_queue.put(exc) + + def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None: + """Handle SSE response from the server.""" + try: + event_source = EventSource(response) + for sse in event_source.iter_sse(): + is_complete = self._handle_sse_event( + sse, + ctx.server_to_client_queue, + resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), + ) + if is_complete: + break + except Exception as e: + ctx.server_to_client_queue.put(e) + + def _handle_unexpected_content_type( + self, + content_type: str, + server_to_client_queue: ServerToClientQueue, + ) -> None: + """Handle unexpected content type in response.""" + error_msg = f"Unexpected content type: {content_type}" + logger.error(error_msg) + server_to_client_queue.put(ValueError(error_msg)) + + def _send_session_terminated_error( + self, + server_to_client_queue: ServerToClientQueue, + request_id: RequestId, + ) -> None: + """Send a session terminated error response.""" + jsonrpc_error = JSONRPCError( + jsonrpc="2.0", + id=request_id, + error=ErrorData(code=32600, message="Session terminated by server"), + ) + session_message = SessionMessage(JSONRPCMessage(jsonrpc_error)) + server_to_client_queue.put(session_message) + + def post_writer( + self, + client: httpx.Client, + client_to_server_queue: ClientToServerQueue, + server_to_client_queue: ServerToClientQueue, + start_get_stream: Callable[[], None], + ) -> None: + """Handle writing requests to the server. + + This method processes messages from the client_to_server_queue and sends them to the server. + Responses are written to the server_to_client_queue. + """ + while True: + try: + # Read message from client queue with timeout to check stop_event periodically + session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT) + if session_message is None: + break + + message = session_message.message + metadata = ( + session_message.metadata if isinstance(session_message.metadata, ClientMessageMetadata) else None + ) + + # Check if this is a resumption request + is_resumption = bool(metadata and metadata.resumption_token) + + logger.debug(f"Sending client message: {message}") + + # Handle initialized notification + if self._is_initialized_notification(message): + start_get_stream() + + ctx = RequestContext( + client=client, + headers=self.request_headers, + session_id=self.session_id, + session_message=session_message, + metadata=metadata, + server_to_client_queue=server_to_client_queue, # Queue to write responses to client + sse_read_timeout=self.sse_read_timeout, + ) + + if is_resumption: + self._handle_resumption_request(ctx) + else: + self._handle_post_request(ctx) + except queue.Empty: + continue + except Exception as exc: + server_to_client_queue.put(exc) + + def terminate_session(self, client: httpx.Client) -> None: + """Terminate the session by sending a DELETE request.""" + if not self.session_id: + return + + try: + headers = self._update_headers_with_session(self.request_headers) + response = client.delete(self.url, headers=headers) + + if response.status_code == 405: + logger.debug("Server does not allow session termination") + elif response.status_code != 200: + logger.warning(f"Session termination failed: {response.status_code}") + except Exception as exc: + logger.warning(f"Session termination failed: {exc}") + + def get_session_id(self) -> str | None: + """Get the current session ID.""" + return self.session_id + + +@contextmanager +def streamablehttp_client( + url: str, + headers: dict[str, Any] | None = None, + timeout: timedelta = timedelta(seconds=30), + sse_read_timeout: timedelta = timedelta(seconds=60 * 5), + terminate_on_close: bool = True, +) -> Generator[ + tuple[ + ServerToClientQueue, # Queue for receiving messages FROM server + ClientToServerQueue, # Queue for sending messages TO server + GetSessionIdCallback, + ], + None, + None, +]: + """ + Client transport for StreamableHTTP. + + `sse_read_timeout` determines how long (in seconds) the client will wait for a new + event before disconnecting. All other HTTP operations are controlled by `timeout`. + + Yields: + Tuple containing: + - server_to_client_queue: Queue for reading messages FROM the server + - client_to_server_queue: Queue for sending messages TO the server + - get_session_id_callback: Function to retrieve the current session ID + """ + transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout) + + # Create queues with clear directional meaning + server_to_client_queue: ServerToClientQueue = queue.Queue() # For messages FROM server TO client + client_to_server_queue: ClientToServerQueue = queue.Queue() # For messages FROM client TO server + + with ThreadPoolExecutor(max_workers=2) as executor: + try: + with create_ssrf_proxy_mcp_http_client( + headers=transport.request_headers, + timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds), + ) as client: + # Define callbacks that need access to thread pool + def start_get_stream() -> None: + """Start a worker thread to handle server-initiated messages.""" + executor.submit(transport.handle_get_stream, client, server_to_client_queue) + + # Start the post_writer worker thread + executor.submit( + transport.post_writer, + client, + client_to_server_queue, # Queue for messages FROM client TO server + server_to_client_queue, # Queue for messages FROM server TO client + start_get_stream, + ) + + try: + yield ( + server_to_client_queue, # Queue for receiving messages FROM server + client_to_server_queue, # Queue for sending messages TO server + transport.get_session_id, + ) + finally: + if transport.session_id and terminate_on_close: + transport.terminate_session(client) + + # Signal threads to stop + client_to_server_queue.put(None) + finally: + # Clear any remaining items and add None sentinel to unblock any waiting threads + try: + while not client_to_server_queue.empty(): + client_to_server_queue.get_nowait() + except queue.Empty: + pass + + client_to_server_queue.put(None) + server_to_client_queue.put(None) diff --git a/api/core/mcp/entities.py b/api/core/mcp/entities.py new file mode 100644 index 0000000000..7553c10a2e --- /dev/null +++ b/api/core/mcp/entities.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass +from typing import Any, Generic, TypeVar + +from core.mcp.session.base_session import BaseSession +from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestId, RequestParams + +SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", LATEST_PROTOCOL_VERSION] + + +SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any]) +LifespanContextT = TypeVar("LifespanContextT") + + +@dataclass +class RequestContext(Generic[SessionT, LifespanContextT]): + request_id: RequestId + meta: RequestParams.Meta | None + session: SessionT + lifespan_context: LifespanContextT diff --git a/api/core/mcp/error.py b/api/core/mcp/error.py new file mode 100644 index 0000000000..92ea7bde09 --- /dev/null +++ b/api/core/mcp/error.py @@ -0,0 +1,10 @@ +class MCPError(Exception): + pass + + +class MCPConnectionError(MCPError): + pass + + +class MCPAuthError(MCPConnectionError): + pass diff --git a/api/core/mcp/mcp_client.py b/api/core/mcp/mcp_client.py new file mode 100644 index 0000000000..5fe52c008a --- /dev/null +++ b/api/core/mcp/mcp_client.py @@ -0,0 +1,153 @@ +import logging +from collections.abc import Callable +from contextlib import AbstractContextManager, ExitStack +from types import TracebackType +from typing import Any, Optional, cast +from urllib.parse import urlparse + +from core.mcp.client.sse_client import sse_client +from core.mcp.client.streamable_client import streamablehttp_client +from core.mcp.error import MCPAuthError, MCPConnectionError +from core.mcp.session.client_session import ClientSession +from core.mcp.types import Tool + +logger = logging.getLogger(__name__) + + +class MCPClient: + def __init__( + self, + server_url: str, + provider_id: str, + tenant_id: str, + authed: bool = True, + authorization_code: Optional[str] = None, + for_list: bool = False, + ): + # Initialize info + self.provider_id = provider_id + self.tenant_id = tenant_id + self.client_type = "streamable" + self.server_url = server_url + + # Authentication info + self.authed = authed + self.authorization_code = authorization_code + if authed: + from core.mcp.auth.auth_provider import OAuthClientProvider + + self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list) + self.token = self.provider.tokens() + + # Initialize session and client objects + self._session: Optional[ClientSession] = None + self._streams_context: Optional[AbstractContextManager[Any]] = None + self._session_context: Optional[ClientSession] = None + self.exit_stack = ExitStack() + + # Whether the client has been initialized + self._initialized = False + + def __enter__(self): + self._initialize() + self._initialized = True + return self + + def __exit__( + self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[TracebackType] + ): + self.cleanup() + + def _initialize( + self, + ): + """Initialize the client with fallback to SSE if streamable connection fails""" + connection_methods: dict[str, Callable[..., AbstractContextManager[Any]]] = { + "mcp": streamablehttp_client, + "sse": sse_client, + } + + parsed_url = urlparse(self.server_url) + path = parsed_url.path or "" + method_name = path.rstrip("/").split("/")[-1] if path else "" + if method_name in connection_methods: + client_factory = connection_methods[method_name] + self.connect_server(client_factory, method_name) + else: + try: + logger.debug(f"Not supported method {method_name} found in URL path, trying default 'mcp' method.") + self.connect_server(sse_client, "sse") + except MCPConnectionError: + logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.") + self.connect_server(streamablehttp_client, "mcp") + + def connect_server( + self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str, first_try: bool = True + ): + from core.mcp.auth.auth_flow import auth + + try: + headers = ( + {"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"} + if self.authed and self.token + else {} + ) + self._streams_context = client_factory(url=self.server_url, headers=headers) + if not self._streams_context: + raise MCPConnectionError("Failed to create connection context") + + # Use exit_stack to manage context managers properly + if method_name == "mcp": + read_stream, write_stream, _ = self.exit_stack.enter_context(self._streams_context) + streams = (read_stream, write_stream) + else: # sse_client + streams = self.exit_stack.enter_context(self._streams_context) + + self._session_context = ClientSession(*streams) + self._session = self.exit_stack.enter_context(self._session_context) + session = cast(ClientSession, self._session) + session.initialize() + return + + except MCPAuthError: + if not self.authed: + raise + try: + auth(self.provider, self.server_url, self.authorization_code) + except Exception as e: + raise ValueError(f"Failed to authenticate: {e}") + self.token = self.provider.tokens() + if first_try: + return self.connect_server(client_factory, method_name, first_try=False) + + except MCPConnectionError: + raise + + def list_tools(self) -> list[Tool]: + """Connect to an MCP server running with SSE transport""" + # List available tools to verify connection + if not self._initialized or not self._session: + raise ValueError("Session not initialized.") + response = self._session.list_tools() + tools = response.tools + return tools + + def invoke_tool(self, tool_name: str, tool_args: dict): + """Call a tool""" + if not self._initialized or not self._session: + raise ValueError("Session not initialized.") + return self._session.call_tool(tool_name, tool_args) + + def cleanup(self): + """Clean up resources""" + try: + # ExitStack will handle proper cleanup of all managed context managers + self.exit_stack.close() + except Exception as e: + logging.exception("Error during cleanup") + raise ValueError(f"Error during cleanup: {e}") + finally: + self._session = None + self._session_context = None + self._streams_context = None + self._initialized = False diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py new file mode 100644 index 0000000000..496b5432a0 --- /dev/null +++ b/api/core/mcp/server/streamable_http.py @@ -0,0 +1,226 @@ +import json +import logging +from collections.abc import Mapping +from typing import Any, cast + +from configs import dify_config +from controllers.web.passport import generate_session_id +from core.app.app_config.entities import VariableEntity, VariableEntityType +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.features.rate_limiting.rate_limit import RateLimitGenerator +from core.mcp import types +from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND +from core.mcp.utils import create_mcp_error_response +from core.model_runtime.utils.encoders import jsonable_encoder +from extensions.ext_database import db +from models.model import App, AppMCPServer, AppMode, EndUser +from services.app_generate_service import AppGenerateService + +""" +Apply to MCP HTTP streamable server with stateless http +""" +logger = logging.getLogger(__name__) + + +class MCPServerStreamableHTTPRequestHandler: + def __init__( + self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity] + ): + self.app = app + self.request = request + mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == self.app.id).first() + if not mcp_server: + raise ValueError("MCP server not found") + self.mcp_server: AppMCPServer = mcp_server + self.end_user = self.retrieve_end_user() + self.user_input_form = user_input_form + + @property + def request_type(self): + return type(self.request.root) + + @property + def parameter_schema(self): + parameters, required = self._convert_input_form_to_parameters(self.user_input_form) + if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}: + return { + "type": "object", + "properties": parameters, + "required": required, + } + return { + "type": "object", + "properties": { + "query": {"type": "string", "description": "User Input/Question content"}, + **parameters, + }, + "required": ["query", *required], + } + + @property + def capabilities(self): + return types.ServerCapabilities( + tools=types.ToolsCapability(listChanged=False), + ) + + def response(self, response: types.Result | str): + if isinstance(response, str): + sse_content = f"event: ping\ndata: {response}\n\n".encode() + yield sse_content + return + json_response = types.JSONRPCResponse( + jsonrpc="2.0", + id=(self.request.root.model_extra or {}).get("id", 1), + result=response.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + json_data = json.dumps(jsonable_encoder(json_response)) + + sse_content = f"event: message\ndata: {json_data}\n\n".encode() + + yield sse_content + + def error_response(self, code: int, message: str, data=None): + request_id = (self.request.root.model_extra or {}).get("id", 1) or 1 + return create_mcp_error_response(request_id, code, message, data) + + def handle(self): + handle_map = { + types.InitializeRequest: self.initialize, + types.ListToolsRequest: self.list_tools, + types.CallToolRequest: self.invoke_tool, + types.InitializedNotification: self.handle_notification, + types.PingRequest: self.handle_ping, + } + try: + if self.request_type in handle_map: + return self.response(handle_map[self.request_type]()) + else: + return self.error_response(METHOD_NOT_FOUND, f"Method not found: {self.request_type}") + except ValueError as e: + logger.exception("Invalid params") + return self.error_response(INVALID_PARAMS, str(e)) + except Exception as e: + logger.exception("Internal server error") + return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}") + + def handle_notification(self): + return "ping" + + def handle_ping(self): + return types.EmptyResult() + + def initialize(self): + request = cast(types.InitializeRequest, self.request.root) + client_info = request.params.clientInfo + client_name = f"{client_info.name}@{client_info.version}" + if not self.end_user: + end_user = EndUser( + tenant_id=self.app.tenant_id, + app_id=self.app.id, + type="mcp", + name=client_name, + session_id=generate_session_id(), + external_user_id=self.mcp_server.id, + ) + db.session.add(end_user) + db.session.commit() + return types.InitializeResult( + protocolVersion=types.SERVER_LATEST_PROTOCOL_VERSION, + capabilities=self.capabilities, + serverInfo=types.Implementation(name="Dify", version=dify_config.project.version), + instructions=self.mcp_server.description, + ) + + def list_tools(self): + if not self.end_user: + raise ValueError("User not found") + return types.ListToolsResult( + tools=[ + types.Tool( + name=self.app.name, + description=self.mcp_server.description, + inputSchema=self.parameter_schema, + ) + ], + ) + + def invoke_tool(self): + if not self.end_user: + raise ValueError("User not found") + request = cast(types.CallToolRequest, self.request.root) + args = request.params.arguments or {} + if self.app.mode in {AppMode.WORKFLOW.value}: + args = {"inputs": args} + elif self.app.mode in {AppMode.COMPLETION.value}: + args = {"query": "", "inputs": args} + else: + args = {"query": args["query"], "inputs": {k: v for k, v in args.items() if k != "query"}} + response = AppGenerateService.generate( + self.app, + self.end_user, + args, + InvokeFrom.SERVICE_API, + streaming=self.app.mode == AppMode.AGENT_CHAT.value, + ) + answer = "" + if isinstance(response, RateLimitGenerator): + for item in response.generator: + data = item + if isinstance(data, str) and data.startswith("data: "): + try: + json_str = data[6:].strip() + parsed_data = json.loads(json_str) + if parsed_data.get("event") == "agent_thought": + answer += parsed_data.get("thought", "") + except json.JSONDecodeError: + continue + if isinstance(response, Mapping): + if self.app.mode in { + AppMode.ADVANCED_CHAT.value, + AppMode.COMPLETION.value, + AppMode.CHAT.value, + AppMode.AGENT_CHAT.value, + }: + answer = response["answer"] + elif self.app.mode in {AppMode.WORKFLOW.value}: + answer = json.dumps(response["data"]["outputs"], ensure_ascii=False) + else: + raise ValueError("Invalid app mode") + # Not support image yet + return types.CallToolResult(content=[types.TextContent(text=answer, type="text")]) + + def retrieve_end_user(self): + return ( + db.session.query(EndUser) + .where(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp") + .first() + ) + + def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]): + parameters: dict[str, dict[str, Any]] = {} + required = [] + for item in user_input_form: + parameters[item.variable] = {} + if item.type in ( + VariableEntityType.FILE, + VariableEntityType.FILE_LIST, + VariableEntityType.EXTERNAL_DATA_TOOL, + ): + continue + if item.required: + required.append(item.variable) + # if the workflow republished, the parameters not changed + # we should not raise error here + try: + description = self.mcp_server.parameters_dict[item.variable] + except KeyError: + description = "" + parameters[item.variable]["description"] = description + if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): + parameters[item.variable]["type"] = "string" + elif item.type == VariableEntityType.SELECT: + parameters[item.variable]["type"] = "string" + parameters[item.variable]["enum"] = item.options + elif item.type == VariableEntityType.NUMBER: + parameters[item.variable]["type"] = "float" + return parameters, required diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py new file mode 100644 index 0000000000..7734b8fdd9 --- /dev/null +++ b/api/core/mcp/session/base_session.py @@ -0,0 +1,415 @@ +import logging +import queue +from collections.abc import Callable +from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError +from contextlib import ExitStack +from datetime import timedelta +from types import TracebackType +from typing import Any, Generic, Self, TypeVar + +from httpx import HTTPStatusError +from pydantic import BaseModel + +from core.mcp.error import MCPAuthError, MCPConnectionError +from core.mcp.types import ( + CancelledNotification, + ClientNotification, + ClientRequest, + ClientResult, + ErrorData, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + MessageMetadata, + RequestId, + RequestParams, + ServerMessageMetadata, + ServerNotification, + ServerRequest, + ServerResult, + SessionMessage, +) + +SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest) +SendResultT = TypeVar("SendResultT", ClientResult, ServerResult) +SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification) +ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest) +ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) +ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification) +DEFAULT_RESPONSE_READ_TIMEOUT = 1.0 + + +class RequestResponder(Generic[ReceiveRequestT, SendResultT]): + """Handles responding to MCP requests and manages request lifecycle. + + This class MUST be used as a context manager to ensure proper cleanup and + cancellation handling: + + Example: + with request_responder as resp: + resp.respond(result) + + The context manager ensures: + 1. Proper cancellation scope setup and cleanup + 2. Request completion tracking + 3. Cleanup of in-flight requests + """ + + request: ReceiveRequestT + _session: Any + _on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any] + + def __init__( + self, + request_id: RequestId, + request_meta: RequestParams.Meta | None, + request: ReceiveRequestT, + session: """BaseSession[ + SendRequestT, + SendNotificationT, + SendResultT, + ReceiveRequestT, + ReceiveNotificationT + ]""", + on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], + ) -> None: + self.request_id = request_id + self.request_meta = request_meta + self.request = request + self._session = session + self._completed = False + self._on_complete = on_complete + self._entered = False # Track if we're in a context manager + + def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]": + """Enter the context manager, enabling request cancellation tracking.""" + self._entered = True + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Exit the context manager, performing cleanup and notifying completion.""" + try: + if self._completed: + self._on_complete(self) + finally: + self._entered = False + + def respond(self, response: SendResultT | ErrorData) -> None: + """Send a response for this request. + + Must be called within a context manager block. + Raises: + RuntimeError: If not used within a context manager + AssertionError: If request was already responded to + """ + if not self._entered: + raise RuntimeError("RequestResponder must be used as a context manager") + assert not self._completed, "Request already responded to" + + self._completed = True + + self._session._send_response(request_id=self.request_id, response=response) + + def cancel(self) -> None: + """Cancel this request and mark it as completed.""" + if not self._entered: + raise RuntimeError("RequestResponder must be used as a context manager") + + self._completed = True # Mark as completed so it's removed from in_flight + # Send an error response to indicate cancellation + self._session._send_response( + request_id=self.request_id, + response=ErrorData(code=0, message="Request cancelled", data=None), + ) + + +class BaseSession( + Generic[ + SendRequestT, + SendNotificationT, + SendResultT, + ReceiveRequestT, + ReceiveNotificationT, + ], +): + """ + Implements an MCP "session" on top of read/write streams, including features + like request/response linking, notifications, and progress. + + This class is a context manager that automatically starts processing + messages when entered. + """ + + _response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError]] + _request_id: int + _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] + _receive_request_type: type[ReceiveRequestT] + _receive_notification_type: type[ReceiveNotificationT] + + def __init__( + self, + read_stream: queue.Queue, + write_stream: queue.Queue, + receive_request_type: type[ReceiveRequestT], + receive_notification_type: type[ReceiveNotificationT], + # If none, reading will never time out + read_timeout_seconds: timedelta | None = None, + ) -> None: + self._read_stream = read_stream + self._write_stream = write_stream + self._response_streams = {} + self._request_id = 0 + self._receive_request_type = receive_request_type + self._receive_notification_type = receive_notification_type + self._session_read_timeout_seconds = read_timeout_seconds + self._in_flight = {} + self._exit_stack = ExitStack() + # Initialize executor and future to None for proper cleanup checks + self._executor: ThreadPoolExecutor | None = None + self._receiver_future: Future | None = None + + def __enter__(self) -> Self: + # The thread pool is dedicated to running `_receive_loop`. Setting `max_workers` to 1 + # ensures no unnecessary threads are created. + self._executor = ThreadPoolExecutor(max_workers=1) + self._receiver_future = self._executor.submit(self._receive_loop) + return self + + def check_receiver_status(self) -> None: + """`check_receiver_status` ensures that any exceptions raised during the + execution of `_receive_loop` are retrieved and propagated.""" + if self._receiver_future and self._receiver_future.done(): + self._receiver_future.result() + + def __exit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None + ) -> None: + self._read_stream.put(None) + self._write_stream.put(None) + + # Wait for the receiver loop to finish + if self._receiver_future: + try: + self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds + except TimeoutError: + # If the receiver loop is still running after timeout, we'll force shutdown + pass + + # Shutdown the executor + if self._executor: + self._executor.shutdown(wait=True) + + def send_request( + self, + request: SendRequestT, + result_type: type[ReceiveResultT], + request_read_timeout_seconds: timedelta | None = None, + metadata: MessageMetadata = None, + ) -> ReceiveResultT: + """ + Sends a request and wait for a response. Raises an McpError if the + response contains an error. If a request read timeout is provided, it + will take precedence over the session read timeout. + + Do not use this method to emit notifications! Use send_notification() + instead. + """ + self.check_receiver_status() + + request_id = self._request_id + self._request_id = request_id + 1 + + response_queue: queue.Queue[JSONRPCResponse | JSONRPCError] = queue.Queue() + self._response_streams[request_id] = response_queue + + try: + jsonrpc_request = JSONRPCRequest( + jsonrpc="2.0", + id=request_id, + **request.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + + self._write_stream.put(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata)) + timeout = DEFAULT_RESPONSE_READ_TIMEOUT + if request_read_timeout_seconds is not None: + timeout = float(request_read_timeout_seconds.total_seconds()) + elif self._session_read_timeout_seconds is not None: + timeout = float(self._session_read_timeout_seconds.total_seconds()) + while True: + try: + response_or_error = response_queue.get(timeout=timeout) + break + except queue.Empty: + self.check_receiver_status() + continue + + if response_or_error is None: + raise MCPConnectionError( + ErrorData( + code=500, + message="No response received", + ) + ) + elif isinstance(response_or_error, JSONRPCError): + if response_or_error.error.code == 401: + raise MCPAuthError( + ErrorData(code=response_or_error.error.code, message=response_or_error.error.message) + ) + else: + raise MCPConnectionError( + ErrorData(code=response_or_error.error.code, message=response_or_error.error.message) + ) + else: + return result_type.model_validate(response_or_error.result) + + finally: + self._response_streams.pop(request_id, None) + + def send_notification( + self, + notification: SendNotificationT, + related_request_id: RequestId | None = None, + ) -> None: + """ + Emits a notification, which is a one-way message that does not expect + a response. + """ + self.check_receiver_status() + + # Some transport implementations may need to set the related_request_id + # to attribute to the notifications to the request that triggered them. + jsonrpc_notification = JSONRPCNotification( + jsonrpc="2.0", + **notification.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + session_message = SessionMessage( + message=JSONRPCMessage(jsonrpc_notification), + metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None, + ) + self._write_stream.put(session_message) + + def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None: + if isinstance(response, ErrorData): + jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) + session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error)) + self._write_stream.put(session_message) + else: + jsonrpc_response = JSONRPCResponse( + jsonrpc="2.0", + id=request_id, + result=response.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response)) + self._write_stream.put(session_message) + + def _receive_loop(self) -> None: + """ + Main message processing loop. + In a real synchronous implementation, this would likely run in a separate thread. + """ + while True: + try: + # Attempt to receive a message (this would be blocking in a synchronous context) + message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT) + if message is None: + break + if isinstance(message, HTTPStatusError): + response_queue = self._response_streams.get(self._request_id - 1) + if response_queue is not None: + response_queue.put( + JSONRPCError( + jsonrpc="2.0", + id=self._request_id - 1, + error=ErrorData(code=message.response.status_code, message=message.args[0]), + ) + ) + else: + self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}")) + elif isinstance(message, Exception): + self._handle_incoming(message) + elif isinstance(message.message.root, JSONRPCRequest): + validated_request = self._receive_request_type.model_validate( + message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + + responder = RequestResponder( + request_id=message.message.root.id, + request_meta=validated_request.root.params.meta if validated_request.root.params else None, + request=validated_request, + session=self, + on_complete=lambda r: self._in_flight.pop(r.request_id, None), + ) + + self._in_flight[responder.request_id] = responder + self._received_request(responder) + + if not responder._completed: + self._handle_incoming(responder) + + elif isinstance(message.message.root, JSONRPCNotification): + try: + notification = self._receive_notification_type.model_validate( + message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + # Handle cancellation notifications + if isinstance(notification.root, CancelledNotification): + cancelled_id = notification.root.params.requestId + if cancelled_id in self._in_flight: + self._in_flight[cancelled_id].cancel() + else: + self._received_notification(notification) + self._handle_incoming(notification) + except Exception as e: + # For other validation errors, log and continue + logging.warning(f"Failed to validate notification: {e}. Message was: {message.message.root}") + else: # Response or error + response_queue = self._response_streams.get(message.message.root.id) + if response_queue is not None: + response_queue.put(message.message.root) + else: + self._handle_incoming(RuntimeError(f"Server Error: {message}")) + except queue.Empty: + continue + except Exception as e: + logging.exception("Error in message processing loop") + raise + + def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: + """ + Can be overridden by subclasses to handle a request without needing to + listen on the message stream. + + If the request is responded to within this method, it will not be + forwarded on to the message stream. + """ + pass + + def _received_notification(self, notification: ReceiveNotificationT) -> None: + """ + Can be overridden by subclasses to handle a notification without needing + to listen on the message stream. + """ + pass + + def send_progress_notification( + self, progress_token: str | int, progress: float, total: float | None = None + ) -> None: + """ + Sends a progress notification for a request that is currently being + processed. + """ + pass + + def _handle_incoming( + self, + req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception, + ) -> None: + """A generic handler for incoming messages. Overwritten by subclasses.""" + pass diff --git a/api/core/mcp/session/client_session.py b/api/core/mcp/session/client_session.py new file mode 100644 index 0000000000..ed2ad508ab --- /dev/null +++ b/api/core/mcp/session/client_session.py @@ -0,0 +1,365 @@ +from datetime import timedelta +from typing import Any, Protocol + +from pydantic import AnyUrl, TypeAdapter + +from configs import dify_config +from core.mcp import types +from core.mcp.entities import SUPPORTED_PROTOCOL_VERSIONS, RequestContext +from core.mcp.session.base_session import BaseSession, RequestResponder + +DEFAULT_CLIENT_INFO = types.Implementation(name="Dify", version=dify_config.project.version) + + +class SamplingFnT(Protocol): + def __call__( + self, + context: RequestContext["ClientSession", Any], + params: types.CreateMessageRequestParams, + ) -> types.CreateMessageResult | types.ErrorData: ... + + +class ListRootsFnT(Protocol): + def __call__(self, context: RequestContext["ClientSession", Any]) -> types.ListRootsResult | types.ErrorData: ... + + +class LoggingFnT(Protocol): + def __call__( + self, + params: types.LoggingMessageNotificationParams, + ) -> None: ... + + +class MessageHandlerFnT(Protocol): + def __call__( + self, + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: ... + + +def _default_message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, +) -> None: + if isinstance(message, Exception): + raise ValueError(str(message)) + elif isinstance(message, (types.ServerNotification | RequestResponder)): + pass + + +def _default_sampling_callback( + context: RequestContext["ClientSession", Any], + params: types.CreateMessageRequestParams, +) -> types.CreateMessageResult | types.ErrorData: + return types.ErrorData( + code=types.INVALID_REQUEST, + message="Sampling not supported", + ) + + +def _default_list_roots_callback( + context: RequestContext["ClientSession", Any], +) -> types.ListRootsResult | types.ErrorData: + return types.ErrorData( + code=types.INVALID_REQUEST, + message="List roots not supported", + ) + + +def _default_logging_callback( + params: types.LoggingMessageNotificationParams, +) -> None: + pass + + +ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData) + + +class ClientSession( + BaseSession[ + types.ClientRequest, + types.ClientNotification, + types.ClientResult, + types.ServerRequest, + types.ServerNotification, + ] +): + def __init__( + self, + read_stream, + write_stream, + read_timeout_seconds: timedelta | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: types.Implementation | None = None, + ) -> None: + super().__init__( + read_stream, + write_stream, + types.ServerRequest, + types.ServerNotification, + read_timeout_seconds=read_timeout_seconds, + ) + self._client_info = client_info or DEFAULT_CLIENT_INFO + self._sampling_callback = sampling_callback or _default_sampling_callback + self._list_roots_callback = list_roots_callback or _default_list_roots_callback + self._logging_callback = logging_callback or _default_logging_callback + self._message_handler = message_handler or _default_message_handler + + def initialize(self) -> types.InitializeResult: + sampling = types.SamplingCapability() + roots = types.RootsCapability( + # TODO: Should this be based on whether we + # _will_ send notifications, or only whether + # they're supported? + listChanged=True, + ) + + result = self.send_request( + types.ClientRequest( + types.InitializeRequest( + method="initialize", + params=types.InitializeRequestParams( + protocolVersion=types.LATEST_PROTOCOL_VERSION, + capabilities=types.ClientCapabilities( + sampling=sampling, + experimental=None, + roots=roots, + ), + clientInfo=self._client_info, + ), + ) + ), + types.InitializeResult, + ) + + if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS: + raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}") + + self.send_notification( + types.ClientNotification(types.InitializedNotification(method="notifications/initialized")) + ) + + return result + + def send_ping(self) -> types.EmptyResult: + """Send a ping request.""" + return self.send_request( + types.ClientRequest( + types.PingRequest( + method="ping", + ) + ), + types.EmptyResult, + ) + + def send_progress_notification( + self, progress_token: str | int, progress: float, total: float | None = None + ) -> None: + """Send a progress notification.""" + self.send_notification( + types.ClientNotification( + types.ProgressNotification( + method="notifications/progress", + params=types.ProgressNotificationParams( + progressToken=progress_token, + progress=progress, + total=total, + ), + ), + ) + ) + + def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult: + """Send a logging/setLevel request.""" + return self.send_request( + types.ClientRequest( + types.SetLevelRequest( + method="logging/setLevel", + params=types.SetLevelRequestParams(level=level), + ) + ), + types.EmptyResult, + ) + + def list_resources(self) -> types.ListResourcesResult: + """Send a resources/list request.""" + return self.send_request( + types.ClientRequest( + types.ListResourcesRequest( + method="resources/list", + ) + ), + types.ListResourcesResult, + ) + + def list_resource_templates(self) -> types.ListResourceTemplatesResult: + """Send a resources/templates/list request.""" + return self.send_request( + types.ClientRequest( + types.ListResourceTemplatesRequest( + method="resources/templates/list", + ) + ), + types.ListResourceTemplatesResult, + ) + + def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: + """Send a resources/read request.""" + return self.send_request( + types.ClientRequest( + types.ReadResourceRequest( + method="resources/read", + params=types.ReadResourceRequestParams(uri=uri), + ) + ), + types.ReadResourceResult, + ) + + def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + """Send a resources/subscribe request.""" + return self.send_request( + types.ClientRequest( + types.SubscribeRequest( + method="resources/subscribe", + params=types.SubscribeRequestParams(uri=uri), + ) + ), + types.EmptyResult, + ) + + def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + """Send a resources/unsubscribe request.""" + return self.send_request( + types.ClientRequest( + types.UnsubscribeRequest( + method="resources/unsubscribe", + params=types.UnsubscribeRequestParams(uri=uri), + ) + ), + types.EmptyResult, + ) + + def call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: timedelta | None = None, + ) -> types.CallToolResult: + """Send a tools/call request.""" + + return self.send_request( + types.ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams(name=name, arguments=arguments), + ) + ), + types.CallToolResult, + request_read_timeout_seconds=read_timeout_seconds, + ) + + def list_prompts(self) -> types.ListPromptsResult: + """Send a prompts/list request.""" + return self.send_request( + types.ClientRequest( + types.ListPromptsRequest( + method="prompts/list", + ) + ), + types.ListPromptsResult, + ) + + def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult: + """Send a prompts/get request.""" + return self.send_request( + types.ClientRequest( + types.GetPromptRequest( + method="prompts/get", + params=types.GetPromptRequestParams(name=name, arguments=arguments), + ) + ), + types.GetPromptResult, + ) + + def complete( + self, + ref: types.ResourceReference | types.PromptReference, + argument: dict[str, str], + ) -> types.CompleteResult: + """Send a completion/complete request.""" + return self.send_request( + types.ClientRequest( + types.CompleteRequest( + method="completion/complete", + params=types.CompleteRequestParams( + ref=ref, + argument=types.CompletionArgument(**argument), + ), + ) + ), + types.CompleteResult, + ) + + def list_tools(self) -> types.ListToolsResult: + """Send a tools/list request.""" + return self.send_request( + types.ClientRequest( + types.ListToolsRequest( + method="tools/list", + ) + ), + types.ListToolsResult, + ) + + def send_roots_list_changed(self) -> None: + """Send a roots/list_changed notification.""" + self.send_notification( + types.ClientNotification( + types.RootsListChangedNotification( + method="notifications/roots/list_changed", + ) + ) + ) + + def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: + ctx = RequestContext[ClientSession, Any]( + request_id=responder.request_id, + meta=responder.request_meta, + session=self, + lifespan_context=None, + ) + + match responder.request.root: + case types.CreateMessageRequest(params=params): + with responder: + response = self._sampling_callback(ctx, params) + client_response = ClientResponse.validate_python(response) + responder.respond(client_response) + + case types.ListRootsRequest(): + with responder: + list_roots_response = self._list_roots_callback(ctx) + client_response = ClientResponse.validate_python(list_roots_response) + responder.respond(client_response) + + case types.PingRequest(): + with responder: + return responder.respond(types.ClientResult(root=types.EmptyResult())) + + def _handle_incoming( + self, + req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + """Handle incoming messages by forwarding to the message handler.""" + self._message_handler(req) + + def _received_notification(self, notification: types.ServerNotification) -> None: + """Handle notifications from the server.""" + # Process specific notification types + match notification.root: + case types.LoggingMessageNotification(params=params): + self._logging_callback(params) + case _: + pass diff --git a/api/core/mcp/types.py b/api/core/mcp/types.py new file mode 100644 index 0000000000..99d985a781 --- /dev/null +++ b/api/core/mcp/types.py @@ -0,0 +1,1217 @@ +from collections.abc import Callable +from dataclasses import dataclass +from typing import ( + Annotated, + Any, + Generic, + Literal, + Optional, + TypeAlias, + TypeVar, +) + +from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel +from pydantic.networks import AnyUrl, UrlConstraints + +""" +Model Context Protocol bindings for Python + +These bindings were generated from https://github.com/modelcontextprotocol/specification, +using Claude, with a prompt something like the following: + +Generate idiomatic Python bindings for this schema for MCP, or the "Model Context +Protocol." The schema is defined in TypeScript, but there's also a JSON Schema version +for reference. + +* For the bindings, let's use Pydantic V2 models. +* Each model should allow extra fields everywhere, by specifying `model_config = + ConfigDict(extra='allow')`. Do this in every case, instead of a custom base class. +* Union types should be represented with a Pydantic `RootModel`. +* Define additional model classes instead of using dictionaries. Do this even if they're + not separate types in the schema. +""" +# Client support both version, not support 2025-06-18 yet. +LATEST_PROTOCOL_VERSION = "2025-03-26" +# Server support 2024-11-05 to allow claude to use. +SERVER_LATEST_PROTOCOL_VERSION = "2024-11-05" +ProgressToken = str | int +Cursor = str +Role = Literal["user", "assistant"] +RequestId = Annotated[int | str, Field(union_mode="left_to_right")] +AnyFunction: TypeAlias = Callable[..., Any] + + +class RequestParams(BaseModel): + class Meta(BaseModel): + progressToken: ProgressToken | None = None + """ + If specified, the caller requests out-of-band progress notifications for + this request (as represented by notifications/progress). The value of this + parameter is an opaque token that will be attached to any subsequent + notifications. The receiver is not obligated to provide these notifications. + """ + + model_config = ConfigDict(extra="allow") + + meta: Meta | None = Field(alias="_meta", default=None) + + +class NotificationParams(BaseModel): + class Meta(BaseModel): + model_config = ConfigDict(extra="allow") + + meta: Meta | None = Field(alias="_meta", default=None) + """ + This parameter name is reserved by MCP to allow clients and servers to attach + additional metadata to their notifications. + """ + + +RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None) +NotificationParamsT = TypeVar("NotificationParamsT", bound=NotificationParams | dict[str, Any] | None) +MethodT = TypeVar("MethodT", bound=str) + + +class Request(BaseModel, Generic[RequestParamsT, MethodT]): + """Base class for JSON-RPC requests.""" + + method: MethodT + params: RequestParamsT + model_config = ConfigDict(extra="allow") + + +class PaginatedRequest(Request[RequestParamsT, MethodT]): + cursor: Cursor | None = None + """ + An opaque token representing the current pagination position. + If provided, the server should return results starting after this cursor. + """ + + +class Notification(BaseModel, Generic[NotificationParamsT, MethodT]): + """Base class for JSON-RPC notifications.""" + + method: MethodT + params: NotificationParamsT + model_config = ConfigDict(extra="allow") + + +class Result(BaseModel): + """Base class for JSON-RPC results.""" + + model_config = ConfigDict(extra="allow") + + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + This result property is reserved by the protocol to allow clients and servers to + attach additional metadata to their responses. + """ + + +class PaginatedResult(Result): + nextCursor: Cursor | None = None + """ + An opaque token representing the pagination position after the last returned result. + If present, there may be more results available. + """ + + +class JSONRPCRequest(Request[dict[str, Any] | None, str]): + """A request that expects a response.""" + + jsonrpc: Literal["2.0"] + id: RequestId + method: str + params: dict[str, Any] | None = None + + +class JSONRPCNotification(Notification[dict[str, Any] | None, str]): + """A notification which does not expect a response.""" + + jsonrpc: Literal["2.0"] + params: dict[str, Any] | None = None + + +class JSONRPCResponse(BaseModel): + """A successful (non-error) response to a request.""" + + jsonrpc: Literal["2.0"] + id: RequestId + result: dict[str, Any] + model_config = ConfigDict(extra="allow") + + +# Standard JSON-RPC error codes +PARSE_ERROR = -32700 +INVALID_REQUEST = -32600 +METHOD_NOT_FOUND = -32601 +INVALID_PARAMS = -32602 +INTERNAL_ERROR = -32603 + + +class ErrorData(BaseModel): + """Error information for JSON-RPC error responses.""" + + code: int + """The error type that occurred.""" + + message: str + """ + A short description of the error. The message SHOULD be limited to a concise single + sentence. + """ + + data: Any | None = None + """ + Additional information about the error. The value of this member is defined by the + sender (e.g. detailed error information, nested errors etc.). + """ + + model_config = ConfigDict(extra="allow") + + +class JSONRPCError(BaseModel): + """A response to a request that indicates an error occurred.""" + + jsonrpc: Literal["2.0"] + id: str | int + error: ErrorData + model_config = ConfigDict(extra="allow") + + +class JSONRPCMessage(RootModel[JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError]): + pass + + +class EmptyResult(Result): + """A response that indicates success but carries no data.""" + + +class Implementation(BaseModel): + """Describes the name and version of an MCP implementation.""" + + name: str + version: str + model_config = ConfigDict(extra="allow") + + +class RootsCapability(BaseModel): + """Capability for root operations.""" + + listChanged: bool | None = None + """Whether the client supports notifications for changes to the roots list.""" + model_config = ConfigDict(extra="allow") + + +class SamplingCapability(BaseModel): + """Capability for logging operations.""" + + model_config = ConfigDict(extra="allow") + + +class ClientCapabilities(BaseModel): + """Capabilities a client may support.""" + + experimental: dict[str, dict[str, Any]] | None = None + """Experimental, non-standard capabilities that the client supports.""" + sampling: SamplingCapability | None = None + """Present if the client supports sampling from an LLM.""" + roots: RootsCapability | None = None + """Present if the client supports listing roots.""" + model_config = ConfigDict(extra="allow") + + +class PromptsCapability(BaseModel): + """Capability for prompts operations.""" + + listChanged: bool | None = None + """Whether this server supports notifications for changes to the prompt list.""" + model_config = ConfigDict(extra="allow") + + +class ResourcesCapability(BaseModel): + """Capability for resources operations.""" + + subscribe: bool | None = None + """Whether this server supports subscribing to resource updates.""" + listChanged: bool | None = None + """Whether this server supports notifications for changes to the resource list.""" + model_config = ConfigDict(extra="allow") + + +class ToolsCapability(BaseModel): + """Capability for tools operations.""" + + listChanged: bool | None = None + """Whether this server supports notifications for changes to the tool list.""" + model_config = ConfigDict(extra="allow") + + +class LoggingCapability(BaseModel): + """Capability for logging operations.""" + + model_config = ConfigDict(extra="allow") + + +class ServerCapabilities(BaseModel): + """Capabilities that a server may support.""" + + experimental: dict[str, dict[str, Any]] | None = None + """Experimental, non-standard capabilities that the server supports.""" + logging: LoggingCapability | None = None + """Present if the server supports sending log messages to the client.""" + prompts: PromptsCapability | None = None + """Present if the server offers any prompt templates.""" + resources: ResourcesCapability | None = None + """Present if the server offers any resources to read.""" + tools: ToolsCapability | None = None + """Present if the server offers any tools to call.""" + model_config = ConfigDict(extra="allow") + + +class InitializeRequestParams(RequestParams): + """Parameters for the initialize request.""" + + protocolVersion: str | int + """The latest version of the Model Context Protocol that the client supports.""" + capabilities: ClientCapabilities + clientInfo: Implementation + model_config = ConfigDict(extra="allow") + + +class InitializeRequest(Request[InitializeRequestParams, Literal["initialize"]]): + """ + This request is sent from the client to the server when it first connects, asking it + to begin initialization. + """ + + method: Literal["initialize"] + params: InitializeRequestParams + + +class InitializeResult(Result): + """After receiving an initialize request from the client, the server sends this.""" + + protocolVersion: str | int + """The version of the Model Context Protocol that the server wants to use.""" + capabilities: ServerCapabilities + serverInfo: Implementation + instructions: str | None = None + """Instructions describing how to use the server and its features.""" + + +class InitializedNotification(Notification[NotificationParams | None, Literal["notifications/initialized"]]): + """ + This notification is sent from the client to the server after initialization has + finished. + """ + + method: Literal["notifications/initialized"] + params: NotificationParams | None = None + + +class PingRequest(Request[RequestParams | None, Literal["ping"]]): + """ + A ping, issued by either the server or the client, to check that the other party is + still alive. + """ + + method: Literal["ping"] + params: RequestParams | None = None + + +class ProgressNotificationParams(NotificationParams): + """Parameters for progress notifications.""" + + progressToken: ProgressToken + """ + The progress token which was given in the initial request, used to associate this + notification with the request that is proceeding. + """ + progress: float + """ + The progress thus far. This should increase every time progress is made, even if the + total is unknown. + """ + total: float | None = None + """Total number of items to process (or total progress required), if known.""" + model_config = ConfigDict(extra="allow") + + +class ProgressNotification(Notification[ProgressNotificationParams, Literal["notifications/progress"]]): + """ + An out-of-band notification used to inform the receiver of a progress update for a + long-running request. + """ + + method: Literal["notifications/progress"] + params: ProgressNotificationParams + + +class ListResourcesRequest(PaginatedRequest[RequestParams | None, Literal["resources/list"]]): + """Sent from the client to request a list of resources the server has.""" + + method: Literal["resources/list"] + params: RequestParams | None = None + + +class Annotations(BaseModel): + audience: list[Role] | None = None + priority: Annotated[float, Field(ge=0.0, le=1.0)] | None = None + model_config = ConfigDict(extra="allow") + + +class Resource(BaseModel): + """A known resource that the server is capable of reading.""" + + uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + """The URI of this resource.""" + name: str + """A human-readable name for this resource.""" + description: str | None = None + """A description of what this resource represents.""" + mimeType: str | None = None + """The MIME type of this resource, if known.""" + size: int | None = None + """ + The size of the raw resource content, in bytes (i.e., before base64 encoding + or any tokenization), if known. + + This can be used by Hosts to display file sizes and estimate context window usage. + """ + annotations: Annotations | None = None + model_config = ConfigDict(extra="allow") + + +class ResourceTemplate(BaseModel): + """A template description for resources available on the server.""" + + uriTemplate: str + """ + A URI template (according to RFC 6570) that can be used to construct resource + URIs. + """ + name: str + """A human-readable name for the type of resource this template refers to.""" + description: str | None = None + """A human-readable description of what this template is for.""" + mimeType: str | None = None + """ + The MIME type for all resources that match this template. This should only be + included if all resources matching this template have the same type. + """ + annotations: Annotations | None = None + model_config = ConfigDict(extra="allow") + + +class ListResourcesResult(PaginatedResult): + """The server's response to a resources/list request from the client.""" + + resources: list[Resource] + + +class ListResourceTemplatesRequest(PaginatedRequest[RequestParams | None, Literal["resources/templates/list"]]): + """Sent from the client to request a list of resource templates the server has.""" + + method: Literal["resources/templates/list"] + params: RequestParams | None = None + + +class ListResourceTemplatesResult(PaginatedResult): + """The server's response to a resources/templates/list request from the client.""" + + resourceTemplates: list[ResourceTemplate] + + +class ReadResourceRequestParams(RequestParams): + """Parameters for reading a resource.""" + + uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + """ + The URI of the resource to read. The URI can use any protocol; it is up to the + server how to interpret it. + """ + model_config = ConfigDict(extra="allow") + + +class ReadResourceRequest(Request[ReadResourceRequestParams, Literal["resources/read"]]): + """Sent from the client to the server, to read a specific resource URI.""" + + method: Literal["resources/read"] + params: ReadResourceRequestParams + + +class ResourceContents(BaseModel): + """The contents of a specific resource or sub-resource.""" + + uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + """The URI of this resource.""" + mimeType: str | None = None + """The MIME type of this resource, if known.""" + model_config = ConfigDict(extra="allow") + + +class TextResourceContents(ResourceContents): + """Text contents of a resource.""" + + text: str + """ + The text of the item. This must only be set if the item can actually be represented + as text (not binary data). + """ + + +class BlobResourceContents(ResourceContents): + """Binary contents of a resource.""" + + blob: str + """A base64-encoded string representing the binary data of the item.""" + + +class ReadResourceResult(Result): + """The server's response to a resources/read request from the client.""" + + contents: list[TextResourceContents | BlobResourceContents] + + +class ResourceListChangedNotification( + Notification[NotificationParams | None, Literal["notifications/resources/list_changed"]] +): + """ + An optional notification from the server to the client, informing it that the list + of resources it can read from has changed. + """ + + method: Literal["notifications/resources/list_changed"] + params: NotificationParams | None = None + + +class SubscribeRequestParams(RequestParams): + """Parameters for subscribing to a resource.""" + + uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + """ + The URI of the resource to subscribe to. The URI can use any protocol; it is up to + the server how to interpret it. + """ + model_config = ConfigDict(extra="allow") + + +class SubscribeRequest(Request[SubscribeRequestParams, Literal["resources/subscribe"]]): + """ + Sent from the client to request resources/updated notifications from the server + whenever a particular resource changes. + """ + + method: Literal["resources/subscribe"] + params: SubscribeRequestParams + + +class UnsubscribeRequestParams(RequestParams): + """Parameters for unsubscribing from a resource.""" + + uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + """The URI of the resource to unsubscribe from.""" + model_config = ConfigDict(extra="allow") + + +class UnsubscribeRequest(Request[UnsubscribeRequestParams, Literal["resources/unsubscribe"]]): + """ + Sent from the client to request cancellation of resources/updated notifications from + the server. + """ + + method: Literal["resources/unsubscribe"] + params: UnsubscribeRequestParams + + +class ResourceUpdatedNotificationParams(NotificationParams): + """Parameters for resource update notifications.""" + + uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + """ + The URI of the resource that has been updated. This might be a sub-resource of the + one that the client actually subscribed to. + """ + model_config = ConfigDict(extra="allow") + + +class ResourceUpdatedNotification( + Notification[ResourceUpdatedNotificationParams, Literal["notifications/resources/updated"]] +): + """ + A notification from the server to the client, informing it that a resource has + changed and may need to be read again. + """ + + method: Literal["notifications/resources/updated"] + params: ResourceUpdatedNotificationParams + + +class ListPromptsRequest(PaginatedRequest[RequestParams | None, Literal["prompts/list"]]): + """Sent from the client to request a list of prompts and prompt templates.""" + + method: Literal["prompts/list"] + params: RequestParams | None = None + + +class PromptArgument(BaseModel): + """An argument for a prompt template.""" + + name: str + """The name of the argument.""" + description: str | None = None + """A human-readable description of the argument.""" + required: bool | None = None + """Whether this argument must be provided.""" + model_config = ConfigDict(extra="allow") + + +class Prompt(BaseModel): + """A prompt or prompt template that the server offers.""" + + name: str + """The name of the prompt or prompt template.""" + description: str | None = None + """An optional description of what this prompt provides.""" + arguments: list[PromptArgument] | None = None + """A list of arguments to use for templating the prompt.""" + model_config = ConfigDict(extra="allow") + + +class ListPromptsResult(PaginatedResult): + """The server's response to a prompts/list request from the client.""" + + prompts: list[Prompt] + + +class GetPromptRequestParams(RequestParams): + """Parameters for getting a prompt.""" + + name: str + """The name of the prompt or prompt template.""" + arguments: dict[str, str] | None = None + """Arguments to use for templating the prompt.""" + model_config = ConfigDict(extra="allow") + + +class GetPromptRequest(Request[GetPromptRequestParams, Literal["prompts/get"]]): + """Used by the client to get a prompt provided by the server.""" + + method: Literal["prompts/get"] + params: GetPromptRequestParams + + +class TextContent(BaseModel): + """Text content for a message.""" + + type: Literal["text"] + text: str + """The text content of the message.""" + annotations: Annotations | None = None + model_config = ConfigDict(extra="allow") + + +class ImageContent(BaseModel): + """Image content for a message.""" + + type: Literal["image"] + data: str + """The base64-encoded image data.""" + mimeType: str + """ + The MIME type of the image. Different providers may support different + image types. + """ + annotations: Annotations | None = None + model_config = ConfigDict(extra="allow") + + +class SamplingMessage(BaseModel): + """Describes a message issued to or received from an LLM API.""" + + role: Role + content: TextContent | ImageContent + model_config = ConfigDict(extra="allow") + + +class EmbeddedResource(BaseModel): + """ + The contents of a resource, embedded into a prompt or tool call result. + + It is up to the client how best to render embedded resources for the benefit + of the LLM and/or the user. + """ + + type: Literal["resource"] + resource: TextResourceContents | BlobResourceContents + annotations: Annotations | None = None + model_config = ConfigDict(extra="allow") + + +class PromptMessage(BaseModel): + """Describes a message returned as part of a prompt.""" + + role: Role + content: TextContent | ImageContent | EmbeddedResource + model_config = ConfigDict(extra="allow") + + +class GetPromptResult(Result): + """The server's response to a prompts/get request from the client.""" + + description: str | None = None + """An optional description for the prompt.""" + messages: list[PromptMessage] + + +class PromptListChangedNotification( + Notification[NotificationParams | None, Literal["notifications/prompts/list_changed"]] +): + """ + An optional notification from the server to the client, informing it that the list + of prompts it offers has changed. + """ + + method: Literal["notifications/prompts/list_changed"] + params: NotificationParams | None = None + + +class ListToolsRequest(PaginatedRequest[RequestParams | None, Literal["tools/list"]]): + """Sent from the client to request a list of tools the server has.""" + + method: Literal["tools/list"] + params: RequestParams | None = None + + +class ToolAnnotations(BaseModel): + """ + Additional properties describing a Tool to clients. + + NOTE: all properties in ToolAnnotations are **hints**. + They are not guaranteed to provide a faithful description of + tool behavior (including descriptive properties like `title`). + + Clients should never make tool use decisions based on ToolAnnotations + received from untrusted servers. + """ + + title: str | None = None + """A human-readable title for the tool.""" + + readOnlyHint: bool | None = None + """ + If true, the tool does not modify its environment. + Default: false + """ + + destructiveHint: bool | None = None + """ + If true, the tool may perform destructive updates to its environment. + If false, the tool performs only additive updates. + (This property is meaningful only when `readOnlyHint == false`) + Default: true + """ + + idempotentHint: bool | None = None + """ + If true, calling the tool repeatedly with the same arguments + will have no additional effect on the its environment. + (This property is meaningful only when `readOnlyHint == false`) + Default: false + """ + + openWorldHint: bool | None = None + """ + If true, this tool may interact with an "open world" of external + entities. If false, the tool's domain of interaction is closed. + For example, the world of a web search tool is open, whereas that + of a memory tool is not. + Default: true + """ + model_config = ConfigDict(extra="allow") + + +class Tool(BaseModel): + """Definition for a tool the client can call.""" + + name: str + """The name of the tool.""" + description: str | None = None + """A human-readable description of the tool.""" + inputSchema: dict[str, Any] + """A JSON Schema object defining the expected parameters for the tool.""" + annotations: ToolAnnotations | None = None + """Optional additional tool information.""" + model_config = ConfigDict(extra="allow") + + +class ListToolsResult(PaginatedResult): + """The server's response to a tools/list request from the client.""" + + tools: list[Tool] + + +class CallToolRequestParams(RequestParams): + """Parameters for calling a tool.""" + + name: str + arguments: dict[str, Any] | None = None + model_config = ConfigDict(extra="allow") + + +class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]): + """Used by the client to invoke a tool provided by the server.""" + + method: Literal["tools/call"] + params: CallToolRequestParams + + +class CallToolResult(Result): + """The server's response to a tool call.""" + + content: list[TextContent | ImageContent | EmbeddedResource] + isError: bool = False + + +class ToolListChangedNotification(Notification[NotificationParams | None, Literal["notifications/tools/list_changed"]]): + """ + An optional notification from the server to the client, informing it that the list + of tools it offers has changed. + """ + + method: Literal["notifications/tools/list_changed"] + params: NotificationParams | None = None + + +LoggingLevel = Literal["debug", "info", "notice", "warning", "error", "critical", "alert", "emergency"] + + +class SetLevelRequestParams(RequestParams): + """Parameters for setting the logging level.""" + + level: LoggingLevel + """The level of logging that the client wants to receive from the server.""" + model_config = ConfigDict(extra="allow") + + +class SetLevelRequest(Request[SetLevelRequestParams, Literal["logging/setLevel"]]): + """A request from the client to the server, to enable or adjust logging.""" + + method: Literal["logging/setLevel"] + params: SetLevelRequestParams + + +class LoggingMessageNotificationParams(NotificationParams): + """Parameters for logging message notifications.""" + + level: LoggingLevel + """The severity of this log message.""" + logger: str | None = None + """An optional name of the logger issuing this message.""" + data: Any + """ + The data to be logged, such as a string message or an object. Any JSON serializable + type is allowed here. + """ + model_config = ConfigDict(extra="allow") + + +class LoggingMessageNotification(Notification[LoggingMessageNotificationParams, Literal["notifications/message"]]): + """Notification of a log message passed from server to client.""" + + method: Literal["notifications/message"] + params: LoggingMessageNotificationParams + + +IncludeContext = Literal["none", "thisServer", "allServers"] + + +class ModelHint(BaseModel): + """Hints to use for model selection.""" + + name: str | None = None + """A hint for a model name.""" + + model_config = ConfigDict(extra="allow") + + +class ModelPreferences(BaseModel): + """ + The server's preferences for model selection, requested by the client during + sampling. + + Because LLMs can vary along multiple dimensions, choosing the "best" model is + rarely straightforward. Different models excel in different areas—some are + faster but less capable, others are more capable but more expensive, and so + on. This interface allows servers to express their priorities across multiple + dimensions to help clients make an appropriate selection for their use case. + + These preferences are always advisory. The client MAY ignore them. It is also + up to the client to decide how to interpret these preferences and how to + balance them against other considerations. + """ + + hints: list[ModelHint] | None = None + """ + Optional hints to use for model selection. + + If multiple hints are specified, the client MUST evaluate them in order + (such that the first match is taken). + + The client SHOULD prioritize these hints over the numeric priorities, but + MAY still use the priorities to select from ambiguous matches. + """ + + costPriority: float | None = None + """ + How much to prioritize cost when selecting a model. A value of 0 means cost + is not important, while a value of 1 means cost is the most important + factor. + """ + + speedPriority: float | None = None + """ + How much to prioritize sampling speed (latency) when selecting a model. A + value of 0 means speed is not important, while a value of 1 means speed is + the most important factor. + """ + + intelligencePriority: float | None = None + """ + How much to prioritize intelligence and capabilities when selecting a + model. A value of 0 means intelligence is not important, while a value of 1 + means intelligence is the most important factor. + """ + + model_config = ConfigDict(extra="allow") + + +class CreateMessageRequestParams(RequestParams): + """Parameters for creating a message.""" + + messages: list[SamplingMessage] + modelPreferences: ModelPreferences | None = None + """ + The server's preferences for which model to select. The client MAY ignore + these preferences. + """ + systemPrompt: str | None = None + """An optional system prompt the server wants to use for sampling.""" + includeContext: IncludeContext | None = None + """ + A request to include context from one or more MCP servers (including the caller), to + be attached to the prompt. + """ + temperature: float | None = None + maxTokens: int + """The maximum number of tokens to sample, as requested by the server.""" + stopSequences: list[str] | None = None + metadata: dict[str, Any] | None = None + """Optional metadata to pass through to the LLM provider.""" + model_config = ConfigDict(extra="allow") + + +class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling/createMessage"]]): + """A request from the server to sample an LLM via the client.""" + + method: Literal["sampling/createMessage"] + params: CreateMessageRequestParams + + +StopReason = Literal["endTurn", "stopSequence", "maxTokens"] | str + + +class CreateMessageResult(Result): + """The client's response to a sampling/create_message request from the server.""" + + role: Role + content: TextContent | ImageContent + model: str + """The name of the model that generated the message.""" + stopReason: StopReason | None = None + """The reason why sampling stopped, if known.""" + + +class ResourceReference(BaseModel): + """A reference to a resource or resource template definition.""" + + type: Literal["ref/resource"] + uri: str + """The URI or URI template of the resource.""" + model_config = ConfigDict(extra="allow") + + +class PromptReference(BaseModel): + """Identifies a prompt.""" + + type: Literal["ref/prompt"] + name: str + """The name of the prompt or prompt template""" + model_config = ConfigDict(extra="allow") + + +class CompletionArgument(BaseModel): + """The argument's information for completion requests.""" + + name: str + """The name of the argument""" + value: str + """The value of the argument to use for completion matching.""" + model_config = ConfigDict(extra="allow") + + +class CompleteRequestParams(RequestParams): + """Parameters for completion requests.""" + + ref: ResourceReference | PromptReference + argument: CompletionArgument + model_config = ConfigDict(extra="allow") + + +class CompleteRequest(Request[CompleteRequestParams, Literal["completion/complete"]]): + """A request from the client to the server, to ask for completion options.""" + + method: Literal["completion/complete"] + params: CompleteRequestParams + + +class Completion(BaseModel): + """Completion information.""" + + values: list[str] + """An array of completion values. Must not exceed 100 items.""" + total: int | None = None + """ + The total number of completion options available. This can exceed the number of + values actually sent in the response. + """ + hasMore: bool | None = None + """ + Indicates whether there are additional completion options beyond those provided in + the current response, even if the exact total is unknown. + """ + model_config = ConfigDict(extra="allow") + + +class CompleteResult(Result): + """The server's response to a completion/complete request""" + + completion: Completion + + +class ListRootsRequest(Request[RequestParams | None, Literal["roots/list"]]): + """ + Sent from the server to request a list of root URIs from the client. Roots allow + servers to ask for specific directories or files to operate on. A common example + for roots is providing a set of repositories or directories a server should operate + on. + + This request is typically used when the server needs to understand the file system + structure or access specific locations that the client has permission to read from. + """ + + method: Literal["roots/list"] + params: RequestParams | None = None + + +class Root(BaseModel): + """Represents a root directory or file that the server can operate on.""" + + uri: FileUrl + """ + The URI identifying the root. This *must* start with file:// for now. + This restriction may be relaxed in future versions of the protocol to allow + other URI schemes. + """ + name: str | None = None + """ + An optional name for the root. This can be used to provide a human-readable + identifier for the root, which may be useful for display purposes or for + referencing the root in other parts of the application. + """ + model_config = ConfigDict(extra="allow") + + +class ListRootsResult(Result): + """ + The client's response to a roots/list request from the server. + This result contains an array of Root objects, each representing a root directory + or file that the server can operate on. + """ + + roots: list[Root] + + +class RootsListChangedNotification( + Notification[NotificationParams | None, Literal["notifications/roots/list_changed"]] +): + """ + A notification from the client to the server, informing it that the list of + roots has changed. + + This notification should be sent whenever the client adds, removes, or + modifies any root. The server should then request an updated list of roots + using the ListRootsRequest. + """ + + method: Literal["notifications/roots/list_changed"] + params: NotificationParams | None = None + + +class CancelledNotificationParams(NotificationParams): + """Parameters for cancellation notifications.""" + + requestId: RequestId + """The ID of the request to cancel.""" + reason: str | None = None + """An optional string describing the reason for the cancellation.""" + model_config = ConfigDict(extra="allow") + + +class CancelledNotification(Notification[CancelledNotificationParams, Literal["notifications/cancelled"]]): + """ + This notification can be sent by either side to indicate that it is canceling a + previously-issued request. + """ + + method: Literal["notifications/cancelled"] + params: CancelledNotificationParams + + +class ClientRequest( + RootModel[ + PingRequest + | InitializeRequest + | CompleteRequest + | SetLevelRequest + | GetPromptRequest + | ListPromptsRequest + | ListResourcesRequest + | ListResourceTemplatesRequest + | ReadResourceRequest + | SubscribeRequest + | UnsubscribeRequest + | CallToolRequest + | ListToolsRequest + ] +): + pass + + +class ClientNotification( + RootModel[CancelledNotification | ProgressNotification | InitializedNotification | RootsListChangedNotification] +): + pass + + +class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult]): + pass + + +class ServerRequest(RootModel[PingRequest | CreateMessageRequest | ListRootsRequest]): + pass + + +class ServerNotification( + RootModel[ + CancelledNotification + | ProgressNotification + | LoggingMessageNotification + | ResourceUpdatedNotification + | ResourceListChangedNotification + | ToolListChangedNotification + | PromptListChangedNotification + ] +): + pass + + +class ServerResult( + RootModel[ + EmptyResult + | InitializeResult + | CompleteResult + | GetPromptResult + | ListPromptsResult + | ListResourcesResult + | ListResourceTemplatesResult + | ReadResourceResult + | CallToolResult + | ListToolsResult + ] +): + pass + + +ResumptionToken = str + +ResumptionTokenUpdateCallback = Callable[[ResumptionToken], None] + + +@dataclass +class ClientMessageMetadata: + """Metadata specific to client messages.""" + + resumption_token: ResumptionToken | None = None + on_resumption_token_update: Callable[[ResumptionToken], None] | None = None + + +@dataclass +class ServerMessageMetadata: + """Metadata specific to server messages.""" + + related_request_id: RequestId | None = None + request_context: object | None = None + + +MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None + + +@dataclass +class SessionMessage: + """A message with specific metadata for transport-specific features.""" + + message: JSONRPCMessage + metadata: MessageMetadata = None + + +class OAuthClientMetadata(BaseModel): + client_name: str + redirect_uris: list[str] + grant_types: Optional[list[str]] = None + response_types: Optional[list[str]] = None + token_endpoint_auth_method: Optional[str] = None + client_uri: Optional[str] = None + scope: Optional[str] = None + + +class OAuthClientInformation(BaseModel): + client_id: str + client_secret: Optional[str] = None + + +class OAuthClientInformationFull(OAuthClientInformation): + client_name: str | None = None + redirect_uris: list[str] + scope: Optional[str] = None + grant_types: Optional[list[str]] = None + response_types: Optional[list[str]] = None + token_endpoint_auth_method: Optional[str] = None + + +class OAuthTokens(BaseModel): + access_token: str + token_type: str + expires_in: Optional[int] = None + refresh_token: Optional[str] = None + scope: Optional[str] = None + + +class OAuthMetadata(BaseModel): + authorization_endpoint: str + token_endpoint: str + registration_endpoint: Optional[str] = None + response_types_supported: list[str] + grant_types_supported: Optional[list[str]] = None + code_challenge_methods_supported: Optional[list[str]] = None diff --git a/api/core/mcp/utils.py b/api/core/mcp/utils.py new file mode 100644 index 0000000000..a54badcd4c --- /dev/null +++ b/api/core/mcp/utils.py @@ -0,0 +1,114 @@ +import json + +import httpx + +from configs import dify_config +from core.mcp.types import ErrorData, JSONRPCError +from core.model_runtime.utils.encoders import jsonable_encoder + +HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY + +STATUS_FORCELIST = [429, 500, 502, 503, 504] + + +def create_ssrf_proxy_mcp_http_client( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, +) -> httpx.Client: + """Create an HTTPX client with SSRF proxy configuration for MCP connections. + + Args: + headers: Optional headers to include in the client + timeout: Optional timeout configuration + + Returns: + Configured httpx.Client with proxy settings + """ + if dify_config.SSRF_PROXY_ALL_URL: + return httpx.Client( + verify=HTTP_REQUEST_NODE_SSL_VERIFY, + headers=headers or {}, + timeout=timeout, + follow_redirects=True, + proxy=dify_config.SSRF_PROXY_ALL_URL, + ) + elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL: + proxy_mounts = { + "http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY), + "https://": httpx.HTTPTransport( + proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY + ), + } + return httpx.Client( + verify=HTTP_REQUEST_NODE_SSL_VERIFY, + headers=headers or {}, + timeout=timeout, + follow_redirects=True, + mounts=proxy_mounts, + ) + else: + return httpx.Client( + verify=HTTP_REQUEST_NODE_SSL_VERIFY, + headers=headers or {}, + timeout=timeout, + follow_redirects=True, + ) + + +def ssrf_proxy_sse_connect(url, **kwargs): + """Connect to SSE endpoint with SSRF proxy protection. + + This function creates an SSE connection using the configured proxy settings + to prevent SSRF attacks when connecting to external endpoints. + + Args: + url: The SSE endpoint URL + **kwargs: Additional arguments passed to the SSE connection + + Returns: + EventSource object for SSE streaming + """ + from httpx_sse import connect_sse + + # Extract client if provided, otherwise create one + client = kwargs.pop("client", None) + if client is None: + # Create client with SSRF proxy configuration + timeout = kwargs.pop( + "timeout", + httpx.Timeout( + timeout=dify_config.SSRF_DEFAULT_TIME_OUT, + connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT, + read=dify_config.SSRF_DEFAULT_READ_TIME_OUT, + write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT, + ), + ) + headers = kwargs.pop("headers", {}) + client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout) + client_provided = False + else: + client_provided = True + + # Extract method if provided, default to GET + method = kwargs.pop("method", "GET") + + try: + return connect_sse(client, method, url, **kwargs) + except Exception: + # If we created the client, we need to clean it up on error + if not client_provided: + client.close() + raise + + +def create_mcp_error_response(request_id: int | str | None, code: int, message: str, data=None): + """Create MCP error response""" + error_data = ErrorData(code=code, message=message, data=data) + json_response = JSONRPCError( + jsonrpc="2.0", + id=request_id or 1, + error=error_data, + ) + json_data = json.dumps(jsonable_encoder(json_response)) + sse_content = f"event: message\ndata: {json_data}\n\n".encode() + yield sse_content diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 2254b3d4d5..7ce124594a 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,6 +1,8 @@ from collections.abc import Sequence from typing import Optional +from sqlalchemy import select + from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.file import file_manager from core.model_manager import ModelInstance @@ -17,11 +19,15 @@ from core.prompt.utils.extract_thread_messages import extract_thread_messages from extensions.ext_database import db from factories import file_factory from models.model import AppMode, Conversation, Message, MessageFile -from models.workflow import WorkflowRun +from models.workflow import Workflow, WorkflowRun class TokenBufferMemory: - def __init__(self, conversation: Conversation, model_instance: ModelInstance) -> None: + def __init__( + self, + conversation: Conversation, + model_instance: ModelInstance, + ) -> None: self.conversation = conversation self.model_instance = model_instance @@ -36,20 +42,8 @@ class TokenBufferMemory: app_record = self.conversation.app # fetch limited messages, and return reversed - query = ( - db.session.query( - Message.id, - Message.query, - Message.answer, - Message.created_at, - Message.workflow_run_id, - Message.parent_message_id, - Message.answer_tokens, - ) - .filter( - Message.conversation_id == self.conversation.id, - ) - .order_by(Message.created_at.desc()) + stmt = ( + select(Message).where(Message.conversation_id == self.conversation.id).order_by(Message.created_at.desc()) ) if message_limit and message_limit > 0: @@ -57,7 +51,9 @@ class TokenBufferMemory: else: message_limit = 500 - messages = query.limit(message_limit).all() + stmt = stmt.limit(message_limit) + + messages = db.session.scalars(stmt).all() # instead of all messages from the conversation, we only need to extract messages # that belong to the thread of last message @@ -71,21 +67,23 @@ class TokenBufferMemory: prompt_messages: list[PromptMessage] = [] for message in messages: - files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() + files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all() if files: file_extra_config = None - if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}: file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) + elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + workflow_run = db.session.scalar( + select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id) + ) + if not workflow_run: + raise ValueError(f"Workflow run not found: {message.workflow_run_id}") + workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id)) + if not workflow: + raise ValueError(f"Workflow not found: {workflow_run.workflow_id}") + file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) else: - if message.workflow_run_id: - workflow_run = ( - db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first() - ) - - if workflow_run and workflow_run.workflow: - file_extra_config = FileUploadConfigManager.convert( - workflow_run.workflow.features_dict, is_vision=False - ) + raise AssertionError(f"Invalid app mode: {self.conversation.mode}") detail = ImagePromptMessageContent.DETAIL.LOW if file_extra_config and app_record: diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 9d010ae28d..83dc7f0525 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -156,6 +156,23 @@ class PromptMessage(ABC, BaseModel): """ return not self.content + def get_text_content(self) -> str: + """ + Get text content from prompt message. + + :return: Text content as string, empty string if no text content + """ + if isinstance(self.content, str): + return self.content + elif isinstance(self.content, list): + text_parts = [] + for item in self.content: + if isinstance(item, TextPromptMessageContent): + text_parts.append(item.data) + return "".join(text_parts) + else: + return "" + @field_validator("content", mode="before") @classmethod def validate_content(cls, v): diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py index d0f9ee13e5..c9aa8d1474 100644 --- a/api/core/model_runtime/entities/provider_entities.py +++ b/api/core/model_runtime/entities/provider_entities.py @@ -123,6 +123,8 @@ class ProviderEntity(BaseModel): description: Optional[I18nObject] = None icon_small: Optional[I18nObject] = None icon_large: Optional[I18nObject] = None + icon_small_dark: Optional[I18nObject] = None + icon_large_dark: Optional[I18nObject] = None background: Optional[str] = None help: Optional[ProviderHelpEntity] = None supported_model_types: Sequence[ModelType] diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index c65a3885fd..332381555b 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -89,7 +89,7 @@ class ApiModeration(Moderation): def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]: extension = ( db.session.query(APIBasedExtension) - .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) + .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) .first() ) diff --git a/api/core/ops/aliyun_trace/__init__.py b/api/core/ops/aliyun_trace/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py new file mode 100644 index 0000000000..cf367efdf0 --- /dev/null +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -0,0 +1,489 @@ +import json +import logging +from collections.abc import Sequence +from typing import Optional +from urllib.parse import urljoin + +from opentelemetry.trace import Status, StatusCode +from sqlalchemy.orm import Session, sessionmaker + +from core.ops.aliyun_trace.data_exporter.traceclient import ( + TraceClient, + convert_datetime_to_nanoseconds, + convert_to_span_id, + convert_to_trace_id, + generate_span_id, +) +from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData +from core.ops.aliyun_trace.entities.semconv import ( + GEN_AI_COMPLETION, + GEN_AI_FRAMEWORK, + GEN_AI_MODEL_NAME, + GEN_AI_PROMPT, + GEN_AI_PROMPT_TEMPLATE_TEMPLATE, + GEN_AI_PROMPT_TEMPLATE_VARIABLE, + GEN_AI_RESPONSE_FINISH_REASON, + GEN_AI_SESSION_ID, + GEN_AI_SPAN_KIND, + GEN_AI_SYSTEM, + GEN_AI_USAGE_INPUT_TOKENS, + GEN_AI_USAGE_OUTPUT_TOKENS, + GEN_AI_USAGE_TOTAL_TOKENS, + GEN_AI_USER_ID, + INPUT_VALUE, + OUTPUT_VALUE, + RETRIEVAL_DOCUMENT, + RETRIEVAL_QUERY, + TOOL_DESCRIPTION, + TOOL_NAME, + TOOL_PARAMETERS, + GenAISpanKind, +) +from core.ops.base_trace_instance import BaseTraceInstance +from core.ops.entities.config_entity import AliyunConfig +from core.ops.entities.trace_entity import ( + BaseTraceInfo, + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) +from core.rag.models.document import Document +from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from core.workflow.nodes import NodeType +from models import Account, App, EndUser, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom, db + +logger = logging.getLogger(__name__) + + +class AliyunDataTrace(BaseTraceInstance): + def __init__( + self, + aliyun_config: AliyunConfig, + ): + super().__init__(aliyun_config) + base_url = aliyun_config.endpoint.rstrip("/") + endpoint = urljoin(base_url, f"adapt_{aliyun_config.license_key}/api/otlp/traces") + self.trace_client = TraceClient(service_name=aliyun_config.app_name, endpoint=endpoint) + + def trace(self, trace_info: BaseTraceInfo): + if isinstance(trace_info, WorkflowTraceInfo): + self.workflow_trace(trace_info) + if isinstance(trace_info, MessageTraceInfo): + self.message_trace(trace_info) + if isinstance(trace_info, ModerationTraceInfo): + pass + if isinstance(trace_info, SuggestedQuestionTraceInfo): + self.suggested_question_trace(trace_info) + if isinstance(trace_info, DatasetRetrievalTraceInfo): + self.dataset_retrieval_trace(trace_info) + if isinstance(trace_info, ToolTraceInfo): + self.tool_trace(trace_info) + if isinstance(trace_info, GenerateNameTraceInfo): + pass + + def api_check(self): + return self.trace_client.api_check() + + def get_project_url(self): + try: + return self.trace_client.get_project_url() + except Exception as e: + logger.info(f"Aliyun get run url failed: {str(e)}", exc_info=True) + raise ValueError(f"Aliyun get run url failed: {str(e)}") + + def workflow_trace(self, trace_info: WorkflowTraceInfo): + external_trace_id = trace_info.metadata.get("external_trace_id") + trace_id = external_trace_id or convert_to_trace_id(trace_info.workflow_run_id) + workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow") + self.add_workflow_span(trace_id, workflow_span_id, trace_info) + + workflow_node_executions = self.get_workflow_node_executions(trace_info) + for node_execution in workflow_node_executions: + node_span = self.build_workflow_node_span(node_execution, trace_id, trace_info, workflow_span_id) + self.trace_client.add_span(node_span) + + def message_trace(self, trace_info: MessageTraceInfo): + message_data = trace_info.message_data + if message_data is None: + return + message_id = trace_info.message_id + + user_id = message_data.from_account_id + if message_data.from_end_user_id: + end_user_data: Optional[EndUser] = ( + db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() + ) + if end_user_data is not None: + user_id = end_user_data.session_id + + status: Status = Status(StatusCode.OK) + if trace_info.error: + status = Status(StatusCode.ERROR, trace_info.error) + + trace_id = convert_to_trace_id(message_id) + message_span_id = convert_to_span_id(message_id, "message") + message_span = SpanData( + trace_id=trace_id, + parent_span_id=None, + span_id=message_span_id, + name="message", + start_time=convert_datetime_to_nanoseconds(trace_info.start_time), + end_time=convert_datetime_to_nanoseconds(trace_info.end_time), + attributes={ + GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""), + GEN_AI_USER_ID: str(user_id), + GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value, + GEN_AI_FRAMEWORK: "dify", + INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False), + OUTPUT_VALUE: str(trace_info.outputs), + }, + status=status, + ) + self.trace_client.add_span(message_span) + + app_model_config = getattr(trace_info.message_data, "app_model_config", {}) + pre_prompt = getattr(app_model_config, "pre_prompt", "") + inputs_data = getattr(trace_info.message_data, "inputs", {}) + llm_span = SpanData( + trace_id=trace_id, + parent_span_id=message_span_id, + span_id=convert_to_span_id(message_id, "llm"), + name="llm", + start_time=convert_datetime_to_nanoseconds(trace_info.start_time), + end_time=convert_datetime_to_nanoseconds(trace_info.end_time), + attributes={ + GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""), + GEN_AI_USER_ID: str(user_id), + GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value, + GEN_AI_FRAMEWORK: "dify", + GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name", ""), + GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider", ""), + GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens), + GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens), + GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens), + GEN_AI_PROMPT_TEMPLATE_VARIABLE: json.dumps(inputs_data, ensure_ascii=False), + GEN_AI_PROMPT_TEMPLATE_TEMPLATE: pre_prompt, + GEN_AI_PROMPT: json.dumps(trace_info.inputs, ensure_ascii=False), + GEN_AI_COMPLETION: str(trace_info.outputs), + INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False), + OUTPUT_VALUE: str(trace_info.outputs), + }, + status=status, + ) + self.trace_client.add_span(llm_span) + + def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo): + if trace_info.message_data is None: + return + message_id = trace_info.message_id + + documents_data = extract_retrieval_documents(trace_info.documents) + dataset_retrieval_span = SpanData( + trace_id=convert_to_trace_id(message_id), + parent_span_id=convert_to_span_id(message_id, "message"), + span_id=generate_span_id(), + name="dataset_retrieval", + start_time=convert_datetime_to_nanoseconds(trace_info.start_time), + end_time=convert_datetime_to_nanoseconds(trace_info.end_time), + attributes={ + GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value, + GEN_AI_FRAMEWORK: "dify", + RETRIEVAL_QUERY: str(trace_info.inputs), + RETRIEVAL_DOCUMENT: json.dumps(documents_data, ensure_ascii=False), + INPUT_VALUE: str(trace_info.inputs), + OUTPUT_VALUE: json.dumps(documents_data, ensure_ascii=False), + }, + ) + self.trace_client.add_span(dataset_retrieval_span) + + def tool_trace(self, trace_info: ToolTraceInfo): + if trace_info.message_data is None: + return + message_id = trace_info.message_id + + status: Status = Status(StatusCode.OK) + if trace_info.error: + status = Status(StatusCode.ERROR, trace_info.error) + + tool_span = SpanData( + trace_id=convert_to_trace_id(message_id), + parent_span_id=convert_to_span_id(message_id, "message"), + span_id=generate_span_id(), + name=trace_info.tool_name, + start_time=convert_datetime_to_nanoseconds(trace_info.start_time), + end_time=convert_datetime_to_nanoseconds(trace_info.end_time), + attributes={ + GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value, + GEN_AI_FRAMEWORK: "dify", + TOOL_NAME: trace_info.tool_name, + TOOL_DESCRIPTION: json.dumps(trace_info.tool_config, ensure_ascii=False), + TOOL_PARAMETERS: json.dumps(trace_info.tool_inputs, ensure_ascii=False), + INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False), + OUTPUT_VALUE: str(trace_info.tool_outputs), + }, + status=status, + ) + self.trace_client.add_span(tool_span) + + def get_workflow_node_executions(self, trace_info: WorkflowTraceInfo) -> Sequence[WorkflowNodeExecution]: + # through workflow_run_id get all_nodes_execution using repository + session_factory = sessionmaker(bind=db.engine) + # Find the app's creator account + with Session(db.engine, expire_on_commit=False) as session: + # Get the app to find its creator + app_id = trace_info.metadata.get("app_id") + if not app_id: + raise ValueError("No app_id found in trace_info metadata") + + app = session.query(App).where(App.id == app_id).first() + if not app: + raise ValueError(f"App with id {app_id} not found") + + if not app.created_by: + raise ValueError(f"App with id {app_id} has no creator (created_by is None)") + + service_account = session.query(Account).where(Account.id == app.created_by).first() + if not service_account: + raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") + current_tenant = ( + session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first() + ) + if not current_tenant: + raise ValueError(f"Current tenant not found for account {service_account.id}") + service_account.set_tenant_id(current_tenant.tenant_id) + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, + user=service_account, + app_id=trace_info.metadata.get("app_id"), + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + # Get all executions for this workflow run + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( + workflow_run_id=trace_info.workflow_run_id + ) + return workflow_node_executions + + def build_workflow_node_span( + self, node_execution: WorkflowNodeExecution, trace_id: int, trace_info: WorkflowTraceInfo, workflow_span_id: int + ): + try: + if node_execution.node_type == NodeType.LLM: + node_span = self.build_workflow_llm_span(trace_id, workflow_span_id, trace_info, node_execution) + elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL: + node_span = self.build_workflow_retrieval_span(trace_id, workflow_span_id, trace_info, node_execution) + elif node_execution.node_type == NodeType.TOOL: + node_span = self.build_workflow_tool_span(trace_id, workflow_span_id, trace_info, node_execution) + else: + node_span = self.build_workflow_task_span(trace_id, workflow_span_id, trace_info, node_execution) + return node_span + except Exception as e: + logging.debug(f"Error occurred in build_workflow_node_span: {e}", exc_info=True) + return None + + def get_workflow_node_status(self, node_execution: WorkflowNodeExecution) -> Status: + span_status: Status = Status(StatusCode.UNSET) + if node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED: + span_status = Status(StatusCode.OK) + elif node_execution.status in [WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.EXCEPTION]: + span_status = Status(StatusCode.ERROR, str(node_execution.error)) + return span_status + + def build_workflow_task_span( + self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution + ) -> SpanData: + return SpanData( + trace_id=trace_id, + parent_span_id=workflow_span_id, + span_id=convert_to_span_id(node_execution.id, "node"), + name=node_execution.title, + start_time=convert_datetime_to_nanoseconds(node_execution.created_at), + end_time=convert_datetime_to_nanoseconds(node_execution.finished_at), + attributes={ + GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "", + GEN_AI_SPAN_KIND: GenAISpanKind.TASK.value, + GEN_AI_FRAMEWORK: "dify", + INPUT_VALUE: json.dumps(node_execution.inputs, ensure_ascii=False), + OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False), + }, + status=self.get_workflow_node_status(node_execution), + ) + + def build_workflow_tool_span( + self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution + ) -> SpanData: + tool_des = {} + if node_execution.metadata: + tool_des = node_execution.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {}) + return SpanData( + trace_id=trace_id, + parent_span_id=workflow_span_id, + span_id=convert_to_span_id(node_execution.id, "node"), + name=node_execution.title, + start_time=convert_datetime_to_nanoseconds(node_execution.created_at), + end_time=convert_datetime_to_nanoseconds(node_execution.finished_at), + attributes={ + GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value, + GEN_AI_FRAMEWORK: "dify", + TOOL_NAME: node_execution.title, + TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False), + TOOL_PARAMETERS: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False), + INPUT_VALUE: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False), + OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False), + }, + status=self.get_workflow_node_status(node_execution), + ) + + def build_workflow_retrieval_span( + self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution + ) -> SpanData: + input_value = "" + if node_execution.inputs: + input_value = str(node_execution.inputs.get("query", "")) + output_value = "" + if node_execution.outputs: + output_value = json.dumps(node_execution.outputs.get("result", []), ensure_ascii=False) + return SpanData( + trace_id=trace_id, + parent_span_id=workflow_span_id, + span_id=convert_to_span_id(node_execution.id, "node"), + name=node_execution.title, + start_time=convert_datetime_to_nanoseconds(node_execution.created_at), + end_time=convert_datetime_to_nanoseconds(node_execution.finished_at), + attributes={ + GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value, + GEN_AI_FRAMEWORK: "dify", + RETRIEVAL_QUERY: input_value, + RETRIEVAL_DOCUMENT: output_value, + INPUT_VALUE: input_value, + OUTPUT_VALUE: output_value, + }, + status=self.get_workflow_node_status(node_execution), + ) + + def build_workflow_llm_span( + self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution + ) -> SpanData: + process_data = node_execution.process_data or {} + outputs = node_execution.outputs or {} + usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) + return SpanData( + trace_id=trace_id, + parent_span_id=workflow_span_id, + span_id=convert_to_span_id(node_execution.id, "node"), + name=node_execution.title, + start_time=convert_datetime_to_nanoseconds(node_execution.created_at), + end_time=convert_datetime_to_nanoseconds(node_execution.finished_at), + attributes={ + GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "", + GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value, + GEN_AI_FRAMEWORK: "dify", + GEN_AI_MODEL_NAME: process_data.get("model_name", ""), + GEN_AI_SYSTEM: process_data.get("model_provider", ""), + GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)), + GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)), + GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)), + GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False), + GEN_AI_COMPLETION: str(outputs.get("text", "")), + GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason", ""), + INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False), + OUTPUT_VALUE: str(outputs.get("text", "")), + }, + status=self.get_workflow_node_status(node_execution), + ) + + def add_workflow_span(self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo): + message_span_id = None + if trace_info.message_id: + message_span_id = convert_to_span_id(trace_info.message_id, "message") + user_id = trace_info.metadata.get("user_id") + status: Status = Status(StatusCode.OK) + if trace_info.error: + status = Status(StatusCode.ERROR, trace_info.error) + if message_span_id: # chatflow + message_span = SpanData( + trace_id=trace_id, + parent_span_id=None, + span_id=message_span_id, + name="message", + start_time=convert_datetime_to_nanoseconds(trace_info.start_time), + end_time=convert_datetime_to_nanoseconds(trace_info.end_time), + attributes={ + GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "", + GEN_AI_USER_ID: str(user_id), + GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value, + GEN_AI_FRAMEWORK: "dify", + INPUT_VALUE: trace_info.workflow_run_inputs.get("sys.query", ""), + OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False), + }, + status=status, + ) + self.trace_client.add_span(message_span) + + workflow_span = SpanData( + trace_id=trace_id, + parent_span_id=message_span_id, + span_id=workflow_span_id, + name="workflow", + start_time=convert_datetime_to_nanoseconds(trace_info.start_time), + end_time=convert_datetime_to_nanoseconds(trace_info.end_time), + attributes={ + GEN_AI_USER_ID: str(user_id), + GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value, + GEN_AI_FRAMEWORK: "dify", + INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False), + OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False), + }, + status=status, + ) + self.trace_client.add_span(workflow_span) + + def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo): + message_id = trace_info.message_id + status: Status = Status(StatusCode.OK) + if trace_info.error: + status = Status(StatusCode.ERROR, trace_info.error) + suggested_question_span = SpanData( + trace_id=convert_to_trace_id(message_id), + parent_span_id=convert_to_span_id(message_id, "message"), + span_id=convert_to_span_id(message_id, "suggested_question"), + name="suggested_question", + start_time=convert_datetime_to_nanoseconds(trace_info.start_time), + end_time=convert_datetime_to_nanoseconds(trace_info.end_time), + attributes={ + GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value, + GEN_AI_FRAMEWORK: "dify", + GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name", ""), + GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider", ""), + GEN_AI_PROMPT: json.dumps(trace_info.inputs, ensure_ascii=False), + GEN_AI_COMPLETION: json.dumps(trace_info.suggested_question, ensure_ascii=False), + INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False), + OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False), + }, + status=status, + ) + self.trace_client.add_span(suggested_question_span) + + +def extract_retrieval_documents(documents: list[Document]): + documents_data = [] + for document in documents: + document_data = { + "content": document.page_content, + "metadata": { + "dataset_id": document.metadata.get("dataset_id"), + "doc_id": document.metadata.get("doc_id"), + "document_id": document.metadata.get("document_id"), + }, + "score": document.metadata.get("score"), + } + documents_data.append(document_data) + return documents_data diff --git a/api/core/ops/aliyun_trace/data_exporter/__init__.py b/api/core/ops/aliyun_trace/data_exporter/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/core/ops/aliyun_trace/data_exporter/traceclient.py new file mode 100644 index 0000000000..ba5ac3f420 --- /dev/null +++ b/api/core/ops/aliyun_trace/data_exporter/traceclient.py @@ -0,0 +1,200 @@ +import hashlib +import logging +import random +import socket +import threading +import uuid +from collections import deque +from collections.abc import Sequence +from datetime import datetime +from typing import Optional + +import requests +from opentelemetry import trace as trace_api +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.sdk.util.instrumentation import InstrumentationScope +from opentelemetry.semconv.resource import ResourceAttributes + +from configs import dify_config +from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData + +INVALID_SPAN_ID = 0x0000000000000000 +INVALID_TRACE_ID = 0x00000000000000000000000000000000 + +logger = logging.getLogger(__name__) + + +class TraceClient: + def __init__( + self, + service_name: str, + endpoint: str, + max_queue_size: int = 1000, + schedule_delay_sec: int = 5, + max_export_batch_size: int = 50, + ): + self.endpoint = endpoint + self.resource = Resource( + attributes={ + ResourceAttributes.SERVICE_NAME: service_name, + ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}", + ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}", + ResourceAttributes.HOST_NAME: socket.gethostname(), + } + ) + self.span_builder = SpanBuilder(self.resource) + self.exporter = OTLPSpanExporter(endpoint=endpoint) + + self.max_queue_size = max_queue_size + self.schedule_delay_sec = schedule_delay_sec + self.max_export_batch_size = max_export_batch_size + + self.queue: deque = deque(maxlen=max_queue_size) + self.condition = threading.Condition(threading.Lock()) + self.done = False + + self.worker_thread = threading.Thread(target=self._worker, daemon=True) + self.worker_thread.start() + + self._spans_dropped = False + + def export(self, spans: Sequence[ReadableSpan]): + self.exporter.export(spans) + + def api_check(self): + try: + response = requests.head(self.endpoint, timeout=5) + if response.status_code == 405: + return True + else: + logger.debug(f"AliyunTrace API check failed: Unexpected status code: {response.status_code}") + return False + except requests.exceptions.RequestException as e: + logger.debug(f"AliyunTrace API check failed: {str(e)}") + raise ValueError(f"AliyunTrace API check failed: {str(e)}") + + def get_project_url(self): + return "https://arms.console.aliyun.com/#/llm" + + def add_span(self, span_data: SpanData): + if span_data is None: + return + span: ReadableSpan = self.span_builder.build_span(span_data) + with self.condition: + if len(self.queue) == self.max_queue_size: + if not self._spans_dropped: + logger.warning("Queue is full, likely spans will be dropped.") + self._spans_dropped = True + + self.queue.appendleft(span) + if len(self.queue) >= self.max_export_batch_size: + self.condition.notify() + + def _worker(self): + while not self.done: + with self.condition: + if len(self.queue) < self.max_export_batch_size and not self.done: + self.condition.wait(timeout=self.schedule_delay_sec) + self._export_batch() + + def _export_batch(self): + spans_to_export: list[ReadableSpan] = [] + with self.condition: + while len(spans_to_export) < self.max_export_batch_size and self.queue: + spans_to_export.append(self.queue.pop()) + + if spans_to_export: + try: + self.exporter.export(spans_to_export) + except Exception as e: + logger.debug(f"Error exporting spans: {e}") + + def shutdown(self): + with self.condition: + self.done = True + self.condition.notify_all() + self.worker_thread.join() + self._export_batch() + self.exporter.shutdown() + + +class SpanBuilder: + def __init__(self, resource): + self.resource = resource + self.instrumentation_scope = InstrumentationScope( + __name__, + "", + None, + None, + ) + + def build_span(self, span_data: SpanData) -> ReadableSpan: + span_context = trace_api.SpanContext( + trace_id=span_data.trace_id, + span_id=span_data.span_id, + is_remote=False, + trace_flags=trace_api.TraceFlags(trace_api.TraceFlags.SAMPLED), + trace_state=None, + ) + + parent_span_context = None + if span_data.parent_span_id is not None: + parent_span_context = trace_api.SpanContext( + trace_id=span_data.trace_id, + span_id=span_data.parent_span_id, + is_remote=False, + trace_flags=trace_api.TraceFlags(trace_api.TraceFlags.SAMPLED), + trace_state=None, + ) + + span = ReadableSpan( + name=span_data.name, + context=span_context, + parent=parent_span_context, + resource=self.resource, + attributes=span_data.attributes, + events=span_data.events, + links=span_data.links, + kind=trace_api.SpanKind.INTERNAL, + status=span_data.status, + start_time=span_data.start_time, + end_time=span_data.end_time, + instrumentation_scope=self.instrumentation_scope, + ) + return span + + +def generate_span_id() -> int: + span_id = random.getrandbits(64) + while span_id == INVALID_SPAN_ID: + span_id = random.getrandbits(64) + return span_id + + +def convert_to_trace_id(uuid_v4: Optional[str]) -> int: + try: + uuid_obj = uuid.UUID(uuid_v4) + return uuid_obj.int + except Exception as e: + raise ValueError(f"Invalid UUID input: {e}") + + +def convert_to_span_id(uuid_v4: Optional[str], span_type: str) -> int: + try: + uuid_obj = uuid.UUID(uuid_v4) + except Exception as e: + raise ValueError(f"Invalid UUID input: {e}") + combined_key = f"{uuid_obj.hex}-{span_type}" + hash_bytes = hashlib.sha256(combined_key.encode("utf-8")).digest() + span_id = int.from_bytes(hash_bytes[:8], byteorder="big", signed=False) + return span_id + + +def convert_datetime_to_nanoseconds(start_time_a: Optional[datetime]) -> Optional[int]: + if start_time_a is None: + return None + timestamp_in_seconds = start_time_a.timestamp() + timestamp_in_nanoseconds = int(timestamp_in_seconds * 1e9) + return timestamp_in_nanoseconds diff --git a/api/core/ops/aliyun_trace/entities/__init__.py b/api/core/ops/aliyun_trace/entities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py b/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py new file mode 100644 index 0000000000..1caa822cd0 --- /dev/null +++ b/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py @@ -0,0 +1,21 @@ +from collections.abc import Sequence +from typing import Optional + +from opentelemetry import trace as trace_api +from opentelemetry.sdk.trace import Event, Status, StatusCode +from pydantic import BaseModel, Field + + +class SpanData(BaseModel): + model_config = {"arbitrary_types_allowed": True} + + trace_id: int = Field(..., description="The unique identifier for the trace.") + parent_span_id: Optional[int] = Field(None, description="The ID of the parent span, if any.") + span_id: int = Field(..., description="The unique identifier for this span.") + name: str = Field(..., description="The name of the span.") + attributes: dict[str, str] = Field(default_factory=dict, description="Attributes associated with the span.") + events: Sequence[Event] = Field(default_factory=list, description="Events recorded in the span.") + links: Sequence[trace_api.Link] = Field(default_factory=list, description="Links to other spans.") + status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.") + start_time: Optional[int] = Field(..., description="The start time of the span in nanoseconds.") + end_time: Optional[int] = Field(..., description="The end time of the span in nanoseconds.") diff --git a/api/core/ops/aliyun_trace/entities/semconv.py b/api/core/ops/aliyun_trace/entities/semconv.py new file mode 100644 index 0000000000..5d70264320 --- /dev/null +++ b/api/core/ops/aliyun_trace/entities/semconv.py @@ -0,0 +1,64 @@ +from enum import Enum + +# public +GEN_AI_SESSION_ID = "gen_ai.session.id" + +GEN_AI_USER_ID = "gen_ai.user.id" + +GEN_AI_USER_NAME = "gen_ai.user.name" + +GEN_AI_SPAN_KIND = "gen_ai.span.kind" + +GEN_AI_FRAMEWORK = "gen_ai.framework" + + +# Chain +INPUT_VALUE = "input.value" + +OUTPUT_VALUE = "output.value" + + +# Retriever +RETRIEVAL_QUERY = "retrieval.query" + +RETRIEVAL_DOCUMENT = "retrieval.document" + + +# LLM +GEN_AI_MODEL_NAME = "gen_ai.model_name" + +GEN_AI_SYSTEM = "gen_ai.system" + +GEN_AI_USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens" + +GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens" + +GEN_AI_USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens" + +GEN_AI_PROMPT_TEMPLATE_TEMPLATE = "gen_ai.prompt_template.template" + +GEN_AI_PROMPT_TEMPLATE_VARIABLE = "gen_ai.prompt_template.variable" + +GEN_AI_PROMPT = "gen_ai.prompt" + +GEN_AI_COMPLETION = "gen_ai.completion" + +GEN_AI_RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason" + +# Tool +TOOL_NAME = "tool.name" + +TOOL_DESCRIPTION = "tool.description" + +TOOL_PARAMETERS = "tool.parameters" + + +class GenAISpanKind(Enum): + CHAIN = "CHAIN" + RETRIEVER = "RETRIEVER" + RERANKER = "RERANKER" + LLM = "LLM" + EMBEDDING = "EMBEDDING" + TOOL = "TOOL" + AGENT = "AGENT" + TASK = "TASK" diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index 0b6834acf3..1b72a4775a 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -3,7 +3,7 @@ import json import logging import os from datetime import datetime, timedelta -from typing import Optional, Union, cast +from typing import Any, Optional, Union, cast from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes from opentelemetry import trace @@ -142,11 +142,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): raise def workflow_trace(self, trace_info: WorkflowTraceInfo): - if trace_info.message_data is None: - return - workflow_metadata = { - "workflow_id": trace_info.workflow_run_id or "", + "workflow_run_id": trace_info.workflow_run_id or "", "message_id": trace_info.message_id or "", "workflow_app_log_id": trace_info.workflow_app_log_id or "", "status": trace_info.workflow_run_status or "", @@ -156,7 +153,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): } workflow_metadata.update(trace_info.metadata) - trace_id = uuid_to_trace_id(trace_info.message_id) + external_trace_id = trace_info.metadata.get("external_trace_id") + trace_id = external_trace_id or uuid_to_trace_id(trace_info.workflow_run_id) span_id = RandomIdGenerator().generate_span_id() context = SpanContext( trace_id=trace_id, @@ -213,11 +211,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance): if model: node_metadata["ls_model_name"] = model - usage = json.loads(node_execution.outputs).get("usage", {}) if node_execution.outputs else {} - if usage: - node_metadata["total_tokens"] = usage.get("total_tokens", 0) - node_metadata["prompt_tokens"] = usage.get("prompt_tokens", 0) - node_metadata["completion_tokens"] = usage.get("completion_tokens", 0) + outputs = json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {} + usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) + if usage_data: + node_metadata["total_tokens"] = usage_data.get("total_tokens", 0) + node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0) + node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0) elif node_execution.node_type == "dataset_retrieval": span_kind = OpenInferenceSpanKindValues.RETRIEVER.value elif node_execution.node_type == "tool": @@ -235,26 +234,34 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.SESSION_ID: trace_info.conversation_id or "", }, start_time=datetime_to_nanos(created_at), + context=trace.set_span_in_context(trace.NonRecordingSpan(context)), ) try: if node_execution.node_type == "llm": + llm_attributes: dict[str, Any] = { + SpanAttributes.INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False), + } provider = process_data.get("model_provider") model = process_data.get("model_name") if provider: - node_span.set_attribute(SpanAttributes.LLM_PROVIDER, provider) + llm_attributes[SpanAttributes.LLM_PROVIDER] = provider if model: - node_span.set_attribute(SpanAttributes.LLM_MODEL_NAME, model) - - usage = json.loads(node_execution.outputs).get("usage", {}) if node_execution.outputs else {} - if usage: - node_span.set_attribute(SpanAttributes.LLM_TOKEN_COUNT_TOTAL, usage.get("total_tokens", 0)) - node_span.set_attribute( - SpanAttributes.LLM_TOKEN_COUNT_PROMPT, usage.get("prompt_tokens", 0) - ) - node_span.set_attribute( - SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, usage.get("completion_tokens", 0) + llm_attributes[SpanAttributes.LLM_MODEL_NAME] = model + outputs = ( + json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {} + ) + usage_data = ( + process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) + ) + if usage_data: + llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = usage_data.get("total_tokens", 0) + llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_PROMPT] = usage_data.get("prompt_tokens", 0) + llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION] = usage_data.get( + "completion_tokens", 0 ) + llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", []))) + node_span.set_attributes(llm_attributes) finally: node_span.end(end_time=datetime_to_nanos(finished_at)) finally: @@ -290,7 +297,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): # Add end user data if available if trace_info.message_data.from_end_user_id: end_user_data: Optional[EndUser] = ( - db.session.query(EndUser).filter(EndUser.id == trace_info.message_data.from_end_user_id).first() + db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first() ) if end_user_data is not None: message_metadata["end_user_id"] = end_user_data.session_id @@ -346,25 +353,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False), SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id, } - - if isinstance(trace_info.inputs, list): - for i, msg in enumerate(trace_info.inputs): - if isinstance(msg, dict): - llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "") - llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get( - "role", "user" - ) - # todo: handle assistant and tool role messages, as they don't always - # have a text field, but may have a tool_calls field instead - # e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58', - # 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]} - elif isinstance(trace_info.inputs, dict): - llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(trace_info.inputs) - llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user" - elif isinstance(trace_info.inputs, str): - llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = trace_info.inputs - llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user" - + llm_attributes.update(self._construct_llm_attributes(trace_info.inputs)) if trace_info.total_tokens is not None and trace_info.total_tokens > 0: llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = trace_info.total_tokens if trace_info.message_tokens is not None and trace_info.message_tokens > 0: @@ -714,7 +703,28 @@ class ArizePhoenixDataTrace(BaseTraceInstance): WorkflowNodeExecutionModel.process_data, WorkflowNodeExecutionModel.execution_metadata, ) - .filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) + .where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) .all() ) return workflow_nodes + + def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]: + """Helper method to construct LLM attributes with passed prompts.""" + attributes = {} + if isinstance(prompts, list): + for i, msg in enumerate(prompts): + if isinstance(msg, dict): + attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "") + attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get("role", "user") + # todo: handle assistant and tool role messages, as they don't always + # have a text field, but may have a tool_calls field instead + # e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58', + # 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]} + elif isinstance(prompts, dict): + attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(prompts) + attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user" + elif isinstance(prompts, str): + attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = prompts + attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user" + + return attributes diff --git a/api/core/ops/base_trace_instance.py b/api/core/ops/base_trace_instance.py index 8593198bc2..f8e428daf1 100644 --- a/api/core/ops/base_trace_instance.py +++ b/api/core/ops/base_trace_instance.py @@ -44,14 +44,14 @@ class BaseTraceInstance(ABC): """ with Session(db.engine, expire_on_commit=False) as session: # Get the app to find its creator - app = session.query(App).filter(App.id == app_id).first() + app = session.query(App).where(App.id == app_id).first() if not app: raise ValueError(f"App with id {app_id} not found") if not app.created_by: raise ValueError(f"App with id {app_id} has no creator (created_by is None)") - service_account = session.query(Account).filter(Account.id == app.created_by).first() + service_account = session.query(Account).where(Account.id == app.created_by).first() if not service_account: raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py index 8a2ce58539..89ff0cfded 100644 --- a/api/core/ops/entities/config_entity.py +++ b/api/core/ops/entities/config_entity.py @@ -2,6 +2,8 @@ from enum import StrEnum from pydantic import BaseModel, ValidationInfo, field_validator +from core.ops.utils import validate_project_name, validate_url, validate_url_with_path + class TracingProviderEnum(StrEnum): ARIZE = "arize" @@ -10,14 +12,41 @@ class TracingProviderEnum(StrEnum): LANGSMITH = "langsmith" OPIK = "opik" WEAVE = "weave" + ALIYUN = "aliyun" class BaseTracingConfig(BaseModel): """ - Base model class for tracing + Base model class for tracing configurations """ - ... + @classmethod + def validate_endpoint_url(cls, v: str, default_url: str) -> str: + """ + Common endpoint URL validation logic + + Args: + v: URL value to validate + default_url: Default URL to use if input is None or empty + + Returns: + Validated and normalized URL + """ + return validate_url(v, default_url) + + @classmethod + def validate_project_field(cls, v: str, default_name: str) -> str: + """ + Common project name validation logic + + Args: + v: Project name to validate + default_name: Default name to use if input is None or empty + + Returns: + Validated project name + """ + return validate_project_name(v, default_name) class ArizeConfig(BaseTracingConfig): @@ -33,23 +62,12 @@ class ArizeConfig(BaseTracingConfig): @field_validator("project") @classmethod def project_validator(cls, v, info: ValidationInfo): - if v is None or v == "": - v = "default" - - return v + return cls.validate_project_field(v, "default") @field_validator("endpoint") @classmethod def endpoint_validator(cls, v, info: ValidationInfo): - if v is None or v == "": - v = "https://otlp.arize.com" - if not v.startswith(("https://", "http://")): - raise ValueError("endpoint must start with https:// or http://") - if "/" in v[8:]: - parts = v.split("/") - v = parts[0] + "//" + parts[2] - - return v + return cls.validate_endpoint_url(v, "https://otlp.arize.com") class PhoenixConfig(BaseTracingConfig): @@ -64,23 +82,12 @@ class PhoenixConfig(BaseTracingConfig): @field_validator("project") @classmethod def project_validator(cls, v, info: ValidationInfo): - if v is None or v == "": - v = "default" - - return v + return cls.validate_project_field(v, "default") @field_validator("endpoint") @classmethod def endpoint_validator(cls, v, info: ValidationInfo): - if v is None or v == "": - v = "https://app.phoenix.arize.com" - if not v.startswith(("https://", "http://")): - raise ValueError("endpoint must start with https:// or http://") - if "/" in v[8:]: - parts = v.split("/") - v = parts[0] + "//" + parts[2] - - return v + return cls.validate_endpoint_url(v, "https://app.phoenix.arize.com") class LangfuseConfig(BaseTracingConfig): @@ -94,13 +101,8 @@ class LangfuseConfig(BaseTracingConfig): @field_validator("host") @classmethod - def set_value(cls, v, info: ValidationInfo): - if v is None or v == "": - v = "https://api.langfuse.com" - if not v.startswith("https://") and not v.startswith("http://"): - raise ValueError("host must start with https:// or http://") - - return v + def host_validator(cls, v, info: ValidationInfo): + return cls.validate_endpoint_url(v, "https://api.langfuse.com") class LangSmithConfig(BaseTracingConfig): @@ -114,13 +116,9 @@ class LangSmithConfig(BaseTracingConfig): @field_validator("endpoint") @classmethod - def set_value(cls, v, info: ValidationInfo): - if v is None or v == "": - v = "https://api.smith.langchain.com" - if not v.startswith("https://"): - raise ValueError("endpoint must start with https://") - - return v + def endpoint_validator(cls, v, info: ValidationInfo): + # LangSmith only allows HTTPS + return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",)) class OpikConfig(BaseTracingConfig): @@ -136,22 +134,12 @@ class OpikConfig(BaseTracingConfig): @field_validator("project") @classmethod def project_validator(cls, v, info: ValidationInfo): - if v is None or v == "": - v = "Default Project" - - return v + return cls.validate_project_field(v, "Default Project") @field_validator("url") @classmethod def url_validator(cls, v, info: ValidationInfo): - if v is None or v == "": - v = "https://www.comet.com/opik/api/" - if not v.startswith(("https://", "http://")): - raise ValueError("url must start with https:// or http://") - if not v.endswith("/api/"): - raise ValueError("url should ends with /api/") - - return v + return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/") class WeaveConfig(BaseTracingConfig): @@ -167,22 +155,44 @@ class WeaveConfig(BaseTracingConfig): @field_validator("endpoint") @classmethod - def set_value(cls, v, info: ValidationInfo): - if v is None or v == "": - v = "https://trace.wandb.ai" - if not v.startswith("https://"): - raise ValueError("endpoint must start with https://") - - return v + def endpoint_validator(cls, v, info: ValidationInfo): + # Weave only allows HTTPS for endpoint + return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",)) @field_validator("host") @classmethod - def validate_host(cls, v, info: ValidationInfo): - if v is not None and v != "": - if not v.startswith(("https://", "http://")): - raise ValueError("host must start with https:// or http://") + def host_validator(cls, v, info: ValidationInfo): + if v is not None and v.strip() != "": + return validate_url(v, v, allowed_schemes=("https", "http")) return v +class AliyunConfig(BaseTracingConfig): + """ + Model class for Aliyun tracing config. + """ + + app_name: str = "dify_app" + license_key: str + endpoint: str + + @field_validator("app_name") + @classmethod + def app_name_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "dify_app") + + @field_validator("license_key") + @classmethod + def license_key_validator(cls, v, info: ValidationInfo): + if not v or v.strip() == "": + raise ValueError("License key cannot be empty") + return v + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + return cls.validate_endpoint_url(v, "https://tracing-analysis-dc-hz.aliyuncs.com") + + OPS_FILE_PATH = "ops_trace/" OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE" diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 1d4ae49fc7..f4a59ef3a7 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -28,7 +28,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( UnitEnum, ) from core.ops.utils import filter_none_values -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.nodes.enums import NodeType from extensions.ext_database import db from models import EndUser, WorkflowNodeExecutionTriggeredFrom @@ -67,13 +67,14 @@ class LangFuseDataTrace(BaseTraceInstance): self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo): - trace_id = trace_info.workflow_run_id + external_trace_id = trace_info.metadata.get("external_trace_id") + trace_id = external_trace_id or trace_info.workflow_run_id user_id = trace_info.metadata.get("user_id") metadata = trace_info.metadata metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id if trace_info.message_id: - trace_id = trace_info.message_id + trace_id = external_trace_id or trace_info.message_id name = TraceTaskName.MESSAGE_TRACE.value trace_data = LangfuseTrace( id=trace_id, @@ -123,10 +124,10 @@ class LangFuseDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, - app_id=trace_info.metadata.get("app_id"), + app_id=app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) @@ -181,12 +182,9 @@ class LangFuseDataTrace(BaseTraceInstance): prompt_tokens = 0 completion_tokens = 0 try: - if outputs.get("usage"): - prompt_tokens = outputs.get("usage", {}).get("prompt_tokens", 0) - completion_tokens = outputs.get("usage", {}).get("completion_tokens", 0) - else: - prompt_tokens = process_data.get("usage", {}).get("prompt_tokens", 0) - completion_tokens = process_data.get("usage", {}).get("completion_tokens", 0) + usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) + prompt_tokens = usage_data.get("prompt_tokens", 0) + completion_tokens = usage_data.get("completion_tokens", 0) except Exception: logger.error("Failed to extract usage", exc_info=True) @@ -246,7 +244,7 @@ class LangFuseDataTrace(BaseTraceInstance): user_id = message_data.from_account_id if message_data.from_end_user_id: end_user_data: Optional[EndUser] = ( - db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() + db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: user_id = end_user_data.session_id diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 8a392940db..c97846dc9b 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -27,7 +27,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( LangSmithRunUpdateModel, ) from core.ops.utils import filter_none_values, generate_dotted_order -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db @@ -65,7 +65,8 @@ class LangSmithDataTrace(BaseTraceInstance): self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo): - trace_id = trace_info.message_id or trace_info.workflow_run_id + external_trace_id = trace_info.metadata.get("external_trace_id") + trace_id = external_trace_id or trace_info.message_id or trace_info.workflow_run_id if trace_info.start_time is None: trace_info.start_time = datetime.now() message_dotted_order = ( @@ -145,10 +146,10 @@ class LangSmithDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, - app_id=trace_info.metadata.get("app_id"), + app_id=app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) @@ -206,12 +207,9 @@ class LangSmithDataTrace(BaseTraceInstance): prompt_tokens = 0 completion_tokens = 0 try: - if outputs.get("usage"): - prompt_tokens = outputs.get("usage", {}).get("prompt_tokens", 0) - completion_tokens = outputs.get("usage", {}).get("completion_tokens", 0) - else: - prompt_tokens = process_data.get("usage", {}).get("prompt_tokens", 0) - completion_tokens = process_data.get("usage", {}).get("completion_tokens", 0) + usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) + prompt_tokens = usage_data.get("prompt_tokens", 0) + completion_tokens = usage_data.get("completion_tokens", 0) except Exception: logger.error("Failed to extract usage", exc_info=True) @@ -264,7 +262,7 @@ class LangSmithDataTrace(BaseTraceInstance): if message_data.from_end_user_id: end_user_data: Optional[EndUser] = ( - db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() + db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: end_user_id = end_user_data.session_id diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index f4d2760ba5..6079b2faef 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -21,7 +21,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db @@ -96,7 +96,8 @@ class OpikDataTrace(BaseTraceInstance): self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo): - dify_trace_id = trace_info.workflow_run_id + external_trace_id = trace_info.metadata.get("external_trace_id") + dify_trace_id = external_trace_id or trace_info.workflow_run_id opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id) workflow_metadata = wrap_metadata( trace_info.metadata, message_id=trace_info.message_id, workflow_app_log_id=trace_info.workflow_app_log_id @@ -104,7 +105,7 @@ class OpikDataTrace(BaseTraceInstance): root_span_id = None if trace_info.message_id: - dify_trace_id = trace_info.message_id + dify_trace_id = external_trace_id or trace_info.message_id opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id) trace_data = { @@ -160,10 +161,10 @@ class OpikDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, - app_id=trace_info.metadata.get("app_id"), + app_id=app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) @@ -222,10 +223,10 @@ class OpikDataTrace(BaseTraceInstance): ) try: - if outputs.get("usage"): - total_tokens = outputs["usage"].get("total_tokens", 0) - prompt_tokens = outputs["usage"].get("prompt_tokens", 0) - completion_tokens = outputs["usage"].get("completion_tokens", 0) + usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) + total_tokens = usage_data.get("total_tokens", 0) + prompt_tokens = usage_data.get("prompt_tokens", 0) + completion_tokens = usage_data.get("completion_tokens", 0) except Exception: logger.error("Failed to extract usage", exc_info=True) @@ -241,7 +242,7 @@ class OpikDataTrace(BaseTraceInstance): "trace_id": opik_trace_id, "id": prepare_opik_uuid(created_at, node_execution_id), "parent_span_id": prepare_opik_uuid(trace_info.start_time, parent_span_id), - "name": node_type, + "name": node_name, "type": run_type, "start_time": created_at, "end_time": finished_at, @@ -283,7 +284,7 @@ class OpikDataTrace(BaseTraceInstance): if message_data.from_end_user_id: end_user_data: Optional[EndUser] = ( - db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() + db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: end_user_id = end_user_data.session_id diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index d6d6b4a1d4..2b546b47cc 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -41,28 +41,6 @@ from tasks.ops_trace_task import process_trace_tasks class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]): def __getitem__(self, provider: str) -> dict[str, Any]: match provider: - case TracingProviderEnum.ARIZE: - from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace - from core.ops.entities.config_entity import ArizeConfig - - return { - "config_class": ArizeConfig, - "secret_keys": ["api_key", "space_id"], - "other_keys": ["project", "endpoint"], - "trace_instance": ArizePhoenixDataTrace, - } - - case TracingProviderEnum.PHOENIX: - from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace - from core.ops.entities.config_entity import PhoenixConfig - - return { - "config_class": PhoenixConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "endpoint"], - "trace_instance": ArizePhoenixDataTrace, - } - case TracingProviderEnum.LANGFUSE: from core.ops.entities.config_entity import LangfuseConfig from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace @@ -126,6 +104,17 @@ class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]): "other_keys": ["project", "endpoint"], "trace_instance": ArizePhoenixDataTrace, } + case TracingProviderEnum.ALIYUN: + from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace + from core.ops.entities.config_entity import AliyunConfig + + return { + "config_class": AliyunConfig, + "secret_keys": ["license_key"], + "other_keys": ["endpoint", "app_name"], + "trace_instance": AliyunDataTrace, + } + case _: raise KeyError(f"Unsupported tracing provider: {provider}") @@ -229,7 +218,7 @@ class OpsTraceManager: """ trace_config_data: Optional[TraceAppConfig] = ( db.session.query(TraceAppConfig) - .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() ) @@ -237,7 +226,7 @@ class OpsTraceManager: return None # decrypt_token - app = db.session.query(App).filter(App.id == app_id).first() + app = db.session.query(App).where(App.id == app_id).first() if not app: raise ValueError("App not found") @@ -264,7 +253,7 @@ class OpsTraceManager: if app_id is None: return None - app: Optional[App] = db.session.query(App).filter(App.id == app_id).first() + app: Optional[App] = db.session.query(App).where(App.id == app_id).first() if app is None: return None @@ -304,18 +293,18 @@ class OpsTraceManager: @classmethod def get_app_config_through_message_id(cls, message_id: str): app_model_config = None - message_data = db.session.query(Message).filter(Message.id == message_id).first() + message_data = db.session.query(Message).where(Message.id == message_id).first() if not message_data: return None conversation_id = message_data.conversation_id - conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() + conversation_data = db.session.query(Conversation).where(Conversation.id == conversation_id).first() if not conversation_data: return None if conversation_data.app_model_config_id: app_model_config = ( db.session.query(AppModelConfig) - .filter(AppModelConfig.id == conversation_data.app_model_config_id) + .where(AppModelConfig.id == conversation_data.app_model_config_id) .first() ) elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs: @@ -342,7 +331,7 @@ class OpsTraceManager: if tracing_provider is not None: raise ValueError(f"Invalid tracing provider: {tracing_provider}") - app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first() + app_config: Optional[App] = db.session.query(App).where(App.id == app_id).first() if not app_config: raise ValueError("App not found") app_config.tracing = json.dumps( @@ -360,7 +349,7 @@ class OpsTraceManager: :param app_id: app id :return: """ - app: Optional[App] = db.session.query(App).filter(App.id == app_id).first() + app: Optional[App] = db.session.query(App).where(App.id == app_id).first() if not app: raise ValueError("App not found") if not app.tracing: @@ -531,6 +520,10 @@ class TraceTask: "app_id": workflow_run.app_id, } + external_trace_id = self.kwargs.get("external_trace_id") + if external_trace_id: + metadata["external_trace_id"] = external_trace_id + workflow_trace_info = WorkflowTraceInfo( workflow_data=workflow_run.to_dict(), conversation_id=conversation_id, diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py index 8b06df1930..573e8cac88 100644 --- a/api/core/ops/utils.py +++ b/api/core/ops/utils.py @@ -1,6 +1,9 @@ from contextlib import contextmanager from datetime import datetime from typing import Optional, Union +from urllib.parse import urlparse + +from sqlalchemy import select from extensions.ext_database import db from models.model import Message @@ -19,7 +22,7 @@ def filter_none_values(data: dict): def get_message_data(message_id: str): - return db.session.query(Message).filter(Message.id == message_id).first() + return db.session.scalar(select(Message).where(Message.id == message_id)) @contextmanager @@ -60,3 +63,83 @@ def generate_dotted_order( return current_segment return f"{parent_dotted_order}.{current_segment}" + + +def validate_url(url: str, default_url: str, allowed_schemes: tuple = ("https", "http")) -> str: + """ + Validate and normalize URL with proper error handling + + Args: + url: The URL to validate + default_url: Default URL to use if input is None or empty + allowed_schemes: Tuple of allowed URL schemes (default: https, http) + + Returns: + Normalized URL string + + Raises: + ValueError: If URL format is invalid or scheme not allowed + """ + if not url or url.strip() == "": + return default_url + + # Parse URL to validate format + parsed = urlparse(url) + + # Check if scheme is allowed + if parsed.scheme not in allowed_schemes: + raise ValueError(f"URL scheme must be one of: {', '.join(allowed_schemes)}") + + # Reconstruct URL with only scheme, netloc (removing path, query, fragment) + normalized_url = f"{parsed.scheme}://{parsed.netloc}" + + return normalized_url + + +def validate_url_with_path(url: str, default_url: str, required_suffix: str | None = None) -> str: + """ + Validate URL that may include path components + + Args: + url: The URL to validate + default_url: Default URL to use if input is None or empty + required_suffix: Optional suffix that URL must end with + + Returns: + Validated URL string + + Raises: + ValueError: If URL format is invalid or doesn't match required suffix + """ + if not url or url.strip() == "": + return default_url + + # Parse URL to validate format + parsed = urlparse(url) + + # Check if scheme is allowed + if parsed.scheme not in ("https", "http"): + raise ValueError("URL must start with https:// or http://") + + # Check required suffix if specified + if required_suffix and not url.endswith(required_suffix): + raise ValueError(f"URL should end with {required_suffix}") + + return url + + +def validate_project_name(project: str, default_name: str) -> str: + """ + Validate and normalize project name + + Args: + project: Project name to validate + default_name: Default name to use if input is None or empty + + Returns: + Normalized project name + """ + if not project or project.strip() == "": + return default_name + + return project.strip() diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 3917348a91..a34b3b780c 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -22,7 +22,7 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db @@ -87,7 +87,8 @@ class WeaveDataTrace(BaseTraceInstance): self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo): - trace_id = trace_info.message_id or trace_info.workflow_run_id + external_trace_id = trace_info.metadata.get("external_trace_id") + trace_id = external_trace_id or trace_info.message_id or trace_info.workflow_run_id if trace_info.start_time is None: trace_info.start_time = datetime.now() @@ -144,10 +145,10 @@ class WeaveDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, - app_id=trace_info.metadata.get("app_id"), + app_id=app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) @@ -234,7 +235,7 @@ class WeaveDataTrace(BaseTraceInstance): if message_data.from_end_user_id: end_user_data: Optional[EndUser] = ( - db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() + db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: end_user_id = end_user_data.session_id diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index 4e43561a15..e8c9bed099 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -193,9 +193,9 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): get the user by user id """ - user = db.session.query(EndUser).filter(EndUser.id == user_id).first() + user = db.session.query(EndUser).where(EndUser.id == user_id).first() if not user: - user = db.session.query(Account).filter(Account.id == user_id).first() + user = db.session.query(Account).where(Account.id == user_id).first() if not user: raise ValueError("user not found") @@ -208,7 +208,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): get app """ try: - app = db.session.query(App).filter(App.id == app_id).filter(App.tenant_id == tenant_id).first() + app = db.session.query(App).where(App.id == app_id).where(App.tenant_id == tenant_id).first() except Exception: raise ValueError("app not found") diff --git a/api/core/plugin/backwards_invocation/encrypt.py b/api/core/plugin/backwards_invocation/encrypt.py index 81a5d033a0..213f5c726a 100644 --- a/api/core/plugin/backwards_invocation/encrypt.py +++ b/api/core/plugin/backwards_invocation/encrypt.py @@ -1,16 +1,20 @@ +from core.helper.provider_cache import SingletonProviderCredentialsCache from core.plugin.entities.request import RequestInvokeEncrypt -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.encryption import create_provider_encrypter from models.account import Tenant class PluginEncrypter: @classmethod def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict: - encrypter = ProviderConfigEncrypter( + encrypter, cache = create_provider_encrypter( tenant_id=tenant.id, config=payload.config, - provider_type=payload.namespace, - provider_identity=payload.identity, + cache=SingletonProviderCredentialsCache( + tenant_id=tenant.id, + provider_type=payload.namespace, + provider_identity=payload.identity, + ), ) if payload.opt == "encrypt": @@ -22,7 +26,7 @@ class PluginEncrypter: "data": encrypter.decrypt(payload.data), } elif payload.opt == "clear": - encrypter.delete_tool_credentials_cache() + cache.delete() return { "data": {}, } diff --git a/api/core/plugin/backwards_invocation/tool.py b/api/core/plugin/backwards_invocation/tool.py index 1d62743f13..06773504d9 100644 --- a/api/core/plugin/backwards_invocation/tool.py +++ b/api/core/plugin/backwards_invocation/tool.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any +from typing import Any, Optional from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.plugin.backwards_invocation.base import BaseBackwardsInvocation @@ -23,6 +23,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation): provider: str, tool_name: str, tool_parameters: dict[str, Any], + credential_id: Optional[str] = None, ) -> Generator[ToolInvokeMessage, None, None]: """ invoke tool @@ -30,7 +31,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation): # get tool runtime try: tool_runtime = ToolManager.get_tool_runtime_from_plugin( - tool_type, tenant_id, provider, tool_name, tool_parameters + tool_type, tenant_id, provider, tool_name, tool_parameters, credential_id ) response = ToolEngine.generic_invoke( tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1 diff --git a/api/core/plugin/entities/marketplace.py b/api/core/plugin/entities/marketplace.py index a19a44aa3c..1c13a621d4 100644 --- a/api/core/plugin/entities/marketplace.py +++ b/api/core/plugin/entities/marketplace.py @@ -32,6 +32,13 @@ class MarketplacePluginDeclaration(BaseModel): latest_package_identifier: str = Field( ..., description="Unique identifier for the latest package release of the plugin" ) + status: str = Field(..., description="Indicate the status of marketplace plugin, enum from `active` `deleted`") + deprecated_reason: str = Field( + ..., description="Not empty when status='deleted', indicates the reason why this plugin is deleted(deprecated)" + ) + alternative_plugin_id: str = Field( + ..., description="Optional, indicates the alternative plugin for user to switch to" + ) @model_validator(mode="before") @classmethod diff --git a/api/core/plugin/entities/parameters.py b/api/core/plugin/entities/parameters.py index 2b438a3c33..47290ee613 100644 --- a/api/core/plugin/entities/parameters.py +++ b/api/core/plugin/entities/parameters.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, Field, field_validator from core.entities.parameter_entities import CommonParameterType from core.tools.entities.common_entities import I18nObject +from core.workflow.nodes.base.entities import NumberType class PluginParameterOption(BaseModel): @@ -38,11 +39,25 @@ class PluginParameterType(enum.StrEnum): APP_SELECTOR = CommonParameterType.APP_SELECTOR.value MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value + ANY = CommonParameterType.ANY.value DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT.value # deprecated, should not use. SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value + # MCP object and array type parameters + ARRAY = CommonParameterType.ARRAY.value + OBJECT = CommonParameterType.OBJECT.value + + +class MCPServerParameterType(enum.StrEnum): + """ + MCP server got complex parameter types + """ + + ARRAY = "array" + OBJECT = "object" + class PluginParameterAutoGenerate(BaseModel): class Type(enum.StrEnum): @@ -138,6 +153,38 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /): if value and not isinstance(value, list): raise ValueError("The tools selector must be a list.") return value + case PluginParameterType.ANY: + if value and not isinstance(value, str | dict | list | NumberType): + raise ValueError("The var selector must be a string, dictionary, list or number.") + return value + case PluginParameterType.ARRAY: + if not isinstance(value, list): + # Try to parse JSON string for arrays + if isinstance(value, str): + try: + import json + + parsed_value = json.loads(value) + if isinstance(parsed_value, list): + return parsed_value + except (json.JSONDecodeError, ValueError): + pass + return [value] + return value + case PluginParameterType.OBJECT: + if not isinstance(value, dict): + # Try to parse JSON string for objects + if isinstance(value, str): + try: + import json + + parsed_value = json.loads(value) + if isinstance(parsed_value, dict): + return parsed_value + except (json.JSONDecodeError, ValueError): + pass + return {} + return value case _: return str(value) except ValueError: diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index bdf7d5ce1f..a07b58d9ea 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -72,12 +72,14 @@ class PluginDeclaration(BaseModel): class Meta(BaseModel): minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$") + version: Optional[str] = Field(default=None) version: str = Field(..., pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$") author: Optional[str] = Field(..., pattern=r"^[a-zA-Z0-9_-]{1,64}$") name: str = Field(..., pattern=r"^[a-z0-9_-]{1,128}$") description: I18nObject icon: str + icon_dark: Optional[str] = Field(default=None) label: I18nObject category: PluginCategory created_at: datetime.datetime @@ -133,17 +135,6 @@ class PluginEntity(PluginInstallation): return self -class GithubPackage(BaseModel): - repo: str - version: str - package: str - - -class GithubVersion(BaseModel): - repo: str - version: str - - class GenericProviderID: organization: str plugin_name: str diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 592b42c0da..16ab661092 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -53,6 +53,7 @@ class PluginAgentProviderEntity(BaseModel): plugin_unique_identifier: str plugin_id: str declaration: AgentProviderEntityWithPlugin + meta: PluginDeclaration.Meta class PluginBasicBooleanResponse(BaseModel): @@ -181,6 +182,10 @@ class PluginOAuthAuthorizationUrlResponse(BaseModel): class PluginOAuthCredentialsResponse(BaseModel): + metadata: Mapping[str, Any] = Field( + default_factory=dict, description="The metadata of the OAuth, like avatar url, name, etc." + ) + expires_at: int = Field(default=-1, description="The expires at time of the credentials. UTC timestamp.") credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.") diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index f9c81ed4d5..3a783dad3e 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -27,15 +27,30 @@ from core.workflow.nodes.question_classifier.entities import ( ) +class InvokeCredentials(BaseModel): + tool_credentials: dict[str, str] = Field( + default_factory=dict, + description="Map of tool provider to credential id, used to store the credential id for the tool provider.", + ) + + +class PluginInvokeContext(BaseModel): + credentials: Optional[InvokeCredentials] = Field( + default_factory=InvokeCredentials, + description="Credentials context for the plugin invocation or backward invocation.", + ) + + class RequestInvokeTool(BaseModel): """ Request to invoke a tool """ - tool_type: Literal["builtin", "workflow", "api"] + tool_type: Literal["builtin", "workflow", "api", "mcp"] provider: str tool: str tool_parameters: dict + credential_id: Optional[str] = None class BaseRequestInvokeModel(BaseModel): diff --git a/api/core/plugin/impl/agent.py b/api/core/plugin/impl/agent.py index 66b77c7489..9575c57ac8 100644 --- a/api/core/plugin/impl/agent.py +++ b/api/core/plugin/impl/agent.py @@ -6,6 +6,7 @@ from core.plugin.entities.plugin import GenericProviderID from core.plugin.entities.plugin_daemon import ( PluginAgentProviderEntity, ) +from core.plugin.entities.request import PluginInvokeContext from core.plugin.impl.base import BasePluginClient @@ -83,6 +84,7 @@ class PluginAgentClient(BasePluginClient): conversation_id: Optional[str] = None, app_id: Optional[str] = None, message_id: Optional[str] = None, + context: Optional[PluginInvokeContext] = None, ) -> Generator[AgentInvokeMessage, None, None]: """ Invoke the agent with the given tenant, user, plugin, provider, name and parameters. @@ -99,6 +101,7 @@ class PluginAgentClient(BasePluginClient): "conversation_id": conversation_id, "app_id": app_id, "message_id": message_id, + "context": context.model_dump() if context else {}, "data": { "agent_strategy_provider": agent_provider_id.provider_name, "agent_strategy": agent_strategy, diff --git a/api/core/plugin/impl/oauth.py b/api/core/plugin/impl/oauth.py index b006bf1d4b..7f022992ff 100644 --- a/api/core/plugin/impl/oauth.py +++ b/api/core/plugin/impl/oauth.py @@ -15,27 +15,32 @@ class OAuthHandler(BasePluginClient): user_id: str, plugin_id: str, provider: str, + redirect_uri: str, system_credentials: Mapping[str, Any], ) -> PluginOAuthAuthorizationUrlResponse: - response = self._request_with_plugin_daemon_response_stream( - "POST", - f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url", - PluginOAuthAuthorizationUrlResponse, - data={ - "user_id": user_id, - "data": { - "provider": provider, - "system_credentials": system_credentials, + try: + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url", + PluginOAuthAuthorizationUrlResponse, + data={ + "user_id": user_id, + "data": { + "provider": provider, + "redirect_uri": redirect_uri, + "system_credentials": system_credentials, + }, }, - }, - headers={ - "X-Plugin-ID": plugin_id, - "Content-Type": "application/json", - }, - ) - for resp in response: - return resp - raise ValueError("No response received from plugin daemon for authorization URL request.") + headers={ + "X-Plugin-ID": plugin_id, + "Content-Type": "application/json", + }, + ) + for resp in response: + return resp + raise ValueError("No response received from plugin daemon for authorization URL request.") + except Exception as e: + raise ValueError(f"Error getting authorization URL: {e}") def get_credentials( self, @@ -43,6 +48,7 @@ class OAuthHandler(BasePluginClient): user_id: str, plugin_id: str, provider: str, + redirect_uri: str, system_credentials: Mapping[str, Any], request: Request, ) -> PluginOAuthCredentialsResponse: @@ -50,30 +56,68 @@ class OAuthHandler(BasePluginClient): Get credentials from the given request. """ - # encode request to raw http request - raw_request_bytes = self._convert_request_to_raw_data(request) - - response = self._request_with_plugin_daemon_response_stream( - "POST", - f"plugin/{tenant_id}/dispatch/oauth/get_credentials", - PluginOAuthCredentialsResponse, - data={ - "user_id": user_id, - "data": { - "provider": provider, - "system_credentials": system_credentials, - # for json serialization - "raw_http_request": binascii.hexlify(raw_request_bytes).decode(), + try: + # encode request to raw http request + raw_request_bytes = self._convert_request_to_raw_data(request) + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/oauth/get_credentials", + PluginOAuthCredentialsResponse, + data={ + "user_id": user_id, + "data": { + "provider": provider, + "redirect_uri": redirect_uri, + "system_credentials": system_credentials, + # for json serialization + "raw_http_request": binascii.hexlify(raw_request_bytes).decode(), + }, }, - }, - headers={ - "X-Plugin-ID": plugin_id, - "Content-Type": "application/json", - }, - ) - for resp in response: - return resp - raise ValueError("No response received from plugin daemon for authorization URL request.") + headers={ + "X-Plugin-ID": plugin_id, + "Content-Type": "application/json", + }, + ) + for resp in response: + return resp + raise ValueError("No response received from plugin daemon for authorization URL request.") + except Exception as e: + raise ValueError(f"Error getting credentials: {e}") + + def refresh_credentials( + self, + tenant_id: str, + user_id: str, + plugin_id: str, + provider: str, + redirect_uri: str, + system_credentials: Mapping[str, Any], + credentials: Mapping[str, Any], + ) -> PluginOAuthCredentialsResponse: + try: + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/oauth/refresh_credentials", + PluginOAuthCredentialsResponse, + data={ + "user_id": user_id, + "data": { + "provider": provider, + "redirect_uri": redirect_uri, + "system_credentials": system_credentials, + "credentials": credentials, + }, + }, + headers={ + "X-Plugin-ID": plugin_id, + "Content-Type": "application/json", + }, + ) + for resp in response: + return resp + raise ValueError("No response received from plugin daemon for refresh credentials request.") + except Exception as e: + raise ValueError(f"Error refreshing credentials: {e}") def _convert_request_to_raw_data(self, request: Request) -> bytes: """ diff --git a/api/core/plugin/impl/plugin.py b/api/core/plugin/impl/plugin.py index b7f7b31655..04ac8c9649 100644 --- a/api/core/plugin/impl/plugin.py +++ b/api/core/plugin/impl/plugin.py @@ -36,7 +36,7 @@ class PluginInstaller(BasePluginClient): "GET", f"plugin/{tenant_id}/management/list", PluginListResponse, - params={"page": 1, "page_size": 256}, + params={"page": 1, "page_size": 256, "response_type": "paged"}, ) return result.list @@ -45,7 +45,7 @@ class PluginInstaller(BasePluginClient): "GET", f"plugin/{tenant_id}/management/list", PluginListResponse, - params={"page": page, "page_size": page_size}, + params={"page": page, "page_size": page_size, "response_type": "paged"}, ) def upload_pkg( diff --git a/api/core/plugin/impl/tool.py b/api/core/plugin/impl/tool.py index 19b26c8fe3..04225f95ee 100644 --- a/api/core/plugin/impl/tool.py +++ b/api/core/plugin/impl/tool.py @@ -6,7 +6,7 @@ from pydantic import BaseModel from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity from core.plugin.impl.base import BasePluginClient -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter class PluginToolManager(BasePluginClient): @@ -78,6 +78,7 @@ class PluginToolManager(BasePluginClient): tool_provider: str, tool_name: str, credentials: dict[str, Any], + credential_type: CredentialType, tool_parameters: dict[str, Any], conversation_id: Optional[str] = None, app_id: Optional[str] = None, @@ -102,6 +103,7 @@ class PluginToolManager(BasePluginClient): "provider": tool_provider_id.provider_name, "tool": tool_name, "credentials": credentials, + "credential_type": credential_type, "tool_parameters": tool_parameters, }, }, diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 25964ae063..0f0fe65f27 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -158,7 +158,7 @@ class AdvancedPromptTransform(PromptTransform): if prompt_item.edition_type == "basic" or not prompt_item.edition_type: if self.with_variable_tmpl: - vp = VariablePool() + vp = VariablePool.empty() for k, v in inputs.items(): if k.startswith("#"): vp.add(k[1:-1].split("."), v) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 47808928f7..e19c6419ca 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -29,19 +29,6 @@ class ModelMode(enum.StrEnum): COMPLETION = "completion" CHAT = "chat" - @classmethod - def value_of(cls, value: str) -> "ModelMode": - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid mode value {value}") - prompt_file_contents: dict[str, Any] = {} @@ -65,7 +52,7 @@ class SimplePromptTransform(PromptTransform): ) -> tuple[list[PromptMessage], Optional[list[str]]]: inputs = {key: str(value) for key, value in inputs.items()} - model_mode = ModelMode.value_of(model_config.mode) + model_mode = ModelMode(model_config.mode) if model_mode == ModelMode.CHAT: prompt_messages, stops = self._get_chat_model_prompt_messages( app_mode=app_mode, diff --git a/api/core/prompt/utils/extract_thread_messages.py b/api/core/prompt/utils/extract_thread_messages.py index f7aef76c87..4b883622a7 100644 --- a/api/core/prompt/utils/extract_thread_messages.py +++ b/api/core/prompt/utils/extract_thread_messages.py @@ -1,10 +1,11 @@ -from typing import Any +from collections.abc import Sequence from constants import UUID_NIL +from models import Message -def extract_thread_messages(messages: list[Any]): - thread_messages = [] +def extract_thread_messages(messages: Sequence[Message]): + thread_messages: list[Message] = [] next_message = None for message in messages: diff --git a/api/core/prompt/utils/get_thread_messages_length.py b/api/core/prompt/utils/get_thread_messages_length.py index f49466db6d..de64c27a73 100644 --- a/api/core/prompt/utils/get_thread_messages_length.py +++ b/api/core/prompt/utils/get_thread_messages_length.py @@ -1,3 +1,5 @@ +from sqlalchemy import select + from core.prompt.utils.extract_thread_messages import extract_thread_messages from extensions.ext_database import db from models.model import Message @@ -8,19 +10,9 @@ def get_thread_messages_length(conversation_id: str) -> int: Get the number of thread messages based on the parent message id. """ # Fetch all messages related to the conversation - query = ( - db.session.query( - Message.id, - Message.parent_message_id, - Message.answer, - ) - .filter( - Message.conversation_id == conversation_id, - ) - .order_by(Message.created_at.desc()) - ) + stmt = select(Message).where(Message.conversation_id == conversation_id).order_by(Message.created_at.desc()) - messages = query.all() + messages = db.session.scalars(stmt).all() # Extract thread messages thread_messages = extract_thread_messages(messages) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 488a394679..6de4f3a303 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -275,7 +275,7 @@ class ProviderManager: # Get the corresponding TenantDefaultModel record default_model = ( db.session.query(TenantDefaultModel) - .filter( + .where( TenantDefaultModel.tenant_id == tenant_id, TenantDefaultModel.model_type == model_type.to_origin_model_type(), ) @@ -367,7 +367,7 @@ class ProviderManager: # Get the list of available models from get_configurations and check if it is LLM default_model = ( db.session.query(TenantDefaultModel) - .filter( + .where( TenantDefaultModel.tenant_id == tenant_id, TenantDefaultModel.model_type == model_type.to_origin_model_type(), ) @@ -541,7 +541,7 @@ class ProviderManager: db.session.rollback() existed_provider_record = ( db.session.query(Provider) - .filter( + .where( Provider.tenant_id == tenant_id, Provider.provider_name == ModelProviderID(provider_name).provider_name, Provider.provider_type == ProviderType.SYSTEM.value, diff --git a/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py deleted file mode 100644 index 167a919e69..0000000000 --- a/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Abstract interface for document clean implementations.""" - -from core.rag.cleaner.cleaner_base import BaseCleaner - - -class UnstructuredNonAsciiCharsCleaner(BaseCleaner): - def clean(self, content) -> str: - """clean document content.""" - from unstructured.cleaners.core import clean_extra_whitespace - - # Returns "ITEM 1A: RISK FACTORS" - return clean_extra_whitespace(content) diff --git a/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py deleted file mode 100644 index 9c682d29db..0000000000 --- a/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Abstract interface for document clean implementations.""" - -from core.rag.cleaner.cleaner_base import BaseCleaner - - -class UnstructuredGroupBrokenParagraphsCleaner(BaseCleaner): - def clean(self, content) -> str: - """clean document content.""" - import re - - from unstructured.cleaners.core import group_broken_paragraphs - - para_split_re = re.compile(r"(\s*\n\s*){3}") - - return group_broken_paragraphs(content, paragraph_split=para_split_re) diff --git a/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py deleted file mode 100644 index 0cdbb171e1..0000000000 --- a/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Abstract interface for document clean implementations.""" - -from core.rag.cleaner.cleaner_base import BaseCleaner - - -class UnstructuredNonAsciiCharsCleaner(BaseCleaner): - def clean(self, content) -> str: - """clean document content.""" - from unstructured.cleaners.core import clean_non_ascii_chars - - # Returns "This text contains non-ascii characters!" - return clean_non_ascii_chars(content) diff --git a/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py deleted file mode 100644 index 9f42044a2d..0000000000 --- a/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Abstract interface for document clean implementations.""" - -from core.rag.cleaner.cleaner_base import BaseCleaner - - -class UnstructuredNonAsciiCharsCleaner(BaseCleaner): - def clean(self, content) -> str: - """Replaces unicode quote characters, such as the \x91 character in a string.""" - - from unstructured.cleaners.core import replace_unicode_quotes - - return replace_unicode_quotes(content) diff --git a/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py deleted file mode 100644 index 32ae7217e8..0000000000 --- a/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Abstract interface for document clean implementations.""" - -from core.rag.cleaner.cleaner_base import BaseCleaner - - -class UnstructuredTranslateTextCleaner(BaseCleaner): - def clean(self, content) -> str: - """clean document content.""" - from unstructured.cleaners.translate import translate_text - - return translate_text(content) diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index d6d0bd88b2..ec3a23bd96 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -93,11 +93,11 @@ class Jieba(BaseKeyword): documents = [] for chunk_index in sorted_chunk_indices: - segment_query = db.session.query(DocumentSegment).filter( + segment_query = db.session.query(DocumentSegment).where( DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index ) if document_ids_filter: - segment_query = segment_query.filter(DocumentSegment.document_id.in_(document_ids_filter)) + segment_query = segment_query.where(DocumentSegment.document_id.in_(document_ids_filter)) segment = segment_query.first() if segment: @@ -214,7 +214,7 @@ class Jieba(BaseKeyword): def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]): document_segment = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id) + .where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id) .first() ) if document_segment: diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 2c5178241c..e872a4e375 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -3,7 +3,7 @@ from concurrent.futures import ThreadPoolExecutor from typing import Optional from flask import Flask, current_app -from sqlalchemy.orm import load_only +from sqlalchemy.orm import Session, load_only from configs import dify_config from core.rag.data_post_processor.data_post_processor import DataPostProcessor @@ -127,7 +127,7 @@ class RetrievalService: external_retrieval_model: Optional[dict] = None, metadata_filtering_conditions: Optional[dict] = None, ): - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: return [] metadata_condition = ( @@ -144,7 +144,8 @@ class RetrievalService: @classmethod def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]: - return db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + with Session(db.engine) as session: + return session.query(Dataset).where(Dataset.id == dataset_id).first() @classmethod def keyword_search( @@ -293,7 +294,7 @@ class RetrievalService: dataset_documents = { doc.id: doc for doc in db.session.query(DatasetDocument) - .filter(DatasetDocument.id.in_(document_ids)) + .where(DatasetDocument.id.in_(document_ids)) .options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id)) .all() } @@ -317,7 +318,7 @@ class RetrievalService: child_index_node_id = document.metadata.get("doc_id") child_chunk = ( - db.session.query(ChildChunk).filter(ChildChunk.index_node_id == child_index_node_id).first() + db.session.query(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id).first() ) if not child_chunk: @@ -325,7 +326,7 @@ class RetrievalService: segment = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.dataset_id == dataset_document.dataset_id, DocumentSegment.enabled == True, DocumentSegment.status == "completed", @@ -380,7 +381,7 @@ class RetrievalService: segment = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.dataset_id == dataset_document.dataset_id, DocumentSegment.enabled == True, DocumentSegment.status == "completed", diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py index 095752ea8e..6f3e15d166 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py @@ -233,6 +233,12 @@ class AnalyticdbVectorOpenAPI: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + document_ids_filter = kwargs.get("document_ids_filter") + where_clause = "" + if document_ids_filter: + document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) + where_clause += f"metadata_->>'document_id' IN ({document_ids})" + score_threshold = kwargs.get("score_threshold") or 0.0 request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, @@ -245,7 +251,7 @@ class AnalyticdbVectorOpenAPI: vector=query_vector, content=None, top_k=kwargs.get("top_k", 4), - filter=None, + filter=where_clause, ) response = self._client.query_collection_data(request) documents = [] @@ -265,6 +271,11 @@ class AnalyticdbVectorOpenAPI: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + document_ids_filter = kwargs.get("document_ids_filter") + where_clause = "" + if document_ids_filter: + document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) + where_clause += f"metadata_->>'document_id' IN ({document_ids})" score_threshold = float(kwargs.get("score_threshold") or 0.0) request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, @@ -277,7 +288,7 @@ class AnalyticdbVectorOpenAPI: vector=None, content=query, top_k=kwargs.get("top_k", 4), - filter=None, + filter=where_clause, ) response = self._client.query_collection_data(request) documents = [] diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index 44cc5d3e98..ad39717183 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -147,10 +147,17 @@ class ElasticSearchVector(BaseVector): return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - query_str = {"match": {Field.CONTENT_KEY.value: query}} + query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY.value: query}} document_ids_filter = kwargs.get("document_ids_filter") + if document_ids_filter: - query_str["filter"] = {"terms": {"metadata.document_id": document_ids_filter}} # type: ignore + query_str = { + "bool": { + "must": {"match": {Field.CONTENT_KEY.value: query}}, + "filter": {"terms": {"metadata.document_id": document_ids_filter}}, + } + } + results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4)) docs = [] for hit in results["hits"]["hits"]: diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index 46aefef11d..b0f0eeca38 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -6,7 +6,7 @@ from uuid import UUID, uuid4 from numpy import ndarray from pgvecto_rs.sqlalchemy import VECTOR # type: ignore from pydantic import BaseModel, model_validator -from sqlalchemy import Float, String, create_engine, insert, select, text +from sqlalchemy import Float, create_engine, insert, select, text from sqlalchemy import text as sql_text from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Mapped, Session, mapped_column @@ -67,7 +67,7 @@ class PGVectoRS(BaseVector): postgresql.UUID(as_uuid=True), primary_key=True, ) - text: Mapped[str] = mapped_column(String) + text: Mapped[str] meta: Mapped[dict] = mapped_column(postgresql.JSONB) vector: Mapped[ndarray] = mapped_column(VECTOR(dim)) diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 8ce194c683..dfb95a1839 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -47,6 +47,7 @@ class QdrantConfig(BaseModel): grpc_port: int = 6334 prefer_grpc: bool = False replication_factor: int = 1 + write_consistency_factor: int = 1 def to_qdrant_params(self): if self.endpoint and self.endpoint.startswith("path:"): @@ -127,6 +128,7 @@ class QdrantVector(BaseVector): hnsw_config=hnsw_config, timeout=int(self._client_config.timeout), replication_factor=self._client_config.replication_factor, + write_consistency_factor=self._client_config.write_consistency_factor, ) # create group_id payload index @@ -441,7 +443,7 @@ class QdrantVectorFactory(AbstractVectorFactory): if dataset.collection_binding_id: dataset_collection_binding = ( db.session.query(DatasetCollectionBinding) - .filter(DatasetCollectionBinding.id == dataset.collection_binding_id) + .where(DatasetCollectionBinding.id == dataset.collection_binding_id) .one_or_none() ) if dataset_collection_binding: diff --git a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py index a124faa503..9ed6e7369b 100644 --- a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py +++ b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py @@ -4,6 +4,7 @@ from typing import Any, Optional import tablestore # type: ignore from pydantic import BaseModel, model_validator +from tablestore import BatchGetRowRequest, TableInBatchGetRowItem from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -50,6 +51,29 @@ class TableStoreVector(BaseVector): self._index_name = f"{collection_name}_idx" self._tags_field = f"{Field.METADATA_KEY.value}_tags" + def create_collection(self, embeddings: list[list[float]], **kwargs): + dimension = len(embeddings[0]) + self._create_collection(dimension) + + def get_by_ids(self, ids: list[str]) -> list[Document]: + docs = [] + request = BatchGetRowRequest() + columns_to_get = [Field.METADATA_KEY.value, Field.CONTENT_KEY.value] + rows_to_get = [[("id", _id)] for _id in ids] + request.add(TableInBatchGetRowItem(self._table_name, rows_to_get, columns_to_get, None, 1)) + + result = self._tablestore_client.batch_get_row(request) + table_result = result.get_result_by_table(self._table_name) + for item in table_result: + if item.is_ok and item.row: + kv = {k: v for k, v, t in item.row.attribute_columns} + docs.append( + Document( + page_content=kv[Field.CONTENT_KEY.value], metadata=json.loads(kv[Field.METADATA_KEY.value]) + ) + ) + return docs + def get_type(self) -> str: return VectorType.TABLESTORE @@ -94,10 +118,21 @@ class TableStoreVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 4) - return self._search_by_vector(query_vector, top_k) + document_ids_filter = kwargs.get("document_ids_filter") + filtered_list = None + if document_ids_filter: + filtered_list = ["document_id=" + item for item in document_ids_filter] + score_threshold = float(kwargs.get("score_threshold") or 0.0) + return self._search_by_vector(query_vector, filtered_list, top_k, score_threshold) def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - return self._search_by_full_text(query) + top_k = kwargs.get("top_k", 4) + document_ids_filter = kwargs.get("document_ids_filter") + filtered_list = None + if document_ids_filter: + filtered_list = ["document_id=" + item for item in document_ids_filter] + + return self._search_by_full_text(query, filtered_list, top_k) def delete(self) -> None: self._delete_table_if_exist() @@ -206,32 +241,51 @@ class TableStoreVector(BaseVector): primary_key = [("id", id)] row = tablestore.Row(primary_key) self._tablestore_client.delete_row(self._table_name, row, None) - logging.info("Tablestore delete row successfully. id:%s", id) def _search_by_metadata(self, key: str, value: str) -> list[str]: query = tablestore.SearchQuery( tablestore.TermQuery(self._tags_field, str(key) + "=" + str(value)), - limit=100, + limit=1000, get_total_count=False, ) + rows: list[str] = [] + next_token = None + while True: + if next_token is not None: + query.next_token = next_token - search_response = self._tablestore_client.search( - table_name=self._table_name, - index_name=self._index_name, - search_query=query, - columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX), - ) + search_response = self._tablestore_client.search( + table_name=self._table_name, + index_name=self._index_name, + search_query=query, + columns_to_get=tablestore.ColumnsToGet( + column_names=[Field.PRIMARY_KEY.value], return_type=tablestore.ColumnReturnType.SPECIFIED + ), + ) - return [row[0][0][1] for row in search_response.rows] + if search_response is not None: + rows.extend([row[0][0][1] for row in search_response.rows]) - def _search_by_vector(self, query_vector: list[float], top_k: int) -> list[Document]: - ots_query = tablestore.KnnVectorQuery( + if search_response is None or search_response.next_token == b"": + break + else: + next_token = search_response.next_token + + return rows + + def _search_by_vector( + self, query_vector: list[float], document_ids_filter: list[str] | None, top_k: int, score_threshold: float + ) -> list[Document]: + knn_vector_query = tablestore.KnnVectorQuery( field_name=Field.VECTOR.value, top_k=top_k, float32_query_vector=query_vector, ) + if document_ids_filter: + knn_vector_query.filter = tablestore.TermsQuery(self._tags_field, document_ids_filter) + sort = tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)]) - search_query = tablestore.SearchQuery(ots_query, limit=top_k, get_total_count=False, sort=sort) + search_query = tablestore.SearchQuery(knn_vector_query, limit=top_k, get_total_count=False, sort=sort) search_response = self._tablestore_client.search( table_name=self._table_name, @@ -239,30 +293,42 @@ class TableStoreVector(BaseVector): search_query=search_query, columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX), ) - logging.info( - "Tablestore search successfully. request_id:%s", - search_response.request_id, - ) - return self._to_query_result(search_response) - - def _to_query_result(self, search_response: tablestore.SearchResponse) -> list[Document]: documents = [] - for row in search_response.rows: - documents.append( - Document( - page_content=row[1][2][1], - vector=json.loads(row[1][3][1]), - metadata=json.loads(row[1][0][1]), - ) - ) + for search_hit in search_response.search_hits: + if search_hit.score > score_threshold: + ots_column_map = {} + for col in search_hit.row[1]: + ots_column_map[col[0]] = col[1] + vector_str = ots_column_map.get(Field.VECTOR.value) + metadata_str = ots_column_map.get(Field.METADATA_KEY.value) + + vector = json.loads(vector_str) if vector_str else None + metadata = json.loads(metadata_str) if metadata_str else {} + + metadata["score"] = search_hit.score + + documents.append( + Document( + page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "", + vector=vector, + metadata=metadata, + ) + ) + documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return documents - def _search_by_full_text(self, query: str) -> list[Document]: + def _search_by_full_text(self, query: str, document_ids_filter: list[str] | None, top_k: int) -> list[Document]: + bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[]) + bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value)) + + if document_ids_filter: + bool_query.filter_queries.append(tablestore.TermsQuery(self._tags_field, document_ids_filter)) + search_query = tablestore.SearchQuery( - query=tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value), + query=bool_query, sort=tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)]), - limit=100, + limit=top_k, ) search_response = self._tablestore_client.search( table_name=self._table_name, @@ -271,7 +337,25 @@ class TableStoreVector(BaseVector): columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX), ) - return self._to_query_result(search_response) + documents = [] + for search_hit in search_response.search_hits: + ots_column_map = {} + for col in search_hit.row[1]: + ots_column_map[col[0]] = col[1] + + vector_str = ots_column_map.get(Field.VECTOR.value) + metadata_str = ots_column_map.get(Field.METADATA_KEY.value) + vector = json.loads(vector_str) if vector_str else None + metadata = json.loads(metadata_str) if metadata_str else {} + + documents.append( + Document( + page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "", + vector=vector, + metadata=metadata, + ) + ) + return documents class TableStoreVectorFactory(AbstractVectorFactory): diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index d2bf3eb92a..23ed8a3344 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -122,7 +122,6 @@ class TencentVector(BaseVector): metric_type, params, ) - index_text = vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER) index_metadate = vdb_index.FilterIndex(self.field_metadata, enum.FieldType.Json, enum.IndexType.FILTER) index_sparse_vector = vdb_index.SparseIndex( name="sparse_vector", @@ -130,7 +129,7 @@ class TencentVector(BaseVector): index_type=enum.IndexType.SPARSE_INVERTED, metric_type=enum.MetricType.IP, ) - indexes = [index_id, index_vector, index_text, index_metadate] + indexes = [index_id, index_vector, index_metadate] if self._enable_hybrid_search: indexes.append(index_sparse_vector) try: @@ -149,7 +148,7 @@ class TencentVector(BaseVector): index_metadate = vdb_index.FilterIndex( self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER ) - indexes = [index_id, index_vector, index_text, index_metadate] + indexes = [index_id, index_vector, index_metadate] if self._enable_hybrid_search: indexes.append(index_sparse_vector) self._client.create_collection( @@ -207,9 +206,19 @@ class TencentVector(BaseVector): def delete_by_ids(self, ids: list[str]) -> None: if not ids: return - self._client.delete( - database_name=self._client_config.database, collection_name=self.collection_name, document_ids=ids - ) + + total_count = len(ids) + batch_size = self._client_config.max_upsert_batch_size + batch = math.ceil(total_count / batch_size) + + for j in range(batch): + start_idx = j * batch_size + end_idx = min(total_count, (j + 1) * batch_size) + batch_ids = ids[start_idx:end_idx] + + self._client.delete( + database_name=self._client_config.database, collection_name=self.collection_name, document_ids=batch_ids + ) def delete_by_metadata_field(self, key: str, value: str) -> None: self._client.delete( @@ -275,7 +284,8 @@ class TencentVector(BaseVector): # Compatible with version 1.1.3 and below. meta = json.loads(meta) score = 1 - result.get("score", 0.0) - score = result.get("score", 0.0) + else: + score = result.get("score", 0.0) if score > score_threshold: meta["score"] = score doc = Document(page_content=result.get(self.field_text), metadata=meta) diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py deleted file mode 100644 index 1e62b3c589..0000000000 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel - - -class ClusterEntity(BaseModel): - """ - Model Config Entity. - """ - - name: str - cluster_id: str - displayName: str - region: str - spendingLimit: Optional[int] = 1000 - version: str - createdBy: str diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index 6f895b12af..ba6a9654f0 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -418,13 +418,13 @@ class TidbOnQdrantVector(BaseVector): class TidbOnQdrantVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector: tidb_auth_binding = ( - db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none() + db.session.query(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none() ) if not tidb_auth_binding: with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900): tidb_auth_binding = ( db.session.query(TidbAuthBinding) - .filter(TidbAuthBinding.tenant_id == dataset.tenant_id) + .where(TidbAuthBinding.tenant_id == dataset.tenant_id) .one_or_none() ) if tidb_auth_binding: @@ -433,7 +433,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): else: idle_tidb_auth_binding = ( db.session.query(TidbAuthBinding) - .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE") + .where(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE") .limit(1) .one_or_none() ) diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 67a4a515b1..e018f7d3d4 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -1,3 +1,5 @@ +import logging +import time from abc import ABC, abstractmethod from typing import Any, Optional @@ -13,6 +15,8 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset, Whitelist +logger = logging.getLogger(__name__) + class AbstractVectorFactory(ABC): @abstractmethod @@ -43,7 +47,7 @@ class Vector: if dify_config.VECTOR_STORE_WHITELIST_ENABLE: whitelist = ( db.session.query(Whitelist) - .filter(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db") + .where(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db") .one_or_none() ) if whitelist: @@ -173,8 +177,20 @@ class Vector: def create(self, texts: Optional[list] = None, **kwargs): if texts: - embeddings = self._embeddings.embed_documents([document.page_content for document in texts]) - self._vector_processor.create(texts=texts, embeddings=embeddings, **kwargs) + start = time.time() + logger.info(f"start embedding {len(texts)} texts {start}") + batch_size = 1000 + total_batches = len(texts) + batch_size - 1 + for i in range(0, len(texts), batch_size): + batch = texts[i : i + batch_size] + batch_start = time.time() + logger.info(f"Processing batch {i // batch_size + 1}/{total_batches} ({len(batch)} texts)") + batch_embeddings = self._embeddings.embed_documents([document.page_content for document in batch]) + logger.info( + f"Embedding batch {i // batch_size + 1}/{total_batches} took {time.time() - batch_start:.3f}s" + ) + self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs) + logger.info(f"Embedding {len(texts)} texts took {time.time() - start:.3f}s") def add_texts(self, documents: list[Document], **kwargs): if kwargs.get("duplicate_check", False): diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 398b0daad9..f844770a20 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -42,7 +42,7 @@ class DatasetDocumentStore: @property def docs(self) -> dict[str, Document]: document_segments = ( - db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == self._dataset.id).all() + db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id).all() ) output = {} @@ -63,7 +63,7 @@ class DatasetDocumentStore: def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False) -> None: max_position = ( db.session.query(func.max(DocumentSegment.position)) - .filter(DocumentSegment.document_id == self._document_id) + .where(DocumentSegment.document_id == self._document_id) .scalar() ) @@ -147,7 +147,7 @@ class DatasetDocumentStore: segment_document.tokens = tokens if save_child and doc.children: # delete the existing child chunks - db.session.query(ChildChunk).filter( + db.session.query(ChildChunk).where( ChildChunk.tenant_id == self._dataset.tenant_id, ChildChunk.dataset_id == self._dataset.id, ChildChunk.document_id == self._document_id, @@ -230,7 +230,7 @@ class DatasetDocumentStore: def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]: document_segment = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id) + .where(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id) .first() ) diff --git a/api/core/rag/extractor/blob/blob.py b/api/core/rag/extractor/blob/blob.py index e46ab8b7fd..01003a13b6 100644 --- a/api/core/rag/extractor/blob/blob.py +++ b/api/core/rag/extractor/blob/blob.py @@ -9,8 +9,7 @@ from __future__ import annotations import contextlib import mimetypes -from abc import ABC, abstractmethod -from collections.abc import Generator, Iterable, Mapping +from collections.abc import Generator, Mapping from io import BufferedReader, BytesIO from pathlib import Path, PurePath from typing import Any, Optional, Union @@ -143,21 +142,3 @@ class Blob(BaseModel): if self.source: str_repr += f" {self.source}" return str_repr - - -class BlobLoader(ABC): - """Abstract interface for blob loaders implementation. - - Implementer should be able to load raw content from a datasource system according - to some criteria and return the raw content lazily as a stream of blobs. - """ - - @abstractmethod - def yield_blobs( - self, - ) -> Iterable[Blob]: - """A lazy loader for raw data represented by Blob object. - - Returns: - A generator over blobs - """ diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index eca955ddd1..875626eb34 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -331,9 +331,10 @@ class NotionExtractor(BaseExtractor): last_edited_time = self.get_notion_last_edited_time() data_source_info = document_model.data_source_info_dict data_source_info["last_edited_time"] = last_edited_time - update_params = {DocumentModel.data_source_info: json.dumps(data_source_info)} - db.session.query(DocumentModel).filter_by(id=document_model.id).update(update_params) + db.session.query(DocumentModel).filter_by(id=document_model.id).update( + {DocumentModel.data_source_info: json.dumps(data_source_info)} + ) # type: ignore db.session.commit() def get_notion_last_edited_time(self) -> str: @@ -365,7 +366,7 @@ class NotionExtractor(BaseExtractor): def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: data_source_binding = ( db.session.query(DataSourceOauthBinding) - .filter( + .where( db.and_( DataSourceOauthBinding.tenant_id == tenant_id, DataSourceOauthBinding.provider == "notion", diff --git a/api/core/rag/extractor/unstructured/unstructured_pdf_extractor.py b/api/core/rag/extractor/unstructured/unstructured_pdf_extractor.py deleted file mode 100644 index dd8a979e70..0000000000 --- a/api/core/rag/extractor/unstructured/unstructured_pdf_extractor.py +++ /dev/null @@ -1,47 +0,0 @@ -import logging - -from core.rag.extractor.extractor_base import BaseExtractor -from core.rag.models.document import Document - -logger = logging.getLogger(__name__) - - -class UnstructuredPDFExtractor(BaseExtractor): - """Load pdf files. - - - Args: - file_path: Path to the file to load. - - api_url: Unstructured API URL - - api_key: Unstructured API Key - """ - - def __init__(self, file_path: str, api_url: str, api_key: str): - """Initialize with file path.""" - self._file_path = file_path - self._api_url = api_url - self._api_key = api_key - - def extract(self) -> list[Document]: - if self._api_url: - from unstructured.partition.api import partition_via_api - - elements = partition_via_api( - filename=self._file_path, api_url=self._api_url, api_key=self._api_key, strategy="auto" - ) - else: - from unstructured.partition.pdf import partition_pdf - - elements = partition_pdf(filename=self._file_path, strategy="auto") - - from unstructured.chunking.title import chunk_by_title - - chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) - documents = [] - for chunk in chunks: - text = chunk.text.strip() - documents.append(Document(page_content=text)) - - return documents diff --git a/api/core/rag/extractor/unstructured/unstructured_text_extractor.py b/api/core/rag/extractor/unstructured/unstructured_text_extractor.py deleted file mode 100644 index 22dfdd2075..0000000000 --- a/api/core/rag/extractor/unstructured/unstructured_text_extractor.py +++ /dev/null @@ -1,34 +0,0 @@ -import logging - -from core.rag.extractor.extractor_base import BaseExtractor -from core.rag.models.document import Document - -logger = logging.getLogger(__name__) - - -class UnstructuredTextExtractor(BaseExtractor): - """Load msg files. - - - Args: - file_path: Path to the file to load. - """ - - def __init__(self, file_path: str, api_url: str): - """Initialize with file path.""" - self._file_path = file_path - self._api_url = api_url - - def extract(self) -> list[Document]: - from unstructured.partition.text import partition_text - - elements = partition_text(filename=self._file_path) - from unstructured.chunking.title import chunk_by_title - - chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) - documents = [] - for chunk in chunks: - text = chunk.text.strip() - documents.append(Document(page_content=text)) - - return documents diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index bff0acc48f..14363de7d4 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -238,9 +238,11 @@ class WordExtractor(BaseExtractor): paragraph_content = [] for run in paragraph.runs: if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"): + # Process drawing type images drawing_elements = run.element.findall( ".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing" ) + has_drawing = False for drawing in drawing_elements: blip_elements = drawing.findall( ".//{http://schemas.openxmlformats.org/drawingml/2006/main}blip" @@ -252,6 +254,34 @@ class WordExtractor(BaseExtractor): if embed_id: image_part = doc.part.related_parts.get(embed_id) if image_part in image_map: + has_drawing = True + paragraph_content.append(image_map[image_part]) + # Process pict type images + shape_elements = run.element.findall( + ".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict" + ) + for shape in shape_elements: + # Find image data in VML + shape_image = shape.find( + ".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}binData" + ) + if shape_image is not None and shape_image.text: + image_id = shape_image.get( + "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id" + ) + if image_id and image_id in doc.part.rels: + image_part = doc.part.rels[image_id].target_part + if image_part in image_map and not has_drawing: + paragraph_content.append(image_map[image_part]) + # Find imagedata element in VML + image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata") + if image_data is not None: + image_id = image_data.get("id") or image_data.get( + "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id" + ) + if image_id and image_id in doc.part.rels: + image_part = doc.part.rels[image_id].target_part + if image_part in image_map and not has_drawing: paragraph_content.append(image_map[image_part]) if run.text.strip(): paragraph_content.append(run.text.strip()) diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 1cde5e1c8f..52756fbacd 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -118,7 +118,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): child_node_ids = ( db.session.query(ChildChunk.index_node_id) .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) - .filter( + .where( DocumentSegment.dataset_id == dataset.id, DocumentSegment.index_node_id.in_(node_ids), ChildChunk.dataset_id == dataset.id, @@ -128,7 +128,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): child_node_ids = [child_node_id[0] for child_node_id in child_node_ids] vector.delete_by_ids(child_node_ids) if delete_child_chunks: - db.session.query(ChildChunk).filter( + db.session.query(ChildChunk).where( ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids) ).delete() db.session.commit() @@ -136,7 +136,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): vector.delete() if delete_child_chunks: - db.session.query(ChildChunk).filter(ChildChunk.dataset_id == dataset.id).delete() + db.session.query(ChildChunk).where(ChildChunk.dataset_id == dataset.id).delete() db.session.commit() def retrieve( diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 3fca48be22..a25bc65646 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -9,6 +9,7 @@ from typing import Any, Optional, Union, cast from flask import Flask, current_app from sqlalchemy import Float, and_, or_, text from sqlalchemy import cast as sqlalchemy_cast +from sqlalchemy.orm import Session from core.app.app_config.entities import ( DatasetEntity, @@ -134,7 +135,7 @@ class DatasetRetrieval: available_datasets = [] for dataset_id in dataset_ids: # get dataset from dataset id - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() # pass if dataset is not available if not dataset: @@ -241,7 +242,7 @@ class DatasetRetrieval: dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() document = ( db.session.query(DatasetDocument) - .filter( + .where( DatasetDocument.id == segment.document_id, DatasetDocument.enabled == True, DatasetDocument.archived == False, @@ -326,7 +327,7 @@ class DatasetRetrieval: if dataset_id: # get retrieval model config - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if dataset: results = [] if dataset.provider == "external": @@ -515,14 +516,14 @@ class DatasetRetrieval: if document.metadata is not None: dataset_document = ( db.session.query(DatasetDocument) - .filter(DatasetDocument.id == document.metadata["document_id"]) + .where(DatasetDocument.id == document.metadata["document_id"]) .first() ) if dataset_document: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: child_chunk = ( db.session.query(ChildChunk) - .filter( + .where( ChildChunk.index_node_id == document.metadata["doc_id"], ChildChunk.dataset_id == dataset_document.dataset_id, ChildChunk.document_id == dataset_document.id, @@ -532,7 +533,7 @@ class DatasetRetrieval: if child_chunk: segment = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.id == child_chunk.segment_id) + .where(DocumentSegment.id == child_chunk.segment_id) .update( {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False, @@ -540,13 +541,13 @@ class DatasetRetrieval: ) db.session.commit() else: - query = db.session.query(DocumentSegment).filter( + query = db.session.query(DocumentSegment).where( DocumentSegment.index_node_id == document.metadata["doc_id"] ) # if 'dataset_id' in document.metadata: if "dataset_id" in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) + query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"]) # add hit count to document segment query.update( @@ -598,7 +599,8 @@ class DatasetRetrieval: metadata_condition: Optional[MetadataCondition] = None, ): with flask_app.app_context(): - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + with Session(db.engine) as session: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: return [] @@ -683,7 +685,7 @@ class DatasetRetrieval: available_datasets = [] for dataset_id in dataset_ids: # get dataset from dataset id - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() # pass if dataset is not available if not dataset: @@ -860,7 +862,7 @@ class DatasetRetrieval: metadata_filtering_conditions: Optional[MetadataFilteringCondition], inputs: dict, ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]: - document_query = db.session.query(DatasetDocument).filter( + document_query = db.session.query(DatasetDocument).where( DatasetDocument.dataset_id.in_(dataset_ids), DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, @@ -928,9 +930,9 @@ class DatasetRetrieval: raise ValueError("Invalid metadata filtering mode") if filters: if metadata_filtering_conditions and metadata_filtering_conditions.logical_operator == "and": # type: ignore - document_query = document_query.filter(and_(*filters)) + document_query = document_query.where(and_(*filters)) else: - document_query = document_query.filter(or_(*filters)) + document_query = document_query.where(or_(*filters)) documents = document_query.all() # group by dataset_id metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore @@ -956,7 +958,7 @@ class DatasetRetrieval: self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig ) -> Optional[list[dict[str, Any]]]: # get all metadata field - metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all() + metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all() all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] # get metadata model config if metadata_model_config is None: @@ -1135,7 +1137,7 @@ class DatasetRetrieval: def _get_prompt_template( self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str ): - model_mode = ModelMode.value_of(mode) + model_mode = ModelMode(mode) input_text = query prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]] diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 0fb1bcb2e0..bcaf299892 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -102,6 +102,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) splits = text.split() else: splits = text.split(separator) + splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)] else: splits = list(text) splits = [s for s in splits if (s not in {"", "\n"})] diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index b711e8434a..529d8ccd27 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -10,7 +10,6 @@ from typing import ( Any, Literal, Optional, - TypedDict, TypeVar, Union, ) @@ -168,167 +167,6 @@ class TextSplitter(BaseDocumentTransformer, ABC): raise NotImplementedError -class CharacterTextSplitter(TextSplitter): - """Splitting text that looks at characters.""" - - def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None: - """Create a new TextSplitter.""" - super().__init__(**kwargs) - self._separator = separator - - def split_text(self, text: str) -> list[str]: - """Split incoming text and return chunks.""" - # First we naively split the large input into a bunch of smaller ones. - splits = _split_text_with_regex(text, self._separator, self._keep_separator) - _separator = "" if self._keep_separator else self._separator - _good_splits_lengths = [] # cache the lengths of the splits - if splits: - _good_splits_lengths.extend(self._length_function(splits)) - return self._merge_splits(splits, _separator, _good_splits_lengths) - - -class LineType(TypedDict): - """Line type as typed dict.""" - - metadata: dict[str, str] - content: str - - -class HeaderType(TypedDict): - """Header type as typed dict.""" - - level: int - name: str - data: str - - -class MarkdownHeaderTextSplitter: - """Splitting markdown files based on specified headers.""" - - def __init__(self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False): - """Create a new MarkdownHeaderTextSplitter. - - Args: - headers_to_split_on: Headers we want to track - return_each_line: Return each line w/ associated headers - """ - # Output line-by-line or aggregated into chunks w/ common headers - self.return_each_line = return_each_line - # Given the headers we want to split on, - # (e.g., "#, ##, etc") order by length - self.headers_to_split_on = sorted(headers_to_split_on, key=lambda split: len(split[0]), reverse=True) - - def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]: - """Combine lines with common metadata into chunks - Args: - lines: Line of text / associated header metadata - """ - aggregated_chunks: list[LineType] = [] - - for line in lines: - if aggregated_chunks and aggregated_chunks[-1]["metadata"] == line["metadata"]: - # If the last line in the aggregated list - # has the same metadata as the current line, - # append the current content to the last lines's content - aggregated_chunks[-1]["content"] += " \n" + line["content"] - else: - # Otherwise, append the current line to the aggregated list - aggregated_chunks.append(line) - - return [Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in aggregated_chunks] - - def split_text(self, text: str) -> list[Document]: - """Split markdown file - Args: - text: Markdown file""" - - # Split the input text by newline character ("\n"). - lines = text.split("\n") - # Final output - lines_with_metadata: list[LineType] = [] - # Content and metadata of the chunk currently being processed - current_content: list[str] = [] - current_metadata: dict[str, str] = {} - # Keep track of the nested header structure - # header_stack: List[Dict[str, Union[int, str]]] = [] - header_stack: list[HeaderType] = [] - initial_metadata: dict[str, str] = {} - - for line in lines: - stripped_line = line.strip() - # Check each line against each of the header types (e.g., #, ##) - for sep, name in self.headers_to_split_on: - # Check if line starts with a header that we intend to split on - if stripped_line.startswith(sep) and ( - # Header with no text OR header is followed by space - # Both are valid conditions that sep is being used a header - len(stripped_line) == len(sep) or stripped_line[len(sep)] == " " - ): - # Ensure we are tracking the header as metadata - if name is not None: - # Get the current header level - current_header_level = sep.count("#") - - # Pop out headers of lower or same level from the stack - while header_stack and header_stack[-1]["level"] >= current_header_level: - # We have encountered a new header - # at the same or higher level - popped_header = header_stack.pop() - # Clear the metadata for the - # popped header in initial_metadata - if popped_header["name"] in initial_metadata: - initial_metadata.pop(popped_header["name"]) - - # Push the current header to the stack - header: HeaderType = { - "level": current_header_level, - "name": name, - "data": stripped_line[len(sep) :].strip(), - } - header_stack.append(header) - # Update initial_metadata with the current header - initial_metadata[name] = header["data"] - - # Add the previous line to the lines_with_metadata - # only if current_content is not empty - if current_content: - lines_with_metadata.append( - { - "content": "\n".join(current_content), - "metadata": current_metadata.copy(), - } - ) - current_content.clear() - - break - else: - if stripped_line: - current_content.append(stripped_line) - elif current_content: - lines_with_metadata.append( - { - "content": "\n".join(current_content), - "metadata": current_metadata.copy(), - } - ) - current_content.clear() - - current_metadata = initial_metadata.copy() - - if current_content: - lines_with_metadata.append({"content": "\n".join(current_content), "metadata": current_metadata}) - - # lines_with_metadata has each line with associated header metadata - # aggregate these into chunks based on common metadata - if not self.return_each_line: - return self.aggregate_lines_to_chunks(lines_with_metadata) - else: - return [ - Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in lines_with_metadata - ] - - -# should be in newer Python versions (3.10+) # @dataclass(frozen=True, kw_only=True, slots=True) @dataclass(frozen=True) class Tokenizer: diff --git a/api/core/repositories/__init__.py b/api/core/repositories/__init__.py index 6452317120..052ba1c2cb 100644 --- a/api/core/repositories/__init__.py +++ b/api/core/repositories/__init__.py @@ -5,8 +5,11 @@ This package contains concrete implementations of the repository interfaces defined in the core.workflow.repository package. """ +from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository __all__ = [ + "DifyCoreRepositoryFactory", + "RepositoryImportError", "SQLAlchemyWorkflowNodeExecutionRepository", ] diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py new file mode 100644 index 0000000000..4118aa61c7 --- /dev/null +++ b/api/core/repositories/factory.py @@ -0,0 +1,224 @@ +""" +Repository factory for dynamically creating repository instances based on configuration. + +This module provides a Django-like settings system for repository implementations, +allowing users to configure different repository backends through string paths. +""" + +import importlib +import inspect +import logging +from typing import Protocol, Union + +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker + +from configs import dify_config +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from models import Account, EndUser +from models.enums import WorkflowRunTriggeredFrom +from models.workflow import WorkflowNodeExecutionTriggeredFrom + +logger = logging.getLogger(__name__) + + +class RepositoryImportError(Exception): + """Raised when a repository implementation cannot be imported or instantiated.""" + + pass + + +class DifyCoreRepositoryFactory: + """ + Factory for creating repository instances based on configuration. + + This factory supports Django-like settings where repository implementations + are specified as module paths (e.g., 'module.submodule.ClassName'). + """ + + @staticmethod + def _import_class(class_path: str) -> type: + """ + Import a class from a module path string. + + Args: + class_path: Full module path to the class (e.g., 'module.submodule.ClassName') + + Returns: + The imported class + + Raises: + RepositoryImportError: If the class cannot be imported + """ + try: + module_path, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_path) + repo_class = getattr(module, class_name) + assert isinstance(repo_class, type) + return repo_class + except (ValueError, ImportError, AttributeError) as e: + raise RepositoryImportError(f"Cannot import repository class '{class_path}': {e}") from e + + @staticmethod + def _validate_repository_interface(repository_class: type, expected_interface: type[Protocol]) -> None: # type: ignore + """ + Validate that a class implements the expected repository interface. + + Args: + repository_class: The class to validate + expected_interface: The expected interface/protocol + + Raises: + RepositoryImportError: If the class doesn't implement the interface + """ + # Check if the class has all required methods from the protocol + required_methods = [ + method + for method in dir(expected_interface) + if not method.startswith("_") and callable(getattr(expected_interface, method, None)) + ] + + missing_methods = [] + for method_name in required_methods: + if not hasattr(repository_class, method_name): + missing_methods.append(method_name) + + if missing_methods: + raise RepositoryImportError( + f"Repository class '{repository_class.__name__}' does not implement required methods " + f"{missing_methods} from interface '{expected_interface.__name__}'" + ) + + @staticmethod + def _validate_constructor_signature(repository_class: type, required_params: list[str]) -> None: + """ + Validate that a repository class constructor accepts required parameters. + + Args: + repository_class: The class to validate + required_params: List of required parameter names + + Raises: + RepositoryImportError: If the constructor doesn't accept required parameters + """ + + try: + # MyPy may flag the line below with the following error: + # + # > Accessing "__init__" on an instance is unsound, since + # > instance.__init__ could be from an incompatible subclass. + # + # Despite this, we need to ensure that the constructor of `repository_class` + # has a compatible signature. + signature = inspect.signature(repository_class.__init__) # type: ignore[misc] + param_names = list(signature.parameters.keys()) + + # Remove 'self' parameter + if "self" in param_names: + param_names.remove("self") + + missing_params = [param for param in required_params if param not in param_names] + if missing_params: + raise RepositoryImportError( + f"Repository class '{repository_class.__name__}' constructor does not accept required parameters: " + f"{missing_params}. Expected parameters: {required_params}" + ) + except Exception as e: + raise RepositoryImportError( + f"Failed to validate constructor signature for '{repository_class.__name__}': {e}" + ) from e + + @classmethod + def create_workflow_execution_repository( + cls, + session_factory: Union[sessionmaker, Engine], + user: Union[Account, EndUser], + app_id: str, + triggered_from: WorkflowRunTriggeredFrom, + ) -> WorkflowExecutionRepository: + """ + Create a WorkflowExecutionRepository instance based on configuration. + + Args: + session_factory: SQLAlchemy sessionmaker or engine + user: Account or EndUser object + app_id: Application ID + triggered_from: Source of the execution trigger + + Returns: + Configured WorkflowExecutionRepository instance + + Raises: + RepositoryImportError: If the configured repository cannot be created + """ + class_path = dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY + logger.debug(f"Creating WorkflowExecutionRepository from: {class_path}") + + try: + repository_class = cls._import_class(class_path) + cls._validate_repository_interface(repository_class, WorkflowExecutionRepository) + cls._validate_constructor_signature( + repository_class, ["session_factory", "user", "app_id", "triggered_from"] + ) + + return repository_class( # type: ignore[no-any-return] + session_factory=session_factory, + user=user, + app_id=app_id, + triggered_from=triggered_from, + ) + except RepositoryImportError: + # Re-raise our custom errors as-is + raise + except Exception as e: + logger.exception("Failed to create WorkflowExecutionRepository") + raise RepositoryImportError(f"Failed to create WorkflowExecutionRepository from '{class_path}': {e}") from e + + @classmethod + def create_workflow_node_execution_repository( + cls, + session_factory: Union[sessionmaker, Engine], + user: Union[Account, EndUser], + app_id: str, + triggered_from: WorkflowNodeExecutionTriggeredFrom, + ) -> WorkflowNodeExecutionRepository: + """ + Create a WorkflowNodeExecutionRepository instance based on configuration. + + Args: + session_factory: SQLAlchemy sessionmaker or engine + user: Account or EndUser object + app_id: Application ID + triggered_from: Source of the execution trigger + + Returns: + Configured WorkflowNodeExecutionRepository instance + + Raises: + RepositoryImportError: If the configured repository cannot be created + """ + class_path = dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY + logger.debug(f"Creating WorkflowNodeExecutionRepository from: {class_path}") + + try: + repository_class = cls._import_class(class_path) + cls._validate_repository_interface(repository_class, WorkflowNodeExecutionRepository) + cls._validate_constructor_signature( + repository_class, ["session_factory", "user", "app_id", "triggered_from"] + ) + + return repository_class( # type: ignore[no-any-return] + session_factory=session_factory, + user=user, + app_id=app_id, + triggered_from=triggered_from, + ) + except RepositoryImportError: + # Re-raise our custom errors as-is + raise + except Exception as e: + logger.exception("Failed to create WorkflowNodeExecutionRepository") + raise RepositoryImportError( + f"Failed to create WorkflowNodeExecutionRepository from '{class_path}': {e}" + ) from e diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index cdec92aee7..c579ff4028 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -6,7 +6,6 @@ import json import logging from typing import Optional, Union -from sqlalchemy import select from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -17,6 +16,7 @@ from core.workflow.entities.workflow_execution import ( ) from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from libs.helper import extract_tenant_id from models import ( Account, CreatorUserRole, @@ -67,7 +67,7 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): ) # Extract tenant_id from user - tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id + tenant_id = extract_tenant_id(user) if not tenant_id: raise ValueError("User must have a tenant_id or current_tenant_id") self._tenant_id = tenant_id @@ -205,44 +205,3 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): # Update the in-memory cache for faster subsequent lookups logger.debug(f"Updating cache for execution_id: {db_model.id}") self._execution_cache[db_model.id] = db_model - - def get(self, execution_id: str) -> Optional[WorkflowExecution]: - """ - Retrieve a WorkflowExecution by its ID. - - First checks the in-memory cache, and if not found, queries the database. - If found in the database, adds it to the cache for future lookups. - - Args: - execution_id: The workflow execution ID - - Returns: - The WorkflowExecution instance if found, None otherwise - """ - # First check the cache - if execution_id in self._execution_cache: - logger.debug(f"Cache hit for execution_id: {execution_id}") - # Convert cached DB model to domain model - cached_db_model = self._execution_cache[execution_id] - return self._to_domain_model(cached_db_model) - - # If not in cache, query the database - logger.debug(f"Cache miss for execution_id: {execution_id}, querying database") - with self._session_factory() as session: - stmt = select(WorkflowRun).where( - WorkflowRun.id == execution_id, - WorkflowRun.tenant_id == self._tenant_id, - ) - - if self._app_id: - stmt = stmt.where(WorkflowRun.app_id == self._app_id) - - db_model = session.scalar(stmt) - if db_model: - # Add DB model to cache - self._execution_cache[execution_id] = db_model - - # Convert to domain model and return - return self._to_domain_model(db_model) - - return None diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 797cce9354..d4a31390f8 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -7,7 +7,7 @@ import logging from collections.abc import Sequence from typing import Optional, Union -from sqlalchemy import UnaryExpression, asc, delete, desc, select +from sqlalchemy import UnaryExpression, asc, desc, select from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -20,6 +20,7 @@ from core.workflow.entities.workflow_node_execution import ( from core.workflow.nodes.enums import NodeType from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from libs.helper import extract_tenant_id from models import ( Account, CreatorUserRole, @@ -70,7 +71,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) ) # Extract tenant_id from user - tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id + tenant_id = extract_tenant_id(user) if not tenant_id: raise ValueError("User must have a tenant_id or current_tenant_id") self._tenant_id = tenant_id @@ -217,47 +218,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}") self._node_execution_cache[db_model.node_execution_id] = db_model - def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: - """ - Retrieve a NodeExecution by its node_execution_id. - - First checks the in-memory cache, and if not found, queries the database. - If found in the database, adds it to the cache for future lookups. - - Args: - node_execution_id: The node execution ID - - Returns: - The NodeExecution instance if found, None otherwise - """ - # First check the cache - if node_execution_id in self._node_execution_cache: - logger.debug(f"Cache hit for node_execution_id: {node_execution_id}") - # Convert cached DB model to domain model - cached_db_model = self._node_execution_cache[node_execution_id] - return self._to_domain_model(cached_db_model) - - # If not in cache, query the database - logger.debug(f"Cache miss for node_execution_id: {node_execution_id}, querying database") - with self._session_factory() as session: - stmt = select(WorkflowNodeExecutionModel).where( - WorkflowNodeExecutionModel.node_execution_id == node_execution_id, - WorkflowNodeExecutionModel.tenant_id == self._tenant_id, - ) - - if self._app_id: - stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id) - - db_model = session.scalar(stmt) - if db_model: - # Add DB model to cache - self._node_execution_cache[node_execution_id] = db_model - - # Convert to domain model and return - return self._to_domain_model(db_model) - - return None - def get_db_models_by_workflow_run( self, workflow_run_id: str, @@ -343,68 +303,3 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) domain_models.append(domain_model) return domain_models - - def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: - """ - Retrieve all running NodeExecution instances for a specific workflow run. - - This method queries the database directly and updates the cache with any - retrieved executions that have a node_execution_id. - - Args: - workflow_run_id: The workflow run ID - - Returns: - A list of running NodeExecution instances - """ - with self._session_factory() as session: - stmt = select(WorkflowNodeExecutionModel).where( - WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, - WorkflowNodeExecutionModel.tenant_id == self._tenant_id, - WorkflowNodeExecutionModel.status == WorkflowNodeExecutionStatus.RUNNING, - WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) - - if self._app_id: - stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id) - - db_models = session.scalars(stmt).all() - domain_models = [] - - for model in db_models: - # Update cache if node_execution_id is present - if model.node_execution_id: - self._node_execution_cache[model.node_execution_id] = model - - # Convert to domain model - domain_model = self._to_domain_model(model) - domain_models.append(domain_model) - - return domain_models - - def clear(self) -> None: - """ - Clear all WorkflowNodeExecution records for the current tenant_id and app_id. - - This method deletes all WorkflowNodeExecution records that match the tenant_id - and app_id (if provided) associated with this repository instance. - It also clears the in-memory cache. - """ - with self._session_factory() as session: - stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.tenant_id == self._tenant_id) - - if self._app_id: - stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id) - - result = session.execute(stmt) - session.commit() - - deleted_count = result.rowcount - logger.info( - f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}" - + (f" and app {self._app_id}" if self._app_id else "") - ) - - # Clear the in-memory cache - self._node_execution_cache.clear() - logger.info("Cleared in-memory node execution cache") diff --git a/api/core/tools/__base/tool_runtime.py b/api/core/tools/__base/tool_runtime.py index c9e157cb77..ddec7b1329 100644 --- a/api/core/tools/__base/tool_runtime.py +++ b/api/core/tools/__base/tool_runtime.py @@ -4,7 +4,7 @@ from openai import BaseModel from pydantic import Field from core.app.entities.app_invoke_entities import InvokeFrom -from core.tools.entities.tool_entities import ToolInvokeFrom +from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom class ToolRuntime(BaseModel): @@ -17,6 +17,7 @@ class ToolRuntime(BaseModel): invoke_from: Optional[InvokeFrom] = None tool_invoke_from: Optional[ToolInvokeFrom] = None credentials: dict[str, Any] = Field(default_factory=dict) + credential_type: CredentialType = Field(default=CredentialType.API_KEY) runtime_parameters: dict[str, Any] = Field(default_factory=dict) diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index cf75bd3d7e..a70ded9efd 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -7,7 +7,13 @@ from core.helper.module_import_helper import load_single_subclass_from_source from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.tool import BuiltinTool -from core.tools.entities.tool_entities import ToolEntity, ToolProviderEntity, ToolProviderType +from core.tools.entities.tool_entities import ( + CredentialType, + OAuthSchema, + ToolEntity, + ToolProviderEntity, + ToolProviderType, +) from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict from core.tools.errors import ( ToolProviderNotFoundError, @@ -39,10 +45,18 @@ class BuiltinToolProviderController(ToolProviderController): credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {}) credentials_schema.append(credential_dict) + oauth_schema = None + if provider_yaml.get("oauth_schema", None) is not None: + oauth_schema = OAuthSchema( + client_schema=provider_yaml.get("oauth_schema", {}).get("client_schema", []), + credentials_schema=provider_yaml.get("oauth_schema", {}).get("credentials_schema", []), + ) + super().__init__( entity=ToolProviderEntity( identity=provider_yaml["identity"], credentials_schema=credentials_schema, + oauth_schema=oauth_schema, ), ) @@ -97,10 +111,39 @@ class BuiltinToolProviderController(ToolProviderController): :return: the credentials schema """ - if not self.entity.credentials_schema: - return [] + return self.get_credentials_schema_by_type(CredentialType.API_KEY.value) - return self.entity.credentials_schema.copy() + def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]: + """ + returns the credentials schema of the provider + + :param credential_type: the type of the credential + :return: the credentials schema of the provider + """ + if credential_type == CredentialType.OAUTH2.value: + return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else [] + if credential_type == CredentialType.API_KEY.value: + return self.entity.credentials_schema.copy() if self.entity.credentials_schema else [] + raise ValueError(f"Invalid credential type: {credential_type}") + + def get_oauth_client_schema(self) -> list[ProviderConfig]: + """ + returns the oauth client schema of the provider + + :return: the oauth client schema + """ + return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else [] + + def get_supported_credential_types(self) -> list[str]: + """ + returns the credential support type of the provider + """ + types = [] + if self.entity.credentials_schema is not None and len(self.entity.credentials_schema) > 0: + types.append(CredentialType.API_KEY.value) + if self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) > 0: + types.append(CredentialType.OAUTH2.value) + return types def get_tools(self) -> list[BuiltinTool]: """ @@ -123,7 +166,11 @@ class BuiltinToolProviderController(ToolProviderController): :return: whether the provider needs credentials """ - return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0 + return ( + self.entity.credentials_schema is not None + and len(self.entity.credentials_schema) != 0 + or (self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) != 0) + ) @property def provider_type(self) -> ToolProviderType: diff --git a/api/core/tools/custom_tool/provider.py b/api/core/tools/custom_tool/provider.py index 3137d32013..95fab6151a 100644 --- a/api/core/tools/custom_tool/provider.py +++ b/api/core/tools/custom_tool/provider.py @@ -39,19 +39,22 @@ class ApiToolProviderController(ToolProviderController): type=ProviderConfig.Type.SELECT, options=[ ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="无")), - ProviderConfig.Option(value="api_key", label=I18nObject(en_US="api_key", zh_Hans="api_key")), + ProviderConfig.Option(value="api_key_header", label=I18nObject(en_US="Header", zh_Hans="请求头")), + ProviderConfig.Option( + value="api_key_query", label=I18nObject(en_US="Query Param", zh_Hans="查询参数") + ), ], default="none", help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"), ) ] - if auth_type == ApiProviderAuthType.API_KEY: + if auth_type == ApiProviderAuthType.API_KEY_HEADER: credentials_schema = [ *credentials_schema, ProviderConfig( name="api_key_header", required=False, - default="api_key", + default="Authorization", type=ProviderConfig.Type.TEXT_INPUT, help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"), ), @@ -74,6 +77,25 @@ class ApiToolProviderController(ToolProviderController): ], ), ] + elif auth_type == ApiProviderAuthType.API_KEY_QUERY: + credentials_schema = [ + *credentials_schema, + ProviderConfig( + name="api_key_query_param", + required=False, + default="key", + type=ProviderConfig.Type.TEXT_INPUT, + help=I18nObject( + en_US="The query parameter name of the api key", zh_Hans="携带 api key 的查询参数名称" + ), + ), + ProviderConfig( + name="api_key_value", + required=True, + type=ProviderConfig.Type.SECRET_INPUT, + help=I18nObject(en_US="The api key", zh_Hans="api key 的值"), + ), + ] elif auth_type == ApiProviderAuthType.NONE: pass @@ -156,7 +178,7 @@ class ApiToolProviderController(ToolProviderController): # get tenant api providers db_providers: list[ApiToolProvider] = ( db.session.query(ApiToolProvider) - .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name) + .where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name) .all() ) diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index 2f5cc6d4c0..10653b9948 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -78,8 +78,8 @@ class ApiTool(Tool): if "auth_type" not in credentials: raise ToolProviderCredentialValidationError("Missing auth_type") - if credentials["auth_type"] == "api_key": - api_key_header = "api_key" + if credentials["auth_type"] in ("api_key_header", "api_key"): # backward compatibility: + api_key_header = "Authorization" if "api_key_header" in credentials: api_key_header = credentials["api_key_header"] @@ -100,6 +100,11 @@ class ApiTool(Tool): headers[api_key_header] = credentials["api_key_value"] + elif credentials["auth_type"] == "api_key_query": + # For query parameter authentication, we don't add anything to headers + # The query parameter will be added in do_http_request method + pass + needed_parameters = [parameter for parameter in (self.api_bundle.parameters or []) if parameter.required] for parameter in needed_parameters: if parameter.required and parameter.name not in parameters: @@ -154,6 +159,15 @@ class ApiTool(Tool): cookies = {} files = [] + # Add API key to query parameters if auth_type is api_key_query + if self.runtime and self.runtime.credentials: + credentials = self.runtime.credentials + if credentials.get("auth_type") == "api_key_query": + api_key_query_param = credentials.get("api_key_query_param", "key") + api_key_value = credentials.get("api_key_value") + if api_key_value: + params[api_key_query_param] = api_key_value + # check parameters for parameter in self.api_bundle.openapi.get("parameters", []): value = self.get_parameter_value(parameter, parameters) @@ -213,7 +227,8 @@ class ApiTool(Tool): elif "default" in property: body[name] = property["default"] else: - body[name] = None + # omit optional parameters that weren't provided, instead of setting them to None + pass break # replace path parameters diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index b96c994cff..27ce96b90e 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -1,11 +1,12 @@ -from typing import Literal, Optional +from datetime import datetime +from typing import Any, Literal, Optional from pydantic import BaseModel, Field, field_validator from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolProviderType +from core.tools.entities.tool_entities import CredentialType, ToolProviderType class ToolApiEntity(BaseModel): @@ -18,7 +19,7 @@ class ToolApiEntity(BaseModel): output_schema: Optional[dict] = None -ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]] +ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow", "mcp"]] class ToolProviderApiEntity(BaseModel): @@ -27,6 +28,7 @@ class ToolProviderApiEntity(BaseModel): name: str # identifier description: I18nObject icon: str | dict + icon_dark: Optional[str | dict] = Field(default=None, description="The dark icon of the tool") label: I18nObject # label type: ToolProviderType masked_credentials: Optional[dict] = None @@ -37,6 +39,10 @@ class ToolProviderApiEntity(BaseModel): plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool") tools: list[ToolApiEntity] = Field(default_factory=list) labels: list[str] = Field(default_factory=list) + # MCP + server_url: Optional[str] = Field(default="", description="The server url of the tool") + updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp())) + server_identifier: Optional[str] = Field(default="", description="The server identifier of the MCP tool") @field_validator("tools", mode="before") @classmethod @@ -52,8 +58,13 @@ class ToolProviderApiEntity(BaseModel): for parameter in tool.get("parameters"): if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value: parameter["type"] = "files" + if parameter.get("input_schema") is None: + parameter.pop("input_schema", None) # ------------- - + optional_fields = self.optional_field("server_url", self.server_url) + if self.type == ToolProviderType.MCP.value: + optional_fields.update(self.optional_field("updated_at", self.updated_at)) + optional_fields.update(self.optional_field("server_identifier", self.server_identifier)) return { "id": self.id, "author": self.author, @@ -62,6 +73,7 @@ class ToolProviderApiEntity(BaseModel): "plugin_unique_identifier": self.plugin_unique_identifier, "description": self.description.to_dict(), "icon": self.icon, + "icon_dark": self.icon_dark, "label": self.label.to_dict(), "type": self.type.value, "team_credentials": self.masked_credentials, @@ -69,4 +81,28 @@ class ToolProviderApiEntity(BaseModel): "allow_delete": self.allow_delete, "tools": tools, "labels": self.labels, + **optional_fields, } + + def optional_field(self, key: str, value: Any) -> dict: + """Return dict with key-value if value is truthy, empty dict otherwise.""" + return {key: value} if value else {} + + +class ToolProviderCredentialApiEntity(BaseModel): + id: str = Field(description="The unique id of the credential") + name: str = Field(description="The name of the credential") + provider: str = Field(description="The provider of the credential") + credential_type: CredentialType = Field(description="The type of the credential") + is_default: bool = Field( + default=False, description="Whether the credential is the default credential for the provider in the workspace" + ) + credentials: dict = Field(description="The credentials of the provider") + + +class ToolProviderCredentialInfoApiEntity(BaseModel): + supported_credential_types: list[str] = Field(description="The supported credential types of the provider") + is_oauth_custom_client_enabled: bool = Field( + default=False, description="Whether the OAuth custom client is enabled for the provider" + ) + credentials: list[ToolProviderCredentialApiEntity] = Field(description="The credentials of the provider") diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index d2c28076ae..5377cbbb69 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_seriali from core.entities.provider_entities import ProviderConfig from core.plugin.entities.parameters import ( + MCPServerParameterType, PluginParameter, PluginParameterOption, PluginParameterType, @@ -15,6 +16,7 @@ from core.plugin.entities.parameters import ( cast_parameter_value, init_frontend_parameter, ) +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.tools.entities.common_entities import I18nObject from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY @@ -49,6 +51,7 @@ class ToolProviderType(enum.StrEnum): API = "api" APP = "app" DATASET_RETRIEVAL = "dataset-retrieval" + MCP = "mcp" @classmethod def value_of(cls, value: str) -> "ToolProviderType": @@ -94,7 +97,8 @@ class ApiProviderAuthType(Enum): """ NONE = "none" - API_KEY = "api_key" + API_KEY_HEADER = "api_key_header" + API_KEY_QUERY = "api_key_query" @classmethod def value_of(cls, value: str) -> "ApiProviderAuthType": @@ -176,6 +180,10 @@ class ToolInvokeMessage(BaseModel): data: Mapping[str, Any] = Field(..., description="Detailed log data") metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log") + class RetrieverResourceMessage(BaseModel): + retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources") + context: str = Field(..., description="context") + class MessageType(Enum): TEXT = "text" IMAGE = "image" @@ -188,13 +196,22 @@ class ToolInvokeMessage(BaseModel): FILE = "file" LOG = "log" BLOB_CHUNK = "blob_chunk" + RETRIEVER_RESOURCES = "retriever_resources" type: MessageType = MessageType.TEXT """ plain text, image url or link url """ message: ( - JsonMessage | TextMessage | BlobChunkMessage | BlobMessage | LogMessage | FileMessage | None | VariableMessage + JsonMessage + | TextMessage + | BlobChunkMessage + | BlobMessage + | LogMessage + | FileMessage + | None + | VariableMessage + | RetrieverResourceMessage ) meta: dict[str, Any] | None = None @@ -240,8 +257,13 @@ class ToolParameter(PluginParameter): FILES = PluginParameterType.FILES.value APP_SELECTOR = PluginParameterType.APP_SELECTOR.value MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value + ANY = PluginParameterType.ANY.value DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value + # MCP object and array type parameters + ARRAY = MCPServerParameterType.ARRAY.value + OBJECT = MCPServerParameterType.OBJECT.value + # deprecated, should not use. SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value @@ -260,6 +282,8 @@ class ToolParameter(PluginParameter): human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user") form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm") llm_description: Optional[str] = None + # MCP object and array type parameters use this field to store the schema + input_schema: Optional[dict] = None @classmethod def get_simple_instance( @@ -309,6 +333,7 @@ class ToolProviderIdentity(BaseModel): name: str = Field(..., description="The name of the tool") description: I18nObject = Field(..., description="The description of the tool") icon: str = Field(..., description="The icon of the tool") + icon_dark: Optional[str] = Field(default=None, description="The dark icon of the tool") label: I18nObject = Field(..., description="The label of the tool") tags: Optional[list[ToolLabelEnum]] = Field( default=[], @@ -345,10 +370,18 @@ class ToolEntity(BaseModel): return v or [] +class OAuthSchema(BaseModel): + client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client") + credentials_schema: list[ProviderConfig] = Field( + default_factory=list, description="The schema of the OAuth credentials" + ) + + class ToolProviderEntity(BaseModel): identity: ToolProviderIdentity plugin_id: Optional[str] = None credentials_schema: list[ProviderConfig] = Field(default_factory=list) + oauth_schema: Optional[OAuthSchema] = None class ToolProviderEntityWithPlugin(ToolProviderEntity): @@ -428,6 +461,7 @@ class ToolSelector(BaseModel): options: Optional[list[PluginParameterOption]] = None provider_id: str = Field(..., description="The id of the provider") + credential_id: Optional[str] = Field(default=None, description="The id of the credential") tool_name: str = Field(..., description="The name of the tool") tool_description: str = Field(..., description="The description of the tool") tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form") @@ -435,3 +469,36 @@ class ToolSelector(BaseModel): def to_plugin_parameter(self) -> dict[str, Any]: return self.model_dump() + + +class CredentialType(enum.StrEnum): + API_KEY = "api-key" + OAUTH2 = "oauth2" + + def get_name(self): + if self == CredentialType.API_KEY: + return "API KEY" + elif self == CredentialType.OAUTH2: + return "AUTH" + else: + return self.value.replace("-", " ").upper() + + def is_editable(self): + return self == CredentialType.API_KEY + + def is_validate_allowed(self): + return self == CredentialType.API_KEY + + @classmethod + def values(cls): + return [item.value for item in cls] + + @classmethod + def of(cls, credential_type: str) -> "CredentialType": + type_name = credential_type.lower() + if type_name == "api-key": + return cls.API_KEY + elif type_name == "oauth2": + return cls.OAUTH2 + else: + raise ValueError(f"Invalid credential type: {credential_type}") diff --git a/api/core/tools/mcp_tool/provider.py b/api/core/tools/mcp_tool/provider.py new file mode 100644 index 0000000000..93f003effe --- /dev/null +++ b/api/core/tools/mcp_tool/provider.py @@ -0,0 +1,130 @@ +import json +from typing import Any + +from core.mcp.types import Tool as RemoteMCPTool +from core.tools.__base.tool_provider import ToolProviderController +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ToolDescription, + ToolEntity, + ToolIdentity, + ToolProviderEntityWithPlugin, + ToolProviderIdentity, + ToolProviderType, +) +from core.tools.mcp_tool.tool import MCPTool +from models.tools import MCPToolProvider +from services.tools.tools_transform_service import ToolTransformService + + +class MCPToolProviderController(ToolProviderController): + provider_id: str + entity: ToolProviderEntityWithPlugin + + def __init__(self, entity: ToolProviderEntityWithPlugin, provider_id: str, tenant_id: str, server_url: str) -> None: + super().__init__(entity) + self.entity = entity + self.tenant_id = tenant_id + self.provider_id = provider_id + self.server_url = server_url + + @property + def provider_type(self) -> ToolProviderType: + """ + returns the type of the provider + + :return: type of the provider + """ + return ToolProviderType.MCP + + @classmethod + def _from_db(cls, db_provider: MCPToolProvider) -> "MCPToolProviderController": + """ + from db provider + """ + tools = [] + tools_data = json.loads(db_provider.tools) + remote_mcp_tools = [RemoteMCPTool(**tool) for tool in tools_data] + user = db_provider.load_user() + tools = [ + ToolEntity( + identity=ToolIdentity( + author=user.name if user else "Anonymous", + name=remote_mcp_tool.name, + label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name), + provider=db_provider.server_identifier, + icon=db_provider.icon, + ), + parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema), + description=ToolDescription( + human=I18nObject( + en_US=remote_mcp_tool.description or "", zh_Hans=remote_mcp_tool.description or "" + ), + llm=remote_mcp_tool.description or "", + ), + output_schema=None, + has_runtime_parameters=len(remote_mcp_tool.inputSchema) > 0, + ) + for remote_mcp_tool in remote_mcp_tools + ] + + return cls( + entity=ToolProviderEntityWithPlugin( + identity=ToolProviderIdentity( + author=user.name if user else "Anonymous", + name=db_provider.name, + label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name), + description=I18nObject(en_US="", zh_Hans=""), + icon=db_provider.icon, + ), + plugin_id=None, + credentials_schema=[], + tools=tools, + ), + provider_id=db_provider.server_identifier or "", + tenant_id=db_provider.tenant_id or "", + server_url=db_provider.decrypted_server_url, + ) + + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: + """ + validate the credentials of the provider + """ + pass + + def get_tool(self, tool_name: str) -> MCPTool: # type: ignore + """ + return tool with given name + """ + tool_entity = next( + (tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name), None + ) + + if not tool_entity: + raise ValueError(f"Tool with name {tool_name} not found") + + return MCPTool( + entity=tool_entity, + runtime=ToolRuntime(tenant_id=self.tenant_id), + tenant_id=self.tenant_id, + icon=self.entity.identity.icon, + server_url=self.server_url, + provider_id=self.provider_id, + ) + + def get_tools(self) -> list[MCPTool]: # type: ignore + """ + get all tools + """ + return [ + MCPTool( + entity=tool_entity, + runtime=ToolRuntime(tenant_id=self.tenant_id), + tenant_id=self.tenant_id, + icon=self.entity.identity.icon, + server_url=self.server_url, + provider_id=self.provider_id, + ) + for tool_entity in self.entity.tools + ] diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py new file mode 100644 index 0000000000..d1bacbc735 --- /dev/null +++ b/api/core/tools/mcp_tool/tool.py @@ -0,0 +1,92 @@ +import base64 +import json +from collections.abc import Generator +from typing import Any, Optional + +from core.mcp.error import MCPAuthError, MCPConnectionError +from core.mcp.mcp_client import MCPClient +from core.mcp.types import ImageContent, TextContent +from core.tools.__base.tool import Tool +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType + + +class MCPTool(Tool): + tenant_id: str + icon: str + runtime_parameters: Optional[list[ToolParameter]] + server_url: str + provider_id: str + + def __init__( + self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, server_url: str, provider_id: str + ) -> None: + super().__init__(entity, runtime) + self.tenant_id = tenant_id + self.icon = icon + self.runtime_parameters = None + self.server_url = server_url + self.provider_id = provider_id + + def tool_provider_type(self) -> ToolProviderType: + return ToolProviderType.MCP + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> Generator[ToolInvokeMessage, None, None]: + from core.tools.errors import ToolInvokeError + + try: + with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client: + tool_parameters = self._handle_none_parameter(tool_parameters) + result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) + except MCPAuthError as e: + raise ToolInvokeError("Please auth the tool first") from e + except MCPConnectionError as e: + raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e + except Exception as e: + raise ToolInvokeError(f"Failed to invoke tool: {e}") from e + + for content in result.content: + if isinstance(content, TextContent): + try: + content_json = json.loads(content.text) + if isinstance(content_json, dict): + yield self.create_json_message(content_json) + elif isinstance(content_json, list): + for item in content_json: + yield self.create_json_message(item) + else: + yield self.create_text_message(content.text) + except json.JSONDecodeError: + yield self.create_text_message(content.text) + + elif isinstance(content, ImageContent): + yield self.create_blob_message( + blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType} + ) + + def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool": + return MCPTool( + entity=self.entity, + runtime=runtime, + tenant_id=self.tenant_id, + icon=self.icon, + server_url=self.server_url, + provider_id=self.provider_id, + ) + + def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]: + """ + in mcp tool invoke, if the parameter is empty, it will be set to None + """ + return { + key: value + for key, value in parameter.items() + if value is not None and not (isinstance(value, str) and value.strip() == "") + } diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py index d21e3d7d1c..aef2677c36 100644 --- a/api/core/tools/plugin_tool/tool.py +++ b/api/core/tools/plugin_tool/tool.py @@ -44,6 +44,7 @@ class PluginTool(Tool): tool_provider=self.entity.identity.provider, tool_name=self.entity.identity.name, credentials=self.runtime.credentials, + credential_type=self.runtime.credential_type, tool_parameters=tool_parameters, conversation_id=conversation_id, app_id=app_id, diff --git a/api/core/tools/signature.py b/api/core/tools/signature.py index e80005d7bf..5cdf473542 100644 --- a/api/core/tools/signature.py +++ b/api/core/tools/signature.py @@ -9,9 +9,10 @@ from configs import dify_config def sign_tool_file(tool_file_id: str, extension: str) -> str: """ - sign file to get a temporary url + sign file to get a temporary url for plugin access """ - base_url = dify_config.FILES_URL + # Use internal URL for plugin/tool file access in Docker environments + base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}" timestamp = str(int(time.time())) diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index b849f51064..ff054041cf 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -35,9 +35,10 @@ class ToolFileManager: @staticmethod def sign_file(tool_file_id: str, extension: str) -> str: """ - sign file to get a temporary url + sign file to get a temporary url for plugin access """ - base_url = dify_config.FILES_URL + # Use internal URL for plugin/tool file access in Docker environments + base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}" timestamp = str(int(time.time())) @@ -159,7 +160,7 @@ class ToolFileManager: with Session(self._engine, expire_on_commit=False) as session: tool_file: ToolFile | None = ( session.query(ToolFile) - .filter( + .where( ToolFile.id == id, ) .first() @@ -183,7 +184,7 @@ class ToolFileManager: with Session(self._engine, expire_on_commit=False) as session: message_file: MessageFile | None = ( session.query(MessageFile) - .filter( + .where( MessageFile.id == id, ) .first() @@ -203,7 +204,7 @@ class ToolFileManager: tool_file: ToolFile | None = ( session.query(ToolFile) - .filter( + .where( ToolFile.id == tool_file_id, ) .first() @@ -227,7 +228,7 @@ class ToolFileManager: with Session(self._engine, expire_on_commit=False) as session: tool_file: ToolFile | None = ( session.query(ToolFile) - .filter( + .where( ToolFile.id == tool_file_id, ) .first() diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 4787d7d79c..cdfefbadb3 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -29,7 +29,7 @@ class ToolLabelManager: raise ValueError("Unsupported tool type") # delete old labels - db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete() + db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id).delete() # insert new labels for label in labels: @@ -57,7 +57,7 @@ class ToolLabelManager: labels = ( db.session.query(ToolLabelBinding.label_name) - .filter( + .where( ToolLabelBinding.tool_id == provider_id, ToolLabelBinding.tool_type == controller.provider_type.value, ) @@ -90,7 +90,7 @@ class ToolLabelManager: provider_ids.append(controller.provider_id) labels: list[ToolLabelBinding] = ( - db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all() + db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids)).all() ) tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels} diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 0bfe6329b1..f286466de0 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -1,26 +1,34 @@ import json import logging import mimetypes -from collections.abc import Generator +import time +from collections.abc import Generator, Mapping from os import listdir, path from threading import Lock -from typing import TYPE_CHECKING, Any, Union, cast +from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast +from pydantic import TypeAdapter from yarl import URL import contexts +from core.helper.provider_cache import ToolProviderCredentialsCache from core.plugin.entities.plugin import ToolProviderID +from core.plugin.impl.oauth import OAuthHandler from core.plugin.impl.tool import PluginToolManager from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.mcp_tool.provider import MCPToolProviderController +from core.tools.mcp_tool.tool import MCPTool from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.tool import PluginTool +from core.tools.utils.uuid_utils import is_valid_uuid from core.tools.workflow_as_tool.provider import WorkflowToolProviderController +from core.workflow.entities.variable_pool import VariablePool +from services.tools.mcp_tools_manage_service import MCPToolManageService if TYPE_CHECKING: from core.workflow.nodes.tool.entities import ToolEntity - from configs import dify_config from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom @@ -37,19 +45,20 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( ApiProviderAuthType, + CredentialType, ToolInvokeFrom, ToolParameter, ToolProviderType, ) -from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError +from core.tools.errors import ToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ( - ProviderConfigEncrypter, ToolParameterConfigurationManager, ) +from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db -from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider +from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) @@ -64,8 +73,11 @@ class ToolManager: @classmethod def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController: """ + get the hardcoded provider + """ + if len(cls._hardcoded_providers) == 0: # init the builtin providers cls.load_hardcoded_providers_cache() @@ -109,7 +121,12 @@ class ToolManager: contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(Lock()) + plugin_tool_providers = contexts.plugin_tool_providers.get() + if provider in plugin_tool_providers: + return plugin_tool_providers[provider] + with contexts.plugin_tool_providers_lock.get(): + # double check plugin_tool_providers = contexts.plugin_tool_providers.get() if provider in plugin_tool_providers: return plugin_tool_providers[provider] @@ -127,25 +144,7 @@ class ToolManager: ) plugin_tool_providers[provider] = controller - - return controller - - @classmethod - def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None: - """ - get the builtin tool - - :param provider: the name of the provider - :param tool_name: the name of the tool - :param tenant_id: the id of the tenant - :return: the provider, the tool - """ - provider_controller = cls.get_builtin_provider(provider, tenant_id) - tool = provider_controller.get_tool(tool_name) - if tool is None: - raise ToolNotFoundError(f"tool {tool_name} not found") - - return tool + return controller @classmethod def get_tool_runtime( @@ -156,7 +155,8 @@ class ToolManager: tenant_id: str, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, - ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool]: + credential_id: Optional[str] = None, + ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]: """ get the tool runtime @@ -166,6 +166,7 @@ class ToolManager: :param tenant_id: the tenant id :param invoke_from: invoke from :param tool_invoke_from: the tool invoke from + :param credential_id: the credential id :return: the tool """ @@ -189,49 +190,105 @@ class ToolManager: ) ), ) - + builtin_provider = None if isinstance(provider_controller, PluginToolProviderController): provider_id_entity = ToolProviderID(provider_id) - # get credentials - builtin_provider: BuiltinToolProvider | None = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - (BuiltinToolProvider.provider == str(provider_id_entity)) - | (BuiltinToolProvider.provider == provider_id_entity.provider_name), - ) - .first() - ) + # get specific credentials + if is_valid_uuid(credential_id): + try: + builtin_provider = ( + db.session.query(BuiltinToolProvider) + .where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.id == credential_id, + ) + .first() + ) + except Exception as e: + builtin_provider = None + logger.info(f"Error getting builtin provider {credential_id}:{e}", exc_info=True) + # if the provider has been deleted, raise an error + if builtin_provider is None: + raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}") + # fallback to the default provider if builtin_provider is None: - raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") + # use the default provider + builtin_provider = ( + db.session.query(BuiltinToolProvider) + .where( + BuiltinToolProvider.tenant_id == tenant_id, + (BuiltinToolProvider.provider == str(provider_id_entity)) + | (BuiltinToolProvider.provider == provider_id_entity.provider_name), + ) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) + .first() + ) + if builtin_provider is None: + raise ToolProviderNotFoundError(f"no default provider for {provider_id}") else: builtin_provider = ( db.session.query(BuiltinToolProvider) - .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) + .where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) .first() ) if builtin_provider is None: raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") - # decrypt the credentials - credentials = builtin_provider.credentials - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_provider_encrypter( tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + config=[ + x.to_basic_provider_config() + for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type) + ], + cache=ToolProviderCredentialsCache( + tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id + ), ) - decrypted_credentials = tool_configuration.decrypt(credentials) + # decrypt the credentials + decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials) + + # check if the credentials is expired + if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()): + # TODO: circular import + from services.tools.builtin_tools_manage_service import BuiltinToolManageService + + # refresh the credentials + tool_provider = ToolProviderID(provider_id) + provider_name = tool_provider.provider_name + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback" + system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id) + oauth_handler = OAuthHandler() + # refresh the credentials + refreshed_credentials = oauth_handler.refresh_credentials( + tenant_id=tenant_id, + user_id=builtin_provider.user_id, + plugin_id=tool_provider.plugin_id, + provider=provider_name, + redirect_uri=redirect_uri, + system_credentials=system_credentials or {}, + credentials=decrypted_credentials, + ) + # update the credentials + builtin_provider.encrypted_credentials = ( + TypeAdapter(dict[str, Any]) + .dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials))) + .decode("utf-8") + ) + builtin_provider.expires_at = refreshed_credentials.expires_at + db.session.commit() + decrypted_credentials = refreshed_credentials.credentials return cast( BuiltinTool, builtin_tool.fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, - credentials=decrypted_credentials, + credentials=dict(decrypted_credentials), + credential_type=CredentialType.of(builtin_provider.credential_type), runtime_parameters={}, invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -241,22 +298,16 @@ class ToolManager: elif provider_type == ToolProviderType.API: api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) - - # decrypt the credentials - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()], - provider_type=api_provider.provider_type.value, - provider_identity=api_provider.entity.identity.name, + controller=api_provider, ) - decrypted_credentials = tool_configuration.decrypt(credentials) - return cast( ApiTool, api_provider.get_tool(tool_name).fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, - credentials=decrypted_credentials, + credentials=encrypter.decrypt(credentials), invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, ) @@ -265,7 +316,7 @@ class ToolManager: elif provider_type == ToolProviderType.WORKFLOW: workflow_provider = ( db.session.query(WorkflowToolProvider) - .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) .first() ) @@ -292,6 +343,8 @@ class ToolManager: raise NotImplementedError("app provider not implemented") elif provider_type == ToolProviderType.PLUGIN: return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name) + elif provider_type == ToolProviderType.MCP: + return cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name) else: raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found") @@ -302,6 +355,7 @@ class ToolManager: app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, + variable_pool: Optional[VariablePool] = None, ) -> Tool: """ get the agent tool runtime @@ -313,27 +367,13 @@ class ToolManager: tenant_id=tenant_id, invoke_from=invoke_from, tool_invoke_from=ToolInvokeFrom.AGENT, + credential_id=agent_tool.credential_id, ) runtime_parameters = {} parameters = tool_entity.get_merged_runtime_parameters() - for parameter in parameters: - # check file types - if ( - parameter.type - in { - ToolParameter.ToolParameterType.SYSTEM_FILES, - ToolParameter.ToolParameterType.FILE, - ToolParameter.ToolParameterType.FILES, - } - and parameter.required - ): - raise ValueError(f"file type parameter {parameter.name} not supported in agent") - - if parameter.form == ToolParameter.ToolParameterForm.FORM: - # save tool parameter to tool entity memory - value = parameter.init_frontend_parameter(agent_tool.tool_parameters.get(parameter.name)) - runtime_parameters[parameter.name] = value - + runtime_parameters = cls._convert_tool_parameters_type( + parameters, variable_pool, agent_tool.tool_parameters, typ="agent" + ) # decrypt runtime parameters encryption_manager = ToolParameterConfigurationManager( tenant_id=tenant_id, @@ -357,10 +397,12 @@ class ToolManager: node_id: str, workflow_tool: "ToolEntity", invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, + variable_pool: Optional[VariablePool] = None, ) -> Tool: """ get the workflow tool runtime """ + tool_runtime = cls.get_tool_runtime( provider_type=workflow_tool.provider_type, provider_id=workflow_tool.provider_id, @@ -368,16 +410,13 @@ class ToolManager: tenant_id=tenant_id, invoke_from=invoke_from, tool_invoke_from=ToolInvokeFrom.WORKFLOW, + credential_id=workflow_tool.credential_id, ) - runtime_parameters = {} + parameters = tool_runtime.get_merged_runtime_parameters() - - for parameter in parameters: - # save tool parameter to tool entity memory - if parameter.form == ToolParameter.ToolParameterForm.FORM: - value = parameter.init_frontend_parameter(workflow_tool.tool_configurations.get(parameter.name)) - runtime_parameters[parameter.name] = value - + runtime_parameters = cls._convert_tool_parameters_type( + parameters, variable_pool, workflow_tool.tool_configurations, typ="workflow" + ) # decrypt runtime parameters encryption_manager = ToolParameterConfigurationManager( tenant_id=tenant_id, @@ -401,6 +440,7 @@ class ToolManager: provider: str, tool_name: str, tool_parameters: dict[str, Any], + credential_id: Optional[str] = None, ) -> Tool: """ get tool runtime from plugin @@ -412,6 +452,7 @@ class ToolManager: tenant_id=tenant_id, invoke_from=InvokeFrom.SERVICE_API, tool_invoke_from=ToolInvokeFrom.PLUGIN, + credential_id=credential_id, ) runtime_parameters = {} parameters = tool_entity.get_merged_runtime_parameters() @@ -561,6 +602,22 @@ class ToolManager: return cls._builtin_tools_labels[tool_name] + @classmethod + def list_default_builtin_providers(cls, tenant_id: str) -> list[BuiltinToolProvider]: + """ + list all the builtin providers + """ + # according to multi credentials, select the one with is_default=True first, then created_at oldest + # for compatibility with old version + sql = """ + SELECT DISTINCT ON (tenant_id, provider) id + FROM tool_builtin_providers + WHERE tenant_id = :tenant_id + ORDER BY tenant_id, provider, is_default DESC, created_at DESC + """ + ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()] + return db.session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all() + @classmethod def list_providers_from_api( cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral @@ -569,27 +626,19 @@ class ToolManager: filters = [] if not typ: - filters.extend(["builtin", "api", "workflow"]) + filters.extend(["builtin", "api", "workflow", "mcp"]) else: filters.append(typ) with db.session.no_autoflush: if "builtin" in filters: - # get builtin providers builtin_providers = cls.list_builtin_providers(tenant_id) - # get db builtin providers - db_builtin_providers: list[BuiltinToolProvider] = ( - db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() - ) - - # rewrite db_builtin_providers - for db_provider in db_builtin_providers: - tool_provider_id = str(ToolProviderID(db_provider.provider)) - db_provider.provider = tool_provider_id - - def find_db_builtin_provider(provider): - return next((x for x in db_builtin_providers if x.provider == provider), None) + # key: provider name, value: provider + db_builtin_providers = { + str(ToolProviderID(provider.provider)): provider + for provider in cls.list_default_builtin_providers(tenant_id) + } # append builtin providers for provider in builtin_providers: @@ -601,10 +650,9 @@ class ToolManager: name_func=lambda x: x.identity.name, ): continue - user_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider, - db_provider=find_db_builtin_provider(provider.entity.identity.name), + db_provider=db_builtin_providers.get(provider.entity.identity.name), decrypt_credentials=False, ) @@ -614,10 +662,9 @@ class ToolManager: result_providers[f"builtin_provider.{user_provider.name}"] = user_provider # get db api providers - if "api" in filters: db_api_providers: list[ApiToolProvider] = ( - db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() + db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() ) api_provider_controllers: list[dict[str, Any]] = [ @@ -640,7 +687,7 @@ class ToolManager: if "workflow" in filters: # get workflow providers workflow_providers: list[WorkflowToolProvider] = ( - db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() + db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all() ) workflow_provider_controllers: list[WorkflowToolProviderController] = [] @@ -663,6 +710,10 @@ class ToolManager: labels=labels.get(provider_controller.provider_id, []), ) result_providers[f"workflow_provider.{user_provider.name}"] = user_provider + if "mcp" in filters: + mcp_providers = MCPToolManageService.retrieve_mcp_tools(tenant_id, for_list=True) + for mcp_provider in mcp_providers: + result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider return BuiltinToolProviderSort.sort(list(result_providers.values())) @@ -680,7 +731,7 @@ class ToolManager: """ provider: ApiToolProvider | None = ( db.session.query(ApiToolProvider) - .filter( + .where( ApiToolProvider.id == provider_id, ApiToolProvider.tenant_id == tenant_id, ) @@ -690,14 +741,47 @@ class ToolManager: if provider is None: raise ToolProviderNotFoundError(f"api provider {provider_id} not found") + auth_type = ApiProviderAuthType.NONE + provider_auth_type = provider.credentials.get("auth_type") + if provider_auth_type in ("api_key_header", "api_key"): # backward compatibility + auth_type = ApiProviderAuthType.API_KEY_HEADER + elif provider_auth_type == "api_key_query": + auth_type = ApiProviderAuthType.API_KEY_QUERY + controller = ApiToolProviderController.from_db( provider, - ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, + auth_type, ) controller.load_bundled_tools(provider.tools) return controller, provider.credentials + @classmethod + def get_mcp_provider_controller(cls, tenant_id: str, provider_id: str) -> MCPToolProviderController: + """ + get the api provider + + :param tenant_id: the id of the tenant + :param provider_id: the id of the provider + + :return: the provider controller, the credentials + """ + provider: MCPToolProvider | None = ( + db.session.query(MCPToolProvider) + .where( + MCPToolProvider.server_identifier == provider_id, + MCPToolProvider.tenant_id == tenant_id, + ) + .first() + ) + + if provider is None: + raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found") + + controller = MCPToolProviderController._from_db(provider) + + return controller + @classmethod def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: """ @@ -709,7 +793,7 @@ class ToolManager: provider_name = provider provider_obj: ApiToolProvider | None = ( db.session.query(ApiToolProvider) - .filter( + .where( ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == provider, ) @@ -725,20 +809,24 @@ class ToolManager: credentials = {} # package tool provider controller + auth_type = ApiProviderAuthType.NONE + credentials_auth_type = credentials.get("auth_type") + if credentials_auth_type in ("api_key_header", "api_key"): # backward compatibility + auth_type = ApiProviderAuthType.API_KEY_HEADER + elif credentials_auth_type == "api_key_query": + auth_type = ApiProviderAuthType.API_KEY_QUERY + controller = ApiToolProviderController.from_db( provider_obj, - ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, + auth_type, ) # init tool configuration - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()], - provider_type=controller.provider_type.value, - provider_identity=controller.entity.identity.name, + controller=controller, ) - decrypted_credentials = tool_configuration.decrypt(credentials) - masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) + masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials)) try: icon = json.loads(provider_obj.icon) @@ -797,7 +885,7 @@ class ToolManager: try: workflow_provider: WorkflowToolProvider | None = ( db.session.query(WorkflowToolProvider) - .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) .first() ) @@ -814,7 +902,7 @@ class ToolManager: try: api_provider: ApiToolProvider | None = ( db.session.query(ApiToolProvider) - .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) + .where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) .first() ) @@ -826,6 +914,22 @@ class ToolManager: except Exception: return {"background": "#252525", "content": "\ud83d\ude01"} + @classmethod + def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict[str, str] | str: + try: + mcp_provider: MCPToolProvider | None = ( + db.session.query(MCPToolProvider) + .where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id) + .first() + ) + + if mcp_provider is None: + raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found") + + return mcp_provider.provider_icon + except Exception: + return {"background": "#252525", "content": "\ud83d\ude01"} + @classmethod def get_tool_icon( cls, @@ -863,8 +967,61 @@ class ToolManager: except Exception: return {"background": "#252525", "content": "\ud83d\ude01"} raise ValueError(f"plugin provider {provider_id} not found") + elif provider_type == ToolProviderType.MCP: + return cls.generate_mcp_tool_icon_url(tenant_id, provider_id) else: raise ValueError(f"provider type {provider_type} not found") + @classmethod + def _convert_tool_parameters_type( + cls, + parameters: list[ToolParameter], + variable_pool: Optional[VariablePool], + tool_configurations: dict[str, Any], + typ: Literal["agent", "workflow", "tool"] = "workflow", + ) -> dict[str, Any]: + """ + Convert tool parameters type + """ + from core.workflow.nodes.tool.entities import ToolNodeData + from core.workflow.nodes.tool.exc import ToolParameterError + + runtime_parameters = {} + for parameter in parameters: + if ( + parameter.type + in { + ToolParameter.ToolParameterType.SYSTEM_FILES, + ToolParameter.ToolParameterType.FILE, + ToolParameter.ToolParameterType.FILES, + } + and parameter.required + and typ == "agent" + ): + raise ValueError(f"file type parameter {parameter.name} not supported in agent") + # save tool parameter to tool entity memory + if parameter.form == ToolParameter.ToolParameterForm.FORM: + if variable_pool: + config = tool_configurations.get(parameter.name, {}) + if not (config and isinstance(config, dict) and config.get("value") is not None): + continue + tool_input = ToolNodeData.ToolInput(**tool_configurations.get(parameter.name, {})) + if tool_input.type == "variable": + variable = variable_pool.get(tool_input.value) + if variable is None: + raise ToolParameterError(f"Variable {tool_input.value} does not exist") + parameter_value = variable.value + elif tool_input.type in {"mixed", "constant"}: + segment_group = variable_pool.convert_template(str(tool_input.value)) + parameter_value = segment_group.text + else: + raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'") + runtime_parameters[parameter.name] = parameter_value + + else: + value = parameter.init_frontend_parameter(tool_configurations.get(parameter.name)) + runtime_parameters[parameter.name] = value + return runtime_parameters + ToolManager.load_hardcoded_providers_cache() diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 1f23e90351..aceba6e69f 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -1,12 +1,8 @@ from copy import deepcopy from typing import Any -from pydantic import BaseModel - -from core.entities.provider_entities import BasicProviderConfig from core.helper import encrypter from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType -from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ( ToolParameter, @@ -14,109 +10,6 @@ from core.tools.entities.tool_entities import ( ) -class ProviderConfigEncrypter(BaseModel): - tenant_id: str - config: list[BasicProviderConfig] - provider_type: str - provider_identity: str - - def _deep_copy(self, data: dict[str, str]) -> dict[str, str]: - """ - deep copy data - """ - return deepcopy(data) - - def encrypt(self, data: dict[str, str]) -> dict[str, str]: - """ - encrypt tool credentials with tenant id - - return a deep copy of credentials with encrypted values - """ - data = self._deep_copy(data) - - # get fields need to be decrypted - fields = dict[str, BasicProviderConfig]() - for credential in self.config: - fields[credential.name] = credential - - for field_name, field in fields.items(): - if field.type == BasicProviderConfig.Type.SECRET_INPUT: - if field_name in data: - encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "") - data[field_name] = encrypted - - return data - - def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]: - """ - mask tool credentials - - return a deep copy of credentials with masked values - """ - data = self._deep_copy(data) - - # get fields need to be decrypted - fields = dict[str, BasicProviderConfig]() - for credential in self.config: - fields[credential.name] = credential - - for field_name, field in fields.items(): - if field.type == BasicProviderConfig.Type.SECRET_INPUT: - if field_name in data: - if len(data[field_name]) > 6: - data[field_name] = ( - data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:] - ) - else: - data[field_name] = "*" * len(data[field_name]) - - return data - - def decrypt(self, data: dict[str, str]) -> dict[str, str]: - """ - decrypt tool credentials with tenant id - - return a deep copy of credentials with decrypted values - """ - cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=f"{self.provider_type}.{self.provider_identity}", - cache_type=ToolProviderCredentialsCacheType.PROVIDER, - ) - cached_credentials = cache.get() - if cached_credentials: - return cached_credentials - - data = self._deep_copy(data) - # get fields need to be decrypted - fields = dict[str, BasicProviderConfig]() - for credential in self.config: - fields[credential.name] = credential - - for field_name, field in fields.items(): - if field.type == BasicProviderConfig.Type.SECRET_INPUT: - if field_name in data: - try: - # if the value is None or empty string, skip decrypt - if not data[field_name]: - continue - - data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name]) - except Exception: - pass - - cache.set(data) - return data - - def delete_tool_credentials_cache(self): - cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=f"{self.provider_type}.{self.provider_identity}", - cache_type=ToolProviderCredentialsCacheType.PROVIDER, - ) - cache.delete() - - class ToolParameterConfigurationManager: """ Tool parameter configuration manager diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index 2cbc4b9821..7eb4bc017a 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -87,7 +87,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] segments = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.dataset_id.in_(self.dataset_ids), DocumentSegment.completed_at.isnot(None), DocumentSegment.status == "completed", @@ -114,7 +114,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() document = ( db.session.query(Document) - .filter( + .where( Document.id == segment.document_id, Document.enabled == True, Document.archived == False, @@ -163,7 +163,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): ): with flask_app.app_context(): dataset = ( - db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first() + db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first() ) if not dataset: diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py index a4d2de3b1c..567275531e 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Any, Optional +from typing import Optional from msal_extensions.persistence import ABC # type: ignore from pydantic import BaseModel, ConfigDict @@ -21,11 +21,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC): model_config = ConfigDict(arbitrary_types_allowed=True) @abstractmethod - def _run( - self, - *args: Any, - **kwargs: Any, - ) -> Any: + def _run(self, query: str) -> str: """Use the tool. Add run_manager: Optional[CallbackManagerForToolRun] = None diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index ff1d9021ce..f7689d7707 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -57,7 +57,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): def _run(self, query: str) -> str: dataset = ( - db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first() + db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first() ) if not dataset: @@ -190,7 +190,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() document = ( db.session.query(DatasetDocument) # type: ignore - .filter( + .where( DatasetDocument.id == segment.document_id, DatasetDocument.enabled == True, DatasetDocument.archived == False, diff --git a/api/core/tools/utils/encryption.py b/api/core/tools/utils/encryption.py new file mode 100644 index 0000000000..5fdfd3b9d1 --- /dev/null +++ b/api/core/tools/utils/encryption.py @@ -0,0 +1,142 @@ +from copy import deepcopy +from typing import Any, Optional, Protocol + +from core.entities.provider_entities import BasicProviderConfig +from core.helper import encrypter +from core.helper.provider_cache import SingletonProviderCredentialsCache +from core.tools.__base.tool_provider import ToolProviderController + + +class ProviderConfigCache(Protocol): + """ + Interface for provider configuration cache operations + """ + + def get(self) -> Optional[dict]: + """Get cached provider configuration""" + ... + + def set(self, config: dict[str, Any]) -> None: + """Cache provider configuration""" + ... + + def delete(self) -> None: + """Delete cached provider configuration""" + ... + + +class ProviderConfigEncrypter: + tenant_id: str + config: list[BasicProviderConfig] + provider_config_cache: ProviderConfigCache + + def __init__( + self, + tenant_id: str, + config: list[BasicProviderConfig], + provider_config_cache: ProviderConfigCache, + ): + self.tenant_id = tenant_id + self.config = config + self.provider_config_cache = provider_config_cache + + def _deep_copy(self, data: dict[str, str]) -> dict[str, str]: + """ + deep copy data + """ + return deepcopy(data) + + def encrypt(self, data: dict[str, str]) -> dict[str, str]: + """ + encrypt tool credentials with tenant id + + return a deep copy of credentials with encrypted values + """ + data = self._deep_copy(data) + + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "") + data[field_name] = encrypted + + return data + + def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]: + """ + mask tool credentials + + return a deep copy of credentials with masked values + """ + data = self._deep_copy(data) + + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + if len(data[field_name]) > 6: + data[field_name] = ( + data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:] + ) + else: + data[field_name] = "*" * len(data[field_name]) + + return data + + def decrypt(self, data: dict[str, str]) -> dict[str, Any]: + """ + decrypt tool credentials with tenant id + + return a deep copy of credentials with decrypted values + """ + cached_credentials = self.provider_config_cache.get() + if cached_credentials: + return cached_credentials + + data = self._deep_copy(data) + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + try: + # if the value is None or empty string, skip decrypt + if not data[field_name]: + continue + + data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name]) + except Exception: + pass + + self.provider_config_cache.set(data) + return data + + +def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache): + return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache + + +def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController): + cache = SingletonProviderCredentialsCache( + tenant_id=tenant_id, + provider_type=controller.provider_type.value, + provider_identity=controller.entity.identity.name, + ) + encrypt = ProviderConfigEncrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()], + provider_config_cache=cache, + ) + return encrypt, cache diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 3f844e8234..a3c84615ca 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -1,5 +1,4 @@ import re -import uuid from json import dumps as json_dumps from json import loads as json_loads from json.decoder import JSONDecodeError @@ -154,7 +153,7 @@ class ApiBasedToolSchemaParser: # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ path = re.sub(r"[^a-zA-Z0-9_-]", "", path) if not path: - path = str(uuid.uuid4()) + path = "" interface["operation"]["operationId"] = f"{path}_{interface['method']}" diff --git a/api/core/tools/utils/system_oauth_encryption.py b/api/core/tools/utils/system_oauth_encryption.py new file mode 100644 index 0000000000..f3c946b95f --- /dev/null +++ b/api/core/tools/utils/system_oauth_encryption.py @@ -0,0 +1,187 @@ +import base64 +import hashlib +import logging +from collections.abc import Mapping +from typing import Any, Optional + +from Crypto.Cipher import AES +from Crypto.Random import get_random_bytes +from Crypto.Util.Padding import pad, unpad +from pydantic import TypeAdapter + +from configs import dify_config + +logger = logging.getLogger(__name__) + + +class OAuthEncryptionError(Exception): + """OAuth encryption/decryption specific error""" + + pass + + +class SystemOAuthEncrypter: + """ + A simple OAuth parameters encrypter using AES-CBC encryption. + + This class provides methods to encrypt and decrypt OAuth parameters + using AES-CBC mode with a key derived from the application's SECRET_KEY. + """ + + def __init__(self, secret_key: Optional[str] = None): + """ + Initialize the OAuth encrypter. + + Args: + secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY + + Raises: + ValueError: If SECRET_KEY is not configured or empty + """ + secret_key = secret_key or dify_config.SECRET_KEY or "" + + # Generate a fixed 256-bit key using SHA-256 + self.key = hashlib.sha256(secret_key.encode()).digest() + + def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str: + """ + Encrypt OAuth parameters. + + Args: + oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"} + + Returns: + Base64-encoded encrypted string + + Raises: + OAuthEncryptionError: If encryption fails + ValueError: If oauth_params is invalid + """ + + try: + # Generate random IV (16 bytes) + iv = get_random_bytes(16) + + # Create AES cipher (CBC mode) + cipher = AES.new(self.key, AES.MODE_CBC, iv) + + # Encrypt data + padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size) + encrypted_data = cipher.encrypt(padded_data) + + # Combine IV and encrypted data + combined = iv + encrypted_data + + # Return base64 encoded string + return base64.b64encode(combined).decode() + + except Exception as e: + raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e + + def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]: + """ + Decrypt OAuth parameters. + + Args: + encrypted_data: Base64-encoded encrypted string + + Returns: + Decrypted OAuth parameters dictionary + + Raises: + OAuthEncryptionError: If decryption fails + ValueError: If encrypted_data is invalid + """ + if not isinstance(encrypted_data, str): + raise ValueError("encrypted_data must be a string") + + if not encrypted_data: + raise ValueError("encrypted_data cannot be empty") + + try: + # Base64 decode + combined = base64.b64decode(encrypted_data) + + # Check minimum length (IV + at least one AES block) + if len(combined) < 32: # 16 bytes IV + 16 bytes minimum encrypted data + raise ValueError("Invalid encrypted data format") + + # Separate IV and encrypted data + iv = combined[:16] + encrypted_data_bytes = combined[16:] + + # Create AES cipher + cipher = AES.new(self.key, AES.MODE_CBC, iv) + + # Decrypt data + decrypted_data = cipher.decrypt(encrypted_data_bytes) + unpadded_data = unpad(decrypted_data, AES.block_size) + + # Parse JSON + oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data) + + if not isinstance(oauth_params, dict): + raise ValueError("Decrypted data is not a valid dictionary") + + return oauth_params + + except Exception as e: + raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e + + +# Factory function for creating encrypter instances +def create_system_oauth_encrypter(secret_key: Optional[str] = None) -> SystemOAuthEncrypter: + """ + Create an OAuth encrypter instance. + + Args: + secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY + + Returns: + SystemOAuthEncrypter instance + """ + return SystemOAuthEncrypter(secret_key=secret_key) + + +# Global encrypter instance (for backward compatibility) +_oauth_encrypter: Optional[SystemOAuthEncrypter] = None + + +def get_system_oauth_encrypter() -> SystemOAuthEncrypter: + """ + Get the global OAuth encrypter instance. + + Returns: + SystemOAuthEncrypter instance + """ + global _oauth_encrypter + if _oauth_encrypter is None: + _oauth_encrypter = SystemOAuthEncrypter() + return _oauth_encrypter + + +# Convenience functions for backward compatibility +def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str: + """ + Encrypt OAuth parameters using the global encrypter. + + Args: + oauth_params: OAuth parameters dictionary + + Returns: + Base64-encoded encrypted string + """ + return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params) + + +def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]: + """ + Decrypt OAuth parameters using the global encrypter. + + Args: + encrypted_data: Base64-encoded encrypted string + + Returns: + Decrypted OAuth parameters dictionary + """ + return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data) diff --git a/api/core/tools/utils/uuid_utils.py b/api/core/tools/utils/uuid_utils.py index 3046c08c89..bdcc33259d 100644 --- a/api/core/tools/utils/uuid_utils.py +++ b/api/core/tools/utils/uuid_utils.py @@ -1,7 +1,9 @@ import uuid -def is_valid_uuid(uuid_str: str) -> bool: +def is_valid_uuid(uuid_str: str | None) -> bool: + if uuid_str is None or len(uuid_str) == 0: + return False try: uuid.UUID(uuid_str) return True diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 7661e1e6a5..83f5f558d5 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -84,7 +84,7 @@ class WorkflowToolProviderController(ToolProviderController): """ workflow: Workflow | None = ( db.session.query(Workflow) - .filter(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version) + .where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version) .first() ) @@ -190,7 +190,7 @@ class WorkflowToolProviderController(ToolProviderController): db_providers: WorkflowToolProvider | None = ( db.session.query(WorkflowToolProvider) - .filter( + .where( WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == self.provider_id, ) diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 57c93d1d45..8b89c2a7a9 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -8,7 +8,12 @@ from flask_login import current_user from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime -from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType +from core.tools.entities.tool_entities import ( + ToolEntity, + ToolInvokeMessage, + ToolParameter, + ToolProviderType, +) from core.tools.errors import ToolInvokeError from extensions.ext_database import db from factories.file_factory import build_from_mapping @@ -137,12 +142,12 @@ class WorkflowTool(Tool): if not version: workflow = ( db.session.query(Workflow) - .filter(Workflow.app_id == app_id, Workflow.version != "draft") + .where(Workflow.app_id == app_id, Workflow.version != "draft") .order_by(Workflow.created_at.desc()) .first() ) else: - workflow = db.session.query(Workflow).filter(Workflow.app_id == app_id, Workflow.version == version).first() + workflow = db.session.query(Workflow).where(Workflow.app_id == app_id, Workflow.version == version).first() if not workflow: raise ValueError("workflow not found or not published") @@ -153,7 +158,7 @@ class WorkflowTool(Tool): """ get the app by app id """ - app = db.session.query(App).filter(App.id == app_id).first() + app = db.session.query(App).where(App.id == app_id).first() if not app: raise ValueError("app not found") diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 6cf09e0372..13274f4e0e 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -1,9 +1,9 @@ import json import sys from collections.abc import Mapping, Sequence -from typing import Any +from typing import Annotated, Any, TypeAlias -from pydantic import BaseModel, ConfigDict, field_validator +from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator from core.file import File @@ -11,6 +11,11 @@ from .types import SegmentType class Segment(BaseModel): + """Segment is runtime type used during the execution of workflow. + + Note: this class is abstract, you should use subclasses of this class instead. + """ + model_config = ConfigDict(frozen=True) value_type: SegmentType @@ -73,7 +78,7 @@ class StringSegment(Segment): class FloatSegment(Segment): - value_type: SegmentType = SegmentType.NUMBER + value_type: SegmentType = SegmentType.FLOAT value: float # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. # The following tests cannot pass. @@ -92,7 +97,7 @@ class FloatSegment(Segment): class IntegerSegment(Segment): - value_type: SegmentType = SegmentType.NUMBER + value_type: SegmentType = SegmentType.INTEGER value: int @@ -181,3 +186,46 @@ class ArrayFileSegment(ArraySegment): @property def text(self) -> str: return "" + + +def get_segment_discriminator(v: Any) -> SegmentType | None: + if isinstance(v, Segment): + return v.value_type + elif isinstance(v, dict): + value_type = v.get("value_type") + if value_type is None: + return None + try: + seg_type = SegmentType(value_type) + except ValueError: + return None + return seg_type + else: + # return None if the discriminator value isn't found + return None + + +# The `SegmentUnion`` type is used to enable serialization and deserialization with Pydantic. +# Use `Segment` for type hinting when serialization is not required. +# +# Note: +# - All variants in `SegmentUnion` must inherit from the `Segment` class. +# - The union must include all non-abstract subclasses of `Segment`, except: +# - `SegmentGroup`, which is not added to the variable pool. +# - `Variable` and its subclasses, which are handled by `VariableUnion`. +SegmentUnion: TypeAlias = Annotated[ + ( + Annotated[NoneSegment, Tag(SegmentType.NONE)] + | Annotated[StringSegment, Tag(SegmentType.STRING)] + | Annotated[FloatSegment, Tag(SegmentType.FLOAT)] + | Annotated[IntegerSegment, Tag(SegmentType.INTEGER)] + | Annotated[ObjectSegment, Tag(SegmentType.OBJECT)] + | Annotated[FileSegment, Tag(SegmentType.FILE)] + | Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)] + | Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)] + | Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)] + | Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)] + | Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)] + ), + Discriminator(get_segment_discriminator), +] diff --git a/api/core/variables/types.py b/api/core/variables/types.py index 68d3d82883..e79b2410bf 100644 --- a/api/core/variables/types.py +++ b/api/core/variables/types.py @@ -1,8 +1,27 @@ +from collections.abc import Mapping from enum import StrEnum +from typing import Any, Optional + +from core.file.models import File + + +class ArrayValidation(StrEnum): + """Strategy for validating array elements""" + + # Skip element validation (only check array container) + NONE = "none" + + # Validate the first element (if array is non-empty) + FIRST = "first" + + # Validate all elements in the array. + ALL = "all" class SegmentType(StrEnum): NUMBER = "number" + INTEGER = "integer" + FLOAT = "float" STRING = "string" OBJECT = "object" SECRET = "secret" @@ -19,16 +38,139 @@ class SegmentType(StrEnum): GROUP = "group" - def is_array_type(self): + def is_array_type(self) -> bool: return self in _ARRAY_TYPES + @classmethod + def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]: + """ + Attempt to infer the `SegmentType` based on the Python type of the `value` parameter. + + Returns `None` if no appropriate `SegmentType` can be determined for the given `value`. + For example, this may occur if the input is a generic Python object of type `object`. + """ + + if isinstance(value, list): + elem_types: set[SegmentType] = set() + for i in value: + segment_type = cls.infer_segment_type(i) + if segment_type is None: + return None + + elem_types.add(segment_type) + + if len(elem_types) != 1: + if elem_types.issubset(_NUMERICAL_TYPES): + return SegmentType.ARRAY_NUMBER + return SegmentType.ARRAY_ANY + elif all(i.is_array_type() for i in elem_types): + return SegmentType.ARRAY_ANY + match elem_types.pop(): + case SegmentType.STRING: + return SegmentType.ARRAY_STRING + case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: + return SegmentType.ARRAY_NUMBER + case SegmentType.OBJECT: + return SegmentType.ARRAY_OBJECT + case SegmentType.FILE: + return SegmentType.ARRAY_FILE + case SegmentType.NONE: + return SegmentType.ARRAY_ANY + case _: + # This should be unreachable. + raise ValueError(f"not supported value {value}") + if value is None: + return SegmentType.NONE + elif isinstance(value, int) and not isinstance(value, bool): + return SegmentType.INTEGER + elif isinstance(value, float): + return SegmentType.FLOAT + elif isinstance(value, str): + return SegmentType.STRING + elif isinstance(value, dict): + return SegmentType.OBJECT + elif isinstance(value, File): + return SegmentType.FILE + else: + return None + + def _validate_array(self, value: Any, array_validation: ArrayValidation) -> bool: + if not isinstance(value, list): + return False + # Skip element validation if array is empty + if len(value) == 0: + return True + if self == SegmentType.ARRAY_ANY: + return True + element_type = _ARRAY_ELEMENT_TYPES_MAPPING[self] + + if array_validation == ArrayValidation.NONE: + return True + elif array_validation == ArrayValidation.FIRST: + return element_type.is_valid(value[0]) + else: + return all([element_type.is_valid(i, array_validation=ArrayValidation.NONE)] for i in value) + + def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool: + """ + Check if a value matches the segment type. + Users of `SegmentType` should call this method, instead of using + `isinstance` manually. + + Args: + value: The value to validate + array_validation: Validation strategy for array types (ignored for non-array types) + + Returns: + True if the value matches the type under the given validation strategy + """ + if self.is_array_type(): + return self._validate_array(value, array_validation) + elif self == SegmentType.NUMBER: + return isinstance(value, (int, float)) + elif self == SegmentType.STRING: + return isinstance(value, str) + elif self == SegmentType.OBJECT: + return isinstance(value, dict) + elif self == SegmentType.SECRET: + return isinstance(value, str) + elif self == SegmentType.FILE: + return isinstance(value, File) + elif self == SegmentType.NONE: + return value is None + else: + raise AssertionError("this statement should be unreachable.") + + def exposed_type(self) -> "SegmentType": + """Returns the type exposed to the frontend. + + The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here. + """ + if self in (SegmentType.INTEGER, SegmentType.FLOAT): + return SegmentType.NUMBER + return self + + +_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = { + # ARRAY_ANY does not have correpond element type. + SegmentType.ARRAY_STRING: SegmentType.STRING, + SegmentType.ARRAY_NUMBER: SegmentType.NUMBER, + SegmentType.ARRAY_OBJECT: SegmentType.OBJECT, + SegmentType.ARRAY_FILE: SegmentType.FILE, +} _ARRAY_TYPES = frozenset( - [ + list(_ARRAY_ELEMENT_TYPES_MAPPING.keys()) + + [ SegmentType.ARRAY_ANY, - SegmentType.ARRAY_STRING, - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_FILE, + ] +) + + +_NUMERICAL_TYPES = frozenset( + [ + SegmentType.NUMBER, + SegmentType.INTEGER, + SegmentType.FLOAT, ] ) diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index b650b1682e..a31ebc848e 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -1,8 +1,8 @@ from collections.abc import Sequence -from typing import cast +from typing import Annotated, TypeAlias, cast from uuid import uuid4 -from pydantic import Field +from pydantic import Discriminator, Field, Tag from core.helper import encrypter @@ -20,6 +20,7 @@ from .segments import ( ObjectSegment, Segment, StringSegment, + get_segment_discriminator, ) from .types import SegmentType @@ -27,6 +28,10 @@ from .types import SegmentType class Variable(Segment): """ A variable is a segment that has a name. + + It is mainly used to store segments and their selector in VariablePool. + + Note: this class is abstract, you should use subclasses of this class instead. """ id: str = Field( @@ -93,3 +98,28 @@ class FileVariable(FileSegment, Variable): class ArrayFileVariable(ArrayFileSegment, ArrayVariable): pass + + +# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic. +# Use `Variable` for type hinting when serialization is not required. +# +# Note: +# - All variants in `VariableUnion` must inherit from the `Variable` class. +# - The union must include all non-abstract subclasses of `Segment`, except: +VariableUnion: TypeAlias = Annotated[ + ( + Annotated[NoneVariable, Tag(SegmentType.NONE)] + | Annotated[StringVariable, Tag(SegmentType.STRING)] + | Annotated[FloatVariable, Tag(SegmentType.FLOAT)] + | Annotated[IntegerVariable, Tag(SegmentType.INTEGER)] + | Annotated[ObjectVariable, Tag(SegmentType.OBJECT)] + | Annotated[FileVariable, Tag(SegmentType.FILE)] + | Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)] + | Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)] + | Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)] + | Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)] + | Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)] + | Annotated[SecretVariable, Tag(SegmentType.SECRET)] + ), + Discriminator(get_segment_discriminator), +] diff --git a/api/core/workflow/callbacks/workflow_logging_callback.py b/api/core/workflow/callbacks/workflow_logging_callback.py index e6813a3997..12b5203ca3 100644 --- a/api/core/workflow/callbacks/workflow_logging_callback.py +++ b/api/core/workflow/callbacks/workflow_logging_callback.py @@ -232,14 +232,14 @@ class WorkflowLoggingCallback(WorkflowCallback): Publish loop started """ self.print_text("\n[LoopRunStartedEvent]", color="blue") - self.print_text(f"Loop Node ID: {event.loop_id}", color="blue") + self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue") def on_workflow_loop_next(self, event: LoopRunNextEvent) -> None: """ Publish loop next """ self.print_text("\n[LoopRunNextEvent]", color="blue") - self.print_text(f"Loop Node ID: {event.loop_id}", color="blue") + self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue") self.print_text(f"Loop Index: {event.index}", color="blue") def on_workflow_loop_completed(self, event: LoopRunSucceededEvent | LoopRunFailedEvent) -> None: @@ -250,7 +250,7 @@ class WorkflowLoggingCallback(WorkflowCallback): "\n[LoopRunSucceededEvent]" if isinstance(event, LoopRunSucceededEvent) else "\n[LoopRunFailedEvent]", color="blue", ) - self.print_text(f"Node ID: {event.loop_id}", color="blue") + self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue") def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None: """Print text with highlighting and no end characters.""" diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 80dda2632d..fbb8df6b01 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -1,7 +1,7 @@ import re from collections import defaultdict from collections.abc import Mapping, Sequence -from typing import Any, Union +from typing import Annotated, Any, Union, cast from pydantic import BaseModel, Field @@ -9,8 +9,9 @@ from core.file import File, FileAttribute, file_manager from core.variables import Segment, SegmentGroup, Variable from core.variables.consts import MIN_SELECTORS_LENGTH from core.variables.segments import FileSegment, NoneSegment +from core.variables.variables import VariableUnion from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from core.workflow.enums import SystemVariableKey +from core.workflow.system_variable import SystemVariable from factories import variable_factory VariableValue = Union[str, int, float, dict, list, File] @@ -23,31 +24,31 @@ class VariablePool(BaseModel): # The first element of the selector is the node id, it's the first-level key in the dictionary. # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # elements of the selector except the first one. - variable_dictionary: dict[str, dict[int, Segment]] = Field( + variable_dictionary: defaultdict[str, Annotated[dict[int, VariableUnion], Field(default_factory=dict)]] = Field( description="Variables mapping", default=defaultdict(dict), ) - # TODO: This user inputs is not used for pool. + + # The `user_inputs` is used only when constructing the inputs for the `StartNode`. It's not used elsewhere. user_inputs: Mapping[str, Any] = Field( description="User inputs", default_factory=dict, ) - system_variables: Mapping[SystemVariableKey, Any] = Field( + system_variables: SystemVariable = Field( description="System variables", - default_factory=dict, ) - environment_variables: Sequence[Variable] = Field( + environment_variables: Sequence[VariableUnion] = Field( description="Environment variables.", default_factory=list, ) - conversation_variables: Sequence[Variable] = Field( + conversation_variables: Sequence[VariableUnion] = Field( description="Conversation variables.", default_factory=list, ) def model_post_init(self, context: Any, /) -> None: - for key, value in self.system_variables.items(): - self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value) + # Create a mapping from field names to SystemVariableKey enum values + self._add_system_variables(self.system_variables) # Add environment variables to the variable pool for var in self.environment_variables: self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) @@ -83,8 +84,22 @@ class VariablePool(BaseModel): segment = variable_factory.build_segment(value) variable = variable_factory.segment_to_variable(segment=segment, selector=selector) - hash_key = hash(tuple(selector[1:])) - self.variable_dictionary[selector[0]][hash_key] = variable + key, hash_key = self._selector_to_keys(selector) + # Based on the definition of `VariableUnion`, + # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. + self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable) + + @classmethod + def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, int]: + return selector[0], hash(tuple(selector[1:])) + + def _has(self, selector: Sequence[str]) -> bool: + key, hash_key = self._selector_to_keys(selector) + if key not in self.variable_dictionary: + return False + if hash_key not in self.variable_dictionary[key]: + return False + return True def get(self, selector: Sequence[str], /) -> Segment | None: """ @@ -102,8 +117,8 @@ class VariablePool(BaseModel): if len(selector) < MIN_SELECTORS_LENGTH: return None - hash_key = hash(tuple(selector[1:])) - value = self.variable_dictionary[selector[0]].get(hash_key) + key, hash_key = self._selector_to_keys(selector) + value: Segment | None = self.variable_dictionary[key].get(hash_key) if value is None: selector, attr = selector[:-1], selector[-1] @@ -136,8 +151,8 @@ class VariablePool(BaseModel): if len(selector) == 1: self.variable_dictionary[selector[0]] = {} return - hash_key = hash(tuple(selector[1:])) - self.variable_dictionary[selector[0]].pop(hash_key, None) + key, hash_key = self._selector_to_keys(selector) + self.variable_dictionary[key].pop(hash_key, None) def convert_template(self, template: str, /): parts = VARIABLE_PATTERN.split(template) @@ -154,3 +169,20 @@ class VariablePool(BaseModel): if isinstance(segment, FileSegment): return segment return None + + def _add_system_variables(self, system_variable: SystemVariable): + sys_var_mapping = system_variable.to_dict() + for key, value in sys_var_mapping.items(): + if value is None: + continue + selector = (SYSTEM_VARIABLE_NODE_ID, key) + # If the system variable already exists, do not add it again. + # This ensures that we can keep the id of the system variables intact. + if self._has(selector): + continue + self.add(selector, value) # type: ignore + + @classmethod + def empty(cls) -> "VariablePool": + """Create an empty variable pool.""" + return cls(system_variables=SystemVariable.empty()) diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py deleted file mode 100644 index 8896416f12..0000000000 --- a/api/core/workflow/entities/workflow_entities.py +++ /dev/null @@ -1,79 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel - -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.nodes.base import BaseIterationState, BaseLoopState, BaseNode -from models.enums import UserFrom -from models.workflow import Workflow, WorkflowType - -from .node_entities import NodeRunResult -from .variable_pool import VariablePool - - -class WorkflowNodeAndResult: - node: BaseNode - result: Optional[NodeRunResult] = None - - def __init__(self, node: BaseNode, result: Optional[NodeRunResult] = None): - self.node = node - self.result = result - - -class WorkflowRunState: - tenant_id: str - app_id: str - workflow_id: str - workflow_type: WorkflowType - user_id: str - user_from: UserFrom - invoke_from: InvokeFrom - - workflow_call_depth: int - - start_at: float - variable_pool: VariablePool - - total_tokens: int = 0 - - workflow_nodes_and_results: list[WorkflowNodeAndResult] - - class NodeRun(BaseModel): - node_id: str - iteration_node_id: str - loop_node_id: str - - workflow_node_runs: list[NodeRun] - workflow_node_steps: int - - current_iteration_state: Optional[BaseIterationState] - current_loop_state: Optional[BaseLoopState] - - def __init__( - self, - workflow: Workflow, - start_at: float, - variable_pool: VariablePool, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, - workflow_call_depth: int, - ): - self.workflow_id = workflow.id - self.tenant_id = workflow.tenant_id - self.app_id = workflow.app_id - self.workflow_type = WorkflowType.value_of(workflow.type) - self.user_id = user_id - self.user_from = user_from - self.invoke_from = invoke_from - self.workflow_call_depth = workflow_call_depth - - self.start_at = start_at - self.variable_pool = variable_pool - - self.total_tokens = 0 - - self.workflow_node_steps = 1 - self.workflow_node_runs = [] - self.current_iteration_state = None - self.current_loop_state = None diff --git a/api/core/workflow/errors.py b/api/core/workflow/errors.py index bd4ccc1072..594bb2b32e 100644 --- a/api/core/workflow/errors.py +++ b/api/core/workflow/errors.py @@ -2,7 +2,7 @@ from core.workflow.nodes.base import BaseNode class WorkflowNodeRunFailedError(Exception): - def __init__(self, node_instance: BaseNode, error: str): - self.node_instance = node_instance - self.error = error - super().__init__(f"Node {node_instance.node_data.title} run failed: {error}") + def __init__(self, node: BaseNode, err_msg: str): + self._node = node + self._error = err_msg + super().__init__(f"Node {node.title} run failed: {err_msg}") diff --git a/api/core/workflow/graph_engine/__init__.py b/api/core/workflow/graph_engine/__init__.py index 2fee3d7fad..12e1de464b 100644 --- a/api/core/workflow/graph_engine/__init__.py +++ b/api/core/workflow/graph_engine/__init__.py @@ -1,3 +1,4 @@ from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState +from .graph_engine import GraphEngine -__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"] +__all__ = ["Graph", "GraphEngine", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"] diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 8e5b1e7142..362777a199 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -334,7 +334,7 @@ class Graph(BaseModel): parallel = GraphParallel( start_from_node_id=start_node_id, - parent_parallel_id=parent_parallel.id if parent_parallel else None, + parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None, ) parallel_mapping[parallel.id] = parallel diff --git a/api/core/workflow/graph_engine/entities/graph_runtime_state.py b/api/core/workflow/graph_engine/entities/graph_runtime_state.py index afc09bfac5..a62ffe46c9 100644 --- a/api/core/workflow/graph_engine/entities/graph_runtime_state.py +++ b/api/core/workflow/graph_engine/entities/graph_runtime_state.py @@ -17,8 +17,12 @@ class GraphRuntimeState(BaseModel): """total tokens""" llm_usage: LLMUsage = LLMUsage.empty_usage() """llm usage info""" + + # The `outputs` field stores the final output values generated by executing workflows or chatflows. + # + # Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent + # after a serialization and deserialization round trip. outputs: dict[str, Any] = {} - """outputs""" node_run_steps: int = 0 """node run steps""" diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 61a7a26652..b315129763 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -12,7 +12,7 @@ from typing import Any, Optional, cast from flask import Flask, current_app from configs import dify_config -from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError +from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult from core.workflow.entities.variable_pool import VariablePool, VariableValue @@ -48,11 +48,9 @@ from core.workflow.nodes.agent.entities import AgentNodeData from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor from core.workflow.nodes.answer.base_stream_processor import StreamProcessor from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.base.entities import BaseNodeData from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent -from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.utils import variable_utils from libs.flask_utils import preserve_flask_contexts from models.enums import UserFrom @@ -103,7 +101,7 @@ class GraphEngine: call_depth: int, graph: Graph, graph_config: Mapping[str, Any], - variable_pool: VariablePool, + graph_runtime_state: GraphRuntimeState, max_execution_steps: int, max_execution_time: int, thread_pool_id: Optional[str] = None, @@ -140,7 +138,7 @@ class GraphEngine: call_depth=call_depth, ) - self.graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + self.graph_runtime_state = graph_runtime_state self.max_execution_steps = max_execution_steps self.max_execution_time = max_execution_time @@ -260,12 +258,16 @@ class GraphEngine: # convert to specific node node_type = NodeType(node_config.get("data", {}).get("type")) node_version = node_config.get("data", {}).get("version", "1") + + # Import here to avoid circular import + from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING + node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None # init workflow run state - node_instance = node_cls( # type: ignore + node = node_cls( id=route_node_state.id, config=node_config, graph_init_params=self.init_params, @@ -274,11 +276,11 @@ class GraphEngine: previous_node_id=previous_node_id, thread_pool_id=self.thread_pool_id, ) - node_instance = cast(BaseNode[BaseNodeData], node_instance) + node.init_node_data(node_config.get("data", {})) try: # run node generator = self._run_node( - node_instance=node_instance, + node=node, route_node_state=route_node_state, parallel_id=in_parallel_id, parallel_start_node_id=parallel_start_node_id, @@ -306,16 +308,16 @@ class GraphEngine: route_node_state.failed_reason = str(e) yield NodeRunFailedEvent( error=str(e), - id=node_instance.id, + id=node.id, node_id=next_node_id, node_type=node_type, - node_data=node_instance.node_data, + node_data=node.get_base_node_data(), route_node_state=route_node_state, parallel_id=in_parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) raise e @@ -337,7 +339,7 @@ class GraphEngine: edge = edge_mappings[0] if ( previous_route_node_state.status == RouteNodeState.Status.EXCEPTION - and node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH + and node.error_strategy == ErrorStrategy.FAIL_BRANCH and edge.run_condition is None ): break @@ -413,8 +415,8 @@ class GraphEngine: next_node_id = final_node_id elif ( - node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH - and node_instance.should_continue_on_error + node.continue_on_error + and node.error_strategy == ErrorStrategy.FAIL_BRANCH and previous_route_node_state.status == RouteNodeState.Status.EXCEPTION ): break @@ -597,7 +599,7 @@ class GraphEngine: def _run_node( self, - node_instance: BaseNode[BaseNodeData], + node: BaseNode, route_node_state: RouteNodeState, parallel_id: Optional[str] = None, parallel_start_node_id: Optional[str] = None, @@ -611,29 +613,29 @@ class GraphEngine: # trigger node run start event agent_strategy = ( AgentNodeStrategyInit( - name=cast(AgentNodeData, node_instance.node_data).agent_strategy_name, - icon=cast(AgentNode, node_instance).agent_strategy_icon, + name=cast(AgentNodeData, node.get_base_node_data()).agent_strategy_name, + icon=cast(AgentNode, node).agent_strategy_icon, ) - if node_instance.node_type == NodeType.AGENT + if node.type_ == NodeType.AGENT else None ) yield NodeRunStartedEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node.get_base_node_data(), route_node_state=route_node_state, - predecessor_node_id=node_instance.previous_node_id, + predecessor_node_id=node.previous_node_id, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, agent_strategy=agent_strategy, - node_version=node_instance.version(), + node_version=node.version(), ) - max_retries = node_instance.node_data.retry_config.max_retries - retry_interval = node_instance.node_data.retry_config.retry_interval_seconds + max_retries = node.retry_config.max_retries + retry_interval = node.retry_config.retry_interval_seconds retries = 0 should_continue_retry = True while should_continue_retry and retries <= max_retries: @@ -642,7 +644,7 @@ class GraphEngine: retry_start_at = datetime.now(UTC).replace(tzinfo=None) # yield control to other threads time.sleep(0.001) - event_stream = node_instance.run() + event_stream = node.run() for event in event_stream: if isinstance(event, GraphEngineEvent): # add parallel info to iteration event @@ -658,21 +660,21 @@ class GraphEngine: if run_result.status == WorkflowNodeExecutionStatus.FAILED: if ( retries == max_retries - and node_instance.node_type == NodeType.HTTP_REQUEST + and node.type_ == NodeType.HTTP_REQUEST and run_result.outputs - and not node_instance.should_continue_on_error + and not node.continue_on_error ): run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED - if node_instance.should_retry and retries < max_retries: + if node.retry and retries < max_retries: retries += 1 route_node_state.node_run_result = run_result yield NodeRunRetryEvent( id=str(uuid.uuid4()), - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, + node_id=node.node_id, + node_type=node.type_, + node_data=node.get_base_node_data(), route_node_state=route_node_state, - predecessor_node_id=node_instance.previous_node_id, + predecessor_node_id=node.previous_node_id, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, @@ -680,17 +682,17 @@ class GraphEngine: error=run_result.error or "Unknown error", retry_index=retries, start_at=retry_start_at, - node_version=node_instance.version(), + node_version=node.version(), ) time.sleep(retry_interval) break route_node_state.set_finished(run_result=run_result) if run_result.status == WorkflowNodeExecutionStatus.FAILED: - if node_instance.should_continue_on_error: + if node.continue_on_error: # if run failed, handle error run_result = self._handle_continue_on_error( - node_instance, + node, event.run_result, self.graph_runtime_state.variable_pool, handle_exceptions=handle_exceptions, @@ -701,44 +703,44 @@ class GraphEngine: for variable_key, variable_value in run_result.outputs.items(): # append variables to variable pool recursively self._append_variables_recursively( - node_id=node_instance.node_id, + node_id=node.node_id, variable_key_list=[variable_key], variable_value=variable_value, ) yield NodeRunExceptionEvent( error=run_result.error or "System Error", - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node.get_base_node_data(), route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) should_continue_retry = False else: yield NodeRunFailedEvent( error=route_node_state.failed_reason or "Unknown error.", - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node.get_base_node_data(), route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) should_continue_retry = False elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: if ( - node_instance.should_continue_on_error - and self.graph.edge_mapping.get(node_instance.node_id) - and node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH + node.continue_on_error + and self.graph.edge_mapping.get(node.node_id) + and node.error_strategy is ErrorStrategy.FAIL_BRANCH ): run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS if run_result.metadata and run_result.metadata.get( @@ -758,7 +760,7 @@ class GraphEngine: for variable_key, variable_value in run_result.outputs.items(): # append variables to variable pool recursively self._append_variables_recursively( - node_id=node_instance.node_id, + node_id=node.node_id, variable_key_list=[variable_key], variable_value=variable_value, ) @@ -783,26 +785,26 @@ class GraphEngine: run_result.metadata = metadata_dict yield NodeRunSucceededEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node.get_base_node_data(), route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) should_continue_retry = False break elif isinstance(event, RunStreamChunkEvent): yield NodeRunStreamChunkEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node.get_base_node_data(), chunk_content=event.chunk_content, from_variable_selector=event.from_variable_selector, route_node_state=route_node_state, @@ -810,14 +812,14 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) elif isinstance(event, RunRetrieverResourceEvent): yield NodeRunRetrieverResourceEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node.get_base_node_data(), retriever_resources=event.retriever_resources, context=event.context, route_node_state=route_node_state, @@ -825,7 +827,7 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) except GenerateTaskStoppedError: # trigger node run failed event @@ -833,20 +835,20 @@ class GraphEngine: route_node_state.failed_reason = "Workflow stopped." yield NodeRunFailedEvent( error="Workflow stopped.", - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node.get_base_node_data(), route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) return except Exception as e: - logger.exception(f"Node {node_instance.node_data.title} run failed") + logger.exception(f"Node {node.title} run failed") raise e def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): @@ -886,22 +888,14 @@ class GraphEngine: def _handle_continue_on_error( self, - node_instance: BaseNode[BaseNodeData], + node: BaseNode, error_result: NodeRunResult, variable_pool: VariablePool, handle_exceptions: list[str] = [], ) -> NodeRunResult: - """ - handle continue on error when self._should_continue_on_error is True - - - :param error_result (NodeRunResult): error run result - :param variable_pool (VariablePool): variable pool - :return: excption run result - """ # add error message and error type to variable pool - variable_pool.add([node_instance.node_id, "error_message"], error_result.error) - variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type) + variable_pool.add([node.node_id, "error_message"], error_result.error) + variable_pool.add([node.node_id, "error_type"], error_result.error_type) # add error message to handle_exceptions handle_exceptions.append(error_result.error or "") node_error_args: dict[str, Any] = { @@ -909,21 +903,21 @@ class GraphEngine: "error": error_result.error, "inputs": error_result.inputs, "metadata": { - WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy, + WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node.error_strategy, }, } - if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: + if node.error_strategy is ErrorStrategy.DEFAULT_VALUE: return NodeRunResult( **node_error_args, outputs={ - **node_instance.node_data.default_value_dict, + **node.default_value_dict, "error_message": error_result.error, "error_type": error_result.error_type, }, ) - elif node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH: - if self.graph.edge_mapping.get(node_instance.node_id): + elif node.error_strategy is ErrorStrategy.FAIL_BRANCH: + if self.graph.edge_mapping.get(node.node_id): node_error_args["edge_source_handle"] = FailBranchSourceHandle.FAILED return NodeRunResult( **node_error_args, diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 987f670acb..c83303034e 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -2,58 +2,99 @@ import json from collections.abc import Generator, Mapping, Sequence from typing import Any, Optional, cast +from packaging.version import Version +from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import Session from core.agent.entities import AgentToolEntity from core.agent.plugin_entities import AgentStrategyParameter +from core.agent.strategy.plugin import PluginAgentStrategy +from core.file import File, FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.model_entities import AIModelEntity, ModelType +from core.plugin.entities.request import InvokeCredentials from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.plugin import PluginInstaller from core.provider_manager import ProviderManager -from core.tools.entities.tool_entities import ToolParameter, ToolProviderType +from core.tools.entities.tool_entities import ( + ToolIdentity, + ToolInvokeMessage, + ToolParameter, + ToolProviderType, +) from core.tools.tool_manager import ToolManager -from core.variables.segments import StringSegment +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.variables.segments import ArrayFileSegment, StringSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.event import AgentLogEvent from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated -from core.workflow.nodes.base.entities import BaseNodeData -from core.workflow.nodes.enums import NodeType -from core.workflow.nodes.event.event import RunCompletedEvent -from core.workflow.nodes.tool.tool_node import ToolNode +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType +from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db +from factories import file_factory from factories.agent_factory import get_plugin_agent_strategy +from models import ToolFile from models.model import Conversation +from services.tools.builtin_tools_manage_service import BuiltinToolManageService + +from .exc import ( + AgentInputTypeError, + AgentInvocationError, + AgentMessageTransformError, + AgentVariableNotFoundError, + AgentVariableTypeError, + ToolFileNotFoundError, +) -class AgentNode(ToolNode): +class AgentNode(BaseNode): """ Agent Node """ - _node_data_cls = AgentNodeData # type: ignore _node_type = NodeType.AGENT + _node_data: AgentNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = AgentNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data @classmethod def version(cls) -> str: return "1" def _run(self) -> Generator: - """ - Run the agent node - """ - node_data = cast(AgentNodeData, self.node_data) - try: strategy = get_plugin_agent_strategy( tenant_id=self.tenant_id, - agent_strategy_provider_name=node_data.agent_strategy_provider_name, - agent_strategy_name=node_data.agent_strategy_name, + agent_strategy_provider_name=self._node_data.agent_strategy_provider_name, + agent_strategy_name=self._node_data.agent_strategy_name, ) except Exception as e: yield RunCompletedEvent( @@ -71,14 +112,17 @@ class AgentNode(ToolNode): parameters = self._generate_agent_parameters( agent_parameters=agent_parameters, variable_pool=self.graph_runtime_state.variable_pool, - node_data=node_data, + node_data=self._node_data, + strategy=strategy, ) parameters_for_log = self._generate_agent_parameters( agent_parameters=agent_parameters, variable_pool=self.graph_runtime_state.variable_pool, - node_data=node_data, + node_data=self._node_data, for_log=True, + strategy=strategy, ) + credentials = self._generate_credentials(parameters=parameters) # get conversation id conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) @@ -89,34 +133,42 @@ class AgentNode(ToolNode): user_id=self.user_id, app_id=self.app_id, conversation_id=conversation_id.text if conversation_id else None, + credentials=credentials, ) except Exception as e: + error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e) yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, - error=f"Failed to invoke agent: {str(e)}", + error=str(error), ) ) return try: - # convert tool messages - yield from self._transform_message( - message_stream, - { + messages=message_stream, + tool_info={ "icon": self.agent_strategy_icon, - "agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name, + "agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name, }, - parameters_for_log, + parameters_for_log=parameters_for_log, + user_id=self.user_id, + tenant_id=self.tenant_id, + node_type=self.type_, + node_id=self.node_id, + node_execution_id=self.id, ) except PluginDaemonClientSideError as e: + transform_error = AgentMessageTransformError( + f"Failed to transform agent message: {str(e)}", original_error=e + ) yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, - error=f"Failed to transform agent message: {str(e)}", + error=str(transform_error), ) ) @@ -127,6 +179,7 @@ class AgentNode(ToolNode): variable_pool: VariablePool, node_data: AgentNodeData, for_log: bool = False, + strategy: PluginAgentStrategy, ) -> dict[str, Any]: """ Generate parameters based on the given tool parameters, variable pool, and node data. @@ -152,7 +205,7 @@ class AgentNode(ToolNode): if agent_input.type == "variable": variable = variable_pool.get(agent_input.value) # type: ignore if variable is None: - raise ValueError(f"Variable {agent_input.value} does not exist") + raise AgentVariableNotFoundError(str(agent_input.value)) parameter_value = variable.value elif agent_input.type in {"mixed", "constant"}: # variable_pool.convert_template expects a string template, @@ -174,12 +227,12 @@ class AgentNode(ToolNode): except json.JSONDecodeError: parameter_value = parameter_value else: - raise ValueError(f"Unknown agent input type '{agent_input.type}'") + raise AgentInputTypeError(agent_input.type) value = parameter_value if parameter.type == "array[tools]": value = cast(list[dict[str, Any]], value) value = [tool for tool in value if tool.get("enabled", False)] - + value = self._filter_mcp_type_tool(strategy, value) for tool in value: if "schemas" in tool: tool.pop("schemas") @@ -213,12 +266,20 @@ class AgentNode(ToolNode): tool_name=tool.get("tool_name", ""), tool_parameters=parameters, plugin_unique_identifier=tool.get("plugin_unique_identifier", None), + credential_id=tool.get("credential_id", None), ) extra = tool.get("extra", {}) + # This is an issue that caused problems before. + # Logically, we shouldn't use the node_data.version field for judgment + # But for backward compatibility with historical data + # this version field judgment is still preserved here. + runtime_variable_pool: VariablePool | None = None + if node_data.version != "1" or node_data.tool_node_version != "1": + runtime_variable_pool = variable_pool tool_runtime = ToolManager.get_agent_tool_runtime( - self.tenant_id, self.app_id, entity, self.invoke_from + self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool ) if tool_runtime.entity.description: tool_runtime.entity.description.llm = ( @@ -243,11 +304,12 @@ class AgentNode(ToolNode): { **tool_runtime.entity.model_dump(mode="json"), "runtime_parameters": runtime_parameters, + "credential_id": tool.get("credential_id", None), "provider_type": provider_type.value, } ) value = tool_value - if parameter.type == "model-selector": + if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR: value = cast(dict[str, Any], value) model_instance, model_schema = self._fetch_model(value) # memory config @@ -272,25 +334,41 @@ class AgentNode(ToolNode): return result + def _generate_credentials( + self, + parameters: dict[str, Any], + ) -> InvokeCredentials: + """ + Generate credentials based on the given agent parameters. + """ + + credentials = InvokeCredentials() + + # generate credentials for tools selector + credentials.tool_credentials = {} + for tool in parameters.get("tools", []): + if tool.get("credential_id"): + try: + identity = ToolIdentity.model_validate(tool.get("identity", {})) + credentials.tool_credentials[identity.provider] = tool.get("credential_id", None) + except ValidationError: + continue + return credentials + @classmethod def _extract_variable_selector_to_variable_mapping( cls, *, graph_config: Mapping[str, Any], node_id: str, - node_data: BaseNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - node_data = cast(AgentNodeData, node_data) + # Create typed NodeData from dict + typed_node_data = AgentNodeData.model_validate(node_data) + result: dict[str, Any] = {} - for parameter_name in node_data.agent_parameters: - input = node_data.agent_parameters[parameter_name] + for parameter_name in typed_node_data.agent_parameters: + input = typed_node_data.agent_parameters[parameter_name] if input.type in ["mixed", "constant"]: selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() for selector in selectors: @@ -315,7 +393,7 @@ class AgentNode(ToolNode): plugin for plugin in plugins if f"{plugin.plugin_id}/{plugin.name}" - == cast(AgentNodeData, self.node_data).agent_strategy_provider_name + == cast(AgentNodeData, self._node_data).agent_strategy_provider_name ) icon = current_plugin.declaration.icon except StopIteration: @@ -370,3 +448,249 @@ class AgentNode(ToolNode): except ValueError: model_schema.features.remove(feature) return model_schema + + def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: + """ + Filter MCP type tool + :param strategy: plugin agent strategy + :param tool: tool + :return: filtered tool dict + """ + meta_version = strategy.meta_version + if meta_version and Version(meta_version) > Version("0.0.1"): + return tools + else: + return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP.value] + + def _transform_message( + self, + messages: Generator[ToolInvokeMessage, None, None], + tool_info: Mapping[str, Any], + parameters_for_log: dict[str, Any], + user_id: str, + tenant_id: str, + node_type: NodeType, + node_id: str, + node_execution_id: str, + ) -> Generator: + """ + Convert ToolInvokeMessages into tuple[plain_text, files] + """ + # transform message and handle file storage + message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=messages, + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + ) + + text = "" + files: list[File] = [] + json_list: list[dict] = [] + + agent_logs: list[AgentLogEvent] = [] + agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {} + llm_usage: LLMUsage | None = None + variables: dict[str, Any] = {} + + for message in message_stream: + if message.type in { + ToolInvokeMessage.MessageType.IMAGE_LINK, + ToolInvokeMessage.MessageType.BINARY_LINK, + ToolInvokeMessage.MessageType.IMAGE, + }: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + + url = message.message.text + if message.meta: + transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) + else: + transfer_method = FileTransferMethod.TOOL_FILE + + tool_file_id = str(url).split("/")[-1].split(".")[0] + + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ToolFileNotFoundError(tool_file_id) + + mapping = { + "tool_file_id": tool_file_id, + "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + ) + files.append(file) + elif message.type == ToolInvokeMessage.MessageType.BLOB: + # get tool file id + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert message.meta + + tool_file_id = message.message.text.split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ToolFileNotFoundError(tool_file_id) + + mapping = { + "tool_file_id": tool_file_id, + "transfer_method": FileTransferMethod.TOOL_FILE, + } + + files.append( + file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + ) + ) + elif message.type == ToolInvokeMessage.MessageType.TEXT: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + text += message.message.text + yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"]) + elif message.type == ToolInvokeMessage.MessageType.JSON: + assert isinstance(message.message, ToolInvokeMessage.JsonMessage) + if node_type == NodeType.AGENT: + msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {}) + llm_usage = LLMUsage.from_metadata(msg_metadata) + agent_execution_metadata = { + WorkflowNodeExecutionMetadataKey(key): value + for key, value in msg_metadata.items() + if key in WorkflowNodeExecutionMetadataKey.__members__.values() + } + if message.message.json_object is not None: + json_list.append(message.message.json_object) + elif message.type == ToolInvokeMessage.MessageType.LINK: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + stream_text = f"Link: {message.message.text}\n" + text += stream_text + yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"]) + elif message.type == ToolInvokeMessage.MessageType.VARIABLE: + assert isinstance(message.message, ToolInvokeMessage.VariableMessage) + variable_name = message.message.variable_name + variable_value = message.message.variable_value + if message.message.stream: + if not isinstance(variable_value, str): + raise AgentVariableTypeError( + "When 'stream' is True, 'variable_value' must be a string.", + variable_name=variable_name, + expected_type="str", + actual_type=type(variable_value).__name__, + ) + if variable_name not in variables: + variables[variable_name] = "" + variables[variable_name] += variable_value + + yield RunStreamChunkEvent( + chunk_content=variable_value, from_variable_selector=[node_id, variable_name] + ) + else: + variables[variable_name] = variable_value + elif message.type == ToolInvokeMessage.MessageType.FILE: + assert message.meta is not None + assert isinstance(message.meta, File) + files.append(message.meta["file"]) + elif message.type == ToolInvokeMessage.MessageType.LOG: + assert isinstance(message.message, ToolInvokeMessage.LogMessage) + if message.message.metadata: + icon = tool_info.get("icon", "") + dict_metadata = dict(message.message.metadata) + if dict_metadata.get("provider"): + manager = PluginInstaller() + plugins = manager.list_plugins(tenant_id) + try: + current_plugin = next( + plugin + for plugin in plugins + if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"] + ) + icon = current_plugin.declaration.icon + except StopIteration: + pass + icon_dark = None + try: + builtin_tool = next( + provider + for provider in BuiltinToolManageService.list_builtin_tools( + user_id, + tenant_id, + ) + if provider.name == dict_metadata["provider"] + ) + icon = builtin_tool.icon + icon_dark = builtin_tool.icon_dark + except StopIteration: + pass + + dict_metadata["icon"] = icon + dict_metadata["icon_dark"] = icon_dark + message.message.metadata = dict_metadata + agent_log = AgentLogEvent( + id=message.message.id, + node_execution_id=node_execution_id, + parent_id=message.message.parent_id, + error=message.message.error, + status=message.message.status.value, + data=message.message.data, + label=message.message.label, + metadata=message.message.metadata, + node_id=node_id, + ) + + # check if the agent log is already in the list + for log in agent_logs: + if log.id == agent_log.id: + # update the log + log.data = agent_log.data + log.status = agent_log.status + log.error = agent_log.error + log.label = agent_log.label + log.metadata = agent_log.metadata + break + else: + agent_logs.append(agent_log) + + yield agent_log + + # Add agent_logs to outputs['json'] to ensure frontend can access thinking process + json_output: list[dict[str, Any]] = [] + + # Step 1: append each agent log as its own dict. + if agent_logs: + for log in agent_logs: + json_output.append( + { + "id": log.id, + "parent_id": log.parent_id, + "error": log.error, + "status": log.status, + "data": log.data, + "label": log.label, + "metadata": log.metadata, + "node_id": log.node_id, + } + ) + # Step 2: normalize JSON into {"data": [...]}.change json to list[dict] + if json_list: + json_output.extend(json_list) + else: + json_output.append({"data": []}) + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables}, + metadata={ + **agent_execution_metadata, + WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, + WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs, + }, + inputs=parameters_for_log, + llm_usage=llm_usage, + ) + ) diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index 075a41fb2f..11b11068e7 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -13,6 +13,10 @@ class AgentNodeData(BaseNodeData): agent_strategy_name: str agent_strategy_label: str # redundancy memory: MemoryConfig | None = None + # The version of the tool parameter. + # If this value is None, it indicates this is a previous version + # and requires using the legacy parameter parsing rules. + tool_node_version: str | None = None class AgentInput(BaseModel): value: Union[list[str], list[ToolSelector], Any] diff --git a/api/core/workflow/nodes/agent/exc.py b/api/core/workflow/nodes/agent/exc.py new file mode 100644 index 0000000000..d5955bdd7d --- /dev/null +++ b/api/core/workflow/nodes/agent/exc.py @@ -0,0 +1,124 @@ +from typing import Optional + + +class AgentNodeError(Exception): + """Base exception for all agent node errors.""" + + def __init__(self, message: str): + self.message = message + super().__init__(self.message) + + +class AgentStrategyError(AgentNodeError): + """Exception raised when there's an error with the agent strategy.""" + + def __init__(self, message: str, strategy_name: Optional[str] = None, provider_name: Optional[str] = None): + self.strategy_name = strategy_name + self.provider_name = provider_name + super().__init__(message) + + +class AgentStrategyNotFoundError(AgentStrategyError): + """Exception raised when the specified agent strategy is not found.""" + + def __init__(self, strategy_name: str, provider_name: Optional[str] = None): + super().__init__( + f"Agent strategy '{strategy_name}' not found" + + (f" for provider '{provider_name}'" if provider_name else ""), + strategy_name, + provider_name, + ) + + +class AgentInvocationError(AgentNodeError): + """Exception raised when there's an error invoking the agent.""" + + def __init__(self, message: str, original_error: Optional[Exception] = None): + self.original_error = original_error + super().__init__(message) + + +class AgentParameterError(AgentNodeError): + """Exception raised when there's an error with agent parameters.""" + + def __init__(self, message: str, parameter_name: Optional[str] = None): + self.parameter_name = parameter_name + super().__init__(message) + + +class AgentVariableError(AgentNodeError): + """Exception raised when there's an error with variables in the agent node.""" + + def __init__(self, message: str, variable_name: Optional[str] = None): + self.variable_name = variable_name + super().__init__(message) + + +class AgentVariableNotFoundError(AgentVariableError): + """Exception raised when a variable is not found in the variable pool.""" + + def __init__(self, variable_name: str): + super().__init__(f"Variable '{variable_name}' does not exist", variable_name) + + +class AgentInputTypeError(AgentNodeError): + """Exception raised when an unknown agent input type is encountered.""" + + def __init__(self, input_type: str): + super().__init__(f"Unknown agent input type '{input_type}'") + + +class ToolFileError(AgentNodeError): + """Exception raised when there's an error with a tool file.""" + + def __init__(self, message: str, file_id: Optional[str] = None): + self.file_id = file_id + super().__init__(message) + + +class ToolFileNotFoundError(ToolFileError): + """Exception raised when a tool file is not found.""" + + def __init__(self, file_id: str): + super().__init__(f"Tool file '{file_id}' does not exist", file_id) + + +class AgentMessageTransformError(AgentNodeError): + """Exception raised when there's an error transforming agent messages.""" + + def __init__(self, message: str, original_error: Optional[Exception] = None): + self.original_error = original_error + super().__init__(message) + + +class AgentModelError(AgentNodeError): + """Exception raised when there's an error with the model used by the agent.""" + + def __init__(self, message: str, model_name: Optional[str] = None, provider: Optional[str] = None): + self.model_name = model_name + self.provider = provider + super().__init__(message) + + +class AgentMemoryError(AgentNodeError): + """Exception raised when there's an error with the agent's memory.""" + + def __init__(self, message: str, conversation_id: Optional[str] = None): + self.conversation_id = conversation_id + super().__init__(message) + + +class AgentVariableTypeError(AgentNodeError): + """Exception raised when a variable has an unexpected type.""" + + def __init__( + self, + message: str, + variable_name: Optional[str] = None, + expected_type: Optional[str] = None, + actual_type: Optional[str] = None, + ): + self.variable_name = variable_name + self.expected_type = expected_type + self.actual_type = actual_type + super().__init__(message) diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 38c2bcbdf5..84bbabca73 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, cast +from typing import Any, Optional, cast from core.variables import ArrayFileSegment, FileSegment from core.workflow.entities.node_entities import NodeRunResult @@ -12,14 +12,37 @@ from core.workflow.nodes.answer.entities import ( VarGenerateRouteChunk, ) from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.utils.variable_template_parser import VariableTemplateParser -class AnswerNode(BaseNode[AnswerNodeData]): - _node_data_cls = AnswerNodeData +class AnswerNode(BaseNode): _node_type = NodeType.ANSWER + _node_data: AnswerNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = AnswerNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" @@ -30,7 +53,7 @@ class AnswerNode(BaseNode[AnswerNodeData]): :return: """ # generate routes - generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self.node_data) + generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self._node_data) answer = "" files = [] @@ -60,16 +83,12 @@ class AnswerNode(BaseNode[AnswerNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: AnswerNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - variable_template_parser = VariableTemplateParser(template=node_data.answer) + # Create typed NodeData from dict + typed_node_data = AnswerNodeData.model_validate(node_data) + + variable_template_parser = VariableTemplateParser(template=typed_node_data.answer) variable_selectors = variable_template_parser.extract_variable_selectors() variable_mapping = {} diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index d853eb71be..dcfed5eed2 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -122,13 +122,13 @@ class RetryConfig(BaseModel): class BaseNodeData(ABC, BaseModel): title: str desc: Optional[str] = None + version: str = "1" error_strategy: Optional[ErrorStrategy] = None default_value: Optional[list[DefaultValue]] = None - version: str = "1" retry_config: RetryConfig = RetryConfig() @property - def default_value_dict(self): + def default_value_dict(self) -> dict[str, Any]: if self.default_value: return {item.key: item.value for item in self.default_value} return {} diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 6973401429..fb5ec55453 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -1,28 +1,22 @@ import logging from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent -from .entities import BaseNodeData - if TYPE_CHECKING: + from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState from core.workflow.graph_engine.entities.event import InNodeEvent - from core.workflow.graph_engine.entities.graph import Graph - from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams - from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState logger = logging.getLogger(__name__) -GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData) - -class BaseNode(Generic[GenericNodeData]): - _node_data_cls: type[GenericNodeData] +class BaseNode: _node_type: ClassVar[NodeType] def __init__( @@ -56,8 +50,8 @@ class BaseNode(Generic[GenericNodeData]): self.node_id = node_id - node_data = self._node_data_cls.model_validate(config.get("data", {})) - self.node_data = node_data + @abstractmethod + def init_node_data(self, data: Mapping[str, Any]) -> None: ... @abstractmethod def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]: @@ -130,9 +124,9 @@ class BaseNode(Generic[GenericNodeData]): if not node_id: raise ValueError("Node ID is required when extracting variable selector to variable mapping.") - node_data = cls._node_data_cls(**config.get("data", {})) + # Pass raw dict data instead of creating NodeData instance data = cls._extract_variable_selector_to_variable_mapping( - graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data) + graph_config=graph_config, node_id=node_id, node_data=config.get("data", {}) ) return data @@ -142,32 +136,16 @@ class BaseNode(Generic[GenericNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: GenericNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ return {} @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: - """ - Get default config of node. - :param filters: filter by node config parameters. - :return: - """ return {} @property - def node_type(self) -> NodeType: - """ - Get node type - :return: - """ + def type_(self) -> NodeType: return self._node_type @classmethod @@ -181,19 +159,68 @@ class BaseNode(Generic[GenericNodeData]): raise NotImplementedError("subclasses of BaseNode must implement `version` method.") @property - def should_continue_on_error(self) -> bool: - """judge if should continue on error - - Returns: - bool: if should continue on error - """ - return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE + def continue_on_error(self) -> bool: + return False @property - def should_retry(self) -> bool: - """judge if should retry + def retry(self) -> bool: + return False - Returns: - bool: if should retry - """ - return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE + # Abstract methods that subclasses must implement to provide access + # to BaseNodeData properties in a type-safe way + + @abstractmethod + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + """Get the error strategy for this node.""" + ... + + @abstractmethod + def _get_retry_config(self) -> RetryConfig: + """Get the retry configuration for this node.""" + ... + + @abstractmethod + def _get_title(self) -> str: + """Get the node title.""" + ... + + @abstractmethod + def _get_description(self) -> Optional[str]: + """Get the node description.""" + ... + + @abstractmethod + def _get_default_value_dict(self) -> dict[str, Any]: + """Get the default values dictionary for this node.""" + ... + + @abstractmethod + def get_base_node_data(self) -> BaseNodeData: + """Get the BaseNodeData object for this node.""" + ... + + # Public interface properties that delegate to abstract methods + @property + def error_strategy(self) -> Optional[ErrorStrategy]: + """Get the error strategy for this node.""" + return self._get_error_strategy() + + @property + def retry_config(self) -> RetryConfig: + """Get the retry configuration for this node.""" + return self._get_retry_config() + + @property + def title(self) -> str: + """Get the node title.""" + return self._get_title() + + @property + def description(self) -> Optional[str]: + """Get the node description.""" + return self._get_description() + + @property + def default_value_dict(self) -> dict[str, Any]: + """Get the default values dictionary for this node.""" + return self._get_default_value_dict() diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 22ed9e2651..fdf3932827 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,4 +1,5 @@ from collections.abc import Mapping, Sequence +from decimal import Decimal from typing import Any, Optional from configs import dify_config @@ -10,8 +11,9 @@ from core.variables.segments import ArrayFileSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.code.entities import CodeNodeData -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.enums import ErrorStrategy, NodeType from .exc import ( CodeNodeError, @@ -20,10 +22,32 @@ from .exc import ( ) -class CodeNode(BaseNode[CodeNodeData]): - _node_data_cls = CodeNodeData +class CodeNode(BaseNode): _node_type = NodeType.CODE + _node_data: CodeNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = CodeNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ @@ -46,12 +70,12 @@ class CodeNode(BaseNode[CodeNodeData]): def _run(self) -> NodeRunResult: # Get code language - code_language = self.node_data.code_language - code = self.node_data.code + code_language = self._node_data.code_language + code = self._node_data.code # Get variables variables = {} - for variable_selector in self.node_data.variables: + for variable_selector in self._node_data.variables: variable_name = variable_selector.variable variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if isinstance(variable, ArrayFileSegment): @@ -67,7 +91,7 @@ class CodeNode(BaseNode[CodeNodeData]): ) # Transform result - result = self._transform_result(result=result, output_schema=self.node_data.outputs) + result = self._transform_result(result=result, output_schema=self._node_data.outputs) except (CodeExecutionError, CodeNodeError) as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__ @@ -114,8 +138,10 @@ class CodeNode(BaseNode[CodeNodeData]): ) if isinstance(value, float): + decimal_value = Decimal(str(value)).normalize() + precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator] # raise error if precision is too high - if len(str(value).split(".")[1]) > dify_config.CODE_MAX_PRECISION: + if precision > dify_config.CODE_MAX_PRECISION: raise OutputValidationError( f"Output variable `{variable}` has too high precision," f" it must be less than {dify_config.CODE_MAX_PRECISION} digits." @@ -331,16 +357,20 @@ class CodeNode(BaseNode[CodeNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: CodeNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ + # Create typed NodeData from dict + typed_node_data = CodeNodeData.model_validate(node_data) + return { node_id + "." + variable_selector.variable: variable_selector.value_selector - for variable_selector in node_data.variables + for variable_selector in typed_node_data.variables } + + @property + def continue_on_error(self) -> bool: + return self._node_data.error_strategy is not None + + @property + def retry(self) -> bool: + return self._node_data.retry_config.retry_enabled diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 8e6150f9cc..ab5964ebd4 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -5,7 +5,7 @@ import logging import os import tempfile from collections.abc import Mapping, Sequence -from typing import Any, cast +from typing import Any, Optional, cast import chardet import docx @@ -28,7 +28,8 @@ from core.variables.segments import ArrayStringSegment, FileSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from .entities import DocumentExtractorNodeData from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError @@ -36,21 +37,43 @@ from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, logger = logging.getLogger(__name__) -class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): +class DocumentExtractorNode(BaseNode): """ Extracts text content from various file types. Supports plain text, PDF, and DOC/DOCX files. """ - _node_data_cls = DocumentExtractorNodeData _node_type = NodeType.DOCUMENT_EXTRACTOR + _node_data: DocumentExtractorNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = DocumentExtractorNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" def _run(self): - variable_selector = self.node_data.variable_selector + variable_selector = self._node_data.variable_selector variable = self.graph_runtime_state.variable_pool.get(variable_selector) if variable is None: @@ -97,16 +120,12 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: DocumentExtractorNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - return {node_id + ".files": node_data.variable_selector} + # Create typed NodeData from dict + typed_node_data = DocumentExtractorNodeData.model_validate(node_data) + + return {node_id + ".files": typed_node_data.variable_selector} def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 17a0b3adeb..f86f2e8129 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,14 +1,40 @@ +from collections.abc import Mapping +from typing import Any, Optional + from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.enums import ErrorStrategy, NodeType -class EndNode(BaseNode[EndNodeData]): - _node_data_cls = EndNodeData +class EndNode(BaseNode): _node_type = NodeType.END + _node_data: EndNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = EndNodeData(**data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" @@ -18,7 +44,7 @@ class EndNode(BaseNode[EndNodeData]): Run node :return: """ - output_variables = self.node_data.outputs + output_variables = self._node_data.outputs outputs = {} for variable_selector in output_variables: diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py index 73b43eeaf7..7cf9ab9107 100644 --- a/api/core/workflow/nodes/enums.py +++ b/api/core/workflow/nodes/enums.py @@ -35,7 +35,3 @@ class ErrorStrategy(StrEnum): class FailBranchSourceHandle(StrEnum): FAILED = "fail-branch" SUCCESS = "success-branch" - - -CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST] -RETRY_ON_ERROR_NODE_TYPE = CONTINUE_ON_ERROR_NODE_TYPE diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 971e0f73e7..6799d5c63c 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -11,7 +11,8 @@ from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.http_request.executor import Executor from core.workflow.utils import variable_template_parser from factories import file_factory @@ -32,10 +33,32 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout( logger = logging.getLogger(__name__) -class HttpRequestNode(BaseNode[HttpRequestNodeData]): - _node_data_cls = HttpRequestNodeData +class HttpRequestNode(BaseNode): _node_type = NodeType.HTTP_REQUEST + _node_data: HttpRequestNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = HttpRequestNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict: return { @@ -69,8 +92,8 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): process_data = {} try: http_executor = Executor( - node_data=self.node_data, - timeout=self._get_request_timeout(self.node_data), + node_data=self._node_data, + timeout=self._get_request_timeout(self._node_data), variable_pool=self.graph_runtime_state.variable_pool, max_retries=0, ) @@ -78,7 +101,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): response = http_executor.invoke() files = self.extract_files(url=http_executor.url, response=response) - if not response.response.is_success and (self.should_continue_on_error or self.should_retry): + if not response.response.is_success and (self.continue_on_error or self.retry): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, outputs={ @@ -131,15 +154,18 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: HttpRequestNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: + # Create typed NodeData from dict + typed_node_data = HttpRequestNodeData.model_validate(node_data) + selectors: list[VariableSelector] = [] - selectors += variable_template_parser.extract_selectors_from_template(node_data.url) - selectors += variable_template_parser.extract_selectors_from_template(node_data.headers) - selectors += variable_template_parser.extract_selectors_from_template(node_data.params) - if node_data.body: - body_type = node_data.body.type - data = node_data.body.data + selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url) + selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.headers) + selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.params) + if typed_node_data.body: + body_type = typed_node_data.body.type + data = typed_node_data.body.data match body_type: case "binary": if len(data) != 1: @@ -217,3 +243,11 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): files.append(file) return ArrayFileSegment(value=files) + + @property + def continue_on_error(self) -> bool: + return self._node_data.error_strategy is not None + + @property + def retry(self) -> bool: + return self._node_data.retry_config.retry_enabled diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 22b748030c..86e703dc68 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, Literal +from typing import Any, Literal, Optional from typing_extensions import deprecated @@ -7,16 +7,39 @@ from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.if_else.entities import IfElseNodeData from core.workflow.utils.condition.entities import Condition from core.workflow.utils.condition.processor import ConditionProcessor -class IfElseNode(BaseNode[IfElseNodeData]): - _node_data_cls = IfElseNodeData +class IfElseNode(BaseNode): _node_type = NodeType.IF_ELSE + _node_data: IfElseNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = IfElseNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" @@ -36,8 +59,8 @@ class IfElseNode(BaseNode[IfElseNodeData]): condition_processor = ConditionProcessor() try: # Check if the new cases structure is used - if self.node_data.cases: - for case in self.node_data.cases: + if self._node_data.cases: + for case in self._node_data.cases: input_conditions, group_result, final_result = condition_processor.process_conditions( variable_pool=self.graph_runtime_state.variable_pool, conditions=case.conditions, @@ -63,8 +86,8 @@ class IfElseNode(BaseNode[IfElseNodeData]): input_conditions, group_result, final_result = _should_not_use_old_function( condition_processor=condition_processor, variable_pool=self.graph_runtime_state.variable_pool, - conditions=self.node_data.conditions or [], - operator=self.node_data.logical_operator or "and", + conditions=self._node_data.conditions or [], + operator=self._node_data.logical_operator or "and", ) selected_case_id = "true" if final_result else "false" @@ -98,10 +121,13 @@ class IfElseNode(BaseNode[IfElseNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: IfElseNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: + # Create typed NodeData from dict + typed_node_data = IfElseNodeData.model_validate(node_data) + var_mapping: dict[str, list[str]] = {} - for case in node_data.cases or []: + for case in typed_node_data.cases or []: for condition in case.conditions: key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector)) var_mapping[key] = condition.variable_selector diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 151efc28ec..5842c8d64b 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -1,5 +1,6 @@ import contextvars import logging +import time import uuid from collections.abc import Generator, Mapping, Sequence from concurrent.futures import Future, wait @@ -35,7 +36,8 @@ from core.workflow.graph_engine.entities.event import ( ) from core.workflow.graph_engine.entities.graph import Graph from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from factories.variable_factory import build_segment @@ -55,14 +57,36 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class IterationNode(BaseNode[IterationNodeData]): +class IterationNode(BaseNode): """ Iteration Node. """ - _node_data_cls = IterationNodeData _node_type = NodeType.ITERATION + _node_data: IterationNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = IterationNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: return { @@ -82,10 +106,10 @@ class IterationNode(BaseNode[IterationNodeData]): """ Run the node. """ - variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) + variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector) if not variable: - raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found") + raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found") if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable): raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") @@ -115,10 +139,10 @@ class IterationNode(BaseNode[IterationNodeData]): graph_config = self.graph_config - if not self.node_data.start_node_id: + if not self._node_data.start_node_id: raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found") - root_node_id = self.node_data.start_node_id + root_node_id = self._node_data.start_node_id # init graph iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id) @@ -133,8 +157,11 @@ class IterationNode(BaseNode[IterationNodeData]): variable_pool.add([self.node_id, "item"], iterator_list_value[0]) # init graph engine + from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + graph_engine = GraphEngine( tenant_id=self.tenant_id, app_id=self.app_id, @@ -146,7 +173,7 @@ class IterationNode(BaseNode[IterationNodeData]): call_depth=self.workflow_call_depth, graph=iteration_graph, graph_config=graph_config, - variable_pool=variable_pool, + graph_runtime_state=graph_runtime_state, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, thread_pool_id=self.thread_pool_id, @@ -157,8 +184,8 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunStartedEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, metadata={"iterator_length": len(iterator_list_value)}, @@ -168,8 +195,8 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunNextEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, index=0, pre_iteration_output=None, duration=None, @@ -177,11 +204,11 @@ class IterationNode(BaseNode[IterationNodeData]): iter_run_map: dict[str, float] = {} outputs: list[Any] = [None] * len(iterator_list_value) try: - if self.node_data.is_parallel: + if self._node_data.is_parallel: futures: list[Future] = [] q: Queue = Queue() thread_pool = GraphEngineThreadPool( - max_workers=self.node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT + max_workers=self._node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT ) for index, item in enumerate(iterator_list_value): future: Future = thread_pool.submit( @@ -238,7 +265,7 @@ class IterationNode(BaseNode[IterationNodeData]): iteration_graph=iteration_graph, iter_run_map=iter_run_map, ) - if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: + if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: outputs = [output for output in outputs if output is not None] # Flatten the list of lists @@ -249,8 +276,8 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunSucceededEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, outputs={"output": outputs}, @@ -274,8 +301,8 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, outputs={"output": outputs}, @@ -301,21 +328,17 @@ class IterationNode(BaseNode[IterationNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: IterationNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ + # Create typed NodeData from dict + typed_node_data = IterationNodeData.model_validate(node_data) + variable_mapping: dict[str, Sequence[str]] = { - f"{node_id}.input_selector": node_data.iterator_selector, + f"{node_id}.input_selector": typed_node_data.iterator_selector, } # init graph - iteration_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id) + iteration_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id) if not iteration_graph: raise IterationGraphNotFoundError("iteration graph not found") @@ -371,7 +394,7 @@ class IterationNode(BaseNode[IterationNodeData]): """ if not isinstance(event, BaseNodeEvent): return event - if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent): + if self._node_data.is_parallel and isinstance(event, NodeRunStartedEvent): event.parallel_mode_run_id = parallel_mode_run_id iter_metadata = { @@ -434,12 +457,12 @@ class IterationNode(BaseNode[IterationNodeData]): elif isinstance(event, BaseGraphEvent): if isinstance(event, GraphRunFailedEvent): # iteration run failed - if self.node_data.is_parallel: + if self._node_data.is_parallel: yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, parallel_mode_run_id=parallel_mode_run_id, start_at=start_at, inputs=inputs, @@ -452,8 +475,8 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, outputs={"output": outputs}, @@ -474,7 +497,7 @@ class IterationNode(BaseNode[IterationNodeData]): event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id ) if isinstance(event, NodeRunFailedEvent): - if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR: + if self._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR: yield NodeInIterationFailedEvent( **metadata_event.model_dump(), ) @@ -487,15 +510,15 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunNextEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, index=next_index, parallel_mode_run_id=parallel_mode_run_id, pre_iteration_output=None, duration=duration, ) return - elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: + elif self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: yield NodeInIterationFailedEvent( **metadata_event.model_dump(), ) @@ -508,30 +531,64 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunNextEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, index=next_index, parallel_mode_run_id=parallel_mode_run_id, pre_iteration_output=None, duration=duration, ) return - elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED: - yield IterationRunFailedEvent( - iteration_id=self.id, - iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, - start_at=start_at, - inputs=inputs, - outputs={"output": None}, - steps=len(iterator_list_value), - metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, - error=event.error, + elif self._node_data.error_handle_mode == ErrorHandleMode.TERMINATED: + yield NodeInIterationFailedEvent( + **metadata_event.model_dump(), ) + outputs[current_index] = None + + # clean nodes resources + for node_id in iteration_graph.node_ids: + variable_pool.remove([node_id]) + + # iteration run failed + if self._node_data.is_parallel: + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, + parallel_mode_run_id=parallel_mode_run_id, + start_at=start_at, + inputs=inputs, + outputs={"output": outputs}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=event.error, + ) + else: + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": outputs}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=event.error, + ) + + # stop the iterator + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=event.error, + ) + ) + return yield metadata_event - current_output_segment = variable_pool.get(self.node_data.output_selector) + current_output_segment = variable_pool.get(self._node_data.output_selector) if current_output_segment is None: raise IterationNodeError("iteration output selector not found") current_iteration_output = current_output_segment.value @@ -550,8 +607,8 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunNextEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, index=next_index, parallel_mode_run_id=parallel_mode_run_id, pre_iteration_output=current_iteration_output or None, @@ -563,8 +620,8 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, outputs={"output": None}, diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py index 9900aa225d..b82c29291a 100644 --- a/api/core/workflow/nodes/iteration/iteration_start_node.py +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -1,18 +1,44 @@ +from collections.abc import Mapping +from typing import Any, Optional + from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.iteration.entities import IterationStartNodeData -class IterationStartNode(BaseNode[IterationStartNodeData]): +class IterationStartNode(BaseNode): """ Iteration Start Node. """ - _node_data_cls = IterationStartNodeData _node_type = NodeType.ITERATION_START + _node_data: IterationStartNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = IterationStartNodeData(**data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 19bdee4fe2..f1767bdf9e 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -1,10 +1,10 @@ from collections.abc import Sequence -from typing import Any, Literal, Optional +from typing import Literal, Optional from pydantic import BaseModel, Field from core.workflow.nodes.base import BaseNodeData -from core.workflow.nodes.llm.entities import VisionConfig +from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig class RerankingModelConfig(BaseModel): @@ -56,17 +56,6 @@ class MultipleRetrievalConfig(BaseModel): weights: Optional[WeightedScoreConfig] = None -class ModelConfig(BaseModel): - """ - Model Config. - """ - - provider: str - name: str - mode: str - completion_params: dict[str, Any] = {} - - class SingleRetrievalConfig(BaseModel): """ Single Retrieval Config. diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index b34d62d669..34b0afc75d 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -4,7 +4,7 @@ import re import time from collections import defaultdict from collections.abc import Mapping, Sequence -from typing import Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast from sqlalchemy import Float, and_, func, or_, text from sqlalchemy import cast as sqlalchemy_cast @@ -15,20 +15,31 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.model_runtime.entities.model_entities import ModelFeature, ModelType +from core.model_runtime.entities.message_entities import ( + PromptMessageRole, +) +from core.model_runtime.entities.model_entities import ( + ModelFeature, + ModelType, +) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.simple_prompt_transform import ModelMode from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.variables import StringSegment +from core.variables import ( + StringSegment, +) from core.variables.segments import ArrayObjectSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.nodes.enums import NodeType -from core.workflow.nodes.event.event import ModelInvokeCompletedEvent +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType +from core.workflow.nodes.event import ( + ModelInvokeCompletedEvent, +) from core.workflow.nodes.knowledge_retrieval.template_prompts import ( METADATA_FILTER_ASSISTANT_PROMPT_1, METADATA_FILTER_ASSISTANT_PROMPT_2, @@ -38,7 +49,8 @@ from core.workflow.nodes.knowledge_retrieval.template_prompts import ( METADATA_FILTER_USER_PROMPT_2, METADATA_FILTER_USER_PROMPT_3, ) -from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate +from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, ModelConfig +from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver from core.workflow.nodes.llm.node import LLMNode from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -46,7 +58,7 @@ from libs.json_in_md_parser import parse_and_check_json_markdown from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog from services.feature_service import FeatureService -from .entities import KnowledgeRetrievalNodeData, ModelConfig +from .entities import KnowledgeRetrievalNodeData from .exc import ( InvalidModelTypeError, KnowledgeRetrievalNodeError, @@ -56,6 +68,10 @@ from .exc import ( ModelQuotaExceededError, ) +if TYPE_CHECKING: + from core.file.models import File + from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState + logger = logging.getLogger(__name__) default_retrieval_model = { @@ -67,18 +83,76 @@ default_retrieval_model = { } -class KnowledgeRetrievalNode(LLMNode): - _node_data_cls = KnowledgeRetrievalNodeData # type: ignore +class KnowledgeRetrievalNode(BaseNode): _node_type = NodeType.KNOWLEDGE_RETRIEVAL + _node_data: KnowledgeRetrievalNodeData + + # Instance attributes specific to LLMNode. + # Output variable for file + _file_outputs: list["File"] + + _llm_file_saver: LLMFileSaver + + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph: "Graph", + graph_runtime_state: "GraphRuntimeState", + previous_node_id: Optional[str] = None, + thread_pool_id: Optional[str] = None, + *, + llm_file_saver: LLMFileSaver | None = None, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph=graph, + graph_runtime_state=graph_runtime_state, + previous_node_id=previous_node_id, + thread_pool_id=thread_pool_id, + ) + # LLM file outputs, used for MultiModal outputs. + self._file_outputs: list[File] = [] + + if llm_file_saver is None: + llm_file_saver = FileSaverImpl( + user_id=graph_init_params.user_id, + tenant_id=graph_init_params.tenant_id, + ) + self._llm_file_saver = llm_file_saver + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = KnowledgeRetrievalNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls): return "1" def _run(self) -> NodeRunResult: # type: ignore - node_data = cast(KnowledgeRetrievalNodeData, self.node_data) # extract variables - variable = self.graph_runtime_state.variable_pool.get(node_data.query_variable_selector) + variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector) if not isinstance(variable, StringSegment): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -119,7 +193,7 @@ class KnowledgeRetrievalNode(LLMNode): # retrieve knowledge try: - results = self._fetch_dataset_retriever(node_data=node_data, query=query) + results = self._fetch_dataset_retriever(node_data=self._node_data, query=query) outputs = {"result": ArrayObjectSegment(value=results)} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -144,6 +218,8 @@ class KnowledgeRetrievalNode(LLMNode): error=str(e), error_type=type(e).__name__, ) + finally: + db.session.close() def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]: available_datasets = [] @@ -152,7 +228,7 @@ class KnowledgeRetrievalNode(LLMNode): # Subquery: Count the number of available documents for each dataset subquery = ( db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count")) - .filter( + .where( Document.indexing_status == "completed", Document.enabled == True, Document.archived == False, @@ -166,11 +242,14 @@ class KnowledgeRetrievalNode(LLMNode): results = ( db.session.query(Dataset) .outerjoin(subquery, Dataset.id == subquery.c.dataset_id) - .filter(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids)) - .filter((subquery.c.available_document_count > 0) | (Dataset.provider == "external")) + .where(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids)) + .where((subquery.c.available_document_count > 0) | (Dataset.provider == "external")) .all() ) + # avoid blocking at retrieval + db.session.close() + for dataset in results: # pass if dataset is not available if not dataset: @@ -291,7 +370,7 @@ class KnowledgeRetrievalNode(LLMNode): dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore document = ( db.session.query(Document) - .filter( + .where( Document.id == segment.document_id, Document.enabled == True, Document.archived == False, @@ -336,7 +415,7 @@ class KnowledgeRetrievalNode(LLMNode): def _get_metadata_filter_condition( self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]: - document_query = db.session.query(Document).filter( + document_query = db.session.query(Document).where( Document.dataset_id.in_(dataset_ids), Document.indexing_status == "completed", Document.enabled == True, @@ -383,7 +462,7 @@ class KnowledgeRetrievalNode(LLMNode): expected_value = self.graph_runtime_state.variable_pool.convert_template( expected_value ).value[0] - if expected_value.value_type == "number": # type: ignore + if expected_value.value_type in {"number", "integer", "float"}: # type: ignore expected_value = expected_value.value # type: ignore elif expected_value.value_type == "string": # type: ignore expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore @@ -414,9 +493,9 @@ class KnowledgeRetrievalNode(LLMNode): node_data.metadata_filtering_conditions and node_data.metadata_filtering_conditions.logical_operator == "and" ): # type: ignore - document_query = document_query.filter(and_(*filters)) + document_query = document_query.where(and_(*filters)) else: - document_query = document_query.filter(or_(*filters)) + document_query = document_query.where(or_(*filters)) documents = document_query.all() # group by dataset_id metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore @@ -428,22 +507,19 @@ class KnowledgeRetrievalNode(LLMNode): self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData ) -> list[dict[str, Any]]: # get all metadata field - metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all() + metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all() all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] - # get metadata model config - metadata_model_config = node_data.metadata_model_config - if metadata_model_config is None: + if node_data.metadata_model_config is None: raise ValueError("metadata_model_config is required") - # get metadata model instance - # fetch model config - model_instance, model_config = self.get_model_config(metadata_model_config) + # get metadata model instance and fetch model config + model_instance, model_config = self.get_model_config(node_data.metadata_model_config) # fetch prompt messages prompt_template = self._get_prompt_template( node_data=node_data, metadata_fields=all_metadata_fields, query=query or "", ) - prompt_messages, stop = self._fetch_prompt_messages( + prompt_messages, stop = LLMNode.fetch_prompt_messages( prompt_template=prompt_template, sys_query=query, memory=None, @@ -453,16 +529,23 @@ class KnowledgeRetrievalNode(LLMNode): vision_detail=node_data.vision.configs.detail, variable_pool=self.graph_runtime_state.variable_pool, jinja2_variables=[], + tenant_id=self.tenant_id, ) result_text = "" try: # handle invoke result - generator = self._invoke_llm( - node_data_model=node_data.metadata_model_config, # type: ignore + generator = LLMNode.invoke_llm( + node_data_model=node_data.metadata_model_config, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, + user_id=self.user_id, + structured_output_enabled=self._node_data.structured_output_enabled, + structured_output=None, + file_saver=self._llm_file_saver, + file_outputs=self._file_outputs, + node_id=self.node_id, ) for event in generator: @@ -552,17 +635,13 @@ class KnowledgeRetrievalNode(LLMNode): *, graph_config: Mapping[str, Any], node_id: str, - node_data: KnowledgeRetrievalNodeData, # type: ignore + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ + # Create typed NodeData from dict + typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data) + variable_mapping = {} - variable_mapping[node_id + ".query"] = node_data.query_variable_selector + variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector return variable_mapping def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: @@ -624,7 +703,7 @@ class KnowledgeRetrievalNode(LLMNode): ) def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str): - model_mode = ModelMode.value_of(node_data.metadata_model_config.mode) # type: ignore + model_mode = ModelMode(node_data.metadata_model_config.mode) # type: ignore input_text = query prompt_messages: list[LLMNodeChatModelMessage] = [] diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index 3c9ba44cf1..b91fc622f6 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -1,5 +1,5 @@ -from collections.abc import Callable, Sequence -from typing import Any, Literal, Union +from collections.abc import Callable, Mapping, Sequence +from typing import Any, Literal, Optional, Union from core.file import File from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment @@ -7,16 +7,39 @@ from core.variables.segments import ArrayAnySegment, ArraySegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from .entities import ListOperatorNodeData from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError -class ListOperatorNode(BaseNode[ListOperatorNodeData]): - _node_data_cls = ListOperatorNodeData +class ListOperatorNode(BaseNode): _node_type = NodeType.LIST_OPERATOR + _node_data: ListOperatorNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = ListOperatorNodeData(**data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" @@ -26,9 +49,9 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): process_data: dict[str, list] = {} outputs: dict[str, Any] = {} - variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable) + variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable) if variable is None: - error_message = f"Variable not found for selector: {self.node_data.variable}" + error_message = f"Variable not found for selector: {self._node_data.variable}" return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs ) @@ -48,7 +71,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): ) if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment): error_message = ( - f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment " + f"Variable {self._node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment " "or ArrayStringSegment" ) return NodeRunResult( @@ -64,19 +87,19 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): try: # Filter - if self.node_data.filter_by.enabled: + if self._node_data.filter_by.enabled: variable = self._apply_filter(variable) # Extract - if self.node_data.extract_by.enabled: + if self._node_data.extract_by.enabled: variable = self._extract_slice(variable) # Order - if self.node_data.order_by.enabled: + if self._node_data.order_by.enabled: variable = self._apply_order(variable) # Slice - if self.node_data.limit.enabled: + if self._node_data.limit.enabled: variable = self._apply_slice(variable) outputs = { @@ -104,7 +127,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: filter_func: Callable[[Any], bool] result: list[Any] = [] - for condition in self.node_data.filter_by.conditions: + for condition in self._node_data.filter_by.conditions: if isinstance(variable, ArrayStringSegment): if not isinstance(condition.value, str): raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") @@ -137,14 +160,14 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: if isinstance(variable, ArrayStringSegment): - result = _order_string(order=self.node_data.order_by.value, array=variable.value) + result = _order_string(order=self._node_data.order_by.value, array=variable.value) variable = variable.model_copy(update={"value": result}) elif isinstance(variable, ArrayNumberSegment): - result = _order_number(order=self.node_data.order_by.value, array=variable.value) + result = _order_number(order=self._node_data.order_by.value, array=variable.value) variable = variable.model_copy(update={"value": result}) elif isinstance(variable, ArrayFileSegment): result = _order_file( - order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value + order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value ) variable = variable.model_copy(update={"value": result}) return variable @@ -152,20 +175,19 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): def _apply_slice( self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: - result = variable.value[: self.node_data.limit.size] + result = variable.value[: self._node_data.limit.size] return variable.model_copy(update={"value": result}) def _extract_slice( self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: - value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text) + value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text) if value < 1: raise ValueError(f"Invalid serial index: must be >= 1, got {value}") + if value > len(variable.value): + raise InvalidKeyError(f"Invalid serial index: must be <= {len(variable.value)}, got {value}") value -= 1 - if len(variable.value) > int(value): - result = variable.value[value] - else: - result = "" + result = variable.value[value] return variable.model_copy(update={"value": [result]}) diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 36d0688807..4bb62d35a2 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -1,4 +1,4 @@ -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from typing import Any, Optional from pydantic import BaseModel, Field, field_validator @@ -65,7 +65,7 @@ class LLMNodeData(BaseNodeData): memory: Optional[MemoryConfig] = None context: ContextConfig vision: VisionConfig = Field(default_factory=VisionConfig) - structured_output: dict | None = None + structured_output: Mapping[str, Any] | None = None # We used 'structured_output_enabled' in the past, but it's not a good name. structured_output_switch_on: bool = Field(False, alias="structured_output_enabled") diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index b5225ce548..90a0397b67 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -59,7 +59,8 @@ from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event import ( ModelInvokeCompletedEvent, NodeEvent, @@ -90,17 +91,16 @@ from .file_saver import FileSaverImpl, LLMFileSaver if TYPE_CHECKING: from core.file.models import File - from core.workflow.graph_engine.entities.graph import Graph - from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams - from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState + from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState logger = logging.getLogger(__name__) -class LLMNode(BaseNode[LLMNodeData]): - _node_data_cls = LLMNodeData +class LLMNode(BaseNode): _node_type = NodeType.LLM + _node_data: LLMNodeData + # Instance attributes specific to LLMNode. # Output variable for file _file_outputs: list["File"] @@ -138,6 +138,27 @@ class LLMNode(BaseNode[LLMNodeData]): ) self._llm_file_saver = llm_file_saver + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = LLMNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" @@ -152,13 +173,13 @@ class LLMNode(BaseNode[LLMNodeData]): try: # init messages template - self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template) + self._node_data.prompt_template = self._transform_chat_messages(self._node_data.prompt_template) # fetch variables and fetch values from variable pool - inputs = self._fetch_inputs(node_data=self.node_data) + inputs = self._fetch_inputs(node_data=self._node_data) # fetch jinja2 inputs - jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data) + jinja_inputs = self._fetch_jinja_inputs(node_data=self._node_data) # merge inputs inputs.update(jinja_inputs) @@ -169,9 +190,9 @@ class LLMNode(BaseNode[LLMNodeData]): files = ( llm_utils.fetch_files( variable_pool=variable_pool, - selector=self.node_data.vision.configs.variable_selector, + selector=self._node_data.vision.configs.variable_selector, ) - if self.node_data.vision.enabled + if self._node_data.vision.enabled else [] ) @@ -179,7 +200,7 @@ class LLMNode(BaseNode[LLMNodeData]): node_inputs["#files#"] = [file.to_dict() for file in files] # fetch context value - generator = self._fetch_context(node_data=self.node_data) + generator = self._fetch_context(node_data=self._node_data) context = None for event in generator: if isinstance(event, RunRetrieverResourceEvent): @@ -189,53 +210,54 @@ class LLMNode(BaseNode[LLMNodeData]): node_inputs["#context#"] = context # fetch model config - model_instance, model_config = self._fetch_model_config(self.node_data.model) + model_instance, model_config = LLMNode._fetch_model_config( + node_data_model=self._node_data.model, + tenant_id=self.tenant_id, + ) # fetch memory memory = llm_utils.fetch_memory( variable_pool=variable_pool, app_id=self.app_id, - node_data_memory=self.node_data.memory, + node_data_memory=self._node_data.memory, model_instance=model_instance, ) query = None - if self.node_data.memory: - query = self.node_data.memory.query_prompt_template + if self._node_data.memory: + query = self._node_data.memory.query_prompt_template if not query and ( query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) ): query = query_variable.text - prompt_messages, stop = self._fetch_prompt_messages( + prompt_messages, stop = LLMNode.fetch_prompt_messages( sys_query=query, sys_files=files, context=context, memory=memory, model_config=model_config, - prompt_template=self.node_data.prompt_template, - memory_config=self.node_data.memory, - vision_enabled=self.node_data.vision.enabled, - vision_detail=self.node_data.vision.configs.detail, + prompt_template=self._node_data.prompt_template, + memory_config=self._node_data.memory, + vision_enabled=self._node_data.vision.enabled, + vision_detail=self._node_data.vision.configs.detail, variable_pool=variable_pool, - jinja2_variables=self.node_data.prompt_config.jinja2_variables, + jinja2_variables=self._node_data.prompt_config.jinja2_variables, + tenant_id=self.tenant_id, ) - process_data = { - "model_mode": model_config.mode, - "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=model_config.mode, prompt_messages=prompt_messages - ), - "model_provider": model_config.provider, - "model_name": model_config.model, - } - # handle invoke result - generator = self._invoke_llm( - node_data_model=self.node_data.model, + generator = LLMNode.invoke_llm( + node_data_model=self._node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, + user_id=self.user_id, + structured_output_enabled=self._node_data.structured_output_enabled, + structured_output=self._node_data.structured_output, + file_saver=self._llm_file_saver, + file_outputs=self._file_outputs, + node_id=self.node_id, ) structured_output: LLMStructuredOutput | None = None @@ -253,6 +275,17 @@ class LLMNode(BaseNode[LLMNodeData]): elif isinstance(event, LLMStructuredOutput): structured_output = event + process_data = { + "model_mode": model_config.mode, + "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_config.mode, prompt_messages=prompt_messages + ), + "usage": jsonable_encoder(usage), + "finish_reason": finish_reason, + "model_provider": model_config.provider, + "model_name": model_config.model, + } + outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} if structured_output: outputs["structured_output"] = structured_output.structured_output @@ -294,12 +327,19 @@ class LLMNode(BaseNode[LLMNodeData]): ) ) - def _invoke_llm( - self, + @staticmethod + def invoke_llm( + *, node_data_model: ModelConfig, model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], stop: Optional[Sequence[str]] = None, + user_id: str, + structured_output_enabled: bool, + structured_output: Optional[Mapping[str, Any]] = None, + file_saver: LLMFileSaver, + file_outputs: list["File"], + node_id: str, ) -> Generator[NodeEvent | LLMStructuredOutput, None, None]: model_schema = model_instance.model_type_instance.get_model_schema( node_data_model.name, model_instance.credentials @@ -307,8 +347,10 @@ class LLMNode(BaseNode[LLMNodeData]): if not model_schema: raise ValueError(f"Model schema not found for {node_data_model.name}") - if self.node_data.structured_output_enabled: - output_schema = self._fetch_structured_output_schema() + if structured_output_enabled: + output_schema = LLMNode.fetch_structured_output_schema( + structured_output=structured_output or {}, + ) invoke_result = invoke_llm_with_structured_output( provider=model_instance.provider, model_schema=model_schema, @@ -318,7 +360,7 @@ class LLMNode(BaseNode[LLMNodeData]): model_parameters=node_data_model.completion_params, stop=list(stop or []), stream=True, - user=self.user_id, + user=user_id, ) else: invoke_result = model_instance.invoke_llm( @@ -326,17 +368,31 @@ class LLMNode(BaseNode[LLMNodeData]): model_parameters=node_data_model.completion_params, stop=list(stop or []), stream=True, - user=self.user_id, + user=user_id, ) - return self._handle_invoke_result(invoke_result=invoke_result) + return LLMNode.handle_invoke_result( + invoke_result=invoke_result, + file_saver=file_saver, + file_outputs=file_outputs, + node_id=node_id, + ) - def _handle_invoke_result( - self, invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None] + @staticmethod + def handle_invoke_result( + *, + invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], + file_saver: LLMFileSaver, + file_outputs: list["File"], + node_id: str, ) -> Generator[NodeEvent | LLMStructuredOutput, None, None]: # For blocking mode if isinstance(invoke_result, LLMResult): - event = self._handle_blocking_result(invoke_result=invoke_result) + event = LLMNode.handle_blocking_result( + invoke_result=invoke_result, + saver=file_saver, + file_outputs=file_outputs, + ) yield event return @@ -354,11 +410,13 @@ class LLMNode(BaseNode[LLMNodeData]): yield result if isinstance(result, LLMResultChunk): contents = result.delta.message.content - for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents): + for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown( + contents=contents, + file_saver=file_saver, + file_outputs=file_outputs, + ): full_text_buffer.write(text_part) - yield RunStreamChunkEvent( - chunk_content=text_part, from_variable_selector=[self.node_id, "text"] - ) + yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[node_id, "text"]) # Update the whole metadata if not model and result.model: @@ -376,7 +434,8 @@ class LLMNode(BaseNode[LLMNodeData]): yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason) - def _image_file_to_markdown(self, file: "File", /): + @staticmethod + def _image_file_to_markdown(file: "File", /): text_chunk = f"![]({file.generate_url()})" return text_chunk @@ -506,7 +565,7 @@ class LLMNode(BaseNode[LLMNodeData]): retriever_resources=original_retriever_resource, context=context_str.strip() ) - def _convert_to_original_retriever_resource(self, context_dict: dict): + def _convert_to_original_retriever_resource(self, context_dict: dict) -> RetrievalSourceMetadata | None: if ( "metadata" in context_dict and "_source" in context_dict["metadata"] @@ -537,11 +596,14 @@ class LLMNode(BaseNode[LLMNodeData]): return None + @staticmethod def _fetch_model_config( - self, node_data_model: ModelConfig + *, + node_data_model: ModelConfig, + tenant_id: str, ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: model, model_config_with_cred = llm_utils.fetch_model_config( - tenant_id=self.tenant_id, node_data_model=node_data_model + tenant_id=tenant_id, node_data_model=node_data_model ) completion_params = model_config_with_cred.parameters @@ -554,8 +616,8 @@ class LLMNode(BaseNode[LLMNodeData]): node_data_model.completion_params = completion_params return model, model_config_with_cred - def _fetch_prompt_messages( - self, + @staticmethod + def fetch_prompt_messages( *, sys_query: str | None = None, sys_files: Sequence["File"], @@ -568,13 +630,14 @@ class LLMNode(BaseNode[LLMNodeData]): vision_detail: ImagePromptMessageContent.DETAIL, variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], + tenant_id: str, ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: prompt_messages: list[PromptMessage] = [] if isinstance(prompt_template, list): # For chat model prompt_messages.extend( - self._handle_list_messages( + LLMNode.handle_list_messages( messages=prompt_template, context=context, jinja2_variables=jinja2_variables, @@ -600,7 +663,7 @@ class LLMNode(BaseNode[LLMNodeData]): edition_type="basic", ) prompt_messages.extend( - self._handle_list_messages( + LLMNode.handle_list_messages( messages=[message], context="", jinja2_variables=[], @@ -729,7 +792,7 @@ class LLMNode(BaseNode[LLMNodeData]): ) model = ModelManager().get_model_instance( - tenant_id=self.tenant_id, + tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model, @@ -748,10 +811,12 @@ class LLMNode(BaseNode[LLMNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: LLMNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - prompt_template = node_data.prompt_template + # Create typed NodeData from dict + typed_node_data = LLMNodeData.model_validate(node_data) + prompt_template = typed_node_data.prompt_template variable_selectors = [] if isinstance(prompt_template, list) and all( isinstance(prompt, LLMNodeChatModelMessage) for prompt in prompt_template @@ -771,7 +836,7 @@ class LLMNode(BaseNode[LLMNodeData]): for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector - memory = node_data.memory + memory = typed_node_data.memory if memory and memory.query_prompt_template: query_variable_selectors = VariableTemplateParser( template=memory.query_prompt_template @@ -779,16 +844,16 @@ class LLMNode(BaseNode[LLMNodeData]): for variable_selector in query_variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector - if node_data.context.enabled: - variable_mapping["#context#"] = node_data.context.variable_selector + if typed_node_data.context.enabled: + variable_mapping["#context#"] = typed_node_data.context.variable_selector - if node_data.vision.enabled: - variable_mapping["#files#"] = node_data.vision.configs.variable_selector + if typed_node_data.vision.enabled: + variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector - if node_data.memory: + if typed_node_data.memory: variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value] - if node_data.prompt_config: + if typed_node_data.prompt_config: enable_jinja = False if isinstance(prompt_template, list): @@ -801,7 +866,7 @@ class LLMNode(BaseNode[LLMNodeData]): enable_jinja = True if enable_jinja: - for variable_selector in node_data.prompt_config.jinja2_variables or []: + for variable_selector in typed_node_data.prompt_config.jinja2_variables or []: variable_mapping[variable_selector.variable] = variable_selector.value_selector variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} @@ -833,8 +898,8 @@ class LLMNode(BaseNode[LLMNodeData]): }, } - def _handle_list_messages( - self, + @staticmethod + def handle_list_messages( *, messages: Sequence[LLMNodeChatModelMessage], context: Optional[str], @@ -847,7 +912,7 @@ class LLMNode(BaseNode[LLMNodeData]): if message.edition_type == "jinja2": result_text = _render_jinja2_message( template=message.jinja2_text or "", - jinjia2_variables=jinja2_variables, + jinja2_variables=jinja2_variables, variable_pool=variable_pool, ) prompt_message = _combine_message_content_with_role( @@ -895,9 +960,19 @@ class LLMNode(BaseNode[LLMNodeData]): return prompt_messages - def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent: + @staticmethod + def handle_blocking_result( + *, + invoke_result: LLMResult, + saver: LLMFileSaver, + file_outputs: list["File"], + ) -> ModelInvokeCompletedEvent: buffer = io.StringIO() - for text_part in self._save_multimodal_output_and_convert_result_to_markdown(invoke_result.message.content): + for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown( + contents=invoke_result.message.content, + file_saver=saver, + file_outputs=file_outputs, + ): buffer.write(text_part) return ModelInvokeCompletedEvent( @@ -906,7 +981,12 @@ class LLMNode(BaseNode[LLMNodeData]): finish_reason=None, ) - def _save_multimodal_image_output(self, content: ImagePromptMessageContent) -> "File": + @staticmethod + def save_multimodal_image_output( + *, + content: ImagePromptMessageContent, + file_saver: LLMFileSaver, + ) -> "File": """_save_multimodal_output saves multi-modal contents generated by LLM plugins. There are two kinds of multimodal outputs: @@ -916,26 +996,21 @@ class LLMNode(BaseNode[LLMNodeData]): Currently, only image files are supported. """ - # Inject the saver somehow... - _saver = self._llm_file_saver - - # If this if content.url != "": - saved_file = _saver.save_remote_url(content.url, FileType.IMAGE) + saved_file = file_saver.save_remote_url(content.url, FileType.IMAGE) else: - saved_file = _saver.save_binary_string( + saved_file = file_saver.save_binary_string( data=base64.b64decode(content.base64_data), mime_type=content.mime_type, file_type=FileType.IMAGE, ) - self._file_outputs.append(saved_file) return saved_file def _fetch_model_schema(self, provider: str) -> AIModelEntity | None: """ Fetch model schema """ - model_name = self.node_data.model.name + model_name = self._node_data.model.name model_manager = ModelManager() model_instance = model_manager.get_model_instance( tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name @@ -946,16 +1021,20 @@ class LLMNode(BaseNode[LLMNodeData]): model_schema = model_type_instance.get_model_schema(model_name, model_credentials) return model_schema - def _fetch_structured_output_schema(self) -> dict[str, Any]: + @staticmethod + def fetch_structured_output_schema( + *, + structured_output: Mapping[str, Any], + ) -> dict[str, Any]: """ Fetch the structured output schema from the node data. Returns: dict[str, Any]: The structured output schema """ - if not self.node_data.structured_output: + if not structured_output: raise LLMNodeError("Please provide a valid structured output schema") - structured_output_schema = json.dumps(self.node_data.structured_output.get("schema", {}), ensure_ascii=False) + structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False) if not structured_output_schema: raise LLMNodeError("Please provide a valid structured output schema") @@ -967,9 +1046,12 @@ class LLMNode(BaseNode[LLMNodeData]): except json.JSONDecodeError: raise LLMNodeError("structured_output_schema is not valid JSON format") + @staticmethod def _save_multimodal_output_and_convert_result_to_markdown( - self, + *, contents: str | list[PromptMessageContentUnionTypes] | None, + file_saver: LLMFileSaver, + file_outputs: list["File"], ) -> Generator[str, None, None]: """Convert intermediate prompt messages into strings and yield them to the caller. @@ -992,9 +1074,12 @@ class LLMNode(BaseNode[LLMNodeData]): if isinstance(item, TextPromptMessageContent): yield item.data elif isinstance(item, ImagePromptMessageContent): - file = self._save_multimodal_image_output(item) - self._file_outputs.append(file) - yield self._image_file_to_markdown(file) + file = LLMNode.save_multimodal_image_output( + content=item, + file_saver=file_saver, + ) + file_outputs.append(file) + yield LLMNode._image_file_to_markdown(file) else: logger.warning("unknown item type encountered, type=%s", type(item)) yield str(item) @@ -1002,6 +1087,14 @@ class LLMNode(BaseNode[LLMNodeData]): logger.warning("unknown contents type encountered, type=%s", type(contents)) yield str(contents) + @property + def continue_on_error(self) -> bool: + return self._node_data.error_strategy is not None + + @property + def retry(self) -> bool: + return self._node_data.retry_config.retry_enabled + def _combine_message_content_with_role( *, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole @@ -1019,20 +1112,20 @@ def _combine_message_content_with_role( def _render_jinja2_message( *, template: str, - jinjia2_variables: Sequence[VariableSelector], + jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, ): if not template: return "" - jinjia2_inputs = {} - for jinja2_variable in jinjia2_variables: + jinja2_inputs = {} + for jinja2_variable in jinja2_variables: variable = variable_pool.get(jinja2_variable.value_selector) - jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" + jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" code_execute_resp = CodeExecutor.execute_workflow_code_template( language=CodeLanguage.JINJA2, code=template, - inputs=jinjia2_inputs, + inputs=jinja2_inputs, ) result_text = code_execute_resp["result"] return result_text @@ -1128,7 +1221,7 @@ def _handle_completion_template( if template.edition_type == "jinja2": result_text = _render_jinja2_message( template=template.jinja2_text or "", - jinjia2_variables=jinja2_variables, + jinja2_variables=jinja2_variables, variable_pool=variable_pool, ) else: diff --git a/api/core/workflow/nodes/loop/entities.py b/api/core/workflow/nodes/loop/entities.py index 3f4a5edab9..d04e0bfae1 100644 --- a/api/core/workflow/nodes/loop/entities.py +++ b/api/core/workflow/nodes/loop/entities.py @@ -1,11 +1,29 @@ from collections.abc import Mapping -from typing import Any, Literal, Optional +from typing import Annotated, Any, Literal, Optional -from pydantic import BaseModel, Field +from pydantic import AfterValidator, BaseModel, Field +from core.variables.types import SegmentType from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData from core.workflow.utils.condition.entities import Condition +_VALID_VAR_TYPE = frozenset( + [ + SegmentType.STRING, + SegmentType.NUMBER, + SegmentType.OBJECT, + SegmentType.ARRAY_STRING, + SegmentType.ARRAY_NUMBER, + SegmentType.ARRAY_OBJECT, + ] +) + + +def _is_valid_var_type(seg_type: SegmentType) -> SegmentType: + if seg_type not in _VALID_VAR_TYPE: + raise ValueError(...) + return seg_type + class LoopVariableData(BaseModel): """ @@ -13,7 +31,7 @@ class LoopVariableData(BaseModel): """ label: str - var_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"] + var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)] value_type: Literal["variable", "constant"] value: Optional[Any | list[str]] = None diff --git a/api/core/workflow/nodes/loop/loop_end_node.py b/api/core/workflow/nodes/loop/loop_end_node.py index b144021bab..53cadc5251 100644 --- a/api/core/workflow/nodes/loop/loop_end_node.py +++ b/api/core/workflow/nodes/loop/loop_end_node.py @@ -1,18 +1,44 @@ +from collections.abc import Mapping +from typing import Any, Optional + from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.loop.entities import LoopEndNodeData -class LoopEndNode(BaseNode[LoopEndNodeData]): +class LoopEndNode(BaseNode): """ Loop End Node. """ - _node_data_cls = LoopEndNodeData _node_type = NodeType.LOOP_END + _node_data: LoopEndNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = LoopEndNodeData(**data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 368d662a75..655de9362f 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -1,19 +1,15 @@ import json import logging +import time from collections.abc import Generator, Mapping, Sequence from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, Optional, cast from configs import dify_config from core.variables import ( - ArrayNumberSegment, - ArrayObjectSegment, - ArrayStringSegment, IntegerSegment, - ObjectSegment, Segment, SegmentType, - StringSegment, ) from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -34,10 +30,12 @@ from core.workflow.graph_engine.entities.event import ( ) from core.workflow.graph_engine.entities.graph import Graph from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.loop.entities import LoopNodeData from core.workflow.utils.condition.processor import ConditionProcessor +from factories.variable_factory import TypeMismatchError, build_segment_with_type if TYPE_CHECKING: from core.workflow.entities.variable_pool import VariablePool @@ -46,14 +44,36 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class LoopNode(BaseNode[LoopNodeData]): +class LoopNode(BaseNode): """ Loop Node. """ - _node_data_cls = LoopNodeData _node_type = NodeType.LOOP + _node_data: LoopNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = LoopNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" @@ -61,17 +81,17 @@ class LoopNode(BaseNode[LoopNodeData]): def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: """Run the node.""" # Get inputs - loop_count = self.node_data.loop_count - break_conditions = self.node_data.break_conditions - logical_operator = self.node_data.logical_operator + loop_count = self._node_data.loop_count + break_conditions = self._node_data.break_conditions + logical_operator = self._node_data.logical_operator inputs = {"loop_count": loop_count} - if not self.node_data.start_node_id: + if not self._node_data.start_node_id: raise ValueError(f"field start_node_id in loop {self.node_id} not found") # Initialize graph - loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self.node_data.start_node_id) + loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self._node_data.start_node_id) if not loop_graph: raise ValueError("loop graph not found") @@ -81,8 +101,8 @@ class LoopNode(BaseNode[LoopNodeData]): # Initialize loop variables loop_variable_selectors = {} - if self.node_data.loop_variables: - for loop_variable in self.node_data.loop_variables: + if self._node_data.loop_variables: + for loop_variable in self._node_data.loop_variables: value_processor = { "constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value), "variable": lambda var=loop_variable: variable_pool.get(var.value), @@ -101,8 +121,11 @@ class LoopNode(BaseNode[LoopNodeData]): loop_variable_selectors[loop_variable.label] = variable_selector inputs[loop_variable.label] = processed_segment.value + from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.graph_engine import GraphEngine + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + graph_engine = GraphEngine( tenant_id=self.tenant_id, app_id=self.app_id, @@ -114,7 +137,7 @@ class LoopNode(BaseNode[LoopNodeData]): call_depth=self.workflow_call_depth, graph=loop_graph, graph_config=self.graph_config, - variable_pool=variable_pool, + graph_runtime_state=graph_runtime_state, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, thread_pool_id=self.thread_pool_id, @@ -127,8 +150,8 @@ class LoopNode(BaseNode[LoopNodeData]): yield LoopRunStartedEvent( loop_id=self.id, loop_node_id=self.node_id, - loop_node_type=self.node_type, - loop_node_data=self.node_data, + loop_node_type=self.type_, + loop_node_data=self._node_data, start_at=start_at, inputs=inputs, metadata={"loop_length": loop_count}, @@ -184,11 +207,11 @@ class LoopNode(BaseNode[LoopNodeData]): yield LoopRunSucceededEvent( loop_id=self.id, loop_node_id=self.node_id, - loop_node_type=self.node_type, - loop_node_data=self.node_data, + loop_node_type=self.type_, + loop_node_data=self._node_data, start_at=start_at, inputs=inputs, - outputs=self.node_data.outputs, + outputs=self._node_data.outputs, steps=loop_count, metadata={ WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, @@ -206,7 +229,7 @@ class LoopNode(BaseNode[LoopNodeData]): WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, }, - outputs=self.node_data.outputs, + outputs=self._node_data.outputs, inputs=inputs, ) ) @@ -217,8 +240,8 @@ class LoopNode(BaseNode[LoopNodeData]): yield LoopRunFailedEvent( loop_id=self.id, loop_node_id=self.node_id, - loop_node_type=self.node_type, - loop_node_data=self.node_data, + loop_node_type=self.type_, + loop_node_data=self._node_data, start_at=start_at, inputs=inputs, steps=loop_count, @@ -320,8 +343,8 @@ class LoopNode(BaseNode[LoopNodeData]): yield LoopRunFailedEvent( loop_id=self.id, loop_node_id=self.node_id, - loop_node_type=self.node_type, - loop_node_data=self.node_data, + loop_node_type=self.type_, + loop_node_data=self._node_data, start_at=start_at, inputs=inputs, steps=current_index, @@ -351,8 +374,8 @@ class LoopNode(BaseNode[LoopNodeData]): yield LoopRunFailedEvent( loop_id=self.id, loop_node_id=self.node_id, - loop_node_type=self.node_type, - loop_node_data=self.node_data, + loop_node_type=self.type_, + loop_node_data=self._node_data, start_at=start_at, inputs=inputs, steps=current_index, @@ -388,7 +411,7 @@ class LoopNode(BaseNode[LoopNodeData]): _outputs[loop_variable_key] = None _outputs["loop_round"] = current_index + 1 - self.node_data.outputs = _outputs + self._node_data.outputs = _outputs if check_break_result: return {"check_break_result": True} @@ -400,10 +423,10 @@ class LoopNode(BaseNode[LoopNodeData]): yield LoopRunNextEvent( loop_id=self.id, loop_node_id=self.node_id, - loop_node_type=self.node_type, - loop_node_data=self.node_data, + loop_node_type=self.type_, + loop_node_data=self._node_data, index=next_index, - pre_loop_output=self.node_data.outputs, + pre_loop_output=self._node_data.outputs, ) return {"check_break_result": False} @@ -438,19 +461,15 @@ class LoopNode(BaseNode[LoopNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: LoopNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ + # Create typed NodeData from dict + typed_node_data = LoopNodeData.model_validate(node_data) + variable_mapping = {} # init graph - loop_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id) + loop_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id) if not loop_graph: raise ValueError("loop graph not found") @@ -486,7 +505,7 @@ class LoopNode(BaseNode[LoopNodeData]): variable_mapping.update(sub_node_variable_mapping) - for loop_variable in node_data.loop_variables or []: + for loop_variable in typed_node_data.loop_variables or []: if loop_variable.value_type == "variable": assert loop_variable.value is not None, "Loop variable value must be provided for variable type" # add loop variable to variable mapping @@ -501,23 +520,21 @@ class LoopNode(BaseNode[LoopNodeData]): return variable_mapping @staticmethod - def _get_segment_for_constant(var_type: str, value: Any) -> Segment: + def _get_segment_for_constant(var_type: SegmentType, value: Any) -> Segment: """Get the appropriate segment type for a constant value.""" - segment_mapping: dict[str, tuple[type[Segment], SegmentType]] = { - "string": (StringSegment, SegmentType.STRING), - "number": (IntegerSegment, SegmentType.NUMBER), - "object": (ObjectSegment, SegmentType.OBJECT), - "array[string]": (ArrayStringSegment, SegmentType.ARRAY_STRING), - "array[number]": (ArrayNumberSegment, SegmentType.ARRAY_NUMBER), - "array[object]": (ArrayObjectSegment, SegmentType.ARRAY_OBJECT), - } if var_type in ["array[string]", "array[number]", "array[object]"]: - if value: + if value and isinstance(value, str): value = json.loads(value) else: value = [] - segment_info = segment_mapping.get(var_type) - if not segment_info: - raise ValueError(f"Invalid variable type: {var_type}") - segment_class, value_type = segment_info - return segment_class(value=value, value_type=value_type) + try: + return build_segment_with_type(var_type, value) + except TypeMismatchError as type_exc: + # Attempt to parse the value as a JSON-encoded string, if applicable. + if not isinstance(value, str): + raise + try: + value = json.loads(value) + except ValueError: + raise type_exc + return build_segment_with_type(var_type, value) diff --git a/api/core/workflow/nodes/loop/loop_start_node.py b/api/core/workflow/nodes/loop/loop_start_node.py index f5e38b7516..29b45ea0c3 100644 --- a/api/core/workflow/nodes/loop/loop_start_node.py +++ b/api/core/workflow/nodes/loop/loop_start_node.py @@ -1,18 +1,44 @@ +from collections.abc import Mapping +from typing import Any, Optional + from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.loop.entities import LoopStartNodeData -class LoopStartNode(BaseNode[LoopStartNodeData]): +class LoopStartNode(BaseNode): """ Loop Start Node. """ - _node_data_cls = LoopStartNodeData _node_type = NodeType.LOOP_START + _node_data: LoopStartNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = LoopStartNodeData(**data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py index 67cc884f20..294b47670b 100644 --- a/api/core/workflow/nodes/node_mapping.py +++ b/api/core/workflow/nodes/node_mapping.py @@ -73,6 +73,10 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = { }, NodeType.TOOL: { LATEST_VERSION: ToolNode, + # This is an issue that caused problems before. + # Logically, we shouldn't use two different versions to point to the same class here, + # but in order to maintain compatibility with historical data, this approach has been retained. + "2": ToolNode, "1": ToolNode, }, NodeType.VARIABLE_AGGREGATOR: { @@ -122,6 +126,10 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = { }, NodeType.AGENT: { LATEST_VERSION: AgentNode, + # This is an issue that caused problems before. + # Logically, we shouldn't use two different versions to point to the same class here, + # but in order to maintain compatibility with historical data, this approach has been retained. + "2": AgentNode, "1": AgentNode, }, } diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 8d6c2d0a5c..a23d284626 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -29,8 +29,9 @@ from core.variables.types import SegmentType from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.llm import ModelConfig, llm_utils from core.workflow.utils import variable_template_parser from factories.variable_factory import build_segment_with_type @@ -91,10 +92,31 @@ class ParameterExtractorNode(BaseNode): Parameter Extractor Node. """ - # FIXME: figure out why here is different from super class - _node_data_cls = ParameterExtractorNodeData # type: ignore _node_type = NodeType.PARAMETER_EXTRACTOR + _node_data: ParameterExtractorNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = ParameterExtractorNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + _model_instance: Optional[ModelInstance] = None _model_config: Optional[ModelConfigWithCredentialsEntity] = None @@ -119,7 +141,7 @@ class ParameterExtractorNode(BaseNode): """ Run the node. """ - node_data = cast(ParameterExtractorNodeData, self.node_data) + node_data = cast(ParameterExtractorNodeData, self._node_data) variable = self.graph_runtime_state.variable_pool.get(node_data.query) query = variable.text if variable else "" @@ -253,7 +275,12 @@ class ParameterExtractorNode(BaseNode): status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, process_data=process_data, - outputs={"__is_success": 1 if not error else 0, "__reason": error, **result}, + outputs={ + "__is_success": 1 if not error else 0, + "__reason": error, + "__usage": jsonable_encoder(usage), + **result, + }, metadata={ WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, @@ -393,7 +420,7 @@ class ParameterExtractorNode(BaseNode): """ Generate prompt engineering prompt. """ - model_mode = ModelMode.value_of(data.model.mode) + model_mode = ModelMode(data.model.mode) if model_mode == ModelMode.COMPLETION: return self._generate_prompt_engineering_completion_prompt( @@ -689,7 +716,7 @@ class ParameterExtractorNode(BaseNode): memory: Optional[TokenBufferMemory], max_token_limit: int = 2000, ) -> list[ChatModelMessage]: - model_mode = ModelMode.value_of(node_data.model.mode) + model_mode = ModelMode(node_data.model.mode) input_text = query memory_str = "" instruction = variable_pool.convert_template(node_data.instruction or "").text @@ -716,7 +743,7 @@ class ParameterExtractorNode(BaseNode): memory: Optional[TokenBufferMemory], max_token_limit: int = 2000, ): - model_mode = ModelMode.value_of(node_data.model.mode) + model_mode = ModelMode(node_data.model.mode) input_text = query memory_str = "" instruction = variable_pool.convert_template(node_data.instruction or "").text @@ -822,19 +849,15 @@ class ParameterExtractorNode(BaseNode): *, graph_config: Mapping[str, Any], node_id: str, - node_data: ParameterExtractorNodeData, # type: ignore + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query} + # Create typed NodeData from dict + typed_node_data = ParameterExtractorNodeData.model_validate(node_data) - if node_data.instruction: - selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction) + variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query} + + if typed_node_data.instruction: + selectors = variable_template_parser.extract_selectors_from_template(typed_node_data.instruction) for selector in selectors: variable_mapping[selector.variable] = selector.value_selector diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index a518167cc6..15012fa48d 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,6 +1,6 @@ import json from collections.abc import Mapping, Sequence -from typing import Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -11,8 +11,11 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.base.node import BaseNode +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event import ModelInvokeCompletedEvent from core.workflow.nodes.llm import ( LLMNode, @@ -20,6 +23,7 @@ from core.workflow.nodes.llm import ( LLMNodeCompletionModelPromptTemplate, llm_utils, ) +from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver from core.workflow.utils.variable_template_parser import VariableTemplateParser from libs.json_in_md_parser import parse_and_check_json_markdown @@ -35,17 +39,77 @@ from .template_prompts import ( QUESTION_CLASSIFIER_USER_PROMPT_3, ) +if TYPE_CHECKING: + from core.file.models import File + from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState -class QuestionClassifierNode(LLMNode): - _node_data_cls = QuestionClassifierNodeData # type: ignore + +class QuestionClassifierNode(BaseNode): _node_type = NodeType.QUESTION_CLASSIFIER + _node_data: QuestionClassifierNodeData + + _file_outputs: list["File"] + _llm_file_saver: LLMFileSaver + + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph: "Graph", + graph_runtime_state: "GraphRuntimeState", + previous_node_id: Optional[str] = None, + thread_pool_id: Optional[str] = None, + *, + llm_file_saver: LLMFileSaver | None = None, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph=graph, + graph_runtime_state=graph_runtime_state, + previous_node_id=previous_node_id, + thread_pool_id=thread_pool_id, + ) + # LLM file outputs, used for MultiModal outputs. + self._file_outputs: list[File] = [] + + if llm_file_saver is None: + llm_file_saver = FileSaverImpl( + user_id=graph_init_params.user_id, + tenant_id=graph_init_params.tenant_id, + ) + self._llm_file_saver = llm_file_saver + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = QuestionClassifierNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls): return "1" def _run(self): - node_data = cast(QuestionClassifierNodeData, self.node_data) + node_data = cast(QuestionClassifierNodeData, self._node_data) variable_pool = self.graph_runtime_state.variable_pool # extract variables @@ -53,7 +117,10 @@ class QuestionClassifierNode(LLMNode): query = variable.value if variable else None variables = {"query": query} # fetch model config - model_instance, model_config = self._fetch_model_config(node_data.model) + model_instance, model_config = LLMNode._fetch_model_config( + node_data_model=node_data.model, + tenant_id=self.tenant_id, + ) # fetch memory memory = llm_utils.fetch_memory( variable_pool=variable_pool, @@ -91,7 +158,7 @@ class QuestionClassifierNode(LLMNode): # If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt, # two consecutive user prompts will be generated, causing model's error. # To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end. - prompt_messages, stop = self._fetch_prompt_messages( + prompt_messages, stop = LLMNode.fetch_prompt_messages( prompt_template=prompt_template, sys_query="", memory=memory, @@ -101,6 +168,7 @@ class QuestionClassifierNode(LLMNode): vision_detail=node_data.vision.configs.detail, variable_pool=variable_pool, jinja2_variables=[], + tenant_id=self.tenant_id, ) result_text = "" @@ -109,11 +177,17 @@ class QuestionClassifierNode(LLMNode): try: # handle invoke result - generator = self._invoke_llm( + generator = LLMNode.invoke_llm( node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, + user_id=self.user_id, + structured_output_enabled=False, + structured_output=None, + file_saver=self._llm_file_saver, + file_outputs=self._file_outputs, + node_id=self.node_id, ) for event in generator: @@ -145,7 +219,11 @@ class QuestionClassifierNode(LLMNode): "model_provider": model_config.provider, "model_name": model_config.model, } - outputs = {"class_name": category_name, "class_id": category_id} + outputs = { + "class_name": category_name, + "class_id": category_id, + "usage": jsonable_encoder(usage), + } return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -179,23 +257,18 @@ class QuestionClassifierNode(LLMNode): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Any, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - node_data = cast(QuestionClassifierNodeData, node_data) - variable_mapping = {"query": node_data.query_variable_selector} - variable_selectors = [] - if node_data.instruction: - variable_template_parser = VariableTemplateParser(template=node_data.instruction) + # Create typed NodeData from dict + typed_node_data = QuestionClassifierNodeData.model_validate(node_data) + + variable_mapping = {"query": typed_node_data.query_variable_selector} + variable_selectors: list[VariableSelector] = [] + if typed_node_data.instruction: + variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction) variable_selectors.extend(variable_template_parser.extract_variable_selectors()) for variable_selector in variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector + variable_mapping[variable_selector.variable] = list(variable_selector.value_selector) variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} @@ -261,7 +334,7 @@ class QuestionClassifierNode(LLMNode): memory: Optional[TokenBufferMemory], max_token_limit: int = 2000, ): - model_mode = ModelMode.value_of(node_data.model.mode) + model_mode = ModelMode(node_data.model.mode) classes = node_data.classes categories = [] for class_ in classes: diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 5ee9bc331f..9e401e76bb 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,22 +1,48 @@ +from collections.abc import Mapping +from typing import Any, Optional + from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.start.entities import StartNodeData -class StartNode(BaseNode[StartNodeData]): - _node_data_cls = StartNodeData +class StartNode(BaseNode): _node_type = NodeType.START + _node_data: StartNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = StartNodeData(**data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" def _run(self) -> NodeRunResult: node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables + system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() # TODO: System variables should be directly accessible, no need for special handling # Set system variables as node outputs. diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index ba573074c3..1962c82db1 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -6,16 +6,39 @@ from core.helper.code_executor.code_executor import CodeExecutionError, CodeExec from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000")) -class TemplateTransformNode(BaseNode[TemplateTransformNodeData]): - _node_data_cls = TemplateTransformNodeData +class TemplateTransformNode(BaseNode): _node_type = NodeType.TEMPLATE_TRANSFORM + _node_data: TemplateTransformNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = TemplateTransformNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ @@ -35,14 +58,14 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]): def _run(self) -> NodeRunResult: # Get variables variables = {} - for variable_selector in self.node_data.variables: + for variable_selector in self._node_data.variables: variable_name = variable_selector.variable value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) variables[variable_name] = value.to_object() if value else None # Run code try: result = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables + language=CodeLanguage.JINJA2, code=self._node_data.template, inputs=variables ) except CodeExecutionError as e: return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) @@ -60,16 +83,12 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData + cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any] ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ + # Create typed NodeData from dict + typed_node_data = TemplateTransformNodeData.model_validate(node_data) + return { node_id + "." + variable_selector.variable: variable_selector.value_selector - for variable_selector in node_data.variables + for variable_selector in typed_node_data.variables } diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 21023d4ab7..f0a44d919b 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -14,6 +14,7 @@ class ToolEntity(BaseModel): tool_name: str tool_label: str # redundancy tool_configurations: dict[str, Any] + credential_id: str | None = None plugin_unique_identifier: str | None = None # redundancy @field_validator("tool_configurations", mode="before") @@ -41,6 +42,10 @@ class ToolNodeData(BaseNodeData, ToolEntity): def check_type(cls, value, validation_info: ValidationInfo): typ = value value = validation_info.data.get("value") + + if value is None: + return typ + if typ == "mixed" and not isinstance(value, str): raise ValueError("value must be a string") elif typ == "variable": @@ -54,3 +59,26 @@ class ToolNodeData(BaseNodeData, ToolEntity): return typ tool_parameters: dict[str, ToolInput] + # The version of the tool parameter. + # If this value is None, it indicates this is a previous version + # and requires using the legacy parameter parsing rules. + tool_node_version: str | None = None + + @field_validator("tool_parameters", mode="before") + @classmethod + def filter_none_tool_inputs(cls, value): + if not isinstance(value, dict): + return value + + return { + key: tool_input + for key, tool_input in value.items() + if tool_input is not None and cls._has_valid_value(tool_input) + } + + @staticmethod + def _has_valid_value(tool_input): + """Check if the value is valid""" + if isinstance(tool_input, dict): + return tool_input.get("value") is not None + return getattr(tool_input, "value", None) is not None diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index a4be02d863..f437ac841d 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,12 +1,11 @@ from collections.abc import Generator, Mapping, Sequence -from typing import Any, cast +from typing import Any, Optional, cast from sqlalchemy import select from sqlalchemy.orm import Session from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.file import File, FileTransferMethod -from core.model_runtime.entities.llm_entities import LLMUsage from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.plugin import PluginInstaller from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter @@ -19,9 +18,9 @@ from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.enums import SystemVariableKey -from core.workflow.graph_engine.entities.event import AgentLogEvent from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db @@ -37,14 +36,18 @@ from .exc import ( ) -class ToolNode(BaseNode[ToolNodeData]): +class ToolNode(BaseNode): """ Tool Node """ - _node_data_cls = ToolNodeData _node_type = NodeType.TOOL + _node_data: ToolNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = ToolNodeData.model_validate(data) + @classmethod def version(cls) -> str: return "1" @@ -54,7 +57,7 @@ class ToolNode(BaseNode[ToolNodeData]): Run the tool node """ - node_data = cast(ToolNodeData, self.node_data) + node_data = cast(ToolNodeData, self._node_data) # fetch tool icon tool_info = { @@ -67,8 +70,15 @@ class ToolNode(BaseNode[ToolNodeData]): try: from core.tools.tool_manager import ToolManager + # This is an issue that caused problems before. + # Logically, we shouldn't use the node_data.version field for judgment + # But for backward compatibility with historical data + # this version field judgment is still preserved here. + variable_pool: VariablePool | None = None + if node_data.version != "1" or node_data.tool_node_version != "1": + variable_pool = self.graph_runtime_state.variable_pool tool_runtime = ToolManager.get_workflow_tool_runtime( - self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from + self.tenant_id, self.app_id, self.node_id, self._node_data, self.invoke_from, variable_pool ) except ToolNodeError as e: yield RunCompletedEvent( @@ -87,15 +97,14 @@ class ToolNode(BaseNode[ToolNodeData]): parameters = self._generate_parameters( tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, - node_data=self.node_data, + node_data=self._node_data, ) parameters_for_log = self._generate_parameters( tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, - node_data=self.node_data, + node_data=self._node_data, for_log=True, ) - # get conversation id conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) @@ -124,7 +133,14 @@ class ToolNode(BaseNode[ToolNodeData]): try: # convert tool messages - yield from self._transform_message(message_stream, tool_info, parameters_for_log) + yield from self._transform_message( + messages=message_stream, + tool_info=tool_info, + parameters_for_log=parameters_for_log, + user_id=self.user_id, + tenant_id=self.tenant_id, + node_id=self.node_id, + ) except (PluginDaemonClientSideError, ToolInvokeError) as e: yield RunCompletedEvent( run_result=NodeRunResult( @@ -191,6 +207,9 @@ class ToolNode(BaseNode[ToolNodeData]): messages: Generator[ToolInvokeMessage, None, None], tool_info: Mapping[str, Any], parameters_for_log: dict[str, Any], + user_id: str, + tenant_id: str, + node_id: str, ) -> Generator: """ Convert ToolInvokeMessages into tuple[plain_text, files] @@ -198,8 +217,8 @@ class ToolNode(BaseNode[ToolNodeData]): # transform message and handle file storage message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( messages=messages, - user_id=self.user_id, - tenant_id=self.tenant_id, + user_id=user_id, + tenant_id=tenant_id, conversation_id=None, ) @@ -207,9 +226,6 @@ class ToolNode(BaseNode[ToolNodeData]): files: list[File] = [] json: list[dict] = [] - agent_logs: list[AgentLogEvent] = [] - agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {} - llm_usage: LLMUsage | None = None variables: dict[str, Any] = {} for message in message_stream: @@ -242,7 +258,7 @@ class ToolNode(BaseNode[ToolNodeData]): } file = file_factory.build_from_mapping( mapping=mapping, - tenant_id=self.tenant_id, + tenant_id=tenant_id, ) files.append(file) elif message.type == ToolInvokeMessage.MessageType.BLOB: @@ -265,50 +281,49 @@ class ToolNode(BaseNode[ToolNodeData]): files.append( file_factory.build_from_mapping( mapping=mapping, - tenant_id=self.tenant_id, + tenant_id=tenant_id, ) ) elif message.type == ToolInvokeMessage.MessageType.TEXT: assert isinstance(message.message, ToolInvokeMessage.TextMessage) text += message.message.text - yield RunStreamChunkEvent( - chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"] - ) + yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"]) elif message.type == ToolInvokeMessage.MessageType.JSON: assert isinstance(message.message, ToolInvokeMessage.JsonMessage) - if self.node_type == NodeType.AGENT: - msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {}) - llm_usage = LLMUsage.from_metadata(msg_metadata) - agent_execution_metadata = { - WorkflowNodeExecutionMetadataKey(key): value - for key, value in msg_metadata.items() - if key in WorkflowNodeExecutionMetadataKey.__members__.values() - } - json.append(message.message.json_object) + # JSON message handling for tool node + if message.message.json_object is not None: + json.append(message.message.json_object) elif message.type == ToolInvokeMessage.MessageType.LINK: assert isinstance(message.message, ToolInvokeMessage.TextMessage) stream_text = f"Link: {message.message.text}\n" text += stream_text - yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"]) + yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"]) elif message.type == ToolInvokeMessage.MessageType.VARIABLE: assert isinstance(message.message, ToolInvokeMessage.VariableMessage) variable_name = message.message.variable_name variable_value = message.message.variable_value if message.message.stream: if not isinstance(variable_value, str): - raise ValueError("When 'stream' is True, 'variable_value' must be a string.") + raise ToolNodeError("When 'stream' is True, 'variable_value' must be a string.") if variable_name not in variables: variables[variable_name] = "" variables[variable_name] += variable_value yield RunStreamChunkEvent( - chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name] + chunk_content=variable_value, from_variable_selector=[node_id, variable_name] ) else: variables[variable_name] = variable_value elif message.type == ToolInvokeMessage.MessageType.FILE: assert message.meta is not None - assert isinstance(message.meta, File) + assert isinstance(message.meta, dict) + # Validate that meta contains a 'file' key + if "file" not in message.meta: + raise ToolNodeError("File message is missing 'file' key in meta") + + # Validate that the file is an instance of File + if not isinstance(message.meta["file"], File): + raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}") files.append(message.meta["file"]) elif message.type == ToolInvokeMessage.MessageType.LOG: assert isinstance(message.message, ToolInvokeMessage.LogMessage) @@ -317,7 +332,7 @@ class ToolNode(BaseNode[ToolNodeData]): dict_metadata = dict(message.message.metadata) if dict_metadata.get("provider"): manager = PluginInstaller() - plugins = manager.list_plugins(self.tenant_id) + plugins = manager.list_plugins(tenant_id) try: current_plugin = next( plugin @@ -327,59 +342,42 @@ class ToolNode(BaseNode[ToolNodeData]): icon = current_plugin.declaration.icon except StopIteration: pass + icon_dark = None try: builtin_tool = next( provider for provider in BuiltinToolManageService.list_builtin_tools( - self.user_id, - self.tenant_id, + user_id, + tenant_id, ) if provider.name == dict_metadata["provider"] ) icon = builtin_tool.icon + icon_dark = builtin_tool.icon_dark except StopIteration: pass dict_metadata["icon"] = icon + dict_metadata["icon_dark"] = icon_dark message.message.metadata = dict_metadata - agent_log = AgentLogEvent( - id=message.message.id, - node_execution_id=self.id, - parent_id=message.message.parent_id, - error=message.message.error, - status=message.message.status.value, - data=message.message.data, - label=message.message.label, - metadata=message.message.metadata, - node_id=self.node_id, - ) - # check if the agent log is already in the list - for log in agent_logs: - if log.id == agent_log.id: - # update the log - log.data = agent_log.data - log.status = agent_log.status - log.error = agent_log.error - log.label = agent_log.label - log.metadata = agent_log.metadata - break - else: - agent_logs.append(agent_log) + # Add agent_logs to outputs['json'] to ensure frontend can access thinking process + json_output: list[dict[str, Any]] = [] - yield agent_log + # Step 2: normalize JSON into {"data": [...]}.change json to list[dict] + if json: + json_output.extend(json) + else: + json_output.append({"data": []}) yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json, **variables}, + outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables}, metadata={ - **agent_execution_metadata, WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, - WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs, }, inputs=parameters_for_log, - llm_usage=llm_usage, ) ) @@ -389,7 +387,7 @@ class ToolNode(BaseNode[ToolNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: ToolNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -398,9 +396,12 @@ class ToolNode(BaseNode[ToolNodeData]): :param node_data: node data :return: """ + # Create typed NodeData from dict + typed_node_data = ToolNodeData.model_validate(node_data) + result = {} - for parameter_name in node_data.tool_parameters: - input = node_data.tool_parameters[parameter_name] + for parameter_name in typed_node_data.tool_parameters: + input = typed_node_data.tool_parameters[parameter_name] if input.type == "mixed": assert isinstance(input.value, str) selectors = VariableTemplateParser(input.value).extract_variable_selectors() @@ -414,3 +415,29 @@ class ToolNode(BaseNode[ToolNodeData]): result = {node_id + "." + key: value for key, value in result.items()} return result + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + + @property + def continue_on_error(self) -> bool: + return self._node_data.error_strategy is not None + + @property + def retry(self) -> bool: + return self._node_data.retry_config.retry_enabled diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index 96bb3e793a..98127bbeb6 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,17 +1,41 @@ from collections.abc import Mapping +from typing import Any, Optional from core.variables.segments import Segment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData -class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): - _node_data_cls = VariableAssignerNodeData +class VariableAggregatorNode(BaseNode): _node_type = NodeType.VARIABLE_AGGREGATOR + _node_data: VariableAssignerNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = VariableAssignerNodeData(**data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" @@ -21,8 +45,8 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): outputs: dict[str, Segment | Mapping[str, Segment]] = {} inputs = {} - if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled: - for selector in self.node_data.variables: + if not self._node_data.advanced_settings or not self._node_data.advanced_settings.group_enabled: + for selector in self._node_data.variables: variable = self.graph_runtime_state.variable_pool.get(selector) if variable is not None: outputs = {"output": variable} @@ -30,7 +54,7 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): inputs = {".".join(selector[1:]): variable.to_object()} break else: - for group in self.node_data.advanced_settings.groups: + for group in self._node_data.advanced_settings.groups: for selector in group.variables: variable = self.graph_runtime_state.variable_pool.get(selector) diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index be5083c9c1..51383fa588 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -7,7 +7,8 @@ from core.workflow.conversation_variable_updater import ConversationVariableUpda from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError from factories import variable_factory @@ -22,11 +23,33 @@ if TYPE_CHECKING: _CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] -class VariableAssignerNode(BaseNode[VariableAssignerData]): - _node_data_cls = VariableAssignerData +class VariableAssignerNode(BaseNode): _node_type = NodeType.VARIABLE_ASSIGNER _conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY + _node_data: VariableAssignerData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = VariableAssignerData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + def __init__( self, id: str, @@ -59,36 +82,39 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: VariableAssignerData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - mapping = {} - assigned_variable_node_id = node_data.assigned_variable_selector[0] - if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID: - selector_key = ".".join(node_data.assigned_variable_selector) - key = f"{node_id}.#{selector_key}#" - mapping[key] = node_data.assigned_variable_selector + # Create typed NodeData from dict + typed_node_data = VariableAssignerData.model_validate(node_data) - selector_key = ".".join(node_data.input_variable_selector) + mapping = {} + assigned_variable_node_id = typed_node_data.assigned_variable_selector[0] + if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID: + selector_key = ".".join(typed_node_data.assigned_variable_selector) + key = f"{node_id}.#{selector_key}#" + mapping[key] = typed_node_data.assigned_variable_selector + + selector_key = ".".join(typed_node_data.input_variable_selector) key = f"{node_id}.#{selector_key}#" - mapping[key] = node_data.input_variable_selector + mapping[key] = typed_node_data.input_variable_selector return mapping def _run(self) -> NodeRunResult: - assigned_variable_selector = self.node_data.assigned_variable_selector + assigned_variable_selector = self._node_data.assigned_variable_selector # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector) if not isinstance(original_variable, Variable): raise VariableOperatorNodeError("assigned variable not found") - match self.node_data.write_mode: + match self._node_data.write_mode: case WriteMode.OVER_WRITE: - income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) + income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector) if not income_value: raise VariableOperatorNodeError("input value not found") updated_variable = original_variable.model_copy(update={"value": income_value.value}) case WriteMode.APPEND: - income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) + income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector) if not income_value: raise VariableOperatorNodeError("input value not found") updated_value = original_variable.value + [income_value.value] @@ -101,7 +127,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]): updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) case _: - raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}") + raise VariableOperatorNodeError(f"unsupported write mode: {self._node_data.write_mode}") # Over write the variable. self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable) @@ -130,6 +156,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]): def get_zero_value(t: SegmentType): + # TODO(QuantumGhost): this should be a method of `SegmentType`. match t: case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER: return variable_factory.build_segment([]) @@ -137,6 +164,10 @@ def get_zero_value(t: SegmentType): return variable_factory.build_segment({}) case SegmentType.STRING: return variable_factory.build_segment("") + case SegmentType.INTEGER: + return variable_factory.build_segment(0) + case SegmentType.FLOAT: + return variable_factory.build_segment(0.0) case SegmentType.NUMBER: return variable_factory.build_segment(0) case _: diff --git a/api/core/workflow/nodes/variable_assigner/v2/constants.py b/api/core/workflow/nodes/variable_assigner/v2/constants.py index 3797bfa77a..7f760e5baa 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/constants.py +++ b/api/core/workflow/nodes/variable_assigner/v2/constants.py @@ -1,5 +1,6 @@ from core.variables import SegmentType +# Note: This mapping is duplicated with `get_zero_value`. Consider refactoring to avoid redundancy. EMPTY_VALUE_MAPPING = { SegmentType.STRING: "", SegmentType.NUMBER: 0, diff --git a/api/core/workflow/nodes/variable_assigner/v2/helpers.py b/api/core/workflow/nodes/variable_assigner/v2/helpers.py index 8fb2a27388..7a20975b15 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/helpers.py +++ b/api/core/workflow/nodes/variable_assigner/v2/helpers.py @@ -10,10 +10,16 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation): case Operation.OVER_WRITE | Operation.CLEAR: return True case Operation.SET: - return variable_type in {SegmentType.OBJECT, SegmentType.STRING, SegmentType.NUMBER} + return variable_type in { + SegmentType.OBJECT, + SegmentType.STRING, + SegmentType.NUMBER, + SegmentType.INTEGER, + SegmentType.FLOAT, + } case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE: # Only number variable can be added, subtracted, multiplied or divided - return variable_type == SegmentType.NUMBER + return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT} case Operation.APPEND | Operation.EXTEND: # Only array variable can be appended or extended return variable_type in { @@ -46,7 +52,7 @@ def is_constant_input_supported(*, variable_type: SegmentType, operation: Operat match variable_type: case SegmentType.STRING | SegmentType.OBJECT: return operation in {Operation.OVER_WRITE, Operation.SET} - case SegmentType.NUMBER: + case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: return operation in { Operation.OVER_WRITE, Operation.SET, @@ -66,7 +72,7 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va case SegmentType.STRING: return isinstance(value, str) - case SegmentType.NUMBER: + case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: if not isinstance(value, int | float): return False if operation == Operation.DIVIDE and value == 0: diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index 9292da6f1c..c0215cae71 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -1,6 +1,6 @@ import json -from collections.abc import Callable, Mapping, MutableMapping, Sequence -from typing import Any, TypeAlias, cast +from collections.abc import Mapping, MutableMapping, Sequence +from typing import Any, Optional, cast from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import SegmentType, Variable @@ -10,7 +10,8 @@ from core.workflow.conversation_variable_updater import ConversationVariableUpda from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory @@ -28,8 +29,6 @@ from .exc import ( VariableNotFoundError, ) -_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] - def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): selector_node_id = item.variable_selector[0] @@ -54,10 +53,32 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_ mapping[key] = selector -class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): - _node_data_cls = VariableAssignerNodeData +class VariableAssignerNode(BaseNode): _node_type = NodeType.VARIABLE_ASSIGNER + _node_data: VariableAssignerNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = VariableAssignerNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + def _conv_var_updater_factory(self) -> ConversationVariableUpdater: return conversation_variable_updater_factory() @@ -71,22 +92,25 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: VariableAssignerNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: + # Create typed NodeData from dict + typed_node_data = VariableAssignerNodeData.model_validate(node_data) + var_mapping: dict[str, Sequence[str]] = {} - for item in node_data.items: + for item in typed_node_data.items: _target_mapping_from_item(var_mapping, node_id, item) _source_mapping_from_item(var_mapping, node_id, item) return var_mapping def _run(self) -> NodeRunResult: - inputs = self.node_data.model_dump() + inputs = self._node_data.model_dump() process_data: dict[str, Any] = {} # NOTE: This node has no outputs updated_variable_selectors: list[Sequence[str]] = [] try: - for item in self.node_data.items: + for item in self._node_data.items: variable = self.graph_runtime_state.variable_pool.get(item.variable_selector) # ==================== Validation Part diff --git a/api/core/workflow/repositories/workflow_execution_repository.py b/api/core/workflow/repositories/workflow_execution_repository.py index 5917310c8b..bcbd253392 100644 --- a/api/core/workflow/repositories/workflow_execution_repository.py +++ b/api/core/workflow/repositories/workflow_execution_repository.py @@ -1,4 +1,4 @@ -from typing import Optional, Protocol +from typing import Protocol from core.workflow.entities.workflow_execution import WorkflowExecution @@ -28,15 +28,3 @@ class WorkflowExecutionRepository(Protocol): execution: The WorkflowExecution instance to save or update """ ... - - def get(self, execution_id: str) -> Optional[WorkflowExecution]: - """ - Retrieve a WorkflowExecution by its ID. - - Args: - execution_id: The workflow execution ID - - Returns: - The WorkflowExecution instance if found, None otherwise - """ - ... diff --git a/api/core/workflow/repositories/workflow_node_execution_repository.py b/api/core/workflow/repositories/workflow_node_execution_repository.py index 1908a6b190..8bf81f5442 100644 --- a/api/core/workflow/repositories/workflow_node_execution_repository.py +++ b/api/core/workflow/repositories/workflow_node_execution_repository.py @@ -39,18 +39,6 @@ class WorkflowNodeExecutionRepository(Protocol): """ ... - def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: - """ - Retrieve a NodeExecution by its node_execution_id. - - Args: - node_execution_id: The node execution ID - - Returns: - The NodeExecution instance if found, None otherwise - """ - ... - def get_by_workflow_run( self, workflow_run_id: str, @@ -69,24 +57,3 @@ class WorkflowNodeExecutionRepository(Protocol): A list of NodeExecution instances """ ... - - def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: - """ - Retrieve all running NodeExecution instances for a specific workflow run. - - Args: - workflow_run_id: The workflow run ID - - Returns: - A list of running NodeExecution instances - """ - ... - - def clear(self) -> None: - """ - Clear all NodeExecution records based on implementation-specific criteria. - - This method is intended to be used for bulk deletion operations, such as removing - all records associated with a specific app_id and tenant_id in multi-tenant implementations. - """ - ... diff --git a/api/core/workflow/system_variable.py b/api/core/workflow/system_variable.py new file mode 100644 index 0000000000..df90c16596 --- /dev/null +++ b/api/core/workflow/system_variable.py @@ -0,0 +1,89 @@ +from collections.abc import Sequence +from typing import Any + +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator + +from core.file.models import File +from core.workflow.enums import SystemVariableKey + + +class SystemVariable(BaseModel): + """A model for managing system variables. + + Fields with a value of `None` are treated as absent and will not be included + in the variable pool. + """ + + model_config = ConfigDict( + extra="forbid", + serialize_by_alias=True, + validate_by_alias=True, + ) + + user_id: str | None = None + + # Ideally, `app_id` and `workflow_id` should be required and not `None`. + # However, there are scenarios in the codebase where these fields are not set. + # To maintain compatibility, they are marked as optional here. + app_id: str | None = None + workflow_id: str | None = None + + files: Sequence[File] = Field(default_factory=list) + + # NOTE: The `workflow_execution_id` field was previously named `workflow_run_id`. + # To maintain compatibility with existing workflows, it must be serialized + # as `workflow_run_id` in dictionaries or JSON objects, and also referenced + # as `workflow_run_id` in the variable pool. + workflow_execution_id: str | None = Field( + validation_alias=AliasChoices("workflow_execution_id", "workflow_run_id"), + serialization_alias="workflow_run_id", + default=None, + ) + # Chatflow related fields. + query: str | None = None + conversation_id: str | None = None + dialogue_count: int | None = None + + @model_validator(mode="before") + @classmethod + def validate_json_fields(cls, data): + if isinstance(data, dict): + # For JSON validation, only allow workflow_run_id + if "workflow_execution_id" in data and "workflow_run_id" not in data: + # This is likely from direct instantiation, allow it + return data + elif "workflow_execution_id" in data and "workflow_run_id" in data: + # Both present, remove workflow_execution_id + data = data.copy() + data.pop("workflow_execution_id") + return data + return data + + @classmethod + def empty(cls) -> "SystemVariable": + return cls() + + def to_dict(self) -> dict[SystemVariableKey, Any]: + # NOTE: This method is provided for compatibility with legacy code. + # New code should use the `SystemVariable` object directly instead of converting + # it to a dictionary, as this conversion results in the loss of type information + # for each key, making static analysis more difficult. + + d: dict[SystemVariableKey, Any] = { + SystemVariableKey.FILES: self.files, + } + if self.user_id is not None: + d[SystemVariableKey.USER_ID] = self.user_id + if self.app_id is not None: + d[SystemVariableKey.APP_ID] = self.app_id + if self.workflow_id is not None: + d[SystemVariableKey.WORKFLOW_ID] = self.workflow_id + if self.workflow_execution_id is not None: + d[SystemVariableKey.WORKFLOW_EXECUTION_ID] = self.workflow_execution_id + if self.query is not None: + d[SystemVariableKey.QUERY] = self.query + if self.conversation_id is not None: + d[SystemVariableKey.CONVERSATION_ID] = self.conversation_id + if self.dialogue_count is not None: + d[SystemVariableKey.DIALOGUE_COUNT] = self.dialogue_count + return d diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index 0aab2426af..03f670707e 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -1,6 +1,6 @@ from collections.abc import Mapping from dataclasses import dataclass -from datetime import UTC, datetime +from datetime import datetime from typing import Any, Optional, Union from uuid import uuid4 @@ -26,6 +26,7 @@ from core.workflow.entities.workflow_node_execution import ( from core.workflow.enums import SystemVariableKey from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry from libs.datetime_utils import naive_utc_now @@ -43,7 +44,7 @@ class WorkflowCycleManager: self, *, application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], - workflow_system_variables: dict[SystemVariableKey, Any], + workflow_system_variables: SystemVariable, workflow_info: CycleManagerWorkflowInfo, workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, @@ -54,19 +55,15 @@ class WorkflowCycleManager: self._workflow_execution_repository = workflow_execution_repository self._workflow_node_execution_repository = workflow_node_execution_repository + # Initialize caches for workflow execution cycle + # These caches avoid redundant repository calls during a single workflow execution + self._workflow_execution_cache: dict[str, WorkflowExecution] = {} + self._node_execution_cache: dict[str, WorkflowNodeExecution] = {} + def handle_workflow_run_start(self) -> WorkflowExecution: - inputs = {**self._application_generate_entity.inputs} - for key, value in (self._workflow_system_variables or {}).items(): - if key.value == "conversation": - continue - inputs[f"sys.{key.value}"] = value + inputs = self._prepare_workflow_inputs() + execution_id = self._get_or_generate_execution_id() - # handle special values - inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) - - # init workflow run - # TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this - execution_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_EXECUTION_ID) or uuid4()) execution = WorkflowExecution.new( id_=execution_id, workflow_id=self._workflow_info.workflow_id, @@ -74,12 +71,10 @@ class WorkflowCycleManager: workflow_version=self._workflow_info.version, graph=self._workflow_info.graph_data, inputs=inputs, - started_at=datetime.now(UTC).replace(tzinfo=None), + started_at=naive_utc_now(), ) - self._workflow_execution_repository.save(execution) - - return execution + return self._save_and_cache_workflow_execution(execution) def handle_workflow_run_success( self, @@ -90,26 +85,19 @@ class WorkflowCycleManager: outputs: Mapping[str, Any] | None = None, conversation_id: Optional[str] = None, trace_manager: Optional[TraceQueueManager] = None, + external_trace_id: Optional[str] = None, ) -> WorkflowExecution: workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) - # outputs = WorkflowEntry.handle_special_values(outputs) + self._update_workflow_execution_completion( + workflow_execution, + status=WorkflowExecutionStatus.SUCCEEDED, + outputs=outputs, + total_tokens=total_tokens, + total_steps=total_steps, + ) - workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED - workflow_execution.outputs = outputs or {} - workflow_execution.total_tokens = total_tokens - workflow_execution.total_steps = total_steps - workflow_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) - - if trace_manager: - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.WORKFLOW_TRACE, - workflow_execution=workflow_execution, - conversation_id=conversation_id, - user_id=trace_manager.user_id, - ) - ) + self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id, external_trace_id) self._workflow_execution_repository.save(workflow_execution) return workflow_execution @@ -124,26 +112,20 @@ class WorkflowCycleManager: exceptions_count: int = 0, conversation_id: Optional[str] = None, trace_manager: Optional[TraceQueueManager] = None, + external_trace_id: Optional[str] = None, ) -> WorkflowExecution: execution = self._get_workflow_execution_or_raise_error(workflow_run_id) - # outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) - execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED - execution.outputs = outputs or {} - execution.total_tokens = total_tokens - execution.total_steps = total_steps - execution.finished_at = datetime.now(UTC).replace(tzinfo=None) - execution.exceptions_count = exceptions_count + self._update_workflow_execution_completion( + execution, + status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED, + outputs=outputs, + total_tokens=total_tokens, + total_steps=total_steps, + exceptions_count=exceptions_count, + ) - if trace_manager: - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.WORKFLOW_TRACE, - workflow_execution=execution, - conversation_id=conversation_id, - user_id=trace_manager.user_id, - ) - ) + self._add_trace_task_if_needed(trace_manager, execution, conversation_id, external_trace_id) self._workflow_execution_repository.save(execution) return execution @@ -159,43 +141,23 @@ class WorkflowCycleManager: conversation_id: Optional[str] = None, trace_manager: Optional[TraceQueueManager] = None, exceptions_count: int = 0, + external_trace_id: Optional[str] = None, ) -> WorkflowExecution: workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) now = naive_utc_now() - workflow_execution.status = WorkflowExecutionStatus(status.value) - workflow_execution.error_message = error_message - workflow_execution.total_tokens = total_tokens - workflow_execution.total_steps = total_steps - workflow_execution.finished_at = now - workflow_execution.exceptions_count = exceptions_count - - # Use the instance repository to find running executions for a workflow run - running_node_executions = self._workflow_node_execution_repository.get_running_executions( - workflow_run_id=workflow_execution.id_ + self._update_workflow_execution_completion( + workflow_execution, + status=status, + total_tokens=total_tokens, + total_steps=total_steps, + error_message=error_message, + exceptions_count=exceptions_count, + finished_at=now, ) - # Update the domain models - for node_execution in running_node_executions: - if node_execution.node_execution_id: - # Update the domain model - node_execution.status = WorkflowNodeExecutionStatus.FAILED - node_execution.error = error_message - node_execution.finished_at = now - node_execution.elapsed_time = (now - node_execution.created_at).total_seconds() - - # Update the repository with the domain model - self._workflow_node_execution_repository.save(node_execution) - - if trace_manager: - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.WORKFLOW_TRACE, - workflow_execution=workflow_execution, - conversation_id=conversation_id, - user_id=trace_manager.user_id, - ) - ) + self._fail_running_node_executions(workflow_execution.id_, error_message, now) + self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id, external_trace_id) self._workflow_execution_repository.save(workflow_execution) return workflow_execution @@ -208,8 +170,200 @@ class WorkflowCycleManager: ) -> WorkflowNodeExecution: workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id) - # Create a domain model - created_at = datetime.now(UTC).replace(tzinfo=None) + domain_execution = self._create_node_execution_from_event( + workflow_execution=workflow_execution, + event=event, + status=WorkflowNodeExecutionStatus.RUNNING, + ) + + return self._save_and_cache_node_execution(domain_execution) + + def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: + domain_execution = self._get_node_execution_from_cache(event.node_execution_id) + + self._update_node_execution_completion( + domain_execution, + event=event, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + ) + + self._workflow_node_execution_repository.save(domain_execution) + return domain_execution + + def handle_workflow_node_execution_failed( + self, + *, + event: QueueNodeFailedEvent + | QueueNodeInIterationFailedEvent + | QueueNodeInLoopFailedEvent + | QueueNodeExceptionEvent, + ) -> WorkflowNodeExecution: + """ + Workflow node execution failed + :param event: queue node failed event + :return: + """ + domain_execution = self._get_node_execution_from_cache(event.node_execution_id) + + status = ( + WorkflowNodeExecutionStatus.EXCEPTION + if isinstance(event, QueueNodeExceptionEvent) + else WorkflowNodeExecutionStatus.FAILED + ) + + self._update_node_execution_completion( + domain_execution, + event=event, + status=status, + error=event.error, + handle_special_values=True, + ) + + self._workflow_node_execution_repository.save(domain_execution) + return domain_execution + + def handle_workflow_node_execution_retried( + self, *, workflow_execution_id: str, event: QueueNodeRetryEvent + ) -> WorkflowNodeExecution: + workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id) + + domain_execution = self._create_node_execution_from_event( + workflow_execution=workflow_execution, + event=event, + status=WorkflowNodeExecutionStatus.RETRY, + error=event.error, + created_at=event.start_at, + ) + + # Handle inputs and outputs + inputs = WorkflowEntry.handle_special_values(event.inputs) + outputs = event.outputs + metadata = self._merge_event_metadata(event) + + domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=metadata) + + return self._save_and_cache_node_execution(domain_execution) + + def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution: + # Check cache first + if id in self._workflow_execution_cache: + return self._workflow_execution_cache[id] + + raise WorkflowRunNotFoundError(id) + + def _prepare_workflow_inputs(self) -> dict[str, Any]: + """Prepare workflow inputs by merging application inputs with system variables.""" + inputs = {**self._application_generate_entity.inputs} + + if self._workflow_system_variables: + for field_name, value in self._workflow_system_variables.to_dict().items(): + if field_name != SystemVariableKey.CONVERSATION_ID: + inputs[f"sys.{field_name}"] = value + + return dict(WorkflowEntry.handle_special_values(inputs) or {}) + + def _get_or_generate_execution_id(self) -> str: + """Get execution ID from system variables or generate a new one.""" + if self._workflow_system_variables and self._workflow_system_variables.workflow_execution_id: + return str(self._workflow_system_variables.workflow_execution_id) + return str(uuid4()) + + def _save_and_cache_workflow_execution(self, execution: WorkflowExecution) -> WorkflowExecution: + """Save workflow execution to repository and cache it.""" + self._workflow_execution_repository.save(execution) + self._workflow_execution_cache[execution.id_] = execution + return execution + + def _save_and_cache_node_execution(self, execution: WorkflowNodeExecution) -> WorkflowNodeExecution: + """Save node execution to repository and cache it if it has an ID.""" + self._workflow_node_execution_repository.save(execution) + if execution.node_execution_id: + self._node_execution_cache[execution.node_execution_id] = execution + return execution + + def _get_node_execution_from_cache(self, node_execution_id: str) -> WorkflowNodeExecution: + """Get node execution from cache or raise error if not found.""" + domain_execution = self._node_execution_cache.get(node_execution_id) + if not domain_execution: + raise ValueError(f"Domain node execution not found: {node_execution_id}") + return domain_execution + + def _update_workflow_execution_completion( + self, + execution: WorkflowExecution, + *, + status: WorkflowExecutionStatus, + total_tokens: int, + total_steps: int, + outputs: Mapping[str, Any] | None = None, + error_message: Optional[str] = None, + exceptions_count: int = 0, + finished_at: Optional[datetime] = None, + ) -> None: + """Update workflow execution with completion data.""" + execution.status = status + execution.outputs = outputs or {} + execution.total_tokens = total_tokens + execution.total_steps = total_steps + execution.finished_at = finished_at or naive_utc_now() + execution.exceptions_count = exceptions_count + if error_message: + execution.error_message = error_message + + def _add_trace_task_if_needed( + self, + trace_manager: Optional[TraceQueueManager], + workflow_execution: WorkflowExecution, + conversation_id: Optional[str], + external_trace_id: Optional[str], + ) -> None: + """Add trace task if trace manager is provided.""" + if trace_manager: + trace_manager.add_trace_task( + TraceTask( + TraceTaskName.WORKFLOW_TRACE, + workflow_execution=workflow_execution, + conversation_id=conversation_id, + user_id=trace_manager.user_id, + external_trace_id=external_trace_id, + ) + ) + + def _fail_running_node_executions( + self, + workflow_execution_id: str, + error_message: str, + now: datetime, + ) -> None: + """Fail all running node executions for a workflow.""" + running_node_executions = [ + node_exec + for node_exec in self._node_execution_cache.values() + if node_exec.workflow_execution_id == workflow_execution_id + and node_exec.status == WorkflowNodeExecutionStatus.RUNNING + ] + + for node_execution in running_node_executions: + if node_execution.node_execution_id: + node_execution.status = WorkflowNodeExecutionStatus.FAILED + node_execution.error = error_message + node_execution.finished_at = now + node_execution.elapsed_time = (now - node_execution.created_at).total_seconds() + self._workflow_node_execution_repository.save(node_execution) + + def _create_node_execution_from_event( + self, + *, + workflow_execution: WorkflowExecution, + event: Union[QueueNodeStartedEvent, QueueNodeRetryEvent], + status: WorkflowNodeExecutionStatus, + error: Optional[str] = None, + created_at: Optional[datetime] = None, + ) -> WorkflowNodeExecution: + """Create a node execution from an event.""" + now = naive_utc_now() + created_at = created_at or now + metadata = { WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id, @@ -226,152 +380,76 @@ class WorkflowCycleManager: node_id=event.node_id, node_type=event.node_type, title=event.node_data.title, - status=WorkflowNodeExecutionStatus.RUNNING, + status=status, metadata=metadata, created_at=created_at, + error=error, ) - # Use the instance repository to save the domain model - self._workflow_node_execution_repository.save(domain_execution) + if status == WorkflowNodeExecutionStatus.RETRY: + domain_execution.finished_at = now + domain_execution.elapsed_time = (now - created_at).total_seconds() return domain_execution - def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: - # Get the domain model from repository - domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id) - if not domain_execution: - raise ValueError(f"Domain node execution not found: {event.node_execution_id}") - - # Process data - inputs = event.inputs - process_data = event.process_data - outputs = event.outputs - - # Convert metadata keys to strings - execution_metadata_dict = {} - if event.execution_metadata: - for key, value in event.execution_metadata.items(): - execution_metadata_dict[key] = value - - finished_at = datetime.now(UTC).replace(tzinfo=None) - elapsed_time = (finished_at - event.start_at).total_seconds() - - # Update domain model - domain_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED - domain_execution.update_from_mapping( - inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict - ) - domain_execution.finished_at = finished_at - domain_execution.elapsed_time = elapsed_time - - # Update the repository with the domain model - self._workflow_node_execution_repository.save(domain_execution) - - return domain_execution - - def handle_workflow_node_execution_failed( + def _update_node_execution_completion( self, + domain_execution: WorkflowNodeExecution, *, - event: QueueNodeFailedEvent - | QueueNodeInIterationFailedEvent - | QueueNodeInLoopFailedEvent - | QueueNodeExceptionEvent, - ) -> WorkflowNodeExecution: - """ - Workflow node execution failed - :param event: queue node failed event - :return: - """ - # Get the domain model from repository - domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id) - if not domain_execution: - raise ValueError(f"Domain node execution not found: {event.node_execution_id}") - - # Process data - inputs = WorkflowEntry.handle_special_values(event.inputs) - process_data = WorkflowEntry.handle_special_values(event.process_data) - outputs = event.outputs - - # Convert metadata keys to strings - execution_metadata_dict = {} - if event.execution_metadata: - for key, value in event.execution_metadata.items(): - execution_metadata_dict[key] = value - - finished_at = datetime.now(UTC).replace(tzinfo=None) + event: Union[ + QueueNodeSucceededEvent, + QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, + QueueNodeInLoopFailedEvent, + QueueNodeExceptionEvent, + ], + status: WorkflowNodeExecutionStatus, + error: Optional[str] = None, + handle_special_values: bool = False, + ) -> None: + """Update node execution with completion data.""" + finished_at = naive_utc_now() elapsed_time = (finished_at - event.start_at).total_seconds() + # Process data + if handle_special_values: + inputs = WorkflowEntry.handle_special_values(event.inputs) + process_data = WorkflowEntry.handle_special_values(event.process_data) + else: + inputs = event.inputs + process_data = event.process_data + + outputs = event.outputs + + # Convert metadata + execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, Any] = {} + if event.execution_metadata: + execution_metadata_dict.update(event.execution_metadata) + # Update domain model - domain_execution.status = ( - WorkflowNodeExecutionStatus.FAILED - if not isinstance(event, QueueNodeExceptionEvent) - else WorkflowNodeExecutionStatus.EXCEPTION - ) - domain_execution.error = event.error + domain_execution.status = status domain_execution.update_from_mapping( - inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict + inputs=inputs, + process_data=process_data, + outputs=outputs, + metadata=execution_metadata_dict, ) domain_execution.finished_at = finished_at domain_execution.elapsed_time = elapsed_time - # Update the repository with the domain model - self._workflow_node_execution_repository.save(domain_execution) + if error: + domain_execution.error = error - return domain_execution - - def handle_workflow_node_execution_retried( - self, *, workflow_execution_id: str, event: QueueNodeRetryEvent - ) -> WorkflowNodeExecution: - workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id) - created_at = event.start_at - finished_at = datetime.now(UTC).replace(tzinfo=None) - elapsed_time = (finished_at - created_at).total_seconds() - inputs = WorkflowEntry.handle_special_values(event.inputs) - outputs = event.outputs - - # Convert metadata keys to strings + def _merge_event_metadata(self, event: QueueNodeRetryEvent) -> dict[WorkflowNodeExecutionMetadataKey, str | None]: + """Merge event metadata with origin metadata.""" origin_metadata = { WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id, WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id, } - # Convert execution metadata keys to strings execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, str | None] = {} if event.execution_metadata: - for key, value in event.execution_metadata.items(): - execution_metadata_dict[key] = value + execution_metadata_dict.update(event.execution_metadata) - merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata - - # Create a domain model - domain_execution = WorkflowNodeExecution( - id=str(uuid4()), - workflow_id=workflow_execution.workflow_id, - workflow_execution_id=workflow_execution.id_, - predecessor_node_id=event.predecessor_node_id, - node_execution_id=event.node_execution_id, - node_id=event.node_id, - node_type=event.node_type, - title=event.node_data.title, - status=WorkflowNodeExecutionStatus.RETRY, - created_at=created_at, - finished_at=finished_at, - elapsed_time=elapsed_time, - error=event.error, - index=event.node_run_index, - ) - - # Update with mappings - domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=merged_metadata) - - # Use the instance repository to save the domain model - self._workflow_node_execution_repository.save(domain_execution) - - return domain_execution - - def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution: - execution = self._workflow_execution_repository.get(id) - if not execution: - raise WorkflowRunNotFoundError(id) - return execution + return {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index c0e98db3db..c8082ebf50 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -5,7 +5,7 @@ from collections.abc import Generator, Mapping, Sequence from typing import Any, Optional, cast from configs import dify_config -from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError +from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.file.models import File from core.workflow.callbacks import WorkflowCallback @@ -21,6 +21,7 @@ from core.workflow.nodes import NodeType from core.workflow.nodes.base import BaseNode from core.workflow.nodes.event import NodeEvent from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from factories import file_factory from models.enums import UserFrom @@ -69,6 +70,7 @@ class WorkflowEntry: raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth)) # init workflow run state + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) self.graph_engine = GraphEngine( tenant_id=tenant_id, app_id=app_id, @@ -80,7 +82,7 @@ class WorkflowEntry: call_depth=call_depth, graph=graph, graph_config=graph_config, - variable_pool=variable_pool, + graph_runtime_state=graph_runtime_state, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, thread_pool_id=thread_pool_id, @@ -144,7 +146,7 @@ class WorkflowEntry: graph = Graph.init(graph_config=workflow.graph_dict) # init workflow run state - node_instance = node_cls( + node = node_cls( id=str(uuid.uuid4()), config=node_config, graph_init_params=GraphInitParams( @@ -161,6 +163,7 @@ class WorkflowEntry: graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), ) + node.init_node_data(node_config_data) try: # variable selector to variable mapping @@ -188,17 +191,11 @@ class WorkflowEntry: try: # run node - generator = node_instance.run() + generator = node.run() except Exception as e: - logger.exception( - "error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s", - workflow.id, - node_instance.id, - node_instance.node_type, - node_instance.version(), - ) - raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) - return node_instance, generator + logger.exception(f"error while running node, {workflow.id=}, {node.id=}, {node.type_=}, {node.version()=}") + raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) + return node, generator @classmethod def run_free_node( @@ -253,14 +250,14 @@ class WorkflowEntry: # init variable pool variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, environment_variables=[], ) node_cls = cast(type[BaseNode], node_cls) # init workflow run state - node_instance: BaseNode = node_cls( + node: BaseNode = node_cls( id=str(uuid.uuid4()), config=node_config, graph_init_params=GraphInitParams( @@ -277,6 +274,7 @@ class WorkflowEntry: graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), ) + node.init_node_data(node_data) try: # variable selector to variable mapping @@ -295,17 +293,12 @@ class WorkflowEntry: ) # run node - generator = node_instance.run() + generator = node.run() - return node_instance, generator + return node, generator except Exception as e: - logger.exception( - "error while running node_instance, node_id=%s, type=%s, version=%s", - node_instance.id, - node_instance.node_type, - node_instance.version(), - ) - raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) + logger.exception(f"error while running node, {node.id=}, {node.type_=}, {node.version()=}") + raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) @staticmethod def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: diff --git a/api/core/workflow/workflow_type_encoder.py b/api/core/workflow/workflow_type_encoder.py index 0123fdac18..2c634d25ec 100644 --- a/api/core/workflow/workflow_type_encoder.py +++ b/api/core/workflow/workflow_type_encoder.py @@ -1,4 +1,3 @@ -import json from collections.abc import Mapping from typing import Any @@ -8,18 +7,6 @@ from core.file.models import File from core.variables import Segment -class WorkflowRuntimeTypeEncoder(json.JSONEncoder): - def default(self, o: Any): - if isinstance(o, Segment): - return o.value - elif isinstance(o, File): - return o.to_dict() - elif isinstance(o, BaseModel): - return o.model_dump(mode="json") - else: - return super().default(o) - - class WorkflowRuntimeTypeConverter: def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None: result = self._to_json_encodable_recursive(value) diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 18d4f4885d..4de9a25c2f 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -5,6 +5,11 @@ set -e if [[ "${MIGRATION_ENABLED}" == "true" ]]; then echo "Running migrations" flask upgrade-db + # Pure migration mode + if [[ "${MODE}" == "migration" ]]; then + echo "Migration completed, exiting normally" + exit 0 + fi fi if [[ "${MODE}" == "worker" ]]; then @@ -22,7 +27,7 @@ if [[ "${MODE}" == "worker" ]]; then exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \ --max-tasks-per-child ${MAX_TASK_PRE_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \ - -Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion} + -Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin} elif [[ "${MODE}" == "beat" ]]; then exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO} diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 8a677f6b6f..dc50ca8d96 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -1,4 +1,3 @@ -import datetime import logging import time @@ -8,6 +7,7 @@ from werkzeug.exceptions import NotFound from core.indexing_runner import DocumentIsPausedError, IndexingRunner from events.event_handlers.document_index_event import document_index_created from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.dataset import Document @@ -22,7 +22,7 @@ def handle(sender, **kwargs): document = ( db.session.query(Document) - .filter( + .where( Document.id == document_id, Document.dataset_id == dataset_id, ) @@ -33,7 +33,7 @@ def handle(sender, **kwargs): raise NotFound("Document not found") document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.processing_started_at = naive_utc_now() documents.append(document) db.session.add(document) db.session.commit() diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index 249bd14429..6c9fc0bf1d 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -20,6 +20,7 @@ def handle(sender, **kwargs): provider_id=tool_entity.provider_id, tool_name=tool_entity.tool_name, tenant_id=app.tenant_id, + credential_id=tool_entity.credential_id, ) manager = ToolParameterConfigurationManager( tenant_id=app.tenant_id, diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py index 14396e9920..b8b5a89dc5 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py @@ -13,7 +13,7 @@ def handle(sender, **kwargs): dataset_ids = get_dataset_ids_from_model_config(app_model_config) - app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() + app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all() removed_dataset_ids: set[str] = set() if not app_dataset_joins: @@ -27,7 +27,7 @@ def handle(sender, **kwargs): if removed_dataset_ids: for dataset_id in removed_dataset_ids: - db.session.query(AppDatasetJoin).filter( + db.session.query(AppDatasetJoin).where( AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id ).delete() diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index dd2efed94b..cf4ba69833 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -15,7 +15,7 @@ def handle(sender, **kwargs): published_workflow = cast(Workflow, published_workflow) dataset_ids = get_dataset_ids_from_workflow(published_workflow) - app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() + app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all() removed_dataset_ids: set[str] = set() if not app_dataset_joins: @@ -29,7 +29,7 @@ def handle(sender, **kwargs): if removed_dataset_ids: for dataset_id in removed_dataset_ids: - db.session.query(AppDatasetJoin).filter( + db.session.query(AppDatasetJoin).where( AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id ).delete() diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py index 316be12f5c..a4d013ffc0 100644 --- a/api/extensions/ext_blueprints.py +++ b/api/extensions/ext_blueprints.py @@ -10,6 +10,7 @@ def init_app(app: DifyApp): from controllers.console import bp as console_app_bp from controllers.files import bp as files_bp from controllers.inner_api import bp as inner_api_bp + from controllers.mcp import bp as mcp_bp from controllers.service_api import bp as service_api_bp from controllers.web import bp as web_bp @@ -46,3 +47,4 @@ def init_app(app: DifyApp): app.register_blueprint(files_bp) app.register_blueprint(inner_api_bp) + app.register_blueprint(mcp_bp) diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 6279b1ad36..2c2846ba26 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -64,49 +64,62 @@ def init_app(app: DifyApp) -> Celery: celery_app.set_default() app.extensions["celery"] = celery_app - imports = [ - "schedule.clean_embedding_cache_task", - "schedule.clean_unused_datasets_task", - "schedule.create_tidb_serverless_task", - "schedule.update_tidb_serverless_status_task", - "schedule.clean_messages", - "schedule.mail_clean_document_notify_task", - "schedule.queue_monitor_task", - ] + imports = [] day = dify_config.CELERY_BEAT_SCHEDULER_TIME - beat_schedule = { - "clean_embedding_cache_task": { + + # if you add a new task, please add the switch to CeleryScheduleTasksConfig + beat_schedule = {} + if dify_config.ENABLE_CLEAN_EMBEDDING_CACHE_TASK: + imports.append("schedule.clean_embedding_cache_task") + beat_schedule["clean_embedding_cache_task"] = { "task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task", "schedule": timedelta(days=day), - }, - "clean_unused_datasets_task": { + } + if dify_config.ENABLE_CLEAN_UNUSED_DATASETS_TASK: + imports.append("schedule.clean_unused_datasets_task") + beat_schedule["clean_unused_datasets_task"] = { "task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task", "schedule": timedelta(days=day), - }, - "create_tidb_serverless_task": { + } + if dify_config.ENABLE_CREATE_TIDB_SERVERLESS_TASK: + imports.append("schedule.create_tidb_serverless_task") + beat_schedule["create_tidb_serverless_task"] = { "task": "schedule.create_tidb_serverless_task.create_tidb_serverless_task", "schedule": crontab(minute="0", hour="*"), - }, - "update_tidb_serverless_status_task": { + } + if dify_config.ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK: + imports.append("schedule.update_tidb_serverless_status_task") + beat_schedule["update_tidb_serverless_status_task"] = { "task": "schedule.update_tidb_serverless_status_task.update_tidb_serverless_status_task", "schedule": timedelta(minutes=10), - }, - "clean_messages": { + } + if dify_config.ENABLE_CLEAN_MESSAGES: + imports.append("schedule.clean_messages") + beat_schedule["clean_messages"] = { "task": "schedule.clean_messages.clean_messages", "schedule": timedelta(days=day), - }, - # every Monday - "mail_clean_document_notify_task": { + } + if dify_config.ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: + imports.append("schedule.mail_clean_document_notify_task") + beat_schedule["mail_clean_document_notify_task"] = { "task": "schedule.mail_clean_document_notify_task.mail_clean_document_notify_task", "schedule": crontab(minute="0", hour="10", day_of_week="1"), - }, - "datasets-queue-monitor": { + } + if dify_config.ENABLE_DATASETS_QUEUE_MONITOR: + imports.append("schedule.queue_monitor_task") + beat_schedule["datasets-queue-monitor"] = { "task": "schedule.queue_monitor_task.queue_monitor_task", "schedule": timedelta( minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30 ), - }, - } + } + if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: + imports.append("schedule.check_upgradable_plugin_task") + beat_schedule["check_upgradable_plugin_task"] = { + "task": "schedule.check_upgradable_plugin_task.check_upgradable_plugin_task", + "schedule": crontab(minute="*/15"), + } + celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) return celery_app diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index ddc2158a02..600e336c19 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -18,6 +18,7 @@ def init_app(app: DifyApp): reset_email, reset_encrypt_key_pair, reset_password, + setup_system_tool_oauth_client, upgrade_db, vdb_migrate, ) @@ -40,6 +41,7 @@ def init_app(app: DifyApp): clear_free_plan_tenant_expired_logs, clear_orphaned_file_records, remove_orphaned_files_on_storage, + setup_system_tool_oauth_client, ] for cmd in cmds_to_register: app.cli.add_command(cmd) diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index 3b4d787d01..9b18e25eaa 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -10,7 +10,7 @@ from dify_app import DifyApp from extensions.ext_database import db from libs.passport import PassportService from models.account import Account, Tenant, TenantAccountJoin -from models.model import EndUser +from models.model import AppMCPServer, EndUser from services.account_service import AccountService login_manager = flask_login.LoginManager() @@ -40,9 +40,9 @@ def load_user_from_request(request_from_flask_login): if workspace_id: tenant_account_join = ( db.session.query(Tenant, TenantAccountJoin) - .filter(Tenant.id == workspace_id) - .filter(TenantAccountJoin.tenant_id == Tenant.id) - .filter(TenantAccountJoin.role == "owner") + .where(Tenant.id == workspace_id) + .where(TenantAccountJoin.tenant_id == Tenant.id) + .where(TenantAccountJoin.role == "owner") .one_or_none() ) if tenant_account_join: @@ -70,7 +70,22 @@ def load_user_from_request(request_from_flask_login): end_user_id = decoded.get("end_user_id") if not end_user_id: raise Unauthorized("Invalid Authorization token.") - end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first() + end_user = db.session.query(EndUser).where(EndUser.id == decoded["end_user_id"]).first() + if not end_user: + raise NotFound("End user not found.") + return end_user + elif request.blueprint == "mcp": + server_code = request.view_args.get("server_code") if request.view_args else None + if not server_code: + raise Unauthorized("Invalid Authorization token.") + app_mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first() + if not app_mcp_server: + raise NotFound("App MCP server not found.") + end_user = ( + db.session.query(EndUser) + .where(EndUser.external_user_id == app_mcp_server.id, EndUser.type == "mcp") + .first() + ) if not end_user: raise NotFound("End user not found.") return end_user diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py index 23cf4c5cab..b027a165f9 100644 --- a/api/extensions/ext_otel.py +++ b/api/extensions/ext_otel.py @@ -12,6 +12,7 @@ from flask_login import user_loaded_from_request, user_logged_in # type: ignore from configs import dify_config from dify_app import DifyApp +from libs.helper import extract_tenant_id from models import Account, EndUser @@ -24,11 +25,8 @@ def on_user_loaded(_sender, user: Union["Account", "EndUser"]): if user: try: current_span = get_current_span() - if isinstance(user, Account) and user.current_tenant_id: - tenant_id = user.current_tenant_id - elif isinstance(user, EndUser): - tenant_id = user.tenant_id - else: + tenant_id = extract_tenant_id(user) + if not tenant_id: return if current_span: current_span.set_attribute("service.tenant.id", tenant_id) @@ -195,13 +193,22 @@ def init_app(app: DifyApp): insecure=True, ) else: + headers = {"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"} if dify_config.OTLP_API_KEY else None + + trace_endpoint = dify_config.OTLP_TRACE_ENDPOINT + if not trace_endpoint: + trace_endpoint = dify_config.OTLP_BASE_ENDPOINT + "/v1/traces" exporter = HTTPSpanExporter( - endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/traces", - headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"}, + endpoint=trace_endpoint, + headers=headers, ) + + metric_endpoint = dify_config.OTLP_METRIC_ENDPOINT + if not metric_endpoint: + metric_endpoint = dify_config.OTLP_BASE_ENDPOINT + "/v1/metrics" metric_exporter = HTTPMetricExporter( - endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/metrics", - headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"}, + endpoint=metric_endpoint, + headers=headers, ) else: exporter = ConsoleSpanExporter() diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index c283b1b7ca..be2f6115f7 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -1,6 +1,10 @@ +import functools +import logging +from collections.abc import Callable from typing import Any, Union import redis +from redis import RedisError from redis.cache import CacheConfig from redis.cluster import ClusterNode, RedisCluster from redis.connection import Connection, SSLConnection @@ -9,6 +13,8 @@ from redis.sentinel import Sentinel from configs import dify_config from dify_app import DifyApp +logger = logging.getLogger(__name__) + class RedisClientWrapper: """ @@ -115,3 +121,25 @@ def init_app(app: DifyApp): redis_client.initialize(redis.Redis(connection_pool=pool)) app.extensions["redis"] = redis_client + + +def redis_fallback(default_return: Any = None): + """ + decorator to handle Redis operation exceptions and return a default value when Redis is unavailable. + + Args: + default_return: The value to return when a Redis operation fails. Defaults to None. + """ + + def decorator(func: Callable): + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except RedisError as e: + logger.warning(f"Redis operation failed in {func.__name__}: {str(e)}", exc_info=True) + return default_return + + return wrapper + + return decorator diff --git a/api/extensions/storage/azure_blob_storage.py b/api/extensions/storage/azure_blob_storage.py index 7448fd4a6b..81eec94da4 100644 --- a/api/extensions/storage/azure_blob_storage.py +++ b/api/extensions/storage/azure_blob_storage.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from datetime import UTC, datetime, timedelta +from datetime import timedelta from typing import Optional from azure.identity import ChainedTokenCredential, DefaultAzureCredential @@ -8,6 +8,7 @@ from azure.storage.blob import AccountSasPermissions, BlobServiceClient, Resourc from configs import dify_config from extensions.ext_redis import redis_client from extensions.storage.base_storage import BaseStorage +from libs.datetime_utils import naive_utc_now class AzureBlobStorage(BaseStorage): @@ -78,7 +79,7 @@ class AzureBlobStorage(BaseStorage): account_key=self.account_key or "", resource_types=ResourceTypes(service=True, container=True, object=True), permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True), - expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1), + expiry=naive_utc_now() + timedelta(hours=1), ) redis_client.set(cache_key, sas_token, ex=3000) return BlobServiceClient(account_url=self.account_url or "", credential=sas_token) diff --git a/api/factories/agent_factory.py b/api/factories/agent_factory.py index 4b12afb528..2570bc22f1 100644 --- a/api/factories/agent_factory.py +++ b/api/factories/agent_factory.py @@ -10,6 +10,6 @@ def get_plugin_agent_strategy( agent_provider = manager.fetch_agent_strategy_provider(tenant_id, agent_strategy_provider_name) for agent_strategy in agent_provider.declaration.strategies: if agent_strategy.identity.name == agent_strategy_name: - return PluginAgentStrategy(tenant_id, agent_strategy) + return PluginAgentStrategy(tenant_id, agent_strategy, agent_provider.meta.version) raise ValueError(f"Agent strategy {agent_strategy_name} not found") diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 25d1390492..512a9cb608 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -148,9 +148,7 @@ def _build_from_local_file( if strict_type_validation and detected_file_type.value != specified_type: raise ValueError("Detected file type does not match the specified type. Please verify the file.") - file_type = ( - FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type - ) + file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type return File( id=mapping.get("id"), @@ -199,9 +197,7 @@ def _build_from_remote_url( raise ValueError("Detected file type does not match the specified type. Please verify the file.") file_type = ( - FileType(specified_type) - if specified_type and specified_type != FileType.CUSTOM.value - else detected_file_type + FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type ) return File( @@ -265,13 +261,11 @@ def _build_from_tool_file( transfer_method: FileTransferMethod, strict_type_validation: bool = False, ) -> File: - tool_file = ( - db.session.query(ToolFile) - .filter( + tool_file = db.session.scalar( + select(ToolFile).where( ToolFile.id == mapping.get("tool_file_id"), ToolFile.tenant_id == tenant_id, ) - .first() ) if tool_file is None: @@ -279,16 +273,14 @@ def _build_from_tool_file( extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" - detected_file_type = _standardize_file_type(extension="." + extension, mime_type=tool_file.mimetype) + detected_file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype) specified_type = mapping.get("type") if strict_type_validation and specified_type and detected_file_type.value != specified_type: raise ValueError("Detected file type does not match the specified type. Please verify the file.") - file_type = ( - FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type - ) + file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type return File( id=mapping.get("id"), diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 250ee4695e..39ebd009d5 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -91,9 +91,13 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen result = StringVariable.model_validate(mapping) case SegmentType.SECRET: result = SecretVariable.model_validate(mapping) - case SegmentType.NUMBER if isinstance(value, int): + case SegmentType.NUMBER | SegmentType.INTEGER if isinstance(value, int): + mapping = dict(mapping) + mapping["value_type"] = SegmentType.INTEGER result = IntegerVariable.model_validate(mapping) - case SegmentType.NUMBER if isinstance(value, float): + case SegmentType.NUMBER | SegmentType.FLOAT if isinstance(value, float): + mapping = dict(mapping) + mapping["value_type"] = SegmentType.FLOAT result = FloatVariable.model_validate(mapping) case SegmentType.NUMBER if not isinstance(value, float | int): raise VariableError(f"invalid number value {value}") @@ -119,6 +123,8 @@ def infer_segment_type_from_value(value: Any, /) -> SegmentType: def build_segment(value: Any, /) -> Segment: + # NOTE: If you have runtime type information available, consider using the `build_segment_with_type` + # below if value is None: return NoneSegment() if isinstance(value, str): @@ -134,12 +140,17 @@ def build_segment(value: Any, /) -> Segment: if isinstance(value, list): items = [build_segment(item) for item in value] types = {item.value_type for item in items} - if len(types) != 1 or all(isinstance(item, ArraySegment) for item in items): + if all(isinstance(item, ArraySegment) for item in items): return ArrayAnySegment(value=value) + elif len(types) != 1: + if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}): + return ArrayNumberSegment(value=value) + return ArrayAnySegment(value=value) + match types.pop(): case SegmentType.STRING: return ArrayStringSegment(value=value) - case SegmentType.NUMBER: + case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: return ArrayNumberSegment(value=value) case SegmentType.OBJECT: return ArrayObjectSegment(value=value) @@ -153,6 +164,22 @@ def build_segment(value: Any, /) -> Segment: raise ValueError(f"not supported value {value}") +_segment_factory: Mapping[SegmentType, type[Segment]] = { + SegmentType.NONE: NoneSegment, + SegmentType.STRING: StringSegment, + SegmentType.INTEGER: IntegerSegment, + SegmentType.FLOAT: FloatSegment, + SegmentType.FILE: FileSegment, + SegmentType.OBJECT: ObjectSegment, + # Array types + SegmentType.ARRAY_ANY: ArrayAnySegment, + SegmentType.ARRAY_STRING: ArrayStringSegment, + SegmentType.ARRAY_NUMBER: ArrayNumberSegment, + SegmentType.ARRAY_OBJECT: ArrayObjectSegment, + SegmentType.ARRAY_FILE: ArrayFileSegment, +} + + def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: """ Build a segment with explicit type checking. @@ -190,7 +217,7 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: if segment_type == SegmentType.NONE: return NoneSegment() else: - raise TypeMismatchError(f"Expected {segment_type}, but got None") + raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None") # Handle empty list special case for array types if isinstance(value, list) and len(value) == 0: @@ -205,21 +232,25 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: elif segment_type == SegmentType.ARRAY_FILE: return ArrayFileSegment(value=value) else: - raise TypeMismatchError(f"Expected {segment_type}, but got empty list") - - # Build segment using existing logic to infer actual type - inferred_segment = build_segment(value) - inferred_type = inferred_segment.value_type + raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list") + inferred_type = SegmentType.infer_segment_type(value) # Type compatibility checking + if inferred_type is None: + raise TypeMismatchError( + f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}" + ) if inferred_type == segment_type: - return inferred_segment - - # Type mismatch - raise error with descriptive message - raise TypeMismatchError( - f"Type mismatch: expected {segment_type}, but value '{value}' " - f"(type: {type(value).__name__}) corresponds to {inferred_type}" - ) + segment_class = _segment_factory[segment_type] + return segment_class(value_type=segment_type, value=value) + elif segment_type == SegmentType.NUMBER and inferred_type in ( + SegmentType.INTEGER, + SegmentType.FLOAT, + ): + segment_class = _segment_factory[inferred_type] + return segment_class(value_type=inferred_type, value=value) + else: + raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}") def segment_to_variable( @@ -247,6 +278,6 @@ def segment_to_variable( name=name, description=description, value=segment.value, - selector=selector, + selector=list(selector), ), ) diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py new file mode 100644 index 0000000000..8288bd54a3 --- /dev/null +++ b/api/fields/_value_type_serializer.py @@ -0,0 +1,15 @@ +from typing import TypedDict + +from core.variables.segments import Segment +from core.variables.types import SegmentType + + +class _VarTypedDict(TypedDict, total=False): + value_type: SegmentType + + +def serialize_value_type(v: _VarTypedDict | Segment) -> str: + if isinstance(v, Segment): + return v.value_type.exposed_type().value + else: + return v["value_type"].exposed_type().value diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index 500ca47c7e..b6d85e0e24 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -1,8 +1,21 @@ +import json + from flask_restful import fields from fields.workflow_fields import workflow_partial_fields from libs.helper import AppIconUrlField, TimestampField + +class JsonStringField(fields.Raw): + def format(self, value): + if isinstance(value, str): + try: + return json.loads(value) + except (json.JSONDecodeError, TypeError): + return value + return value + + app_detail_kernel_fields = { "id": fields.String, "name": fields.String, @@ -175,6 +188,7 @@ app_detail_fields_with_site = { "site": fields.Nested(site_fields), "api_base_url": fields.String, "use_icon_as_answer_icon": fields.Boolean, + "max_active_requests": fields.Integer, "created_by": fields.String, "created_at": TimestampField, "updated_by": fields.String, @@ -218,3 +232,14 @@ app_import_fields = { app_import_check_dependencies_fields = { "leaked_dependencies": fields.List(fields.Nested(leaked_dependency_fields)), } + +app_server_fields = { + "id": fields.String, + "name": fields.String, + "server_code": fields.String, + "description": fields.String, + "status": fields.String, + "parameters": JsonStringField, + "created_at": TimestampField, + "updated_at": TimestampField, +} diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py index 71785e7d67..c5a0c9a49d 100644 --- a/api/fields/conversation_variable_fields.py +++ b/api/fields/conversation_variable_fields.py @@ -2,10 +2,12 @@ from flask_restful import fields from libs.helper import TimestampField +from ._value_type_serializer import serialize_value_type + conversation_variable_fields = { "id": fields.String, "name": fields.String, - "value_type": fields.String(attribute="value_type.value"), + "value_type": fields.String(attribute=serialize_value_type), "value": fields.String, "description": fields.String, "created_at": TimestampField, diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 9f1bef3b36..930e59cc1c 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -5,6 +5,8 @@ from core.variables import SecretVariable, SegmentType, Variable from fields.member_fields import simple_account_fields from libs.helper import TimestampField +from ._value_type_serializer import serialize_value_type + ENVIRONMENT_VARIABLE_SUPPORTED_TYPES = (SegmentType.STRING, SegmentType.NUMBER, SegmentType.SECRET) @@ -17,16 +19,23 @@ class EnvironmentVariableField(fields.Raw): "name": value.name, "value": encrypter.obfuscated_token(value.value), "value_type": value.value_type.value, + "description": value.description, } if isinstance(value, Variable): return { "id": value.id, "name": value.name, "value": value.value, - "value_type": value.value_type.value, + "value_type": value.value_type.exposed_type().value, + "description": value.description, } if isinstance(value, dict): - value_type = value.get("value_type") + value_type_str = value.get("value_type") + if not isinstance(value_type_str, str): + raise TypeError( + f"unexpected type for value_type field, value={value_type_str}, type={type(value_type_str)}" + ) + value_type = SegmentType(value_type_str).exposed_type() if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES: raise ValueError(f"Unsupported environment variable value type: {value_type}") return value @@ -35,7 +44,7 @@ class EnvironmentVariableField(fields.Raw): conversation_variable_fields = { "id": fields.String, "name": fields.String, - "value_type": fields.String(attribute="value_type.value"), + "value_type": fields.String(attribute=serialize_value_type), "value": fields.Raw, "description": fields.String, } diff --git a/api/libs/email_i18n.py b/api/libs/email_i18n.py new file mode 100644 index 0000000000..bfbf41a073 --- /dev/null +++ b/api/libs/email_i18n.py @@ -0,0 +1,461 @@ +""" +Email Internationalization Module + +This module provides a centralized, elegant way to handle email internationalization +in Dify. It follows Domain-Driven Design principles with proper type hints and +eliminates the need for repetitive language switching logic. +""" + +from dataclasses import dataclass +from enum import Enum +from typing import Any, Optional, Protocol + +from flask import render_template +from pydantic import BaseModel, Field + +from extensions.ext_mail import mail +from services.feature_service import BrandingModel, FeatureService + + +class EmailType(Enum): + """Enumeration of supported email types.""" + + RESET_PASSWORD = "reset_password" + INVITE_MEMBER = "invite_member" + EMAIL_CODE_LOGIN = "email_code_login" + CHANGE_EMAIL_OLD = "change_email_old" + CHANGE_EMAIL_NEW = "change_email_new" + OWNER_TRANSFER_CONFIRM = "owner_transfer_confirm" + OWNER_TRANSFER_OLD_NOTIFY = "owner_transfer_old_notify" + OWNER_TRANSFER_NEW_NOTIFY = "owner_transfer_new_notify" + ACCOUNT_DELETION_SUCCESS = "account_deletion_success" + ACCOUNT_DELETION_VERIFICATION = "account_deletion_verification" + ENTERPRISE_CUSTOM = "enterprise_custom" + QUEUE_MONITOR_ALERT = "queue_monitor_alert" + DOCUMENT_CLEAN_NOTIFY = "document_clean_notify" + + +class EmailLanguage(Enum): + """Supported email languages with fallback handling.""" + + EN_US = "en-US" + ZH_HANS = "zh-Hans" + + @classmethod + def from_language_code(cls, language_code: str) -> "EmailLanguage": + """Convert a language code to EmailLanguage with fallback to English.""" + if language_code == "zh-Hans": + return cls.ZH_HANS + return cls.EN_US + + +@dataclass(frozen=True) +class EmailTemplate: + """Immutable value object representing an email template configuration.""" + + subject: str + template_path: str + branded_template_path: str + + +@dataclass(frozen=True) +class EmailContent: + """Immutable value object containing rendered email content.""" + + subject: str + html_content: str + template_context: dict[str, Any] + + +class EmailI18nConfig(BaseModel): + """Configuration for email internationalization.""" + + model_config = {"frozen": True, "extra": "forbid"} + + templates: dict[EmailType, dict[EmailLanguage, EmailTemplate]] = Field( + default_factory=dict, description="Mapping of email types to language-specific templates" + ) + + def get_template(self, email_type: EmailType, language: EmailLanguage) -> EmailTemplate: + """Get template configuration for specific email type and language.""" + type_templates = self.templates.get(email_type) + if not type_templates: + raise ValueError(f"No templates configured for email type: {email_type}") + + template = type_templates.get(language) + if not template: + # Fallback to English if specific language not found + template = type_templates.get(EmailLanguage.EN_US) + if not template: + raise ValueError(f"No template found for {email_type} in {language} or English") + + return template + + +class EmailRenderer(Protocol): + """Protocol for email template renderers.""" + + def render_template(self, template_path: str, **context: Any) -> str: + """Render email template with given context.""" + ... + + +class FlaskEmailRenderer: + """Flask-based email template renderer.""" + + def render_template(self, template_path: str, **context: Any) -> str: + """Render email template using Flask's render_template.""" + return render_template(template_path, **context) + + +class BrandingService(Protocol): + """Protocol for branding service abstraction.""" + + def get_branding_config(self) -> BrandingModel: + """Get current branding configuration.""" + ... + + +class FeatureBrandingService: + """Feature service based branding implementation.""" + + def get_branding_config(self) -> BrandingModel: + """Get branding configuration from feature service.""" + return FeatureService.get_system_features().branding + + +class EmailSender(Protocol): + """Protocol for email sending abstraction.""" + + def send_email(self, to: str, subject: str, html_content: str) -> None: + """Send email with given parameters.""" + ... + + +class FlaskMailSender: + """Flask-Mail based email sender.""" + + def send_email(self, to: str, subject: str, html_content: str) -> None: + """Send email using Flask-Mail.""" + if mail.is_inited(): + mail.send(to=to, subject=subject, html=html_content) + + +class EmailI18nService: + """ + Main service for internationalized email handling. + + This service provides a clean API for sending internationalized emails + with proper branding support and template management. + """ + + def __init__( + self, + config: EmailI18nConfig, + renderer: EmailRenderer, + branding_service: BrandingService, + sender: EmailSender, + ) -> None: + self._config = config + self._renderer = renderer + self._branding_service = branding_service + self._sender = sender + + def send_email( + self, + email_type: EmailType, + language_code: str, + to: str, + template_context: Optional[dict[str, Any]] = None, + ) -> None: + """ + Send internationalized email with branding support. + + Args: + email_type: Type of email to send + language_code: Target language code + to: Recipient email address + template_context: Additional context for template rendering + """ + if template_context is None: + template_context = {} + + language = EmailLanguage.from_language_code(language_code) + email_content = self._render_email_content(email_type, language, template_context) + + self._sender.send_email(to=to, subject=email_content.subject, html_content=email_content.html_content) + + def send_change_email( + self, + language_code: str, + to: str, + code: str, + phase: str, + ) -> None: + """ + Send change email notification with phase-specific handling. + + Args: + language_code: Target language code + to: Recipient email address + code: Verification code + phase: Either 'old_email' or 'new_email' + """ + if phase == "old_email": + email_type = EmailType.CHANGE_EMAIL_OLD + elif phase == "new_email": + email_type = EmailType.CHANGE_EMAIL_NEW + else: + raise ValueError(f"Invalid phase: {phase}. Must be 'old_email' or 'new_email'") + + self.send_email( + email_type=email_type, + language_code=language_code, + to=to, + template_context={ + "to": to, + "code": code, + }, + ) + + def send_raw_email( + self, + to: str | list[str], + subject: str, + html_content: str, + ) -> None: + """ + Send a raw email directly without template processing. + + This method is provided for backward compatibility with legacy email + sending that uses pre-rendered HTML content (e.g., enterprise emails + with custom templates). + + Args: + to: Recipient email address(es) + subject: Email subject + html_content: Pre-rendered HTML content + """ + if isinstance(to, list): + for recipient in to: + self._sender.send_email(to=recipient, subject=subject, html_content=html_content) + else: + self._sender.send_email(to=to, subject=subject, html_content=html_content) + + def _render_email_content( + self, + email_type: EmailType, + language: EmailLanguage, + template_context: dict[str, Any], + ) -> EmailContent: + """Render email content with branding and internationalization.""" + template_config = self._config.get_template(email_type, language) + branding = self._branding_service.get_branding_config() + + # Determine template path based on branding + template_path = template_config.branded_template_path if branding.enabled else template_config.template_path + + # Prepare template context with branding information + full_context = { + **template_context, + "branding_enabled": branding.enabled, + "application_title": branding.application_title if branding.enabled else "Dify", + } + + # Render template + html_content = self._renderer.render_template(template_path, **full_context) + + # Apply templating to subject with all context variables + subject = template_config.subject + try: + subject = subject.format(**full_context) + except KeyError: + # If template variables are missing, fall back to basic formatting + if branding.enabled and "{application_title}" in subject: + subject = subject.format(application_title=branding.application_title) + + return EmailContent( + subject=subject, + html_content=html_content, + template_context=full_context, + ) + + +def create_default_email_config() -> EmailI18nConfig: + """Create default email i18n configuration with all supported templates.""" + templates: dict[EmailType, dict[EmailLanguage, EmailTemplate]] = { + EmailType.RESET_PASSWORD: { + EmailLanguage.EN_US: EmailTemplate( + subject="Set Your {application_title} Password", + template_path="reset_password_mail_template_en-US.html", + branded_template_path="without-brand/reset_password_mail_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="设置您的 {application_title} 密码", + template_path="reset_password_mail_template_zh-CN.html", + branded_template_path="without-brand/reset_password_mail_template_zh-CN.html", + ), + }, + EmailType.INVITE_MEMBER: { + EmailLanguage.EN_US: EmailTemplate( + subject="Join {application_title} Workspace Now", + template_path="invite_member_mail_template_en-US.html", + branded_template_path="without-brand/invite_member_mail_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="立即加入 {application_title} 工作空间", + template_path="invite_member_mail_template_zh-CN.html", + branded_template_path="without-brand/invite_member_mail_template_zh-CN.html", + ), + }, + EmailType.EMAIL_CODE_LOGIN: { + EmailLanguage.EN_US: EmailTemplate( + subject="{application_title} Login Code", + template_path="email_code_login_mail_template_en-US.html", + branded_template_path="without-brand/email_code_login_mail_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="{application_title} 登录验证码", + template_path="email_code_login_mail_template_zh-CN.html", + branded_template_path="without-brand/email_code_login_mail_template_zh-CN.html", + ), + }, + EmailType.CHANGE_EMAIL_OLD: { + EmailLanguage.EN_US: EmailTemplate( + subject="Check your current email", + template_path="change_mail_confirm_old_template_en-US.html", + branded_template_path="without-brand/change_mail_confirm_old_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="检测您现在的邮箱", + template_path="change_mail_confirm_old_template_zh-CN.html", + branded_template_path="without-brand/change_mail_confirm_old_template_zh-CN.html", + ), + }, + EmailType.CHANGE_EMAIL_NEW: { + EmailLanguage.EN_US: EmailTemplate( + subject="Confirm your new email address", + template_path="change_mail_confirm_new_template_en-US.html", + branded_template_path="without-brand/change_mail_confirm_new_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="确认您的邮箱地址变更", + template_path="change_mail_confirm_new_template_zh-CN.html", + branded_template_path="without-brand/change_mail_confirm_new_template_zh-CN.html", + ), + }, + EmailType.OWNER_TRANSFER_CONFIRM: { + EmailLanguage.EN_US: EmailTemplate( + subject="Verify Your Request to Transfer Workspace Ownership", + template_path="transfer_workspace_owner_confirm_template_en-US.html", + branded_template_path="without-brand/transfer_workspace_owner_confirm_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="验证您转移工作空间所有权的请求", + template_path="transfer_workspace_owner_confirm_template_zh-CN.html", + branded_template_path="without-brand/transfer_workspace_owner_confirm_template_zh-CN.html", + ), + }, + EmailType.OWNER_TRANSFER_OLD_NOTIFY: { + EmailLanguage.EN_US: EmailTemplate( + subject="Workspace ownership has been transferred", + template_path="transfer_workspace_old_owner_notify_template_en-US.html", + branded_template_path="without-brand/transfer_workspace_old_owner_notify_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="工作区所有权已转移", + template_path="transfer_workspace_old_owner_notify_template_zh-CN.html", + branded_template_path="without-brand/transfer_workspace_old_owner_notify_template_zh-CN.html", + ), + }, + EmailType.OWNER_TRANSFER_NEW_NOTIFY: { + EmailLanguage.EN_US: EmailTemplate( + subject="You are now the owner of {WorkspaceName}", + template_path="transfer_workspace_new_owner_notify_template_en-US.html", + branded_template_path="without-brand/transfer_workspace_new_owner_notify_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="您现在是 {WorkspaceName} 的所有者", + template_path="transfer_workspace_new_owner_notify_template_zh-CN.html", + branded_template_path="without-brand/transfer_workspace_new_owner_notify_template_zh-CN.html", + ), + }, + EmailType.ACCOUNT_DELETION_SUCCESS: { + EmailLanguage.EN_US: EmailTemplate( + subject="Your Dify.AI Account Has Been Successfully Deleted", + template_path="delete_account_success_template_en-US.html", + branded_template_path="delete_account_success_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="您的 Dify.AI 账户已成功删除", + template_path="delete_account_success_template_zh-CN.html", + branded_template_path="delete_account_success_template_zh-CN.html", + ), + }, + EmailType.ACCOUNT_DELETION_VERIFICATION: { + EmailLanguage.EN_US: EmailTemplate( + subject="Dify.AI Account Deletion and Verification", + template_path="delete_account_code_email_template_en-US.html", + branded_template_path="delete_account_code_email_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="Dify.AI 账户删除和验证", + template_path="delete_account_code_email_template_zh-CN.html", + branded_template_path="delete_account_code_email_template_zh-CN.html", + ), + }, + EmailType.QUEUE_MONITOR_ALERT: { + EmailLanguage.EN_US: EmailTemplate( + subject="Alert: Dataset Queue pending tasks exceeded the limit", + template_path="queue_monitor_alert_email_template_en-US.html", + branded_template_path="queue_monitor_alert_email_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="警报:数据集队列待处理任务超过限制", + template_path="queue_monitor_alert_email_template_zh-CN.html", + branded_template_path="queue_monitor_alert_email_template_zh-CN.html", + ), + }, + EmailType.DOCUMENT_CLEAN_NOTIFY: { + EmailLanguage.EN_US: EmailTemplate( + subject="Dify Knowledge base auto disable notification", + template_path="clean_document_job_mail_template-US.html", + branded_template_path="clean_document_job_mail_template-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="Dify 知识库自动禁用通知", + template_path="clean_document_job_mail_template_zh-CN.html", + branded_template_path="clean_document_job_mail_template_zh-CN.html", + ), + }, + } + + return EmailI18nConfig(templates=templates) + + +# Singleton instance for application-wide use +def get_default_email_i18n_service() -> EmailI18nService: + """Get configured email i18n service with default dependencies.""" + config = create_default_email_config() + renderer = FlaskEmailRenderer() + branding_service = FeatureBrandingService() + sender = FlaskMailSender() + + return EmailI18nService( + config=config, + renderer=renderer, + branding_service=branding_service, + sender=sender, + ) + + +# Global instance +_email_i18n_service: Optional[EmailI18nService] = None + + +def get_email_i18n_service() -> EmailI18nService: + """Get global email i18n service instance.""" + global _email_i18n_service + if _email_i18n_service is None: + _email_i18n_service = get_default_email_i18n_service() + return _email_i18n_service diff --git a/api/libs/helper.py b/api/libs/helper.py index 3f2a630956..00772d530a 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -25,6 +25,31 @@ from extensions.ext_redis import redis_client if TYPE_CHECKING: from models.account import Account + from models.model import EndUser + + +def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None: + """ + Extract tenant_id from Account or EndUser object. + + Args: + user: Account or EndUser object + + Returns: + tenant_id string if available, None otherwise + + Raises: + ValueError: If user is neither Account nor EndUser + """ + from models.account import Account + from models.model import EndUser + + if isinstance(user, Account): + return user.current_tenant_id + elif isinstance(user, EndUser): + return user.tenant_id + else: + raise ValueError(f"Invalid user type: {type(user)}. Expected Account or EndUser.") def run(script): @@ -123,25 +148,6 @@ class StrLen: return value -class FloatRange: - """Restrict input to an float in a range (inclusive)""" - - def __init__(self, low, high, argument="argument"): - self.low = low - self.high = high - self.argument = argument - - def __call__(self, value): - value = _get_float(value) - if value < self.low or value > self.high: - error = "Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}".format( - arg=self.argument, val=value, lo=self.low, hi=self.high - ) - raise ValueError(error) - - return value - - class DatetimeString: def __init__(self, format, argument="argument"): self.format = format diff --git a/api/libs/jsonutil.py b/api/libs/jsonutil.py deleted file mode 100644 index fa29671034..0000000000 --- a/api/libs/jsonutil.py +++ /dev/null @@ -1,11 +0,0 @@ -import json - -from pydantic import BaseModel - - -class PydanticModelEncoder(json.JSONEncoder): - def default(self, o): - if isinstance(o, BaseModel): - return o.model_dump() - else: - super().default(o) diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 218109522d..987c5d7135 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -1,11 +1,12 @@ -import datetime import urllib.parse from typing import Any import requests from flask_login import current_user +from sqlalchemy import select from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.source import DataSourceOauthBinding @@ -61,21 +62,17 @@ class NotionOAuth(OAuthDataSource): "total": len(pages), } # save data source binding - data_source_binding = ( - db.session.query(DataSourceOauthBinding) - .filter( - db.and_( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.access_token == access_token, - ) + data_source_binding = db.session.scalar( + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.access_token == access_token, ) - .first() ) if data_source_binding: data_source_binding.source_info = source_info data_source_binding.disabled = False - data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + data_source_binding.updated_at = naive_utc_now() db.session.commit() else: new_data_source_binding = DataSourceOauthBinding( @@ -101,21 +98,17 @@ class NotionOAuth(OAuthDataSource): "total": len(pages), } # save data source binding - data_source_binding = ( - db.session.query(DataSourceOauthBinding) - .filter( - db.and_( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.access_token == access_token, - ) + data_source_binding = db.session.scalar( + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.access_token == access_token, ) - .first() ) if data_source_binding: data_source_binding.source_info = source_info data_source_binding.disabled = False - data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + data_source_binding.updated_at = naive_utc_now() db.session.commit() else: new_data_source_binding = DataSourceOauthBinding( @@ -129,18 +122,15 @@ class NotionOAuth(OAuthDataSource): def sync_data_source(self, binding_id: str): # save data source binding - data_source_binding = ( - db.session.query(DataSourceOauthBinding) - .filter( - db.and_( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.id == binding_id, - DataSourceOauthBinding.disabled == False, - ) + data_source_binding = db.session.scalar( + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.id == binding_id, + DataSourceOauthBinding.disabled == False, ) - .first() ) + if data_source_binding: # get all authorized pages pages = self.get_authorized_pages(data_source_binding.access_token) @@ -154,7 +144,7 @@ class NotionOAuth(OAuthDataSource): } data_source_binding.source_info = new_source_info data_source_binding.disabled = False - data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + data_source_binding.updated_at = naive_utc_now() db.session.commit() else: raise ValueError("Data source binding not found") diff --git a/api/libs/passport.py b/api/libs/passport.py index 8df4f529bc..fe8fc33b5f 100644 --- a/api/libs/passport.py +++ b/api/libs/passport.py @@ -14,9 +14,11 @@ class PassportService: def verify(self, token): try: return jwt.decode(token, self.sk, algorithms=["HS256"]) + except jwt.exceptions.ExpiredSignatureError: + raise Unauthorized("Token has expired.") except jwt.exceptions.InvalidSignatureError: raise Unauthorized("Invalid token signature.") except jwt.exceptions.DecodeError: raise Unauthorized("Invalid token.") - except jwt.exceptions.ExpiredSignatureError: - raise Unauthorized("Token has expired.") + except jwt.exceptions.PyJWTError: # Catch-all for other JWT errors + raise Unauthorized("Invalid token.") diff --git a/api/libs/rsa.py b/api/libs/rsa.py index 637bcc4a1d..ed7a0eb116 100644 --- a/api/libs/rsa.py +++ b/api/libs/rsa.py @@ -1,4 +1,6 @@ import hashlib +import os +from typing import Union from Crypto.Cipher import AES from Crypto.PublicKey import RSA @@ -9,14 +11,14 @@ from extensions.ext_storage import storage from libs import gmpy2_pkcs10aep_cipher -def generate_key_pair(tenant_id): +def generate_key_pair(tenant_id: str) -> str: private_key = RSA.generate(2048) public_key = private_key.publickey() pem_private = private_key.export_key() pem_public = public_key.export_key() - filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem" + filepath = os.path.join("privkeys", tenant_id, "private.pem") storage.save(filepath, pem_private) @@ -26,7 +28,7 @@ def generate_key_pair(tenant_id): prefix_hybrid = b"HYBRID:" -def encrypt(text, public_key): +def encrypt(text: str, public_key: Union[str, bytes]) -> bytes: if isinstance(public_key, str): public_key = public_key.encode() @@ -38,15 +40,15 @@ def encrypt(text, public_key): rsa_key = RSA.import_key(public_key) cipher_rsa = gmpy2_pkcs10aep_cipher.new(rsa_key) - enc_aes_key = cipher_rsa.encrypt(aes_key) + enc_aes_key: bytes = cipher_rsa.encrypt(aes_key) encrypted_data = enc_aes_key + cipher_aes.nonce + tag + ciphertext return prefix_hybrid + encrypted_data -def get_decrypt_decoding(tenant_id): - filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem" +def get_decrypt_decoding(tenant_id: str) -> tuple[RSA.RsaKey, object]: + filepath = os.path.join("privkeys", tenant_id, "private.pem") cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(filepath.encode()).hexdigest()) private_key = redis_client.get(cache_key) @@ -64,7 +66,7 @@ def get_decrypt_decoding(tenant_id): return rsa_key, cipher_rsa -def decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa): +def decrypt_token_with_decoding(encrypted_text: bytes, rsa_key: RSA.RsaKey, cipher_rsa) -> str: if encrypted_text.startswith(prefix_hybrid): encrypted_text = encrypted_text[len(prefix_hybrid) :] @@ -83,10 +85,10 @@ def decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa): return decrypted_text.decode() -def decrypt(encrypted_text, tenant_id): +def decrypt(encrypted_text: bytes, tenant_id: str) -> str: rsa_key, cipher_rsa = get_decrypt_decoding(tenant_id) - return decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa) + return decrypt_token_with_decoding(encrypted_text=encrypted_text, rsa_key=rsa_key, cipher_rsa=cipher_rsa) class PrivkeyNotFoundError(Exception): diff --git a/api/libs/uuid_utils.py b/api/libs/uuid_utils.py new file mode 100644 index 0000000000..a8190011ed --- /dev/null +++ b/api/libs/uuid_utils.py @@ -0,0 +1,164 @@ +import secrets +import struct +import time +import uuid + +# Reference for UUIDv7 specification: +# RFC 9562, Section 5.7 - https://www.rfc-editor.org/rfc/rfc9562.html#section-5.7 + +# Define the format for packing the timestamp as an unsigned 64-bit integer (big-endian). +# +# For details on the `struct.pack` format, refer to: +# https://docs.python.org/3/library/struct.html#byte-order-size-and-alignment +_PACK_TIMESTAMP = ">Q" + +# Define the format for packing the 12-bit random data A (as specified in RFC 9562 Section 5.7) +# into an unsigned 16-bit integer (big-endian). +_PACK_RAND_A = ">H" + + +def _create_uuidv7_bytes(timestamp_ms: int, random_bytes: bytes) -> bytes: + """Create UUIDv7 byte structure with given timestamp and random bytes. + + This is a private helper function that handles the common logic for creating + UUIDv7 byte structure according to RFC 9562 specification. + + UUIDv7 Structure: + - 48 bits: timestamp (milliseconds since Unix epoch) + - 12 bits: random data A (with version bits) + - 62 bits: random data B (with variant bits) + + The function performs the following operations: + 1. Creates a 128-bit (16-byte) UUID structure + 2. Packs the timestamp into the first 48 bits (6 bytes) + 3. Sets the version bits to 7 (0111) in the correct position + 4. Sets the variant bits to 10 (binary) in the correct position + 5. Fills the remaining bits with the provided random bytes + + Args: + timestamp_ms: The timestamp in milliseconds since Unix epoch (48 bits). + random_bytes: Random bytes to use for the random portions (must be 10 bytes). + First 2 bytes are used for random data A (12 bits after version). + Last 8 bytes are used for random data B (62 bits after variant). + + Returns: + A 16-byte bytes object representing the complete UUIDv7 structure. + + Note: + This function assumes the random_bytes parameter is exactly 10 bytes. + The caller is responsible for providing appropriate random data. + """ + # Create the 128-bit UUID structure + uuid_bytes = bytearray(16) + + # Pack timestamp (48 bits) into first 6 bytes + uuid_bytes[0:6] = struct.pack(_PACK_TIMESTAMP, timestamp_ms)[2:8] # Take last 6 bytes of 8-byte big-endian + + # Next 16 bits: random data A (12 bits) + version (4 bits) + # Take first 2 random bytes and set version to 7 + rand_a = struct.unpack(_PACK_RAND_A, random_bytes[0:2])[0] + # Clear the highest 4 bits to make room for the version field + # by performing a bitwise AND with 0x0FFF (binary: 0b0000_1111_1111_1111). + rand_a = rand_a & 0x0FFF + # Set the version field to 7 (binary: 0111) by performing a bitwise OR with 0x7000 (binary: 0b0111_0000_0000_0000). + rand_a = rand_a | 0x7000 + uuid_bytes[6:8] = struct.pack(_PACK_RAND_A, rand_a) + + # Last 64 bits: random data B (62 bits) + variant (2 bits) + # Use remaining 8 random bytes and set variant to 10 (binary) + uuid_bytes[8:16] = random_bytes[2:10] + + # Set variant bits (first 2 bits of byte 8 should be '10') + uuid_bytes[8] = (uuid_bytes[8] & 0x3F) | 0x80 # Set variant to 10xxxxxx + + return bytes(uuid_bytes) + + +def uuidv7(timestamp_ms: int | None = None) -> uuid.UUID: + """Generate a UUID version 7 according to RFC 9562 specification. + + UUIDv7 features a time-ordered value field derived from the widely + implemented and well known Unix Epoch timestamp source, the number of + milliseconds since midnight 1 Jan 1970 UTC, leap seconds excluded. + + Structure: + - 48 bits: timestamp (milliseconds since Unix epoch) + - 12 bits: random data A (with version bits) + - 62 bits: random data B (with variant bits) + + Args: + timestamp_ms: The timestamp used when generating UUID, use the current time if unspecified. + Should be an integer representing milliseconds since Unix epoch. + + Returns: + A UUID object representing a UUIDv7. + + Example: + >>> import time + >>> # Generate UUIDv7 with current time + >>> uuid_current = uuidv7() + >>> # Generate UUIDv7 with specific timestamp + >>> uuid_specific = uuidv7(int(time.time() * 1000)) + """ + if timestamp_ms is None: + timestamp_ms = int(time.time() * 1000) + + # Generate 10 random bytes for the random portions + random_bytes = secrets.token_bytes(10) + + # Create UUIDv7 bytes using the helper function + uuid_bytes = _create_uuidv7_bytes(timestamp_ms, random_bytes) + + return uuid.UUID(bytes=uuid_bytes) + + +def uuidv7_timestamp(id_: uuid.UUID) -> int: + """Extract the timestamp from a UUIDv7. + + UUIDv7 contains a 48-bit timestamp field representing milliseconds since + the Unix epoch (1970-01-01 00:00:00 UTC). This function extracts and + returns that timestamp as an integer representing milliseconds since the epoch. + + Args: + id_: A UUID object that should be a UUIDv7 (version 7). + + Returns: + The timestamp as an integer representing milliseconds since Unix epoch. + + Raises: + ValueError: If the provided UUID is not version 7. + + Example: + >>> uuid_v7 = uuidv7() + >>> timestamp = uuidv7_timestamp(uuid_v7) + >>> print(f"UUID was created at: {timestamp} ms") + """ + # Verify this is a UUIDv7 + if id_.version != 7: + raise ValueError(f"Expected UUIDv7 (version 7), got version {id_.version}") + + # Extract the UUID bytes + uuid_bytes = id_.bytes + + # Extract the first 48 bits (6 bytes) as the timestamp in milliseconds + # Pad with 2 zero bytes at the beginning to make it 8 bytes for unpacking as Q (unsigned long long) + timestamp_bytes = b"\x00\x00" + uuid_bytes[0:6] + ts_in_ms = struct.unpack(_PACK_TIMESTAMP, timestamp_bytes)[0] + + # Return timestamp directly in milliseconds as integer + assert isinstance(ts_in_ms, int) + return ts_in_ms + + +def uuidv7_boundary(timestamp_ms: int) -> uuid.UUID: + """Generate a non-random uuidv7 with the given timestamp (first 48 bits) and + all random bits to 0. As the smallest possible uuidv7 for that timestamp, + it may be used as a boundary for partitions. + """ + # Use zero bytes for all random portions + zero_random_bytes = b"\x00" * 10 + + # Create UUIDv7 bytes using the helper function + uuid_bytes = _create_uuidv7_bytes(timestamp_ms, zero_random_bytes) + + return uuid.UUID(bytes=uuid_bytes) diff --git a/api/migrations/versions/2025_06_25_0936-58eb7bdb93fe_add_mcp_server_tool_and_app_server.py b/api/migrations/versions/2025_06_25_0936-58eb7bdb93fe_add_mcp_server_tool_and_app_server.py new file mode 100644 index 0000000000..0548bf05ef --- /dev/null +++ b/api/migrations/versions/2025_06_25_0936-58eb7bdb93fe_add_mcp_server_tool_and_app_server.py @@ -0,0 +1,64 @@ +"""add mcp server tool and app server + +Revision ID: 58eb7bdb93fe +Revises: 0ab65e1cc7fa +Create Date: 2025-06-25 09:36:07.510570 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '58eb7bdb93fe' +down_revision = '0ab65e1cc7fa' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('app_mcp_servers', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('description', sa.String(length=255), nullable=False), + sa.Column('server_code', sa.String(length=255), nullable=False), + sa.Column('status', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False), + sa.Column('parameters', sa.Text(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='app_mcp_server_pkey'), + sa.UniqueConstraint('tenant_id', 'app_id', name='unique_app_mcp_server_tenant_app_id'), + sa.UniqueConstraint('server_code', name='unique_app_mcp_server_server_code') + ) + op.create_table('tool_mcp_providers', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('name', sa.String(length=40), nullable=False), + sa.Column('server_identifier', sa.String(length=24), nullable=False), + sa.Column('server_url', sa.Text(), nullable=False), + sa.Column('server_url_hash', sa.String(length=64), nullable=False), + sa.Column('icon', sa.String(length=255), nullable=True), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('user_id', models.types.StringUUID(), nullable=False), + sa.Column('encrypted_credentials', sa.Text(), nullable=True), + sa.Column('authed', sa.Boolean(), nullable=False), + sa.Column('tools', sa.Text(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_mcp_provider_pkey'), + sa.UniqueConstraint('tenant_id', 'name', name='unique_mcp_provider_name'), + sa.UniqueConstraint('tenant_id', 'server_identifier', name='unique_mcp_provider_server_identifier'), + sa.UniqueConstraint('tenant_id', 'server_url_hash', name='unique_mcp_provider_server_url') + ) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('tool_mcp_providers') + op.drop_table('app_mcp_servers') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_07_02_2332-1c9ba48be8e4_add_uuidv7_function_in_sql.py b/api/migrations/versions/2025_07_02_2332-1c9ba48be8e4_add_uuidv7_function_in_sql.py new file mode 100644 index 0000000000..2bbbb3d28e --- /dev/null +++ b/api/migrations/versions/2025_07_02_2332-1c9ba48be8e4_add_uuidv7_function_in_sql.py @@ -0,0 +1,86 @@ +"""add uuidv7 function in SQL + +Revision ID: 1c9ba48be8e4 +Revises: 58eb7bdb93fe +Create Date: 2025-07-02 23:32:38.484499 + +""" + +""" +The functions in this files comes from https://github.com/dverite/postgres-uuidv7-sql/, with minor modifications. + +LICENSE: + +# Copyright and License + +Copyright (c) 2024, Daniel Vérité + +Permission to use, copy, modify, and distribute this software and its documentation for any purpose, without fee, and without a written agreement is hereby granted, provided that the above copyright notice and this paragraph and the following two paragraphs appear in all copies. + +In no event shall Daniel Vérité be liable to any party for direct, indirect, special, incidental, or consequential damages, including lost profits, arising out of the use of this software and its documentation, even if Daniel Vérité has been advised of the possibility of such damage. + +Daniel Vérité specifically disclaims any warranties, including, but not limited to, the implied warranties of merchantability and fitness for a particular purpose. The software provided hereunder is on an "AS IS" basis, and Daniel Vérité has no obligations to provide maintenance, support, updates, enhancements, or modifications. +""" + +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '1c9ba48be8e4' +down_revision = '58eb7bdb93fe' +branch_labels: None = None +depends_on: None = None + + +def upgrade(): + # This implementation differs slightly from the original uuidv7 function in + # https://github.com/dverite/postgres-uuidv7-sql/. + # The ability to specify source timestamp has been removed because its type signature is incompatible with + # PostgreSQL 18's `uuidv7` function. This capability is rarely needed in practice, as IDs can be + # generated and controlled within the application layer. + op.execute(sa.text(r""" +/* Main function to generate a uuidv7 value with millisecond precision */ +CREATE FUNCTION uuidv7() RETURNS uuid +AS +$$ + -- Replace the first 48 bits of a uuidv4 with the current + -- number of milliseconds since 1970-01-01 UTC + -- and set the "ver" field to 7 by setting additional bits +SELECT encode( + set_bit( + set_bit( + overlay(uuid_send(gen_random_uuid()) placing + substring(int8send((extract(epoch from clock_timestamp()) * 1000)::bigint) from + 3) + from 1 for 6), + 52, 1), + 53, 1), 'hex')::uuid; +$$ LANGUAGE SQL VOLATILE PARALLEL SAFE; + +COMMENT ON FUNCTION uuidv7 IS + 'Generate a uuid-v7 value with a 48-bit timestamp (millisecond precision) and 74 bits of randomness'; +""")) + + op.execute(sa.text(r""" +CREATE FUNCTION uuidv7_boundary(timestamptz) RETURNS uuid +AS +$$ + /* uuid fields: version=0b0111, variant=0b10 */ +SELECT encode( + overlay('\x00000000000070008000000000000000'::bytea + placing substring(int8send(floor(extract(epoch from $1) * 1000)::bigint) from 3) + from 1 for 6), + 'hex')::uuid; +$$ LANGUAGE SQL STABLE STRICT PARALLEL SAFE; + +COMMENT ON FUNCTION uuidv7_boundary(timestamptz) IS + 'Generate a non-random uuidv7 with the given timestamp (first 48 bits) and all random bits to 0. As the smallest possible uuidv7 for that timestamp, it may be used as a boundary for partitions.'; +""" +)) + + +def downgrade(): + op.execute(sa.text("DROP FUNCTION uuidv7")) + op.execute(sa.text("DROP FUNCTION uuidv7_boundary")) diff --git a/api/migrations/versions/2025_07_04_1705-71f5020c6470_tool_oauth.py b/api/migrations/versions/2025_07_04_1705-71f5020c6470_tool_oauth.py new file mode 100644 index 0000000000..df4fbf0a0e --- /dev/null +++ b/api/migrations/versions/2025_07_04_1705-71f5020c6470_tool_oauth.py @@ -0,0 +1,62 @@ +"""tool oauth + +Revision ID: 71f5020c6470 +Revises: 4474872b0ee6 +Create Date: 2025-06-24 17:05:43.118647 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '71f5020c6470' +down_revision = '1c9ba48be8e4' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tool_oauth_system_clients', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('plugin_id', sa.String(length=512), nullable=False), + sa.Column('provider', sa.String(length=255), nullable=False), + sa.Column('encrypted_oauth_params', sa.Text(), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_oauth_system_client_pkey'), + sa.UniqueConstraint('plugin_id', 'provider', name='tool_oauth_system_client_plugin_id_provider_idx') + ) + op.create_table('tool_oauth_tenant_clients', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('plugin_id', sa.String(length=512), nullable=False), + sa.Column('provider', sa.String(length=255), nullable=False), + sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False), + sa.Column('encrypted_oauth_params', sa.Text(), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_oauth_tenant_client_pkey'), + sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_tool_oauth_tenant_client') + ) + + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('name', sa.String(length=256), server_default=sa.text("'API KEY 1'::character varying"), nullable=False)) + batch_op.add_column(sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False)) + batch_op.add_column(sa.Column('credential_type', sa.String(length=32), server_default=sa.text("'api-key'::character varying"), nullable=False)) + batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique') + batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider', 'name']) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique') + batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider']) + batch_op.drop_column('credential_type') + batch_op.drop_column('is_default') + batch_op.drop_column('name') + + op.drop_table('tool_oauth_tenant_clients') + op.drop_table('tool_oauth_system_clients') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_07_21_0935-1a83934ad6d1_update_models.py b/api/migrations/versions/2025_07_21_0935-1a83934ad6d1_update_models.py new file mode 100644 index 0000000000..3bdbafda7c --- /dev/null +++ b/api/migrations/versions/2025_07_21_0935-1a83934ad6d1_update_models.py @@ -0,0 +1,51 @@ +"""update models + +Revision ID: 1a83934ad6d1 +Revises: 71f5020c6470 +Create Date: 2025-07-21 09:35:48.774794 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '1a83934ad6d1' +down_revision = '71f5020c6470' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op: + batch_op.alter_column('server_identifier', + existing_type=sa.VARCHAR(length=24), + type_=sa.String(length=64), + existing_nullable=False) + + with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op: + batch_op.alter_column('tool_name', + existing_type=sa.VARCHAR(length=40), + type_=sa.String(length=128), + existing_nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op: + batch_op.alter_column('tool_name', + existing_type=sa.String(length=128), + type_=sa.VARCHAR(length=40), + existing_nullable=False) + + with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op: + batch_op.alter_column('server_identifier', + existing_type=sa.String(length=64), + type_=sa.VARCHAR(length=24), + existing_nullable=False) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_07_22_0019-375fe79ead14_oauth_refresh_token.py b/api/migrations/versions/2025_07_22_0019-375fe79ead14_oauth_refresh_token.py new file mode 100644 index 0000000000..76d0cb2940 --- /dev/null +++ b/api/migrations/versions/2025_07_22_0019-375fe79ead14_oauth_refresh_token.py @@ -0,0 +1,34 @@ +"""oauth_refresh_token + +Revision ID: 375fe79ead14 +Revises: 1a83934ad6d1 +Create Date: 2025-07-22 00:19:45.599636 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '375fe79ead14' +down_revision = '1a83934ad6d1' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('expires_at', sa.BigInteger(), server_default=sa.text('-1'), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.drop_column('expires_at') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_07_23_1508-8bcc02c9bd07_add_tenant_plugin_autoupgrade_table.py b/api/migrations/versions/2025_07_23_1508-8bcc02c9bd07_add_tenant_plugin_autoupgrade_table.py new file mode 100644 index 0000000000..4ff0402a97 --- /dev/null +++ b/api/migrations/versions/2025_07_23_1508-8bcc02c9bd07_add_tenant_plugin_autoupgrade_table.py @@ -0,0 +1,42 @@ +"""add_tenant_plugin_autoupgrade_table + +Revision ID: 8bcc02c9bd07 +Revises: 375fe79ead14 +Create Date: 2025-07-23 15:08:50.161441 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '8bcc02c9bd07' +down_revision = '375fe79ead14' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tenant_plugin_auto_upgrade_strategies', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('strategy_setting', sa.String(length=16), server_default='fix_only', nullable=False), + sa.Column('upgrade_time_of_day', sa.Integer(), nullable=False), + sa.Column('upgrade_mode', sa.String(length=16), server_default='exclude', nullable=False), + sa.Column('exclude_plugins', sa.ARRAY(sa.String(length=255)), nullable=False), + sa.Column('include_plugins', sa.ARRAY(sa.String(length=255)), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tenant_plugin_auto_upgrade_strategy_pkey'), + sa.UniqueConstraint('tenant_id', name='unique_tenant_plugin_auto_upgrade_strategy') + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + op.drop_table('tenant_plugin_auto_upgrade_strategies') + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index 83b50eb099..1b4bdd32e4 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -34,6 +34,7 @@ from .model import ( App, AppAnnotationHitHistory, AppAnnotationSetting, + AppMCPServer, AppMode, AppModelConfig, Conversation, @@ -103,6 +104,7 @@ __all__ = [ "AppAnnotationHitHistory", "AppAnnotationSetting", "AppDatasetJoin", + "AppMCPServer", # Added "AppMode", "AppModelConfig", "BuiltinToolProvider", diff --git a/api/models/account.py b/api/models/account.py index 7ffeefa980..d63c5d7fb5 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -1,9 +1,10 @@ import enum import json +from datetime import datetime from typing import Optional, cast from flask_login import UserMixin # type: ignore -from sqlalchemy import func +from sqlalchemy import func, select from sqlalchemy.orm import Mapped, mapped_column, reconstructor from models.base import Base @@ -85,21 +86,23 @@ class Account(UserMixin, Base): __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) - name = db.Column(db.String(255), nullable=False) - email = db.Column(db.String(255), nullable=False) - password = db.Column(db.String(255), nullable=True) - password_salt = db.Column(db.String(255), nullable=True) - avatar = db.Column(db.String(255)) - interface_language = db.Column(db.String(255)) - interface_theme = db.Column(db.String(255)) - timezone = db.Column(db.String(255)) - last_login_at = db.Column(db.DateTime) - last_login_ip = db.Column(db.String(255)) - last_active_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - status = db.Column(db.String(16), nullable=False, server_default=db.text("'active'::character varying")) - initialized_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + name: Mapped[str] = mapped_column(db.String(255)) + email: Mapped[str] = mapped_column(db.String(255)) + password: Mapped[Optional[str]] = mapped_column(db.String(255)) + password_salt: Mapped[Optional[str]] = mapped_column(db.String(255)) + avatar: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) + interface_language: Mapped[Optional[str]] = mapped_column(db.String(255)) + interface_theme: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) + timezone: Mapped[Optional[str]] = mapped_column(db.String(255)) + last_login_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) + last_login_ip: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) + last_active_at: Mapped[datetime] = mapped_column( + db.DateTime, server_default=func.current_timestamp(), nullable=False + ) + status: Mapped[str] = mapped_column(db.String(16), server_default=db.text("'active'::character varying")) + initialized_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) + created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False) @reconstructor def init_on_load(self): @@ -116,7 +119,7 @@ class Account(UserMixin, Base): @current_tenant.setter def current_tenant(self, tenant: "Tenant"): - ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).first() + ta = db.session.scalar(select(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).limit(1)) if ta: self.role = TenantAccountRole(ta.role) self._current_tenant = tenant @@ -132,9 +135,9 @@ class Account(UserMixin, Base): tuple[Tenant, TenantAccountJoin], ( db.session.query(Tenant, TenantAccountJoin) - .filter(Tenant.id == tenant_id) - .filter(TenantAccountJoin.tenant_id == Tenant.id) - .filter(TenantAccountJoin.account_id == self.id) + .where(Tenant.id == tenant_id) + .where(TenantAccountJoin.tenant_id == Tenant.id) + .where(TenantAccountJoin.account_id == self.id) .one_or_none() ), ) @@ -143,7 +146,7 @@ class Account(UserMixin, Base): return tenant, join = tenant_account_join - self.role = join.role + self.role = TenantAccountRole(join.role) self._current_tenant = tenant @property @@ -158,11 +161,11 @@ class Account(UserMixin, Base): def get_by_openid(cls, provider: str, open_id: str): account_integrate = ( db.session.query(AccountIntegrate) - .filter(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id) + .where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id) .one_or_none() ) if account_integrate: - return db.session.query(Account).filter(Account.id == account_integrate.account_id).one_or_none() + return db.session.query(Account).where(Account.id == account_integrate.account_id).one_or_none() return None # check current_user.current_tenant.current_role in ['admin', 'owner'] @@ -196,19 +199,19 @@ class Tenant(Base): __tablename__ = "tenants" __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - name = db.Column(db.String(255), nullable=False) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + name: Mapped[str] = mapped_column(db.String(255)) encrypt_public_key = db.Column(db.Text) - plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying")) - status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) - custom_config = db.Column(db.Text) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + plan: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'basic'::character varying")) + status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'normal'::character varying")) + custom_config: Mapped[Optional[str]] = mapped_column(db.Text) + created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) def get_accounts(self) -> list[Account]: return ( db.session.query(Account) - .filter(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id) + .where(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id) .all() ) @@ -230,14 +233,14 @@ class TenantAccountJoin(Base): db.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - account_id = db.Column(StringUUID, nullable=False) - current = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - role = db.Column(db.String(16), nullable=False, server_default="normal") - invited_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID) + account_id: Mapped[str] = mapped_column(StringUUID) + current: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false")) + role: Mapped[str] = mapped_column(db.String(16), server_default="normal") + invited_by: Mapped[Optional[str]] = mapped_column(StringUUID) + created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) class AccountIntegrate(Base): @@ -248,13 +251,13 @@ class AccountIntegrate(Base): db.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - account_id = db.Column(StringUUID, nullable=False) - provider = db.Column(db.String(16), nullable=False) - open_id = db.Column(db.String(255), nullable=False) - encrypted_token = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + account_id: Mapped[str] = mapped_column(StringUUID) + provider: Mapped[str] = mapped_column(db.String(16)) + open_id: Mapped[str] = mapped_column(db.String(255)) + encrypted_token: Mapped[str] = mapped_column(db.String(255)) + created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) class InvitationCode(Base): @@ -265,15 +268,15 @@ class InvitationCode(Base): db.Index("invitation_codes_code_idx", "code", "status"), ) - id = db.Column(db.Integer, nullable=False) - batch = db.Column(db.String(255), nullable=False) - code = db.Column(db.String(32), nullable=False) - status = db.Column(db.String(16), nullable=False, server_default=db.text("'unused'::character varying")) - used_at = db.Column(db.DateTime) - used_by_tenant_id = db.Column(StringUUID) - used_by_account_id = db.Column(StringUUID) - deprecated_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + id: Mapped[int] = mapped_column(db.Integer) + batch: Mapped[str] = mapped_column(db.String(255)) + code: Mapped[str] = mapped_column(db.String(32)) + status: Mapped[str] = mapped_column(db.String(16), server_default=db.text("'unused'::character varying")) + used_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + used_by_tenant_id: Mapped[Optional[str]] = mapped_column(StringUUID) + used_by_account_id: Mapped[Optional[str]] = mapped_column(StringUUID) + deprecated_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) + created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=db.text("CURRENT_TIMESTAMP(0)")) class TenantPluginPermission(Base): @@ -299,3 +302,35 @@ class TenantPluginPermission(Base): db.String(16), nullable=False, server_default="everyone" ) debug_permission: Mapped[DebugPermission] = mapped_column(db.String(16), nullable=False, server_default="noone") + + +class TenantPluginAutoUpgradeStrategy(Base): + class StrategySetting(enum.StrEnum): + DISABLED = "disabled" + FIX_ONLY = "fix_only" + LATEST = "latest" + + class UpgradeMode(enum.StrEnum): + ALL = "all" + PARTIAL = "partial" + EXCLUDE = "exclude" + + __tablename__ = "tenant_plugin_auto_upgrade_strategies" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tenant_plugin_auto_upgrade_strategy_pkey"), + db.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + strategy_setting: Mapped[StrategySetting] = mapped_column(db.String(16), nullable=False, server_default="fix_only") + upgrade_time_of_day: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) # seconds of the day + upgrade_mode: Mapped[UpgradeMode] = mapped_column(db.String(16), nullable=False, server_default="exclude") + exclude_plugins: Mapped[list[str]] = mapped_column( + db.ARRAY(db.String(255)), nullable=False + ) # plugin_id (author/name) + include_plugins: Mapped[list[str]] = mapped_column( + db.ARRAY(db.String(255)), nullable=False + ) # plugin_id (author/name) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index 5a70e18622..3cef5a0fb2 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -1,6 +1,7 @@ import enum from sqlalchemy import func +from sqlalchemy.orm import mapped_column from .base import Base from .engine import db @@ -21,9 +22,9 @@ class APIBasedExtension(Base): db.Index("api_based_extension_tenant_idx", "tenant_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - name = db.Column(db.String(255), nullable=False) - api_endpoint = db.Column(db.String(255), nullable=False) - api_key = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + name = mapped_column(db.String(255), nullable=False) + api_endpoint = mapped_column(db.String(255), nullable=False) + api_key = mapped_column(db.Text, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/dataset.py b/api/models/dataset.py index 1ec27203a0..d877540213 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -8,12 +8,13 @@ import os import pickle import re import time +from datetime import datetime from json import JSONDecodeError -from typing import Any, cast +from typing import Any, Optional, cast -from sqlalchemy import func +from sqlalchemy import func, select from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.orm import Mapped +from sqlalchemy.orm import Mapped, mapped_column from configs import dify_config from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource @@ -45,29 +46,29 @@ class Dataset(Base): INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None] PROVIDER_LIST = ["vendor", "external", None] - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - name = db.Column(db.String(255), nullable=False) - description = db.Column(db.Text, nullable=True) - provider = db.Column(db.String(255), nullable=False, server_default=db.text("'vendor'::character varying")) - permission = db.Column(db.String(255), nullable=False, server_default=db.text("'only_me'::character varying")) - data_source_type = db.Column(db.String(255)) - indexing_technique = db.Column(db.String(255), nullable=True) - index_struct = db.Column(db.Text, nullable=True) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - embedding_model = db.Column(db.String(255), nullable=True) - embedding_model_provider = db.Column(db.String(255), nullable=True) - collection_binding_id = db.Column(StringUUID, nullable=True) - retrieval_model = db.Column(JSONB, nullable=True) - built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID) + name: Mapped[str] = mapped_column(db.String(255)) + description = mapped_column(db.Text, nullable=True) + provider: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'vendor'::character varying")) + permission: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'only_me'::character varying")) + data_source_type = mapped_column(db.String(255)) + indexing_technique: Mapped[Optional[str]] = mapped_column(db.String(255)) + index_struct = mapped_column(db.Text, nullable=True) + created_by = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by = mapped_column(StringUUID, nullable=True) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + embedding_model = db.Column(db.String(255), nullable=True) # TODO: mapped_column + embedding_model_provider = db.Column(db.String(255), nullable=True) # TODO: mapped_column + collection_binding_id = mapped_column(StringUUID, nullable=True) + retrieval_model = mapped_column(JSONB, nullable=True) + built_in_field_enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) @property def dataset_keyword_table(self): dataset_keyword_table = ( - db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == self.id).first() + db.session.query(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id).first() ) if dataset_keyword_table: return dataset_keyword_table @@ -94,7 +95,7 @@ class Dataset(Base): def latest_process_rule(self): return ( db.session.query(DatasetProcessRule) - .filter(DatasetProcessRule.dataset_id == self.id) + .where(DatasetProcessRule.dataset_id == self.id) .order_by(DatasetProcessRule.created_at.desc()) .first() ) @@ -103,19 +104,19 @@ class Dataset(Base): def app_count(self): return ( db.session.query(func.count(AppDatasetJoin.id)) - .filter(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id) + .where(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id) .scalar() ) @property def document_count(self): - return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar() + return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar() @property def available_document_count(self): return ( db.session.query(func.count(Document.id)) - .filter( + .where( Document.dataset_id == self.id, Document.indexing_status == "completed", Document.enabled == True, @@ -128,7 +129,7 @@ class Dataset(Base): def available_segment_count(self): return ( db.session.query(func.count(DocumentSegment.id)) - .filter( + .where( DocumentSegment.dataset_id == self.id, DocumentSegment.status == "completed", DocumentSegment.enabled == True, @@ -141,13 +142,13 @@ class Dataset(Base): return ( db.session.query(Document) .with_entities(func.coalesce(func.sum(Document.word_count), 0)) - .filter(Document.dataset_id == self.id) + .where(Document.dataset_id == self.id) .scalar() ) @property def doc_form(self): - document = db.session.query(Document).filter(Document.dataset_id == self.id).first() + document = db.session.query(Document).where(Document.dataset_id == self.id).first() if document: return document.doc_form return None @@ -168,7 +169,7 @@ class Dataset(Base): tags = ( db.session.query(Tag) .join(TagBinding, Tag.id == TagBinding.tag_id) - .filter( + .where( TagBinding.target_id == self.id, TagBinding.tenant_id == self.tenant_id, Tag.tenant_id == self.tenant_id, @@ -184,14 +185,14 @@ class Dataset(Base): if self.provider != "external": return None external_knowledge_binding = ( - db.session.query(ExternalKnowledgeBindings).filter(ExternalKnowledgeBindings.dataset_id == self.id).first() + db.session.query(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id).first() ) if not external_knowledge_binding: return None - external_knowledge_api = ( - db.session.query(ExternalKnowledgeApis) - .filter(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id) - .first() + external_knowledge_api = db.session.scalar( + select(ExternalKnowledgeApis).where( + ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id + ) ) if not external_knowledge_api: return None @@ -204,7 +205,7 @@ class Dataset(Base): @property def doc_metadata(self): - dataset_metadatas = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id == self.id).all() + dataset_metadatas = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == self.id).all() doc_metadata = [ { @@ -255,7 +256,7 @@ class Dataset(Base): @staticmethod def gen_collection_name_by_id(dataset_id: str) -> str: normalized_dataset_id = dataset_id.replace("-", "_") - return f"Vector_index_{normalized_dataset_id}_Node" + return f"{dify_config.VECTOR_INDEX_NAME_PREFIX}_{normalized_dataset_id}_Node" class DatasetProcessRule(Base): @@ -265,12 +266,12 @@ class DatasetProcessRule(Base): db.Index("dataset_process_rule_dataset_id_idx", "dataset_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) - dataset_id = db.Column(StringUUID, nullable=False) - mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) - rules = db.Column(db.Text, nullable=True) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + dataset_id = mapped_column(StringUUID, nullable=False) + mode = mapped_column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) + rules = mapped_column(db.Text, nullable=True) + created_by = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) MODES = ["automatic", "custom", "hierarchical"] PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] @@ -309,62 +310,64 @@ class Document(Base): ) # initial fields - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - dataset_id = db.Column(StringUUID, nullable=False) - position = db.Column(db.Integer, nullable=False) - data_source_type = db.Column(db.String(255), nullable=False) - data_source_info = db.Column(db.Text, nullable=True) - dataset_process_rule_id = db.Column(StringUUID, nullable=True) - batch = db.Column(db.String(255), nullable=False) - name = db.Column(db.String(255), nullable=False) - created_from = db.Column(db.String(255), nullable=False) - created_by = db.Column(StringUUID, nullable=False) - created_api_request_id = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + dataset_id = mapped_column(StringUUID, nullable=False) + position = mapped_column(db.Integer, nullable=False) + data_source_type = mapped_column(db.String(255), nullable=False) + data_source_info = mapped_column(db.Text, nullable=True) + dataset_process_rule_id = mapped_column(StringUUID, nullable=True) + batch = mapped_column(db.String(255), nullable=False) + name = mapped_column(db.String(255), nullable=False) + created_from = mapped_column(db.String(255), nullable=False) + created_by = mapped_column(StringUUID, nullable=False) + created_api_request_id = mapped_column(StringUUID, nullable=True) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) # start processing - processing_started_at = db.Column(db.DateTime, nullable=True) + processing_started_at = mapped_column(db.DateTime, nullable=True) # parsing - file_id = db.Column(db.Text, nullable=True) - word_count = db.Column(db.Integer, nullable=True) - parsing_completed_at = db.Column(db.DateTime, nullable=True) + file_id = mapped_column(db.Text, nullable=True) + word_count = mapped_column(db.Integer, nullable=True) + parsing_completed_at = mapped_column(db.DateTime, nullable=True) # cleaning - cleaning_completed_at = db.Column(db.DateTime, nullable=True) + cleaning_completed_at = mapped_column(db.DateTime, nullable=True) # split - splitting_completed_at = db.Column(db.DateTime, nullable=True) + splitting_completed_at = mapped_column(db.DateTime, nullable=True) # indexing - tokens = db.Column(db.Integer, nullable=True) - indexing_latency = db.Column(db.Float, nullable=True) - completed_at = db.Column(db.DateTime, nullable=True) + tokens = mapped_column(db.Integer, nullable=True) + indexing_latency = mapped_column(db.Float, nullable=True) + completed_at = mapped_column(db.DateTime, nullable=True) # pause - is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) - paused_by = db.Column(StringUUID, nullable=True) - paused_at = db.Column(db.DateTime, nullable=True) + is_paused = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) + paused_by = mapped_column(StringUUID, nullable=True) + paused_at = mapped_column(db.DateTime, nullable=True) # error - error = db.Column(db.Text, nullable=True) - stopped_at = db.Column(db.DateTime, nullable=True) + error = mapped_column(db.Text, nullable=True) + stopped_at = mapped_column(db.DateTime, nullable=True) # basic fields - indexing_status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) - enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - disabled_at = db.Column(db.DateTime, nullable=True) - disabled_by = db.Column(StringUUID, nullable=True) - archived = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - archived_reason = db.Column(db.String(255), nullable=True) - archived_by = db.Column(StringUUID, nullable=True) - archived_at = db.Column(db.DateTime, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - doc_type = db.Column(db.String(40), nullable=True) - doc_metadata = db.Column(JSONB, nullable=True) - doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) - doc_language = db.Column(db.String(255), nullable=True) + indexing_status = mapped_column( + db.String(255), nullable=False, server_default=db.text("'waiting'::character varying") + ) + enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + disabled_at = mapped_column(db.DateTime, nullable=True) + disabled_by = mapped_column(StringUUID, nullable=True) + archived = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + archived_reason = mapped_column(db.String(255), nullable=True) + archived_by = mapped_column(StringUUID, nullable=True) + archived_at = mapped_column(db.DateTime, nullable=True) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + doc_type = mapped_column(db.String(40), nullable=True) + doc_metadata = mapped_column(JSONB, nullable=True) + doc_form = mapped_column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) + doc_language = mapped_column(db.String(255), nullable=True) DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"] @@ -405,7 +408,7 @@ class Document(Base): data_source_info_dict = json.loads(self.data_source_info) file_detail = ( db.session.query(UploadFile) - .filter(UploadFile.id == data_source_info_dict["upload_file_id"]) + .where(UploadFile.id == data_source_info_dict["upload_file_id"]) .one_or_none() ) if file_detail: @@ -438,24 +441,24 @@ class Document(Base): @property def dataset(self): - return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none() + return db.session.query(Dataset).where(Dataset.id == self.dataset_id).one_or_none() @property def segment_count(self): - return db.session.query(DocumentSegment).filter(DocumentSegment.document_id == self.id).count() + return db.session.query(DocumentSegment).where(DocumentSegment.document_id == self.id).count() @property def hit_count(self): return ( db.session.query(DocumentSegment) .with_entities(func.coalesce(func.sum(DocumentSegment.hit_count), 0)) - .filter(DocumentSegment.document_id == self.id) + .where(DocumentSegment.document_id == self.id) .scalar() ) @property def uploader(self): - user = db.session.query(Account).filter(Account.id == self.created_by).first() + user = db.session.query(Account).where(Account.id == self.created_by).first() return user.name if user else None @property @@ -472,7 +475,7 @@ class Document(Base): document_metadatas = ( db.session.query(DatasetMetadata) .join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id) - .filter( + .where( DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id ) .all() @@ -652,58 +655,58 @@ class DocumentSegment(Base): ) # initial fields - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - dataset_id = db.Column(StringUUID, nullable=False) - document_id = db.Column(StringUUID, nullable=False) + id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + dataset_id = mapped_column(StringUUID, nullable=False) + document_id = mapped_column(StringUUID, nullable=False) position: Mapped[int] - content = db.Column(db.Text, nullable=False) - answer = db.Column(db.Text, nullable=True) - word_count = db.Column(db.Integer, nullable=False) - tokens = db.Column(db.Integer, nullable=False) + content = mapped_column(db.Text, nullable=False) + answer = mapped_column(db.Text, nullable=True) + word_count: Mapped[int] + tokens: Mapped[int] # indexing fields - keywords = db.Column(db.JSON, nullable=True) - index_node_id = db.Column(db.String(255), nullable=True) - index_node_hash = db.Column(db.String(255), nullable=True) + keywords = mapped_column(db.JSON, nullable=True) + index_node_id = mapped_column(db.String(255), nullable=True) + index_node_hash = mapped_column(db.String(255), nullable=True) # basic fields - hit_count = db.Column(db.Integer, nullable=False, default=0) - enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - disabled_at = db.Column(db.DateTime, nullable=True) - disabled_by = db.Column(StringUUID, nullable=True) - status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - indexing_at = db.Column(db.DateTime, nullable=True) - completed_at = db.Column(db.DateTime, nullable=True) - error = db.Column(db.Text, nullable=True) - stopped_at = db.Column(db.DateTime, nullable=True) + hit_count = mapped_column(db.Integer, nullable=False, default=0) + enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + disabled_at = mapped_column(db.DateTime, nullable=True) + disabled_by = mapped_column(StringUUID, nullable=True) + status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'waiting'::character varying")) + created_by = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by = mapped_column(StringUUID, nullable=True) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + indexing_at = mapped_column(db.DateTime, nullable=True) + completed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) + error = mapped_column(db.Text, nullable=True) + stopped_at = mapped_column(db.DateTime, nullable=True) @property def dataset(self): - return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first() + return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id)) @property def document(self): - return db.session.query(Document).filter(Document.id == self.document_id).first() + return db.session.scalar(select(Document).where(Document.id == self.document_id)) @property def previous_segment(self): - return ( - db.session.query(DocumentSegment) - .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1) - .first() + return db.session.scalar( + select(DocumentSegment).where( + DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1 + ) ) @property def next_segment(self): - return ( - db.session.query(DocumentSegment) - .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1) - .first() + return db.session.scalar( + select(DocumentSegment).where( + DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1 + ) ) @property @@ -714,7 +717,7 @@ class DocumentSegment(Base): if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: child_chunks = ( db.session.query(ChildChunk) - .filter(ChildChunk.segment_id == self.id) + .where(ChildChunk.segment_id == self.id) .order_by(ChildChunk.position.asc()) .all() ) @@ -731,7 +734,7 @@ class DocumentSegment(Base): if rules.parent_mode: child_chunks = ( db.session.query(ChildChunk) - .filter(ChildChunk.segment_id == self.id) + .where(ChildChunk.segment_id == self.id) .order_by(ChildChunk.position.asc()) .all() ) @@ -800,37 +803,37 @@ class ChildChunk(Base): ) # initial fields - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - dataset_id = db.Column(StringUUID, nullable=False) - document_id = db.Column(StringUUID, nullable=False) - segment_id = db.Column(StringUUID, nullable=False) - position = db.Column(db.Integer, nullable=False) - content = db.Column(db.Text, nullable=False) - word_count = db.Column(db.Integer, nullable=False) + id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + dataset_id = mapped_column(StringUUID, nullable=False) + document_id = mapped_column(StringUUID, nullable=False) + segment_id = mapped_column(StringUUID, nullable=False) + position = mapped_column(db.Integer, nullable=False) + content = mapped_column(db.Text, nullable=False) + word_count = mapped_column(db.Integer, nullable=False) # indexing fields - index_node_id = db.Column(db.String(255), nullable=True) - index_node_hash = db.Column(db.String(255), nullable=True) - type = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - indexing_at = db.Column(db.DateTime, nullable=True) - completed_at = db.Column(db.DateTime, nullable=True) - error = db.Column(db.Text, nullable=True) + index_node_id = mapped_column(db.String(255), nullable=True) + index_node_hash = mapped_column(db.String(255), nullable=True) + type = mapped_column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) + created_by = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_by = mapped_column(StringUUID, nullable=True) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + indexing_at = mapped_column(db.DateTime, nullable=True) + completed_at = mapped_column(db.DateTime, nullable=True) + error = mapped_column(db.Text, nullable=True) @property def dataset(self): - return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first() + return db.session.query(Dataset).where(Dataset.id == self.dataset_id).first() @property def document(self): - return db.session.query(Document).filter(Document.id == self.document_id).first() + return db.session.query(Document).where(Document.id == self.document_id).first() @property def segment(self): - return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first() + return db.session.query(DocumentSegment).where(DocumentSegment.id == self.segment_id).first() class AppDatasetJoin(Base): @@ -840,10 +843,10 @@ class AppDatasetJoin(Base): db.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"), ) - id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) - dataset_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) + app_id = mapped_column(StringUUID, nullable=False) + dataset_id = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) @property def app(self): @@ -857,14 +860,14 @@ class DatasetQuery(Base): db.Index("dataset_query_dataset_id_idx", "dataset_id"), ) - id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) - dataset_id = db.Column(StringUUID, nullable=False) - content = db.Column(db.Text, nullable=False) - source = db.Column(db.String(255), nullable=False) - source_app_id = db.Column(StringUUID, nullable=True) - created_by_role = db.Column(db.String, nullable=False) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) + dataset_id = mapped_column(StringUUID, nullable=False) + content = mapped_column(db.Text, nullable=False) + source = mapped_column(db.String(255), nullable=False) + source_app_id = mapped_column(StringUUID, nullable=True) + created_by_role = mapped_column(db.String, nullable=False) + created_by = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) class DatasetKeywordTable(Base): @@ -874,10 +877,10 @@ class DatasetKeywordTable(Base): db.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) - dataset_id = db.Column(StringUUID, nullable=False, unique=True) - keyword_table = db.Column(db.Text, nullable=False) - data_source_type = db.Column( + id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + dataset_id = mapped_column(StringUUID, nullable=False, unique=True) + keyword_table = mapped_column(db.Text, nullable=False) + data_source_type = mapped_column( db.String(255), nullable=False, server_default=db.text("'database'::character varying") ) @@ -920,14 +923,14 @@ class Embedding(Base): db.Index("created_at_idx", "created_at"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) - model_name = db.Column( + id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + model_name = mapped_column( db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying") ) - hash = db.Column(db.String(64), nullable=False) - embedding = db.Column(db.LargeBinary, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying")) + hash = mapped_column(db.String(64), nullable=False) + embedding = mapped_column(db.LargeBinary, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name = mapped_column(db.String(255), nullable=False, server_default=db.text("''::character varying")) def set_embedding(self, embedding_data: list[float]): self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) @@ -943,12 +946,12 @@ class DatasetCollectionBinding(Base): db.Index("provider_model_name_idx", "provider_name", "model_name"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) - provider_name = db.Column(db.String(255), nullable=False) - model_name = db.Column(db.String(255), nullable=False) - type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) - collection_name = db.Column(db.String(64), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + provider_name = mapped_column(db.String(255), nullable=False) + model_name = mapped_column(db.String(255), nullable=False) + type = mapped_column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) + collection_name = mapped_column(db.String(64), nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TidbAuthBinding(Base): @@ -960,15 +963,15 @@ class TidbAuthBinding(Base): db.Index("tidb_auth_bindings_created_at_idx", "created_at"), db.Index("tidb_auth_bindings_status_idx", "status"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=True) - cluster_id = db.Column(db.String(255), nullable=False) - cluster_name = db.Column(db.String(255), nullable=False) - active = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING")) - account = db.Column(db.String(255), nullable=False) - password = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=True) + cluster_id = mapped_column(db.String(255), nullable=False) + cluster_name = mapped_column(db.String(255), nullable=False) + active = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + status = mapped_column(db.String(255), nullable=False, server_default=db.text("CREATING")) + account = mapped_column(db.String(255), nullable=False) + password = mapped_column(db.String(255), nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class Whitelist(Base): @@ -977,10 +980,10 @@ class Whitelist(Base): db.PrimaryKeyConstraint("id", name="whitelists_pkey"), db.Index("whitelists_tenant_idx", "tenant_id"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=True) - category = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=True) + category = mapped_column(db.String(255), nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class DatasetPermission(Base): @@ -992,12 +995,12 @@ class DatasetPermission(Base): db.Index("idx_dataset_permissions_tenant_id", "tenant_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True) - dataset_id = db.Column(StringUUID, nullable=False) - account_id = db.Column(StringUUID, nullable=False) - tenant_id = db.Column(StringUUID, nullable=False) - has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True) + dataset_id = mapped_column(StringUUID, nullable=False) + account_id = mapped_column(StringUUID, nullable=False) + tenant_id = mapped_column(StringUUID, nullable=False) + has_permission = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class ExternalKnowledgeApis(Base): @@ -1008,15 +1011,15 @@ class ExternalKnowledgeApis(Base): db.Index("external_knowledge_apis_name_idx", "name"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) - name = db.Column(db.String(255), nullable=False) - description = db.Column(db.String(255), nullable=False) - tenant_id = db.Column(StringUUID, nullable=False) - settings = db.Column(db.Text, nullable=True) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + name = mapped_column(db.String(255), nullable=False) + description = mapped_column(db.String(255), nullable=False) + tenant_id = mapped_column(StringUUID, nullable=False) + settings = mapped_column(db.Text, nullable=True) + created_by = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by = mapped_column(StringUUID, nullable=True) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) def to_dict(self): return { @@ -1041,11 +1044,11 @@ class ExternalKnowledgeApis(Base): def dataset_bindings(self): external_knowledge_bindings = ( db.session.query(ExternalKnowledgeBindings) - .filter(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) + .where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) .all() ) dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] - datasets = db.session.query(Dataset).filter(Dataset.id.in_(dataset_ids)).all() + datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all() dataset_bindings = [] for dataset in datasets: dataset_bindings.append({"id": dataset.id, "name": dataset.name}) @@ -1063,15 +1066,15 @@ class ExternalKnowledgeBindings(Base): db.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - external_knowledge_api_id = db.Column(StringUUID, nullable=False) - dataset_id = db.Column(StringUUID, nullable=False) - external_knowledge_id = db.Column(db.Text, nullable=False) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + external_knowledge_api_id = mapped_column(StringUUID, nullable=False) + dataset_id = mapped_column(StringUUID, nullable=False) + external_knowledge_id = mapped_column(db.Text, nullable=False) + created_by = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by = mapped_column(StringUUID, nullable=True) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class DatasetAutoDisableLog(Base): @@ -1083,12 +1086,12 @@ class DatasetAutoDisableLog(Base): db.Index("dataset_auto_disable_log_created_atx", "created_at"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - dataset_id = db.Column(StringUUID, nullable=False) - document_id = db.Column(StringUUID, nullable=False) - notified = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + dataset_id = mapped_column(StringUUID, nullable=False) + document_id = mapped_column(StringUUID, nullable=False) + notified = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class RateLimitLog(Base): @@ -1099,11 +1102,11 @@ class RateLimitLog(Base): db.Index("rate_limit_log_operation_idx", "operation"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - subscription_plan = db.Column(db.String(255), nullable=False) - operation = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + subscription_plan = mapped_column(db.String(255), nullable=False) + operation = mapped_column(db.String(255), nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class DatasetMetadata(Base): @@ -1114,15 +1117,15 @@ class DatasetMetadata(Base): db.Index("dataset_metadata_dataset_idx", "dataset_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - dataset_id = db.Column(StringUUID, nullable=False) - type = db.Column(db.String(255), nullable=False) - name = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - created_by = db.Column(StringUUID, nullable=False) - updated_by = db.Column(StringUUID, nullable=True) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + dataset_id = mapped_column(StringUUID, nullable=False) + type = mapped_column(db.String(255), nullable=False) + name = mapped_column(db.String(255), nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_by = mapped_column(StringUUID, nullable=False) + updated_by = mapped_column(StringUUID, nullable=True) class DatasetMetadataBinding(Base): @@ -1135,10 +1138,10 @@ class DatasetMetadataBinding(Base): db.Index("dataset_metadata_binding_document_idx", "document_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - dataset_id = db.Column(StringUUID, nullable=False) - metadata_id = db.Column(StringUUID, nullable=False) - document_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - created_by = db.Column(StringUUID, nullable=False) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + dataset_id = mapped_column(StringUUID, nullable=False) + metadata_id = mapped_column(StringUUID, nullable=False) + document_id = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_by = mapped_column(StringUUID, nullable=False) diff --git a/api/models/model.py b/api/models/model.py index 93737043d5..a78a91ebd5 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -40,8 +40,8 @@ class DifySetup(Base): __tablename__ = "dify_setups" __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) - version = db.Column(db.String(255), nullable=False) - setup_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + version = mapped_column(db.String(255), nullable=False) + setup_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class AppMode(StrEnum): @@ -50,7 +50,6 @@ class AppMode(StrEnum): CHAT = "chat" ADVANCED_CHAT = "advanced-chat" AGENT_CHAT = "agent-chat" - CHANNEL = "channel" @classmethod def value_of(cls, value: str) -> "AppMode": @@ -75,31 +74,31 @@ class App(Base): __tablename__ = "apps" __table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id")) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) - name = db.Column(db.String(255), nullable=False) - description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) - mode: Mapped[str] = mapped_column(db.String(255), nullable=False) - icon_type = db.Column(db.String(255), nullable=True) # image, emoji + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID) + name: Mapped[str] = mapped_column(db.String(255)) + description: Mapped[str] = mapped_column(db.Text, server_default=db.text("''::character varying")) + mode: Mapped[str] = mapped_column(db.String(255)) + icon_type: Mapped[Optional[str]] = mapped_column(db.String(255)) # image, emoji icon = db.Column(db.String(255)) - icon_background = db.Column(db.String(255)) - app_model_config_id = db.Column(StringUUID, nullable=True) - workflow_id = db.Column(StringUUID, nullable=True) - status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) - enable_site = db.Column(db.Boolean, nullable=False) - enable_api = db.Column(db.Boolean, nullable=False) - api_rpm = db.Column(db.Integer, nullable=False, server_default=db.text("0")) - api_rph = db.Column(db.Integer, nullable=False, server_default=db.text("0")) - is_demo = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - tracing = db.Column(db.Text, nullable=True) - max_active_requests: Mapped[Optional[int]] = mapped_column(nullable=True) - created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + icon_background: Mapped[Optional[str]] = mapped_column(db.String(255)) + app_model_config_id = mapped_column(StringUUID, nullable=True) + workflow_id = mapped_column(StringUUID, nullable=True) + status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'normal'::character varying")) + enable_site: Mapped[bool] = mapped_column(db.Boolean) + enable_api: Mapped[bool] = mapped_column(db.Boolean) + api_rpm: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0")) + api_rph: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0")) + is_demo: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false")) + is_public: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false")) + is_universal: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false")) + tracing = mapped_column(db.Text, nullable=True) + max_active_requests: Mapped[Optional[int]] + created_by = mapped_column(StringUUID, nullable=True) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by = mapped_column(StringUUID, nullable=True) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + use_icon_as_answer_icon: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) @property def desc_or_prompt(self): @@ -114,13 +113,13 @@ class App(Base): @property def site(self): - site = db.session.query(Site).filter(Site.app_id == self.id).first() + site = db.session.query(Site).where(Site.app_id == self.id).first() return site @property def app_model_config(self): if self.app_model_config_id: - return db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first() + return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first() return None @@ -129,7 +128,7 @@ class App(Base): if self.workflow_id: from .workflow import Workflow - return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() + return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first() return None @@ -139,7 +138,7 @@ class App(Base): @property def tenant(self): - tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() + tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() return tenant @property @@ -283,7 +282,7 @@ class App(Base): tags = ( db.session.query(Tag) .join(TagBinding, Tag.id == TagBinding.tag_id) - .filter( + .where( TagBinding.target_id == self.id, TagBinding.tenant_id == self.tenant_id, Tag.tenant_id == self.tenant_id, @@ -297,7 +296,7 @@ class App(Base): @property def author_name(self): if self.created_by: - account = db.session.query(Account).filter(Account.id == self.created_by).first() + account = db.session.query(Account).where(Account.id == self.created_by).first() if account: return account.name @@ -308,38 +307,38 @@ class AppModelConfig(Base): __tablename__ = "app_model_configs" __table_args__ = (db.PrimaryKeyConstraint("id", name="app_model_config_pkey"), db.Index("app_app_id_idx", "app_id")) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) - provider = db.Column(db.String(255), nullable=True) - model_id = db.Column(db.String(255), nullable=True) - configs = db.Column(db.JSON, nullable=True) - created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - opening_statement = db.Column(db.Text) - suggested_questions = db.Column(db.Text) - suggested_questions_after_answer = db.Column(db.Text) - speech_to_text = db.Column(db.Text) - text_to_speech = db.Column(db.Text) - more_like_this = db.Column(db.Text) - model = db.Column(db.Text) - user_input_form = db.Column(db.Text) - dataset_query_variable = db.Column(db.String(255)) - pre_prompt = db.Column(db.Text) - agent_mode = db.Column(db.Text) - sensitive_word_avoidance = db.Column(db.Text) - retriever_resource = db.Column(db.Text) - prompt_type = db.Column(db.String(255), nullable=False, server_default=db.text("'simple'::character varying")) - chat_prompt_config = db.Column(db.Text) - completion_prompt_config = db.Column(db.Text) - dataset_configs = db.Column(db.Text) - external_data_tools = db.Column(db.Text) - file_upload = db.Column(db.Text) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = mapped_column(StringUUID, nullable=False) + provider = mapped_column(db.String(255), nullable=True) + model_id = mapped_column(db.String(255), nullable=True) + configs = mapped_column(db.JSON, nullable=True) + created_by = mapped_column(StringUUID, nullable=True) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by = mapped_column(StringUUID, nullable=True) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + opening_statement = mapped_column(db.Text) + suggested_questions = mapped_column(db.Text) + suggested_questions_after_answer = mapped_column(db.Text) + speech_to_text = mapped_column(db.Text) + text_to_speech = mapped_column(db.Text) + more_like_this = mapped_column(db.Text) + model = mapped_column(db.Text) + user_input_form = mapped_column(db.Text) + dataset_query_variable = mapped_column(db.String(255)) + pre_prompt = mapped_column(db.Text) + agent_mode = mapped_column(db.Text) + sensitive_word_avoidance = mapped_column(db.Text) + retriever_resource = mapped_column(db.Text) + prompt_type = mapped_column(db.String(255), nullable=False, server_default=db.text("'simple'::character varying")) + chat_prompt_config = mapped_column(db.Text) + completion_prompt_config = mapped_column(db.Text) + dataset_configs = mapped_column(db.Text) + external_data_tools = mapped_column(db.Text) + file_upload = mapped_column(db.Text) @property def app(self): - app = db.session.query(App).filter(App.id == self.app_id).first() + app = db.session.query(App).where(App.id == self.app_id).first() return app @property @@ -373,7 +372,7 @@ class AppModelConfig(Base): @property def annotation_reply_dict(self) -> dict: annotation_setting = ( - db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == self.app_id).first() + db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first() ) if annotation_setting: collection_binding_detail = annotation_setting.collection_binding_detail @@ -562,23 +561,23 @@ class RecommendedApp(Base): db.Index("recommended_app_is_listed_idx", "is_listed", "language"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) - description = db.Column(db.JSON, nullable=False) - copyright = db.Column(db.String(255), nullable=False) - privacy_policy = db.Column(db.String(255), nullable=False) + id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + app_id = mapped_column(StringUUID, nullable=False) + description = mapped_column(db.JSON, nullable=False) + copyright = mapped_column(db.String(255), nullable=False) + privacy_policy = mapped_column(db.String(255), nullable=False) custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") - category = db.Column(db.String(255), nullable=False) - position = db.Column(db.Integer, nullable=False, default=0) - is_listed = db.Column(db.Boolean, nullable=False, default=True) - install_count = db.Column(db.Integer, nullable=False, default=0) - language = db.Column(db.String(255), nullable=False, server_default=db.text("'en-US'::character varying")) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + category = mapped_column(db.String(255), nullable=False) + position = mapped_column(db.Integer, nullable=False, default=0) + is_listed = mapped_column(db.Boolean, nullable=False, default=True) + install_count = mapped_column(db.Integer, nullable=False, default=0) + language = mapped_column(db.String(255), nullable=False, server_default=db.text("'en-US'::character varying")) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def app(self): - app = db.session.query(App).filter(App.id == self.app_id).first() + app = db.session.query(App).where(App.id == self.app_id).first() return app @@ -591,34 +590,26 @@ class InstalledApp(Base): db.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - app_id = db.Column(StringUUID, nullable=False) - app_owner_tenant_id = db.Column(StringUUID, nullable=False) - position = db.Column(db.Integer, nullable=False, default=0) - is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - last_used_at = db.Column(db.DateTime, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + app_id = mapped_column(StringUUID, nullable=False) + app_owner_tenant_id = mapped_column(StringUUID, nullable=False) + position = mapped_column(db.Integer, nullable=False, default=0) + is_pinned = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + last_used_at = mapped_column(db.DateTime, nullable=True) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def app(self): - app = db.session.query(App).filter(App.id == self.app_id).first() + app = db.session.query(App).where(App.id == self.app_id).first() return app @property def tenant(self): - tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() + tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() return tenant -class ConversationSource(StrEnum): - """This enumeration is designed for use with `Conversation.from_source`.""" - - # NOTE(QuantumGhost): The enumeration members may not cover all possible cases. - API = "api" - CONSOLE = "console" - - class Conversation(Base): __tablename__ = "conversations" __table_args__ = ( @@ -627,42 +618,42 @@ class Conversation(Base): ) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) - app_model_config_id = db.Column(StringUUID, nullable=True) - model_provider = db.Column(db.String(255), nullable=True) - override_model_configs = db.Column(db.Text) - model_id = db.Column(db.String(255), nullable=True) + app_id = mapped_column(StringUUID, nullable=False) + app_model_config_id = mapped_column(StringUUID, nullable=True) + model_provider = mapped_column(db.String(255), nullable=True) + override_model_configs = mapped_column(db.Text) + model_id = mapped_column(db.String(255), nullable=True) mode: Mapped[str] = mapped_column(db.String(255)) - name = db.Column(db.String(255), nullable=False) - summary = db.Column(db.Text) + name = mapped_column(db.String(255), nullable=False) + summary = mapped_column(db.Text) _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) - introduction = db.Column(db.Text) - system_instruction = db.Column(db.Text) - system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) - status = db.Column(db.String(255), nullable=False) + introduction = mapped_column(db.Text) + system_instruction = mapped_column(db.Text) + system_instruction_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + status = mapped_column(db.String(255), nullable=False) # The `invoke_from` records how the conversation is created. # # Its value corresponds to the members of `InvokeFrom`. # (api/core/app/entities/app_invoke_entities.py) - invoke_from = db.Column(db.String(255), nullable=True) + invoke_from = mapped_column(db.String(255), nullable=True) # ref: ConversationSource. - from_source = db.Column(db.String(255), nullable=False) - from_end_user_id = db.Column(StringUUID) - from_account_id = db.Column(StringUUID) - read_at = db.Column(db.DateTime) - read_account_id = db.Column(StringUUID) + from_source = mapped_column(db.String(255), nullable=False) + from_end_user_id = mapped_column(StringUUID) + from_account_id = mapped_column(StringUUID) + read_at = mapped_column(db.DateTime) + read_account_id = mapped_column(StringUUID) dialogue_count: Mapped[int] = mapped_column(default=0) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all") message_annotations = db.relationship( "MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all" ) - is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + is_deleted = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) @property def inputs(self): @@ -723,7 +714,7 @@ class Conversation(Base): model_config["configs"] = override_model_configs else: app_model_config = ( - db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first() + db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first() ) if app_model_config: model_config = app_model_config.to_dict() @@ -746,21 +737,21 @@ class Conversation(Base): @property def annotated(self): - return db.session.query(MessageAnnotation).filter(MessageAnnotation.conversation_id == self.id).count() > 0 + return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).count() > 0 @property def annotation(self): - return db.session.query(MessageAnnotation).filter(MessageAnnotation.conversation_id == self.id).first() + return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).first() @property def message_count(self): - return db.session.query(Message).filter(Message.conversation_id == self.id).count() + return db.session.query(Message).where(Message.conversation_id == self.id).count() @property def user_feedback_stats(self): like = ( db.session.query(MessageFeedback) - .filter( + .where( MessageFeedback.conversation_id == self.id, MessageFeedback.from_source == "user", MessageFeedback.rating == "like", @@ -770,7 +761,7 @@ class Conversation(Base): dislike = ( db.session.query(MessageFeedback) - .filter( + .where( MessageFeedback.conversation_id == self.id, MessageFeedback.from_source == "user", MessageFeedback.rating == "dislike", @@ -784,7 +775,7 @@ class Conversation(Base): def admin_feedback_stats(self): like = ( db.session.query(MessageFeedback) - .filter( + .where( MessageFeedback.conversation_id == self.id, MessageFeedback.from_source == "admin", MessageFeedback.rating == "like", @@ -794,7 +785,7 @@ class Conversation(Base): dislike = ( db.session.query(MessageFeedback) - .filter( + .where( MessageFeedback.conversation_id == self.id, MessageFeedback.from_source == "admin", MessageFeedback.rating == "dislike", @@ -806,7 +797,7 @@ class Conversation(Base): @property def status_count(self): - messages = db.session.query(Message).filter(Message.conversation_id == self.id).all() + messages = db.session.query(Message).where(Message.conversation_id == self.id).all() status_counts = { WorkflowExecutionStatus.RUNNING: 0, WorkflowExecutionStatus.SUCCEEDED: 0, @@ -833,19 +824,19 @@ class Conversation(Base): def first_message(self): return ( db.session.query(Message) - .filter(Message.conversation_id == self.id) + .where(Message.conversation_id == self.id) .order_by(Message.created_at.asc()) .first() ) @property def app(self): - return db.session.query(App).filter(App.id == self.app_id).first() + return db.session.query(App).where(App.id == self.app_id).first() @property def from_end_user_session_id(self): if self.from_end_user_id: - end_user = db.session.query(EndUser).filter(EndUser.id == self.from_end_user_id).first() + end_user = db.session.query(EndUser).where(EndUser.id == self.from_end_user_id).first() if end_user: return end_user.session_id @@ -854,7 +845,7 @@ class Conversation(Base): @property def from_account_name(self): if self.from_account_id: - account = db.session.query(Account).filter(Account.id == self.from_account_id).first() + account = db.session.query(Account).where(Account.id == self.from_account_id).first() if account: return account.name @@ -905,36 +896,36 @@ class Message(Base): ) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) - model_provider = db.Column(db.String(255), nullable=True) - model_id = db.Column(db.String(255), nullable=True) - override_model_configs = db.Column(db.Text) - conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=False) + app_id = mapped_column(StringUUID, nullable=False) + model_provider = mapped_column(db.String(255), nullable=True) + model_id = mapped_column(db.String(255), nullable=True) + override_model_configs = mapped_column(db.Text) + conversation_id = mapped_column(StringUUID, db.ForeignKey("conversations.id"), nullable=False) _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) - query: Mapped[str] = db.Column(db.Text, nullable=False) - message = db.Column(db.JSON, nullable=False) - message_tokens: Mapped[int] = db.Column(db.Integer, nullable=False, server_default=db.text("0")) - message_unit_price = db.Column(db.Numeric(10, 4), nullable=False) - message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - answer: Mapped[str] = db.Column(db.Text, nullable=False) - answer_tokens: Mapped[int] = db.Column(db.Integer, nullable=False, server_default=db.text("0")) - answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) - answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - parent_message_id = db.Column(StringUUID, nullable=True) - provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0")) - total_price = db.Column(db.Numeric(10, 7)) - currency = db.Column(db.String(255), nullable=False) - status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) - error = db.Column(db.Text) - message_metadata = db.Column(db.Text) - invoke_from: Mapped[Optional[str]] = db.Column(db.String(255), nullable=True) - from_source = db.Column(db.String(255), nullable=False) - from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID) - from_account_id: Mapped[Optional[str]] = db.Column(StringUUID) + query: Mapped[str] = mapped_column(db.Text, nullable=False) + message = mapped_column(db.JSON, nullable=False) + message_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + message_unit_price = mapped_column(db.Numeric(10, 4), nullable=False) + message_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + answer: Mapped[str] = db.Column(db.Text, nullable=False) # TODO make it mapped_column + answer_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + answer_unit_price = mapped_column(db.Numeric(10, 4), nullable=False) + answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + parent_message_id = mapped_column(StringUUID, nullable=True) + provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0")) + total_price = mapped_column(db.Numeric(10, 7)) + currency = mapped_column(db.String(255), nullable=False) + status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + error = mapped_column(db.Text) + message_metadata = mapped_column(db.Text) + invoke_from: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) + from_source = mapped_column(db.String(255), nullable=False) + from_end_user_id: Mapped[Optional[str]] = mapped_column(StringUUID) + from_account_id: Mapped[Optional[str]] = mapped_column(StringUUID) created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - workflow_run_id = db.Column(StringUUID) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + agent_based = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) @property def inputs(self): @@ -1049,7 +1040,7 @@ class Message(Base): def user_feedback(self): feedback = ( db.session.query(MessageFeedback) - .filter(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user") + .where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user") .first() ) return feedback @@ -1058,30 +1049,30 @@ class Message(Base): def admin_feedback(self): feedback = ( db.session.query(MessageFeedback) - .filter(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin") + .where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin") .first() ) return feedback @property def feedbacks(self): - feedbacks = db.session.query(MessageFeedback).filter(MessageFeedback.message_id == self.id).all() + feedbacks = db.session.query(MessageFeedback).where(MessageFeedback.message_id == self.id).all() return feedbacks @property def annotation(self): - annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == self.id).first() + annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == self.id).first() return annotation @property def annotation_hit_history(self): annotation_history = ( - db.session.query(AppAnnotationHitHistory).filter(AppAnnotationHitHistory.message_id == self.id).first() + db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id).first() ) if annotation_history: annotation = ( db.session.query(MessageAnnotation) - .filter(MessageAnnotation.id == annotation_history.annotation_id) + .where(MessageAnnotation.id == annotation_history.annotation_id) .first() ) return annotation @@ -1089,11 +1080,9 @@ class Message(Base): @property def app_model_config(self): - conversation = db.session.query(Conversation).filter(Conversation.id == self.conversation_id).first() + conversation = db.session.query(Conversation).where(Conversation.id == self.conversation_id).first() if conversation: - return ( - db.session.query(AppModelConfig).filter(AppModelConfig.id == conversation.app_model_config_id).first() - ) + return db.session.query(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id).first() return None @@ -1109,7 +1098,7 @@ class Message(Base): def agent_thoughts(self): return ( db.session.query(MessageAgentThought) - .filter(MessageAgentThought.message_id == self.id) + .where(MessageAgentThought.message_id == self.id) .order_by(MessageAgentThought.position.asc()) .all() ) @@ -1122,8 +1111,8 @@ class Message(Base): def message_files(self): from factories import file_factory - message_files = db.session.query(MessageFile).filter(MessageFile.message_id == self.id).all() - current_app = db.session.query(App).filter(App.id == self.app_id).first() + message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all() + current_app = db.session.query(App).where(App.id == self.app_id).first() if not current_app: raise ValueError(f"App {self.app_id} not found") @@ -1187,7 +1176,7 @@ class Message(Base): if self.workflow_run_id: from .workflow import WorkflowRun - return db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first() + return db.session.query(WorkflowRun).where(WorkflowRun.id == self.workflow_run_id).first() return None @@ -1248,21 +1237,21 @@ class MessageFeedback(Base): db.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) - conversation_id = db.Column(StringUUID, nullable=False) - message_id = db.Column(StringUUID, nullable=False) - rating = db.Column(db.String(255), nullable=False) - content = db.Column(db.Text) - from_source = db.Column(db.String(255), nullable=False) - from_end_user_id = db.Column(StringUUID) - from_account_id = db.Column(StringUUID) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = mapped_column(StringUUID, nullable=False) + conversation_id = mapped_column(StringUUID, nullable=False) + message_id = mapped_column(StringUUID, nullable=False) + rating = mapped_column(db.String(255), nullable=False) + content = mapped_column(db.Text) + from_source = mapped_column(db.String(255), nullable=False) + from_end_user_id = mapped_column(StringUUID) + from_account_id = mapped_column(StringUUID) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def from_account(self): - account = db.session.query(Account).filter(Account.id == self.from_account_id).first() + account = db.session.query(Account).where(Account.id == self.from_account_id).first() return account def to_dict(self): @@ -1310,16 +1299,16 @@ class MessageFile(Base): self.created_by_role = created_by_role.value self.created_by = created_by - id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - message_id: Mapped[str] = db.Column(StringUUID, nullable=False) - type: Mapped[str] = db.Column(db.String(255), nullable=False) - transfer_method: Mapped[str] = db.Column(db.String(255), nullable=False) - url: Mapped[Optional[str]] = db.Column(db.Text, nullable=True) - belongs_to: Mapped[Optional[str]] = db.Column(db.String(255), nullable=True) - upload_file_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True) - created_by_role: Mapped[str] = db.Column(db.String(255), nullable=False) - created_by: Mapped[str] = db.Column(StringUUID, nullable=False) - created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + type: Mapped[str] = mapped_column(db.String(255), nullable=False) + transfer_method: Mapped[str] = mapped_column(db.String(255), nullable=False) + url: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) + belongs_to: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) + upload_file_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) + created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class MessageAnnotation(Base): @@ -1331,25 +1320,25 @@ class MessageAnnotation(Base): db.Index("message_annotation_message_idx", "message_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) - conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=True) - message_id = db.Column(StringUUID, nullable=True) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id: Mapped[str] = mapped_column(StringUUID) + conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, db.ForeignKey("conversations.id")) + message_id: Mapped[Optional[str]] = mapped_column(StringUUID) question = db.Column(db.Text, nullable=True) - content = db.Column(db.Text, nullable=False) - hit_count = db.Column(db.Integer, nullable=False, server_default=db.text("0")) - account_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + content = mapped_column(db.Text, nullable=False) + hit_count = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + account_id = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def account(self): - account = db.session.query(Account).filter(Account.id == self.account_id).first() + account = db.session.query(Account).where(Account.id == self.account_id).first() return account @property def annotation_create_account(self): - account = db.session.query(Account).filter(Account.id == self.account_id).first() + account = db.session.query(Account).where(Account.id == self.account_id).first() return account @@ -1363,31 +1352,31 @@ class AppAnnotationHitHistory(Base): db.Index("app_annotation_hit_histories_message_idx", "message_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) - annotation_id: Mapped[str] = db.Column(StringUUID, nullable=False) - source = db.Column(db.Text, nullable=False) - question = db.Column(db.Text, nullable=False) - account_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - score = db.Column(Float, nullable=False, server_default=db.text("0")) - message_id = db.Column(StringUUID, nullable=False) - annotation_question = db.Column(db.Text, nullable=False) - annotation_content = db.Column(db.Text, nullable=False) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = mapped_column(StringUUID, nullable=False) + annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + source = mapped_column(db.Text, nullable=False) + question = mapped_column(db.Text, nullable=False) + account_id = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + score = mapped_column(Float, nullable=False, server_default=db.text("0")) + message_id = mapped_column(StringUUID, nullable=False) + annotation_question = mapped_column(db.Text, nullable=False) + annotation_content = mapped_column(db.Text, nullable=False) @property def account(self): account = ( db.session.query(Account) .join(MessageAnnotation, MessageAnnotation.account_id == Account.id) - .filter(MessageAnnotation.id == self.annotation_id) + .where(MessageAnnotation.id == self.annotation_id) .first() ) return account @property def annotation_create_account(self): - account = db.session.query(Account).filter(Account.id == self.account_id).first() + account = db.session.query(Account).where(Account.id == self.account_id).first() return account @@ -1398,14 +1387,14 @@ class AppAnnotationSetting(Base): db.Index("app_annotation_settings_app_idx", "app_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) - score_threshold = db.Column(Float, nullable=False, server_default=db.text("0")) - collection_binding_id = db.Column(StringUUID, nullable=False) - created_user_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_user_id = db.Column(StringUUID, nullable=False) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = mapped_column(StringUUID, nullable=False) + score_threshold = mapped_column(Float, nullable=False, server_default=db.text("0")) + collection_binding_id = mapped_column(StringUUID, nullable=False) + created_user_id = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_user_id = mapped_column(StringUUID, nullable=False) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def collection_binding_detail(self): @@ -1413,7 +1402,7 @@ class AppAnnotationSetting(Base): collection_binding_detail = ( db.session.query(DatasetCollectionBinding) - .filter(DatasetCollectionBinding.id == self.collection_binding_id) + .where(DatasetCollectionBinding.id == self.collection_binding_id) .first() ) return collection_binding_detail @@ -1426,14 +1415,14 @@ class OperationLog(Base): db.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - account_id = db.Column(StringUUID, nullable=False) - action = db.Column(db.String(255), nullable=False) - content = db.Column(db.JSON) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - created_ip = db.Column(db.String(255), nullable=False) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + account_id = mapped_column(StringUUID, nullable=False) + action = mapped_column(db.String(255), nullable=False) + content = mapped_column(db.JSON) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_ip = mapped_column(db.String(255), nullable=False) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class EndUser(Base, UserMixin): @@ -1444,16 +1433,49 @@ class EndUser(Base, UserMixin): db.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) - app_id = db.Column(StringUUID, nullable=True) - type = db.Column(db.String(255), nullable=False) - external_user_id = db.Column(db.String(255), nullable=True) - name = db.Column(db.String(255)) - is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id = mapped_column(StringUUID, nullable=True) + type = mapped_column(db.String(255), nullable=False) + external_user_id = mapped_column(db.String(255), nullable=True) + name = mapped_column(db.String(255)) + is_anonymous = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) session_id: Mapped[str] = mapped_column() - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class AppMCPServer(Base): + __tablename__ = "app_mcp_servers" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="app_mcp_server_pkey"), + db.UniqueConstraint("tenant_id", "app_id", name="unique_app_mcp_server_tenant_app_id"), + db.UniqueConstraint("server_code", name="unique_app_mcp_server_server_code"), + ) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + app_id = mapped_column(StringUUID, nullable=False) + name = mapped_column(db.String(255), nullable=False) + description = mapped_column(db.String(255), nullable=False) + server_code = mapped_column(db.String(255), nullable=False) + status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + parameters = mapped_column(db.Text, nullable=False) + + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + @staticmethod + def generate_server_code(n): + while True: + result = generate_string(n) + while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0: + result = generate_string(n) + + return result + + @property + def parameters_dict(self) -> dict[str, Any]: + return cast(dict[str, Any], json.loads(self.parameters)) class Site(Base): @@ -1464,30 +1486,30 @@ class Site(Base): db.Index("site_code_idx", "code", "status"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) - title = db.Column(db.String(255), nullable=False) - icon_type = db.Column(db.String(255), nullable=True) - icon = db.Column(db.String(255)) - icon_background = db.Column(db.String(255)) - description = db.Column(db.Text) - default_language = db.Column(db.String(255), nullable=False) - chat_color_theme = db.Column(db.String(255)) - chat_color_theme_inverted = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - copyright = db.Column(db.String(255)) - privacy_policy = db.Column(db.String(255)) - show_workflow_steps = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = mapped_column(StringUUID, nullable=False) + title = mapped_column(db.String(255), nullable=False) + icon_type = mapped_column(db.String(255), nullable=True) + icon = mapped_column(db.String(255)) + icon_background = mapped_column(db.String(255)) + description = mapped_column(db.Text) + default_language = mapped_column(db.String(255), nullable=False) + chat_color_theme = mapped_column(db.String(255)) + chat_color_theme_inverted = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + copyright = mapped_column(db.String(255)) + privacy_policy = mapped_column(db.String(255)) + show_workflow_steps = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + use_icon_as_answer_icon = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) _custom_disclaimer: Mapped[str] = mapped_column("custom_disclaimer", sa.TEXT, default="") - customize_domain = db.Column(db.String(255)) - customize_token_strategy = db.Column(db.String(255), nullable=False) - prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) - created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - code = db.Column(db.String(255)) + customize_domain = mapped_column(db.String(255)) + customize_token_strategy = mapped_column(db.String(255), nullable=False) + prompt_public = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + created_by = mapped_column(StringUUID, nullable=True) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by = mapped_column(StringUUID, nullable=True) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + code = mapped_column(db.String(255)) @property def custom_disclaimer(self): @@ -1503,7 +1525,7 @@ class Site(Base): def generate_code(n): while True: result = generate_string(n) - while db.session.query(Site).filter(Site.code == result).count() > 0: + while db.session.query(Site).where(Site.code == result).count() > 0: result = generate_string(n) return result @@ -1522,19 +1544,19 @@ class ApiToken(Base): db.Index("api_token_tenant_idx", "tenant_id", "type"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=True) - tenant_id = db.Column(StringUUID, nullable=True) - type = db.Column(db.String(16), nullable=False) - token = db.Column(db.String(255), nullable=False) - last_used_at = db.Column(db.DateTime, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = mapped_column(StringUUID, nullable=True) + tenant_id = mapped_column(StringUUID, nullable=True) + type = mapped_column(db.String(16), nullable=False) + token = mapped_column(db.String(255), nullable=False) + last_used_at = mapped_column(db.DateTime, nullable=True) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @staticmethod def generate_api_key(prefix, n): while True: result = prefix + generate_string(n) - if db.session.query(ApiToken).filter(ApiToken.token == result).count() > 0: + if db.session.query(ApiToken).where(ApiToken.token == result).count() > 0: continue return result @@ -1546,23 +1568,23 @@ class UploadFile(Base): db.Index("upload_file_tenant_idx", "tenant_id"), ) - id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) - storage_type: Mapped[str] = db.Column(db.String(255), nullable=False) - key: Mapped[str] = db.Column(db.String(255), nullable=False) - name: Mapped[str] = db.Column(db.String(255), nullable=False) - size: Mapped[int] = db.Column(db.Integer, nullable=False) - extension: Mapped[str] = db.Column(db.String(255), nullable=False) - mime_type: Mapped[str] = db.Column(db.String(255), nullable=True) - created_by_role: Mapped[str] = db.Column( + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + storage_type: Mapped[str] = mapped_column(db.String(255), nullable=False) + key: Mapped[str] = mapped_column(db.String(255), nullable=False) + name: Mapped[str] = mapped_column(db.String(255), nullable=False) + size: Mapped[int] = mapped_column(db.Integer, nullable=False) + extension: Mapped[str] = mapped_column(db.String(255), nullable=False) + mime_type: Mapped[str] = mapped_column(db.String(255), nullable=True) + created_by_role: Mapped[str] = mapped_column( db.String(255), nullable=False, server_default=db.text("'account'::character varying") ) - created_by: Mapped[str] = db.Column(StringUUID, nullable=False) - created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - used: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - used_by: Mapped[str | None] = db.Column(StringUUID, nullable=True) - used_at: Mapped[datetime | None] = db.Column(db.DateTime, nullable=True) - hash: Mapped[str | None] = db.Column(db.String(255), nullable=True) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + used: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + used_at: Mapped[datetime | None] = mapped_column(db.DateTime, nullable=True) + hash: Mapped[str | None] = mapped_column(db.String(255), nullable=True) source_url: Mapped[str] = mapped_column(sa.TEXT, default="") def __init__( @@ -1608,14 +1630,14 @@ class ApiRequest(Base): db.Index("api_request_token_idx", "tenant_id", "api_token_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - api_token_id = db.Column(StringUUID, nullable=False) - path = db.Column(db.String(255), nullable=False) - request = db.Column(db.Text, nullable=True) - response = db.Column(db.Text, nullable=True) - ip = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + api_token_id = mapped_column(StringUUID, nullable=False) + path = mapped_column(db.String(255), nullable=False) + request = mapped_column(db.Text, nullable=True) + response = mapped_column(db.Text, nullable=True) + ip = mapped_column(db.String(255), nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class MessageChain(Base): @@ -1625,12 +1647,12 @@ class MessageChain(Base): db.Index("message_chain_message_id_idx", "message_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) - message_id = db.Column(StringUUID, nullable=False) - type = db.Column(db.String(255), nullable=False) - input = db.Column(db.Text, nullable=True) - output = db.Column(db.Text, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + message_id = mapped_column(StringUUID, nullable=False) + type = mapped_column(db.String(255), nullable=False) + input = mapped_column(db.Text, nullable=True) + output = mapped_column(db.Text, nullable=True) + created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) class MessageAgentThought(Base): @@ -1641,34 +1663,34 @@ class MessageAgentThought(Base): db.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) - message_id = db.Column(StringUUID, nullable=False) - message_chain_id = db.Column(StringUUID, nullable=True) - position = db.Column(db.Integer, nullable=False) - thought = db.Column(db.Text, nullable=True) - tool = db.Column(db.Text, nullable=True) - tool_labels_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text")) - tool_meta_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text")) - tool_input = db.Column(db.Text, nullable=True) - observation = db.Column(db.Text, nullable=True) - # plugin_id = db.Column(StringUUID, nullable=True) ## for future design - tool_process_data = db.Column(db.Text, nullable=True) - message = db.Column(db.Text, nullable=True) - message_token = db.Column(db.Integer, nullable=True) - message_unit_price = db.Column(db.Numeric, nullable=True) - message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - message_files = db.Column(db.Text, nullable=True) + id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + message_id = mapped_column(StringUUID, nullable=False) + message_chain_id = mapped_column(StringUUID, nullable=True) + position = mapped_column(db.Integer, nullable=False) + thought = mapped_column(db.Text, nullable=True) + tool = mapped_column(db.Text, nullable=True) + tool_labels_str = mapped_column(db.Text, nullable=False, server_default=db.text("'{}'::text")) + tool_meta_str = mapped_column(db.Text, nullable=False, server_default=db.text("'{}'::text")) + tool_input = mapped_column(db.Text, nullable=True) + observation = mapped_column(db.Text, nullable=True) + # plugin_id = mapped_column(StringUUID, nullable=True) ## for future design + tool_process_data = mapped_column(db.Text, nullable=True) + message = mapped_column(db.Text, nullable=True) + message_token = mapped_column(db.Integer, nullable=True) + message_unit_price = mapped_column(db.Numeric, nullable=True) + message_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + message_files = mapped_column(db.Text, nullable=True) answer = db.Column(db.Text, nullable=True) - answer_token = db.Column(db.Integer, nullable=True) - answer_unit_price = db.Column(db.Numeric, nullable=True) - answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - tokens = db.Column(db.Integer, nullable=True) - total_price = db.Column(db.Numeric, nullable=True) - currency = db.Column(db.String, nullable=True) - latency = db.Column(db.Float, nullable=True) - created_by_role = db.Column(db.String, nullable=False) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + answer_token = mapped_column(db.Integer, nullable=True) + answer_unit_price = mapped_column(db.Numeric, nullable=True) + answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + tokens = mapped_column(db.Integer, nullable=True) + total_price = mapped_column(db.Numeric, nullable=True) + currency = mapped_column(db.String, nullable=True) + latency = mapped_column(db.Float, nullable=True) + created_by_role = mapped_column(db.String, nullable=False) + created_by = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) @property def files(self) -> list: @@ -1754,24 +1776,24 @@ class DatasetRetrieverResource(Base): db.Index("dataset_retriever_resource_message_id_idx", "message_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) - message_id = db.Column(StringUUID, nullable=False) - position = db.Column(db.Integer, nullable=False) - dataset_id = db.Column(StringUUID, nullable=False) - dataset_name = db.Column(db.Text, nullable=False) - document_id = db.Column(StringUUID, nullable=True) - document_name = db.Column(db.Text, nullable=False) - data_source_type = db.Column(db.Text, nullable=True) - segment_id = db.Column(StringUUID, nullable=True) - score = db.Column(db.Float, nullable=True) - content = db.Column(db.Text, nullable=False) - hit_count = db.Column(db.Integer, nullable=True) - word_count = db.Column(db.Integer, nullable=True) - segment_position = db.Column(db.Integer, nullable=True) - index_node_hash = db.Column(db.Text, nullable=True) - retriever_from = db.Column(db.Text, nullable=False) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + message_id = mapped_column(StringUUID, nullable=False) + position = mapped_column(db.Integer, nullable=False) + dataset_id = mapped_column(StringUUID, nullable=False) + dataset_name = mapped_column(db.Text, nullable=False) + document_id = mapped_column(StringUUID, nullable=True) + document_name = mapped_column(db.Text, nullable=False) + data_source_type = mapped_column(db.Text, nullable=True) + segment_id = mapped_column(StringUUID, nullable=True) + score = mapped_column(db.Float, nullable=True) + content = mapped_column(db.Text, nullable=False) + hit_count = mapped_column(db.Integer, nullable=True) + word_count = mapped_column(db.Integer, nullable=True) + segment_position = mapped_column(db.Integer, nullable=True) + index_node_hash = mapped_column(db.Text, nullable=True) + retriever_from = mapped_column(db.Text, nullable=False) + created_by = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) class Tag(Base): @@ -1784,12 +1806,12 @@ class Tag(Base): TAG_TYPE_LIST = ["knowledge", "app"] - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=True) - type = db.Column(db.String(16), nullable=False) - name = db.Column(db.String(255), nullable=False) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=True) + type = mapped_column(db.String(16), nullable=False) + name = mapped_column(db.String(255), nullable=False) + created_by = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TagBinding(Base): @@ -1800,12 +1822,12 @@ class TagBinding(Base): db.Index("tag_bind_tag_id_idx", "tag_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=True) - tag_id = db.Column(StringUUID, nullable=True) - target_id = db.Column(StringUUID, nullable=True) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=True) + tag_id = mapped_column(StringUUID, nullable=True) + target_id = mapped_column(StringUUID, nullable=True) + created_by = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TraceAppConfig(Base): @@ -1815,15 +1837,15 @@ class TraceAppConfig(Base): db.Index("trace_app_config_app_id_idx", "app_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) - tracing_provider = db.Column(db.String(255), nullable=True) - tracing_config = db.Column(db.JSON, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column( + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = mapped_column(StringUUID, nullable=False) + tracing_provider = mapped_column(db.String(255), nullable=True) + tracing_config = mapped_column(db.JSON, nullable=True) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column( db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) - is_active = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + is_active = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) @property def tracing_config_dict(self): diff --git a/api/models/source.py b/api/models/source.py index f6e0900ae6..100e0d96ef 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -2,6 +2,7 @@ import json from sqlalchemy import func from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import mapped_column from models.base import Base @@ -17,14 +18,14 @@ class DataSourceOauthBinding(Base): db.Index("source_info_idx", "source_info", postgresql_using="gin"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - access_token = db.Column(db.String(255), nullable=False) - provider = db.Column(db.String(255), nullable=False) - source_info = db.Column(JSONB, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + access_token = mapped_column(db.String(255), nullable=False) + provider = mapped_column(db.String(255), nullable=False) + source_info = mapped_column(JSONB, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + disabled = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) class DataSourceApiKeyAuthBinding(Base): @@ -35,14 +36,14 @@ class DataSourceApiKeyAuthBinding(Base): db.Index("data_source_api_key_auth_binding_provider_idx", "provider"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - category = db.Column(db.String(255), nullable=False) - provider = db.Column(db.String(255), nullable=False) - credentials = db.Column(db.Text, nullable=True) # JSON - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + category = mapped_column(db.String(255), nullable=False) + provider = mapped_column(db.String(255), nullable=False) + credentials = mapped_column(db.Text, nullable=True) # JSON + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + disabled = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) def to_dict(self): return { diff --git a/api/models/task.py b/api/models/task.py index d853c1dd9a..3e5ebd2099 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -1,7 +1,10 @@ -from datetime import UTC, datetime +from datetime import datetime +from typing import Optional from celery import states # type: ignore +from sqlalchemy.orm import Mapped, mapped_column +from libs.datetime_utils import naive_utc_now from models.base import Base from .engine import db @@ -12,23 +15,23 @@ class CeleryTask(Base): __tablename__ = "celery_taskmeta" - id = db.Column(db.Integer, db.Sequence("task_id_sequence"), primary_key=True, autoincrement=True) - task_id = db.Column(db.String(155), unique=True) - status = db.Column(db.String(50), default=states.PENDING) - result = db.Column(db.PickleType, nullable=True) - date_done = db.Column( + id = mapped_column(db.Integer, db.Sequence("task_id_sequence"), primary_key=True, autoincrement=True) + task_id = mapped_column(db.String(155), unique=True) + status = mapped_column(db.String(50), default=states.PENDING) + result = mapped_column(db.PickleType, nullable=True) + date_done = mapped_column( db.DateTime, - default=lambda: datetime.now(UTC).replace(tzinfo=None), - onupdate=lambda: datetime.now(UTC).replace(tzinfo=None), + default=lambda: naive_utc_now(), + onupdate=lambda: naive_utc_now(), nullable=True, ) - traceback = db.Column(db.Text, nullable=True) - name = db.Column(db.String(155), nullable=True) - args = db.Column(db.LargeBinary, nullable=True) - kwargs = db.Column(db.LargeBinary, nullable=True) - worker = db.Column(db.String(155), nullable=True) - retries = db.Column(db.Integer, nullable=True) - queue = db.Column(db.String(155), nullable=True) + traceback = mapped_column(db.Text, nullable=True) + name = mapped_column(db.String(155), nullable=True) + args = mapped_column(db.LargeBinary, nullable=True) + kwargs = mapped_column(db.LargeBinary, nullable=True) + worker = mapped_column(db.String(155), nullable=True) + retries = mapped_column(db.Integer, nullable=True) + queue = mapped_column(db.String(155), nullable=True) class CeleryTaskSet(Base): @@ -36,7 +39,9 @@ class CeleryTaskSet(Base): __tablename__ = "celery_tasksetmeta" - id = db.Column(db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True) - taskset_id = db.Column(db.String(155), unique=True) - result = db.Column(db.PickleType, nullable=True) - date_done = db.Column(db.DateTime, default=lambda: datetime.now(UTC).replace(tzinfo=None), nullable=True) + id: Mapped[int] = mapped_column( + db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True + ) + taskset_id = mapped_column(db.String(155), unique=True) + result = mapped_column(db.PickleType, nullable=True) + date_done: Mapped[Optional[datetime]] = mapped_column(db.DateTime, default=lambda: naive_utc_now(), nullable=True) diff --git a/api/models/tools.py b/api/models/tools.py index 03fbc3acb1..68f4211e59 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,12 +1,16 @@ import json from datetime import datetime from typing import Any, cast +from urllib.parse import urlparse import sqlalchemy as sa from deprecated import deprecated from sqlalchemy import ForeignKey, func from sqlalchemy.orm import Mapped, mapped_column +from core.file import helpers as file_helpers +from core.helper import encrypter +from core.mcp.types import Tool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration @@ -17,6 +21,43 @@ from .model import Account, App, Tenant from .types import StringUUID +# system level tool oauth client params (client_id, client_secret, etc.) +class ToolOAuthSystemClient(Base): + __tablename__ = "tool_oauth_system_clients" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"), + db.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False) + provider: Mapped[str] = mapped_column(db.String(255), nullable=False) + # oauth params of the tool provider + encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) + + +# tenant level tool oauth client params (client_id, client_secret, etc.) +class ToolOAuthTenantClient(Base): + __tablename__ = "tool_oauth_tenant_clients" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"), + db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + # tenant id + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False) + provider: Mapped[str] = mapped_column(db.String(255), nullable=False) + enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + # oauth params of the tool provider + encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) + + @property + def oauth_params(self) -> dict: + return cast(dict, json.loads(self.encrypted_oauth_params or "{}")) + + class BuiltinToolProvider(Base): """ This table stores the tool provider information for built-in tools for each tenant. @@ -25,12 +66,14 @@ class BuiltinToolProvider(Base): __tablename__ = "tool_builtin_providers" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"), - # one tenant can only have one tool provider with the same name - db.UniqueConstraint("tenant_id", "provider", name="unique_builtin_tool_provider"), + db.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"), ) # id of the tool provider id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + name: Mapped[str] = mapped_column( + db.String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying") + ) # id of the tenant tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True) # who created this tool provider @@ -45,6 +88,12 @@ class BuiltinToolProvider(Base): updated_at: Mapped[datetime] = mapped_column( db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) + is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + # credential type, e.g., "api-key", "oauth2" + credential_type: Mapped[str] = mapped_column( + db.String(32), nullable=False, server_default=db.text("'api-key'::character varying") + ) + expires_at: Mapped[int] = mapped_column(db.BigInteger, nullable=False, server_default=db.text("-1")) @property def credentials(self) -> dict: @@ -62,26 +111,26 @@ class ApiToolProvider(Base): db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the api provider - name = db.Column(db.String(255), nullable=False) + name = mapped_column(db.String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying")) # icon - icon = db.Column(db.String(255), nullable=False) + icon = mapped_column(db.String(255), nullable=False) # original schema - schema = db.Column(db.Text, nullable=False) - schema_type_str: Mapped[str] = db.Column(db.String(40), nullable=False) + schema = mapped_column(db.Text, nullable=False) + schema_type_str: Mapped[str] = mapped_column(db.String(40), nullable=False) # who created this tool - user_id = db.Column(StringUUID, nullable=False) + user_id = mapped_column(StringUUID, nullable=False) # tenant id - tenant_id = db.Column(StringUUID, nullable=False) + tenant_id = mapped_column(StringUUID, nullable=False) # description of the provider - description = db.Column(db.Text, nullable=False) + description = mapped_column(db.Text, nullable=False) # json format tools - tools_str = db.Column(db.Text, nullable=False) + tools_str = mapped_column(db.Text, nullable=False) # json format credentials - credentials_str = db.Column(db.Text, nullable=False) + credentials_str = mapped_column(db.Text, nullable=False) # privacy policy - privacy_policy = db.Column(db.String(255), nullable=True) + privacy_policy = mapped_column(db.String(255), nullable=True) # custom_disclaimer custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") @@ -104,11 +153,11 @@ class ApiToolProvider(Base): def user(self) -> Account | None: if not self.user_id: return None - return db.session.query(Account).filter(Account.id == self.user_id).first() + return db.session.query(Account).where(Account.id == self.user_id).first() @property def tenant(self) -> Tenant | None: - return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() + return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() class ToolLabelBinding(Base): @@ -174,11 +223,11 @@ class WorkflowToolProvider(Base): @property def user(self) -> Account | None: - return db.session.query(Account).filter(Account.id == self.user_id).first() + return db.session.query(Account).where(Account.id == self.user_id).first() @property def tenant(self) -> Tenant | None: - return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() + return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() @property def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]: @@ -186,7 +235,110 @@ class WorkflowToolProvider(Base): @property def app(self) -> App | None: - return db.session.query(App).filter(App.id == self.app_id).first() + return db.session.query(App).where(App.id == self.app_id).first() + + +class MCPToolProvider(Base): + """ + The table stores the mcp providers. + """ + + __tablename__ = "tool_mcp_providers" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"), + db.UniqueConstraint("tenant_id", "server_url_hash", name="unique_mcp_provider_server_url"), + db.UniqueConstraint("tenant_id", "name", name="unique_mcp_provider_name"), + db.UniqueConstraint("tenant_id", "server_identifier", name="unique_mcp_provider_server_identifier"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + # name of the mcp provider + name: Mapped[str] = mapped_column(db.String(40), nullable=False) + # server identifier of the mcp provider + server_identifier: Mapped[str] = mapped_column(db.String(64), nullable=False) + # encrypted url of the mcp provider + server_url: Mapped[str] = mapped_column(db.Text, nullable=False) + # hash of server_url for uniqueness check + server_url_hash: Mapped[str] = mapped_column(db.String(64), nullable=False) + # icon of the mcp provider + icon: Mapped[str] = mapped_column(db.String(255), nullable=True) + # tenant id + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + # who created this tool + user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + # encrypted credentials + encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True) + # authed + authed: Mapped[bool] = mapped_column(db.Boolean, nullable=False, default=False) + # tools + tools: Mapped[str] = mapped_column(db.Text, nullable=False, default="[]") + created_at: Mapped[datetime] = mapped_column( + db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) + updated_at: Mapped[datetime] = mapped_column( + db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) + + def load_user(self) -> Account | None: + return db.session.query(Account).where(Account.id == self.user_id).first() + + @property + def tenant(self) -> Tenant | None: + return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() + + @property + def credentials(self) -> dict: + try: + return cast(dict, json.loads(self.encrypted_credentials)) or {} + except Exception: + return {} + + @property + def mcp_tools(self) -> list[Tool]: + return [Tool(**tool) for tool in json.loads(self.tools)] + + @property + def provider_icon(self) -> dict[str, str] | str: + try: + return cast(dict[str, str], json.loads(self.icon)) + except json.JSONDecodeError: + return file_helpers.get_signed_file_url(self.icon) + + @property + def decrypted_server_url(self) -> str: + return cast(str, encrypter.decrypt_token(self.tenant_id, self.server_url)) + + @property + def masked_server_url(self) -> str: + def mask_url(url: str, mask_char: str = "*") -> str: + """ + mask the url to a simple string + """ + parsed = urlparse(url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + + if parsed.path and parsed.path != "/": + return f"{base_url}/{mask_char * 6}" + else: + return base_url + + return mask_url(self.decrypted_server_url) + + @property + def decrypted_credentials(self) -> dict: + from core.helper.provider_cache import NoOpProviderCredentialCache + from core.tools.mcp_tool.provider import MCPToolProviderController + from core.tools.utils.encryption import create_provider_encrypter + + provider_controller = MCPToolProviderController._from_db(self) + + encrypter, _ = create_provider_encrypter( + tenant_id=self.tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], + cache=NoOpProviderCredentialCache(), + ) + + return encrypter.decrypt(self.credentials) # type: ignore class ToolModelInvoke(Base): @@ -197,33 +349,33 @@ class ToolModelInvoke(Base): __tablename__ = "tool_model_invokes" __table_args__ = (db.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # who invoke this tool - user_id = db.Column(StringUUID, nullable=False) + user_id = mapped_column(StringUUID, nullable=False) # tenant id - tenant_id = db.Column(StringUUID, nullable=False) + tenant_id = mapped_column(StringUUID, nullable=False) # provider - provider = db.Column(db.String(255), nullable=False) + provider = mapped_column(db.String(255), nullable=False) # type - tool_type = db.Column(db.String(40), nullable=False) + tool_type = mapped_column(db.String(40), nullable=False) # tool name - tool_name = db.Column(db.String(40), nullable=False) + tool_name = mapped_column(db.String(128), nullable=False) # invoke parameters - model_parameters = db.Column(db.Text, nullable=False) + model_parameters = mapped_column(db.Text, nullable=False) # prompt messages - prompt_messages = db.Column(db.Text, nullable=False) + prompt_messages = mapped_column(db.Text, nullable=False) # invoke response - model_response = db.Column(db.Text, nullable=False) + model_response = mapped_column(db.Text, nullable=False) - prompt_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) - answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) - answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) - answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0")) - total_price = db.Column(db.Numeric(10, 7)) - currency = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + prompt_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + answer_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + answer_unit_price = mapped_column(db.Numeric(10, 4), nullable=False) + answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0")) + total_price = mapped_column(db.Numeric(10, 7)) + currency = mapped_column(db.String(255), nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @deprecated @@ -240,18 +392,18 @@ class ToolConversationVariables(Base): db.Index("conversation_id_idx", "conversation_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # conversation user id - user_id = db.Column(StringUUID, nullable=False) + user_id = mapped_column(StringUUID, nullable=False) # tenant id - tenant_id = db.Column(StringUUID, nullable=False) + tenant_id = mapped_column(StringUUID, nullable=False) # conversation id - conversation_id = db.Column(StringUUID, nullable=False) + conversation_id = mapped_column(StringUUID, nullable=False) # variables pool - variables_str = db.Column(db.Text, nullable=False) + variables_str = mapped_column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def variables(self) -> Any: @@ -300,26 +452,26 @@ class DeprecatedPublishedAppTool(Base): db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # id of the app - app_id = db.Column(StringUUID, ForeignKey("apps.id"), nullable=False) + app_id = mapped_column(StringUUID, ForeignKey("apps.id"), nullable=False) - user_id: Mapped[str] = db.Column(StringUUID, nullable=False) + user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # who published this tool - description = db.Column(db.Text, nullable=False) + description = mapped_column(db.Text, nullable=False) # llm_description of the tool, for LLM - llm_description = db.Column(db.Text, nullable=False) + llm_description = mapped_column(db.Text, nullable=False) # query description, query will be seem as a parameter of the tool, # to describe this parameter to llm, we need this field - query_description = db.Column(db.Text, nullable=False) + query_description = mapped_column(db.Text, nullable=False) # query name, the name of the query parameter - query_name = db.Column(db.String(40), nullable=False) + query_name = mapped_column(db.String(40), nullable=False) # name of the tool provider - tool_name = db.Column(db.String(40), nullable=False) + tool_name = mapped_column(db.String(40), nullable=False) # author - author = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + author = mapped_column(db.String(40), nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def description_i18n(self) -> I18nObject: diff --git a/api/models/web.py b/api/models/web.py index fe2f0c47f8..ce00f4010f 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -15,16 +15,18 @@ class SavedMessage(Base): db.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) - message_id = db.Column(StringUUID, nullable=False) - created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = mapped_column(StringUUID, nullable=False) + message_id = mapped_column(StringUUID, nullable=False) + created_by_role = mapped_column( + db.String(255), nullable=False, server_default=db.text("'end_user'::character varying") + ) + created_by = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def message(self): - return db.session.query(Message).filter(Message.id == self.message_id).first() + return db.session.query(Message).where(Message.id == self.message_id).first() class PinnedConversation(Base): @@ -34,9 +36,11 @@ class PinnedConversation(Base): db.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = mapped_column(StringUUID, nullable=False) conversation_id: Mapped[str] = mapped_column(StringUUID) - created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_by_role = mapped_column( + db.String(255), nullable=False, server_default=db.text("'end_user'::character varying") + ) + created_by = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/workflow.py b/api/models/workflow.py index 7f01135af3..79d96e42dd 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,7 +1,7 @@ import json import logging from collections.abc import Mapping, Sequence -from datetime import UTC, datetime +from datetime import datetime from enum import Enum, StrEnum from typing import TYPE_CHECKING, Any, Optional, Union from uuid import uuid4 @@ -12,9 +12,12 @@ from sqlalchemy import orm from core.file.constants import maybe_file_object from core.file.models import File from core.variables import utils as variable_utils +from core.variables.variables import FloatVariable, IntegerVariable, StringVariable from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.nodes.enums import NodeType from factories.variable_factory import TypeMismatchError, build_segment_with_type +from libs.datetime_utils import naive_utc_now +from libs.helper import extract_tenant_id from ._workflow_exc import NodeNotFoundError, WorkflowDataError @@ -136,7 +139,7 @@ class Workflow(Base): updated_at: Mapped[datetime] = mapped_column( db.DateTime, nullable=False, - default=datetime.now(UTC).replace(tzinfo=None), + default=naive_utc_now(), server_onupdate=func.current_timestamp(), ) _environment_variables: Mapped[str] = mapped_column( @@ -177,7 +180,7 @@ class Workflow(Base): workflow.conversation_variables = conversation_variables or [] workflow.marked_name = marked_name workflow.marked_comment = marked_comment - workflow.created_at = datetime.now(UTC).replace(tzinfo=None) + workflow.created_at = naive_utc_now() workflow.updated_at = workflow.created_at return workflow @@ -340,24 +343,19 @@ class Workflow(Base): return ( db.session.query(WorkflowToolProvider) - .filter(WorkflowToolProvider.tenant_id == self.tenant_id, WorkflowToolProvider.app_id == self.app_id) + .where(WorkflowToolProvider.tenant_id == self.tenant_id, WorkflowToolProvider.app_id == self.app_id) .count() > 0 ) @property - def environment_variables(self) -> Sequence[Variable]: + def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]: # TODO: find some way to init `self._environment_variables` when instance created. if self._environment_variables is None: self._environment_variables = "{}" # Get tenant_id from current_user (Account or EndUser) - if isinstance(current_user, Account): - # Account user - tenant_id = current_user.current_tenant_id - else: - # EndUser - tenant_id = current_user.tenant_id + tenant_id = extract_tenant_id(current_user) if not tenant_id: return [] @@ -371,11 +369,15 @@ class Workflow(Base): def decrypt_func(var): if isinstance(var, SecretVariable): return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) - else: + elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)): return var + else: + raise AssertionError("this statement should be unreachable.") - results = list(map(decrypt_func, results)) - return results + decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = list( + map(decrypt_func, results) + ) + return decrypted_results @environment_variables.setter def environment_variables(self, value: Sequence[Variable]): @@ -384,12 +386,7 @@ class Workflow(Base): return # Get tenant_id from current_user (Account or EndUser) - if isinstance(current_user, Account): - # Account user - tenant_id = current_user.current_tenant_id - else: - # EndUser - tenant_id = current_user.tenant_id + tenant_id = extract_tenant_id(current_user) if not tenant_id: self._environment_variables = "{}" @@ -552,12 +549,12 @@ class WorkflowRun(Base): from models.model import Message return ( - db.session.query(Message).filter(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first() + db.session.query(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first() ) @property def workflow(self): - return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() + return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first() def to_dict(self): return { @@ -911,7 +908,7 @@ _EDITABLE_SYSTEM_VARIABLE = frozenset(["query", "files"]) def _naive_utc_datetime(): - return datetime.now(UTC).replace(tzinfo=None) + return naive_utc_now() class WorkflowDraftVariable(Base): diff --git a/api/pyproject.toml b/api/pyproject.toml index d33806d0ae..7ec8a91198 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dify-api" -version = "1.5.1" +version = "1.7.0" requires-python = ">=3.11,<3.13" dependencies = [ @@ -82,6 +82,8 @@ dependencies = [ "weave~=0.51.0", "yarl~=1.18.3", "webvtt-py~=0.5.1", + "sseclient-py>=1.8.0", + "httpx-sse>=0.4.0", "sendgrid~=6.12.3", ] # Before adding new dependency, consider place it in @@ -106,7 +108,7 @@ dev = [ "faker~=32.1.0", "lxml-stubs~=0.5.1", "mypy~=1.16.0", - "ruff~=0.11.5", + "ruff~=0.12.3", "pytest~=8.3.2", "pytest-benchmark~=4.0.0", "pytest-cov~=4.1.0", diff --git a/api/repositories/__init__.py b/api/repositories/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py new file mode 100644 index 0000000000..00a2d1f87d --- /dev/null +++ b/api/repositories/api_workflow_node_execution_repository.py @@ -0,0 +1,197 @@ +""" +Service-layer repository protocol for WorkflowNodeExecutionModel operations. + +This module provides a protocol interface for service-layer operations on WorkflowNodeExecutionModel +that abstracts database queries currently done directly in service classes. This repository is +specifically designed for service-layer needs and is separate from the core domain repository. + +The service repository handles operations that require access to database-specific fields like +tenant_id, app_id, triggered_from, etc., which are not part of the core domain model. +""" + +from collections.abc import Sequence +from datetime import datetime +from typing import Optional, Protocol + +from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from models.workflow import WorkflowNodeExecutionModel + + +class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Protocol): + """ + Protocol for service-layer operations on WorkflowNodeExecutionModel. + + This repository provides database access patterns specifically needed by service classes, + handling queries that involve database-specific fields and multi-tenancy concerns. + + Key responsibilities: + - Manages database operations for workflow node executions + - Handles multi-tenant data isolation + - Provides batch processing capabilities + - Supports execution lifecycle management + + Implementation notes: + - Returns database models directly (WorkflowNodeExecutionModel) + - Handles tenant/app filtering automatically + - Provides service-specific query patterns + - Focuses on database operations without domain logic + - Supports cleanup and maintenance operations + """ + + def get_node_last_execution( + self, + tenant_id: str, + app_id: str, + workflow_id: str, + node_id: str, + ) -> Optional[WorkflowNodeExecutionModel]: + """ + Get the most recent execution for a specific node. + + This method finds the latest execution of a specific node within a workflow, + ordered by creation time. Used primarily for debugging and inspection purposes. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + workflow_id: The workflow identifier + node_id: The node identifier + + Returns: + The most recent WorkflowNodeExecutionModel for the node, or None if not found + """ + ... + + def get_executions_by_workflow_run( + self, + tenant_id: str, + app_id: str, + workflow_run_id: str, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Get all node executions for a specific workflow run. + + This method retrieves all node executions that belong to a specific workflow run, + ordered by index in descending order for proper trace visualization. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + workflow_run_id: The workflow run identifier + + Returns: + A sequence of WorkflowNodeExecutionModel instances ordered by index (desc) + """ + ... + + def get_execution_by_id( + self, + execution_id: str, + tenant_id: Optional[str] = None, + ) -> Optional[WorkflowNodeExecutionModel]: + """ + Get a workflow node execution by its ID. + + This method retrieves a specific execution by its unique identifier. + Tenant filtering is optional for cases where the execution ID is globally unique. + + When `tenant_id` is None, it's the caller's responsibility to ensure proper data isolation between tenants. + If the `execution_id` comes from untrusted sources (e.g., retrieved from an API request), the caller should + set `tenant_id` to prevent horizontal privilege escalation. + + Args: + execution_id: The execution identifier + tenant_id: Optional tenant identifier for additional filtering + + Returns: + The WorkflowNodeExecutionModel if found, or None if not found + """ + ... + + def delete_expired_executions( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> int: + """ + Delete workflow node executions that are older than the specified date. + + This method is used for cleanup operations to remove expired executions + in batches to avoid overwhelming the database. + + Args: + tenant_id: The tenant identifier + before_date: Delete executions created before this date + batch_size: Maximum number of executions to delete in one batch + + Returns: + The number of executions deleted + """ + ... + + def delete_executions_by_app( + self, + tenant_id: str, + app_id: str, + batch_size: int = 1000, + ) -> int: + """ + Delete all workflow node executions for a specific app. + + This method is used when removing an app and all its related data. + Executions are deleted in batches to avoid overwhelming the database. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + batch_size: Maximum number of executions to delete in one batch + + Returns: + The total number of executions deleted + """ + ... + + def get_expired_executions_batch( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Get a batch of expired workflow node executions for backup purposes. + + This method retrieves expired executions without deleting them, + allowing the caller to backup the data before deletion. + + Args: + tenant_id: The tenant identifier + before_date: Get executions created before this date + batch_size: Maximum number of executions to retrieve + + Returns: + A sequence of WorkflowNodeExecutionModel instances + """ + ... + + def delete_executions_by_ids( + self, + execution_ids: Sequence[str], + ) -> int: + """ + Delete workflow node executions by their IDs. + + This method deletes specific executions by their IDs, + typically used after backing up the data. + + This method does not perform tenant isolation checks. The caller is responsible for ensuring proper + data isolation between tenants. When execution IDs come from untrusted sources (e.g., API requests), + additional tenant validation should be implemented to prevent unauthorized access. + + Args: + execution_ids: List of execution IDs to delete + + Returns: + The number of executions deleted + """ + ... diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py new file mode 100644 index 0000000000..59e7baeb79 --- /dev/null +++ b/api/repositories/api_workflow_run_repository.py @@ -0,0 +1,181 @@ +""" +API WorkflowRun Repository Protocol + +This module defines the protocol for service-layer WorkflowRun operations. +The repository provides an abstraction layer for WorkflowRun database operations +used by service classes, separating service-layer concerns from core domain logic. + +Key Features: +- Paginated workflow run queries with filtering +- Bulk deletion operations with OSS backup support +- Multi-tenant data isolation +- Expired record cleanup with data retention +- Service-layer specific query patterns + +Usage: + This protocol should be used by service classes that need to perform + WorkflowRun database operations. It provides a clean interface that + hides implementation details and supports dependency injection. + +Example: + ```python + from repositories.dify_api_repository_factory import DifyAPIRepositoryFactory + + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + + # Get paginated workflow runs + runs = repo.get_paginated_workflow_runs( + tenant_id="tenant-123", + app_id="app-456", + triggered_from="debugging", + limit=20 + ) + ``` +""" + +from collections.abc import Sequence +from datetime import datetime +from typing import Optional, Protocol + +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models.workflow import WorkflowRun + + +class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): + """ + Protocol for service-layer WorkflowRun repository operations. + + This protocol defines the interface for WorkflowRun database operations + that are specific to service-layer needs, including pagination, filtering, + and bulk operations with data backup support. + """ + + def get_paginated_workflow_runs( + self, + tenant_id: str, + app_id: str, + triggered_from: str, + limit: int = 20, + last_id: Optional[str] = None, + ) -> InfiniteScrollPagination: + """ + Get paginated workflow runs with filtering. + + Retrieves workflow runs for a specific app and trigger source with + cursor-based pagination support. Used primarily for debugging and + workflow run listing in the UI. + + Args: + tenant_id: Tenant identifier for multi-tenant isolation + app_id: Application identifier + triggered_from: Filter by trigger source (e.g., "debugging", "app-run") + limit: Maximum number of records to return (default: 20) + last_id: Cursor for pagination - ID of the last record from previous page + + Returns: + InfiniteScrollPagination object containing: + - data: List of WorkflowRun objects + - limit: Applied limit + - has_more: Boolean indicating if more records exist + + Raises: + ValueError: If last_id is provided but the corresponding record doesn't exist + """ + ... + + def get_workflow_run_by_id( + self, + tenant_id: str, + app_id: str, + run_id: str, + ) -> Optional[WorkflowRun]: + """ + Get a specific workflow run by ID. + + Retrieves a single workflow run with tenant and app isolation. + Used for workflow run detail views and execution tracking. + + Args: + tenant_id: Tenant identifier for multi-tenant isolation + app_id: Application identifier + run_id: Workflow run identifier + + Returns: + WorkflowRun object if found, None otherwise + """ + ... + + def get_expired_runs_batch( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> Sequence[WorkflowRun]: + """ + Get a batch of expired workflow runs for cleanup. + + Retrieves workflow runs created before the specified date for + cleanup operations. Used by scheduled tasks to remove old data + while maintaining data retention policies. + + Args: + tenant_id: Tenant identifier for multi-tenant isolation + before_date: Only return runs created before this date + batch_size: Maximum number of records to return + + Returns: + Sequence of WorkflowRun objects to be processed for cleanup + """ + ... + + def delete_runs_by_ids( + self, + run_ids: Sequence[str], + ) -> int: + """ + Delete workflow runs by their IDs. + + Performs bulk deletion of workflow runs by ID. This method should + be used after backing up the data to OSS storage for retention. + + Args: + run_ids: Sequence of workflow run IDs to delete + + Returns: + Number of records actually deleted + + Note: + This method performs hard deletion. Ensure data is backed up + to OSS storage before calling this method for compliance with + data retention policies. + """ + ... + + def delete_runs_by_app( + self, + tenant_id: str, + app_id: str, + batch_size: int = 1000, + ) -> int: + """ + Delete all workflow runs for a specific app. + + Performs bulk deletion of all workflow runs associated with an app. + Used during app cleanup operations. Processes records in batches + to avoid memory issues and long-running transactions. + + Args: + tenant_id: Tenant identifier for multi-tenant isolation + app_id: Application identifier + batch_size: Number of records to process in each batch + + Returns: + Total number of records deleted across all batches + + Note: + This method performs hard deletion without backup. Use with caution + and ensure proper data retention policies are followed. + """ + ... diff --git a/api/repositories/factory.py b/api/repositories/factory.py new file mode 100644 index 0000000000..0a0adbf2c2 --- /dev/null +++ b/api/repositories/factory.py @@ -0,0 +1,103 @@ +""" +DifyAPI Repository Factory for creating repository instances. + +This factory is specifically designed for DifyAPI repositories that handle +service-layer operations with dependency injection patterns. +""" + +import logging + +from sqlalchemy.orm import sessionmaker + +from configs import dify_config +from core.repositories import DifyCoreRepositoryFactory, RepositoryImportError +from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository +from repositories.api_workflow_run_repository import APIWorkflowRunRepository + +logger = logging.getLogger(__name__) + + +class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory): + """ + Factory for creating DifyAPI repository instances based on configuration. + + This factory handles the creation of repositories that are specifically designed + for service-layer operations and use dependency injection with sessionmaker + for better testability and separation of concerns. + """ + + @classmethod + def create_api_workflow_node_execution_repository( + cls, session_maker: sessionmaker + ) -> DifyAPIWorkflowNodeExecutionRepository: + """ + Create a DifyAPIWorkflowNodeExecutionRepository instance based on configuration. + + This repository is designed for service-layer operations and uses dependency injection + with a sessionmaker for better testability and separation of concerns. It provides + database access patterns specifically needed by service classes, handling queries + that involve database-specific fields and multi-tenancy concerns. + + Args: + session_maker: SQLAlchemy sessionmaker to inject for database session management. + + Returns: + Configured DifyAPIWorkflowNodeExecutionRepository instance + + Raises: + RepositoryImportError: If the configured repository cannot be imported or instantiated + """ + class_path = dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY + logger.debug(f"Creating DifyAPIWorkflowNodeExecutionRepository from: {class_path}") + + try: + repository_class = cls._import_class(class_path) + cls._validate_repository_interface(repository_class, DifyAPIWorkflowNodeExecutionRepository) + # Service repository requires session_maker parameter + cls._validate_constructor_signature(repository_class, ["session_maker"]) + + return repository_class(session_maker=session_maker) # type: ignore[no-any-return] + except RepositoryImportError: + # Re-raise our custom errors as-is + raise + except Exception as e: + logger.exception("Failed to create DifyAPIWorkflowNodeExecutionRepository") + raise RepositoryImportError( + f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}" + ) from e + + @classmethod + def create_api_workflow_run_repository(cls, session_maker: sessionmaker) -> APIWorkflowRunRepository: + """ + Create an APIWorkflowRunRepository instance based on configuration. + + This repository is designed for service-layer WorkflowRun operations and uses dependency + injection with a sessionmaker for better testability and separation of concerns. It provides + database access patterns specifically needed by service classes for workflow run management, + including pagination, filtering, and bulk operations. + + Args: + session_maker: SQLAlchemy sessionmaker to inject for database session management. + + Returns: + Configured APIWorkflowRunRepository instance + + Raises: + RepositoryImportError: If the configured repository cannot be imported or instantiated + """ + class_path = dify_config.API_WORKFLOW_RUN_REPOSITORY + logger.debug(f"Creating APIWorkflowRunRepository from: {class_path}") + + try: + repository_class = cls._import_class(class_path) + cls._validate_repository_interface(repository_class, APIWorkflowRunRepository) + # Service repository requires session_maker parameter + cls._validate_constructor_signature(repository_class, ["session_maker"]) + + return repository_class(session_maker=session_maker) # type: ignore[no-any-return] + except RepositoryImportError: + # Re-raise our custom errors as-is + raise + except Exception as e: + logger.exception("Failed to create APIWorkflowRunRepository") + raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py new file mode 100644 index 0000000000..e6a23ddf9f --- /dev/null +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -0,0 +1,290 @@ +""" +SQLAlchemy implementation of WorkflowNodeExecutionServiceRepository. + +This module provides a concrete implementation of the service repository protocol +using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations. +""" + +from collections.abc import Sequence +from datetime import datetime +from typing import Optional + +from sqlalchemy import delete, desc, select +from sqlalchemy.orm import Session, sessionmaker + +from models.workflow import WorkflowNodeExecutionModel +from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository + + +class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository): + """ + SQLAlchemy implementation of DifyAPIWorkflowNodeExecutionRepository. + + This repository provides service-layer database operations for WorkflowNodeExecutionModel + using SQLAlchemy 2.0 style queries. It implements the DifyAPIWorkflowNodeExecutionRepository + protocol with the following features: + + - Multi-tenancy data isolation through tenant_id filtering + - Direct database model operations without domain conversion + - Batch processing for efficient large-scale operations + - Optimized query patterns for common access patterns + - Dependency injection for better testability and maintainability + - Session management and transaction handling with proper cleanup + - Maintenance operations for data lifecycle management + - Thread-safe database operations using session-per-request pattern + """ + + def __init__(self, session_maker: sessionmaker[Session]): + """ + Initialize the repository with a sessionmaker. + + Args: + session_maker: SQLAlchemy sessionmaker for creating database sessions + """ + self._session_maker = session_maker + + def get_node_last_execution( + self, + tenant_id: str, + app_id: str, + workflow_id: str, + node_id: str, + ) -> Optional[WorkflowNodeExecutionModel]: + """ + Get the most recent execution for a specific node. + + This method replicates the query pattern from WorkflowService.get_node_last_run() + using SQLAlchemy 2.0 style syntax. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + workflow_id: The workflow identifier + node_id: The node identifier + + Returns: + The most recent WorkflowNodeExecutionModel for the node, or None if not found + """ + stmt = ( + select(WorkflowNodeExecutionModel) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.app_id == app_id, + WorkflowNodeExecutionModel.workflow_id == workflow_id, + WorkflowNodeExecutionModel.node_id == node_id, + ) + .order_by(desc(WorkflowNodeExecutionModel.created_at)) + .limit(1) + ) + + with self._session_maker() as session: + return session.scalar(stmt) + + def get_executions_by_workflow_run( + self, + tenant_id: str, + app_id: str, + workflow_run_id: str, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Get all node executions for a specific workflow run. + + This method replicates the query pattern from WorkflowRunService.get_workflow_run_node_executions() + using SQLAlchemy 2.0 style syntax. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + workflow_run_id: The workflow run identifier + + Returns: + A sequence of WorkflowNodeExecutionModel instances ordered by index (desc) + """ + stmt = ( + select(WorkflowNodeExecutionModel) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.app_id == app_id, + WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, + ) + .order_by(desc(WorkflowNodeExecutionModel.index)) + ) + + with self._session_maker() as session: + return session.execute(stmt).scalars().all() + + def get_execution_by_id( + self, + execution_id: str, + tenant_id: Optional[str] = None, + ) -> Optional[WorkflowNodeExecutionModel]: + """ + Get a workflow node execution by its ID. + + This method replicates the query pattern from WorkflowDraftVariableService + and WorkflowService.single_step_run_workflow_node() using SQLAlchemy 2.0 style syntax. + + When `tenant_id` is None, it's the caller's responsibility to ensure proper data isolation between tenants. + If the `execution_id` comes from untrusted sources (e.g., retrieved from an API request), the caller should + set `tenant_id` to prevent horizontal privilege escalation. + + Args: + execution_id: The execution identifier + tenant_id: Optional tenant identifier for additional filtering + + Returns: + The WorkflowNodeExecutionModel if found, or None if not found + """ + stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution_id) + + # Add tenant filtering if provided + if tenant_id is not None: + stmt = stmt.where(WorkflowNodeExecutionModel.tenant_id == tenant_id) + + with self._session_maker() as session: + return session.scalar(stmt) + + def delete_expired_executions( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> int: + """ + Delete workflow node executions that are older than the specified date. + + Args: + tenant_id: The tenant identifier + before_date: Delete executions created before this date + batch_size: Maximum number of executions to delete in one batch + + Returns: + The number of executions deleted + """ + total_deleted = 0 + + while True: + with self._session_maker() as session: + # Find executions to delete in batches + stmt = ( + select(WorkflowNodeExecutionModel.id) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.created_at < before_date, + ) + .limit(batch_size) + ) + + execution_ids = session.execute(stmt).scalars().all() + if not execution_ids: + break + + # Delete the batch + delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids)) + result = session.execute(delete_stmt) + session.commit() + total_deleted += result.rowcount + + # If we deleted fewer than the batch size, we're done + if len(execution_ids) < batch_size: + break + + return total_deleted + + def delete_executions_by_app( + self, + tenant_id: str, + app_id: str, + batch_size: int = 1000, + ) -> int: + """ + Delete all workflow node executions for a specific app. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + batch_size: Maximum number of executions to delete in one batch + + Returns: + The total number of executions deleted + """ + total_deleted = 0 + + while True: + with self._session_maker() as session: + # Find executions to delete in batches + stmt = ( + select(WorkflowNodeExecutionModel.id) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.app_id == app_id, + ) + .limit(batch_size) + ) + + execution_ids = session.execute(stmt).scalars().all() + if not execution_ids: + break + + # Delete the batch + delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids)) + result = session.execute(delete_stmt) + session.commit() + total_deleted += result.rowcount + + # If we deleted fewer than the batch size, we're done + if len(execution_ids) < batch_size: + break + + return total_deleted + + def get_expired_executions_batch( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Get a batch of expired workflow node executions for backup purposes. + + Args: + tenant_id: The tenant identifier + before_date: Get executions created before this date + batch_size: Maximum number of executions to retrieve + + Returns: + A sequence of WorkflowNodeExecutionModel instances + """ + stmt = ( + select(WorkflowNodeExecutionModel) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.created_at < before_date, + ) + .limit(batch_size) + ) + + with self._session_maker() as session: + return session.execute(stmt).scalars().all() + + def delete_executions_by_ids( + self, + execution_ids: Sequence[str], + ) -> int: + """ + Delete workflow node executions by their IDs. + + Args: + execution_ids: List of execution IDs to delete + + Returns: + The number of executions deleted + """ + if not execution_ids: + return 0 + + with self._session_maker() as session: + stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids)) + result = session.execute(stmt) + session.commit() + return result.rowcount diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py new file mode 100644 index 0000000000..ebd1d74b20 --- /dev/null +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -0,0 +1,203 @@ +""" +SQLAlchemy API WorkflowRun Repository Implementation + +This module provides the SQLAlchemy-based implementation of the APIWorkflowRunRepository +protocol. It handles service-layer WorkflowRun database operations using SQLAlchemy 2.0 +style queries with proper session management and multi-tenant data isolation. + +Key Features: +- SQLAlchemy 2.0 style queries for modern database operations +- Cursor-based pagination for efficient large dataset handling +- Bulk operations with batch processing for performance +- Multi-tenant data isolation and security +- Proper session management with dependency injection + +Implementation Notes: +- Uses sessionmaker for consistent session management +- Implements cursor-based pagination using created_at timestamps +- Provides efficient bulk deletion with batch processing +- Maintains data consistency with proper transaction handling +""" + +import logging +from collections.abc import Sequence +from datetime import datetime +from typing import Optional, cast + +from sqlalchemy import delete, select +from sqlalchemy.orm import Session, sessionmaker + +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models.workflow import WorkflowRun +from repositories.api_workflow_run_repository import APIWorkflowRunRepository + +logger = logging.getLogger(__name__) + + +class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): + """ + SQLAlchemy implementation of APIWorkflowRunRepository. + + Provides service-layer WorkflowRun database operations using SQLAlchemy 2.0 + style queries. Supports dependency injection through sessionmaker and + maintains proper multi-tenant data isolation. + + Args: + session_maker: SQLAlchemy sessionmaker instance for database connections + """ + + def __init__(self, session_maker: sessionmaker[Session]) -> None: + """ + Initialize the repository with a sessionmaker. + + Args: + session_maker: SQLAlchemy sessionmaker for database connections + """ + self._session_maker = session_maker + + def get_paginated_workflow_runs( + self, + tenant_id: str, + app_id: str, + triggered_from: str, + limit: int = 20, + last_id: Optional[str] = None, + ) -> InfiniteScrollPagination: + """ + Get paginated workflow runs with filtering. + + Implements cursor-based pagination using created_at timestamps for + efficient handling of large datasets. Filters by tenant, app, and + trigger source for proper data isolation. + """ + with self._session_maker() as session: + # Build base query with filters + base_stmt = select(WorkflowRun).where( + WorkflowRun.tenant_id == tenant_id, + WorkflowRun.app_id == app_id, + WorkflowRun.triggered_from == triggered_from, + ) + + if last_id: + # Get the last workflow run for cursor-based pagination + last_run_stmt = base_stmt.where(WorkflowRun.id == last_id) + last_workflow_run = session.scalar(last_run_stmt) + + if not last_workflow_run: + raise ValueError("Last workflow run not exists") + + # Get records created before the last run's timestamp + base_stmt = base_stmt.where( + WorkflowRun.created_at < last_workflow_run.created_at, + WorkflowRun.id != last_workflow_run.id, + ) + + # First page - get most recent records + workflow_runs = session.scalars(base_stmt.order_by(WorkflowRun.created_at.desc()).limit(limit + 1)).all() + + # Check if there are more records for pagination + has_more = len(workflow_runs) > limit + if has_more: + workflow_runs = workflow_runs[:-1] + + return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) + + def get_workflow_run_by_id( + self, + tenant_id: str, + app_id: str, + run_id: str, + ) -> Optional[WorkflowRun]: + """ + Get a specific workflow run by ID with tenant and app isolation. + """ + with self._session_maker() as session: + stmt = select(WorkflowRun).where( + WorkflowRun.tenant_id == tenant_id, + WorkflowRun.app_id == app_id, + WorkflowRun.id == run_id, + ) + return cast(Optional[WorkflowRun], session.scalar(stmt)) + + def get_expired_runs_batch( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> Sequence[WorkflowRun]: + """ + Get a batch of expired workflow runs for cleanup operations. + """ + with self._session_maker() as session: + stmt = ( + select(WorkflowRun) + .where( + WorkflowRun.tenant_id == tenant_id, + WorkflowRun.created_at < before_date, + ) + .limit(batch_size) + ) + return cast(Sequence[WorkflowRun], session.scalars(stmt).all()) + + def delete_runs_by_ids( + self, + run_ids: Sequence[str], + ) -> int: + """ + Delete workflow runs by their IDs using bulk deletion. + """ + if not run_ids: + return 0 + + with self._session_maker() as session: + stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids)) + result = session.execute(stmt) + session.commit() + + deleted_count = cast(int, result.rowcount) + logger.info(f"Deleted {deleted_count} workflow runs by IDs") + return deleted_count + + def delete_runs_by_app( + self, + tenant_id: str, + app_id: str, + batch_size: int = 1000, + ) -> int: + """ + Delete all workflow runs for a specific app in batches. + """ + total_deleted = 0 + + while True: + with self._session_maker() as session: + # Get a batch of run IDs to delete + stmt = ( + select(WorkflowRun.id) + .where( + WorkflowRun.tenant_id == tenant_id, + WorkflowRun.app_id == app_id, + ) + .limit(batch_size) + ) + run_ids = session.scalars(stmt).all() + + if not run_ids: + break + + # Delete the batch + delete_stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids)) + result = session.execute(delete_stmt) + session.commit() + + batch_deleted = result.rowcount + total_deleted += batch_deleted + + logger.info(f"Deleted batch of {batch_deleted} workflow runs for app {app_id}") + + # If we deleted fewer records than the batch size, we're done + if batch_deleted < batch_size: + break + + logger.info(f"Total deleted {total_deleted} workflow runs for app {app_id}") + return total_deleted diff --git a/api/schedule/check_upgradable_plugin_task.py b/api/schedule/check_upgradable_plugin_task.py new file mode 100644 index 0000000000..c1d6018827 --- /dev/null +++ b/api/schedule/check_upgradable_plugin_task.py @@ -0,0 +1,49 @@ +import time + +import click + +import app +from extensions.ext_database import db +from models.account import TenantPluginAutoUpgradeStrategy +from tasks.process_tenant_plugin_autoupgrade_check_task import process_tenant_plugin_autoupgrade_check_task + +AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL = 15 * 60 # 15 minutes + + +@app.celery.task(queue="plugin") +def check_upgradable_plugin_task(): + click.echo(click.style("Start check upgradable plugin.", fg="green")) + start_at = time.perf_counter() + + now_seconds_of_day = time.time() % 86400 - 30 # we assume the tz is UTC + click.echo(click.style("Now seconds of day: {}".format(now_seconds_of_day), fg="green")) + + strategies = ( + db.session.query(TenantPluginAutoUpgradeStrategy) + .filter( + TenantPluginAutoUpgradeStrategy.upgrade_time_of_day >= now_seconds_of_day, + TenantPluginAutoUpgradeStrategy.upgrade_time_of_day + < now_seconds_of_day + AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL, + TenantPluginAutoUpgradeStrategy.strategy_setting + != TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED, + ) + .all() + ) + + for strategy in strategies: + process_tenant_plugin_autoupgrade_check_task.delay( + strategy.tenant_id, + strategy.strategy_setting, + strategy.upgrade_time_of_day, + strategy.upgrade_mode, + strategy.exclude_plugins, + strategy.include_plugins, + ) + + end_at = time.perf_counter() + click.echo( + click.style( + "Checked upgradable plugin success latency: {}".format(end_at - start_at), + fg="green", + ) + ) diff --git a/api/schedule/clean_embedding_cache_task.py b/api/schedule/clean_embedding_cache_task.py index 9efe120b7a..024e3d6f50 100644 --- a/api/schedule/clean_embedding_cache_task.py +++ b/api/schedule/clean_embedding_cache_task.py @@ -21,7 +21,7 @@ def clean_embedding_cache_task(): try: embedding_ids = ( db.session.query(Embedding.id) - .filter(Embedding.created_at < thirty_days_ago) + .where(Embedding.created_at < thirty_days_ago) .order_by(Embedding.created_at.desc()) .limit(100) .all() diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py index d02bc81f33..a6851e36e5 100644 --- a/api/schedule/clean_messages.py +++ b/api/schedule/clean_messages.py @@ -36,7 +36,7 @@ def clean_messages(): # Main query with join and filter messages = ( db.session.query(Message) - .filter(Message.created_at < plan_sandbox_clean_message_day) + .where(Message.created_at < plan_sandbox_clean_message_day) .order_by(Message.created_at.desc()) .limit(100) .all() @@ -66,25 +66,25 @@ def clean_messages(): plan = plan_cache.decode() if plan == "sandbox": # clean related message - db.session.query(MessageFeedback).filter(MessageFeedback.message_id == message.id).delete( + db.session.query(MessageFeedback).where(MessageFeedback.message_id == message.id).delete( synchronize_session=False ) - db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == message.id).delete( + db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message.id).delete( synchronize_session=False ) - db.session.query(MessageChain).filter(MessageChain.message_id == message.id).delete( + db.session.query(MessageChain).where(MessageChain.message_id == message.id).delete( synchronize_session=False ) - db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).delete( + db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message.id).delete( synchronize_session=False ) - db.session.query(MessageFile).filter(MessageFile.message_id == message.id).delete( + db.session.query(MessageFile).where(MessageFile.message_id == message.id).delete( synchronize_session=False ) - db.session.query(SavedMessage).filter(SavedMessage.message_id == message.id).delete( + db.session.query(SavedMessage).where(SavedMessage.message_id == message.id).delete( synchronize_session=False ) - db.session.query(Message).filter(Message.id == message.id).delete() + db.session.query(Message).where(Message.id == message.id).delete() db.session.commit() end_at = time.perf_counter() click.echo(click.style("Cleaned messages from db success latency: {}".format(end_at - start_at), fg="green")) diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index c0cd42a226..72e2e73e65 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -27,7 +27,7 @@ def clean_unused_datasets_task(): # Subquery for counting new documents document_subquery_new = ( db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) - .filter( + .where( Document.indexing_status == "completed", Document.enabled == True, Document.archived == False, @@ -40,7 +40,7 @@ def clean_unused_datasets_task(): # Subquery for counting old documents document_subquery_old = ( db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) - .filter( + .where( Document.indexing_status == "completed", Document.enabled == True, Document.archived == False, @@ -55,7 +55,7 @@ def clean_unused_datasets_task(): select(Dataset) .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) - .filter( + .where( Dataset.created_at < plan_sandbox_clean_day, func.coalesce(document_subquery_new.c.document_count, 0) == 0, func.coalesce(document_subquery_old.c.document_count, 0) > 0, @@ -72,7 +72,7 @@ def clean_unused_datasets_task(): for dataset in datasets: dataset_query = ( db.session.query(DatasetQuery) - .filter(DatasetQuery.created_at > plan_sandbox_clean_day, DatasetQuery.dataset_id == dataset.id) + .where(DatasetQuery.created_at > plan_sandbox_clean_day, DatasetQuery.dataset_id == dataset.id) .all() ) if not dataset_query or len(dataset_query) == 0: @@ -80,7 +80,7 @@ def clean_unused_datasets_task(): # add auto disable log documents = ( db.session.query(Document) - .filter( + .where( Document.dataset_id == dataset.id, Document.enabled == True, Document.archived == False, @@ -99,9 +99,7 @@ def clean_unused_datasets_task(): index_processor.clean(dataset, None) # update document - update_params = {Document.enabled: False} - - db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params) + db.session.query(Document).filter_by(dataset_id=dataset.id).update({Document.enabled: False}) db.session.commit() click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green")) except Exception as e: @@ -113,7 +111,7 @@ def clean_unused_datasets_task(): # Subquery for counting new documents document_subquery_new = ( db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) - .filter( + .where( Document.indexing_status == "completed", Document.enabled == True, Document.archived == False, @@ -126,7 +124,7 @@ def clean_unused_datasets_task(): # Subquery for counting old documents document_subquery_old = ( db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) - .filter( + .where( Document.indexing_status == "completed", Document.enabled == True, Document.archived == False, @@ -141,7 +139,7 @@ def clean_unused_datasets_task(): select(Dataset) .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) - .filter( + .where( Dataset.created_at < plan_pro_clean_day, func.coalesce(document_subquery_new.c.document_count, 0) == 0, func.coalesce(document_subquery_old.c.document_count, 0) > 0, @@ -157,7 +155,7 @@ def clean_unused_datasets_task(): for dataset in datasets: dataset_query = ( db.session.query(DatasetQuery) - .filter(DatasetQuery.created_at > plan_pro_clean_day, DatasetQuery.dataset_id == dataset.id) + .where(DatasetQuery.created_at > plan_pro_clean_day, DatasetQuery.dataset_id == dataset.id) .all() ) if not dataset_query or len(dataset_query) == 0: @@ -176,9 +174,7 @@ def clean_unused_datasets_task(): index_processor.clean(dataset, None) # update document - update_params = {Document.enabled: False} - - db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params) + db.session.query(Document).filter_by(dataset_id=dataset.id).update({Document.enabled: False}) db.session.commit() click.echo( click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green") diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py index 8a02278de8..91953354e6 100644 --- a/api/schedule/create_tidb_serverless_task.py +++ b/api/schedule/create_tidb_serverless_task.py @@ -20,7 +20,7 @@ def create_tidb_serverless_task(): try: # check the number of idle tidb serverless idle_tidb_serverless_number = ( - db.session.query(TidbAuthBinding).filter(TidbAuthBinding.active == False).count() + db.session.query(TidbAuthBinding).where(TidbAuthBinding.active == False).count() ) if idle_tidb_serverless_number >= tidb_serverless_number: break diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py index 5ee813e1de..5911c98b0a 100644 --- a/api/schedule/mail_clean_document_notify_task.py +++ b/api/schedule/mail_clean_document_notify_task.py @@ -3,12 +3,12 @@ import time from collections import defaultdict import click -from flask import render_template # type: ignore import app from configs import dify_config from extensions.ext_database import db from extensions.ext_mail import mail +from libs.email_i18n import EmailType, get_email_i18n_service from models.account import Account, Tenant, TenantAccountJoin from models.dataset import Dataset, DatasetAutoDisableLog from services.feature_service import FeatureService @@ -30,7 +30,7 @@ def mail_clean_document_notify_task(): # send document clean notify mail try: dataset_auto_disable_logs = ( - db.session.query(DatasetAutoDisableLog).filter(DatasetAutoDisableLog.notified == False).all() + db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False).all() ) # group by tenant_id dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) @@ -45,7 +45,7 @@ def mail_clean_document_notify_task(): if plan != "sandbox": knowledge_details = [] # check tenant - tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first() + tenant = db.session.query(Tenant).where(Tenant.id == tenant_id).first() if not tenant: continue # check current owner @@ -54,7 +54,7 @@ def mail_clean_document_notify_task(): ) if not current_owner_join: continue - account = db.session.query(Account).filter(Account.id == current_owner_join.account_id).first() + account = db.session.query(Account).where(Account.id == current_owner_join.account_id).first() if not account: continue @@ -67,19 +67,21 @@ def mail_clean_document_notify_task(): ) for dataset_id, document_ids in dataset_auto_dataset_map.items(): - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if dataset: document_count = len(document_ids) knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") if knowledge_details: - html_content = render_template( - "clean_document_job_mail_template-US.html", - userName=account.email, - knowledge_details=knowledge_details, - url=url, - ) - mail.send( - to=account.email, subject="Dify Knowledge base auto disable notification", html=html_content + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.DOCUMENT_CLEAN_NOTIFY, + language_code="en-US", + to=account.email, + template_context={ + "userName": account.email, + "knowledge_details": knowledge_details, + "url": url, + }, ) # update notified to True diff --git a/api/schedule/queue_monitor_task.py b/api/schedule/queue_monitor_task.py index e3a7021b9d..a05e1358ed 100644 --- a/api/schedule/queue_monitor_task.py +++ b/api/schedule/queue_monitor_task.py @@ -3,13 +3,12 @@ from datetime import datetime from urllib.parse import urlparse import click -from flask import render_template from redis import Redis import app from configs import dify_config from extensions.ext_database import db -from extensions.ext_mail import mail +from libs.email_i18n import EmailType, get_email_i18n_service # Create a dedicated Redis connection (using the same configuration as Celery) celery_broker_url = dify_config.CELERY_BROKER_URL @@ -39,18 +38,20 @@ def queue_monitor_task(): alter_emails = dify_config.QUEUE_MONITOR_ALERT_EMAILS if alter_emails: to_list = alter_emails.split(",") + email_service = get_email_i18n_service() for to in to_list: try: current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - html_content = render_template( - "queue_monitor_alert_email_template_en-US.html", - queue_name=queue_name, - queue_length=queue_length, - threshold=threshold, - alert_time=current_time, - ) - mail.send( - to=to, subject="Alert: Dataset Queue pending tasks exceeded the limit", html=html_content + email_service.send_email( + email_type=EmailType.QUEUE_MONITOR_ALERT, + language_code="en-US", + to=to, + template_context={ + "queue_name": queue_name, + "queue_length": queue_length, + "threshold": threshold, + "alert_time": current_time, + }, ) except Exception as e: logging.exception(click.style("Exception occurred during sending email", fg="red")) diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py index ce4ecb6e7c..4d6c1f1877 100644 --- a/api/schedule/update_tidb_serverless_status_task.py +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -17,7 +17,7 @@ def update_tidb_serverless_status_task(): # check the number of idle tidb serverless tidb_serverless_list = ( db.session.query(TidbAuthBinding) - .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING") + .where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING") .all() ) if len(tidb_serverless_list) == 0: diff --git a/api/services/account_service.py b/api/services/account_service.py index 3fdbda48a6..59bffa873c 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -16,7 +16,8 @@ from configs import dify_config from constants.languages import language_timezone_mapping, languages from events.tenant_event import tenant_was_created from extensions.ext_database import db -from extensions.ext_redis import redis_client +from extensions.ext_redis import redis_client, redis_fallback +from libs.datetime_utils import naive_utc_now from libs.helper import RateLimiter, TokenManager from libs.passport import PassportService from libs.password import compare_password, hash_password, valid_password @@ -28,6 +29,7 @@ from models.account import ( Tenant, TenantAccountJoin, TenantAccountRole, + TenantPluginAutoUpgradeStrategy, TenantStatus, ) from models.model import DifySetup @@ -52,8 +54,14 @@ from services.errors.workspace import WorkSpaceNotAllowedCreateError, Workspaces from services.feature_service import FeatureService from tasks.delete_account_task import delete_account_task from tasks.mail_account_deletion_task import send_account_deletion_verification_code +from tasks.mail_change_mail_task import send_change_mail_task from tasks.mail_email_code_login import send_email_code_login_mail_task from tasks.mail_invite_member_task import send_invite_member_mail_task +from tasks.mail_owner_transfer_task import ( + send_new_owner_transfer_notify_email_task, + send_old_owner_transfer_notify_email_task, + send_owner_transfer_confirm_task, +) from tasks.mail_reset_password_task import send_reset_password_mail_task @@ -75,8 +83,13 @@ class AccountService: email_code_account_deletion_rate_limiter = RateLimiter( prefix="email_code_account_deletion_rate_limit", max_attempts=1, time_window=60 * 1 ) + change_email_rate_limiter = RateLimiter(prefix="change_email_rate_limit", max_attempts=1, time_window=60 * 1) + owner_transfer_rate_limiter = RateLimiter(prefix="owner_transfer_rate_limit", max_attempts=1, time_window=60 * 1) + LOGIN_MAX_ERROR_LIMITS = 5 FORGOT_PASSWORD_MAX_ERROR_LIMITS = 5 + CHANGE_EMAIL_MAX_ERROR_LIMITS = 5 + OWNER_TRANSFER_MAX_ERROR_LIMITS = 5 @staticmethod def _get_refresh_token_key(refresh_token: str) -> str: @@ -124,8 +137,8 @@ class AccountService: available_ta.current = True db.session.commit() - if datetime.now(UTC).replace(tzinfo=None) - account.last_active_at > timedelta(minutes=10): - account.last_active_at = datetime.now(UTC).replace(tzinfo=None) + if naive_utc_now() - account.last_active_at > timedelta(minutes=10): + account.last_active_at = naive_utc_now() db.session.commit() return cast(Account, account) @@ -169,7 +182,7 @@ class AccountService: if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value - account.initialized_at = datetime.now(UTC).replace(tzinfo=None) + account.initialized_at = naive_utc_now() db.session.commit() @@ -307,7 +320,7 @@ class AccountService: # If it exists, update the record account_integrate.open_id = open_id account_integrate.encrypted_token = "" # todo - account_integrate.updated_at = datetime.now(UTC).replace(tzinfo=None) + account_integrate.updated_at = naive_utc_now() else: # If it does not exist, create a new record account_integrate = AccountIntegrate( @@ -342,7 +355,7 @@ class AccountService: @staticmethod def update_login_info(account: Account, *, ip_address: str) -> None: """Update last login time and ip""" - account.last_login_at = datetime.now(UTC).replace(tzinfo=None) + account.last_login_at = naive_utc_now() account.last_login_ip = ip_address db.session.add(account) db.session.commit() @@ -419,6 +432,101 @@ class AccountService: cls.reset_password_rate_limiter.increment_rate_limit(account_email) return token + @classmethod + def send_change_email_email( + cls, + account: Optional[Account] = None, + email: Optional[str] = None, + old_email: Optional[str] = None, + language: Optional[str] = "en-US", + phase: Optional[str] = None, + ): + account_email = account.email if account else email + if account_email is None: + raise ValueError("Email must be provided.") + + if cls.change_email_rate_limiter.is_rate_limited(account_email): + from controllers.console.auth.error import EmailChangeRateLimitExceededError + + raise EmailChangeRateLimitExceededError() + + code, token = cls.generate_change_email_token(account_email, account, old_email=old_email) + + send_change_mail_task.delay( + language=language, + to=account_email, + code=code, + phase=phase, + ) + cls.change_email_rate_limiter.increment_rate_limit(account_email) + return token + + @classmethod + def send_owner_transfer_email( + cls, + account: Optional[Account] = None, + email: Optional[str] = None, + language: Optional[str] = "en-US", + workspace_name: Optional[str] = "", + ): + account_email = account.email if account else email + if account_email is None: + raise ValueError("Email must be provided.") + + if cls.owner_transfer_rate_limiter.is_rate_limited(account_email): + from controllers.console.auth.error import OwnerTransferRateLimitExceededError + + raise OwnerTransferRateLimitExceededError() + + code, token = cls.generate_owner_transfer_token(account_email, account) + + send_owner_transfer_confirm_task.delay( + language=language, + to=account_email, + code=code, + workspace=workspace_name, + ) + cls.owner_transfer_rate_limiter.increment_rate_limit(account_email) + return token + + @classmethod + def send_old_owner_transfer_notify_email( + cls, + account: Optional[Account] = None, + email: Optional[str] = None, + language: Optional[str] = "en-US", + workspace_name: Optional[str] = "", + new_owner_email: Optional[str] = "", + ): + account_email = account.email if account else email + if account_email is None: + raise ValueError("Email must be provided.") + + send_old_owner_transfer_notify_email_task.delay( + language=language, + to=account_email, + workspace=workspace_name, + new_owner_email=new_owner_email, + ) + + @classmethod + def send_new_owner_transfer_notify_email( + cls, + account: Optional[Account] = None, + email: Optional[str] = None, + language: Optional[str] = "en-US", + workspace_name: Optional[str] = "", + ): + account_email = account.email if account else email + if account_email is None: + raise ValueError("Email must be provided.") + + send_new_owner_transfer_notify_email_task.delay( + language=language, + to=account_email, + workspace=workspace_name, + ) + @classmethod def generate_reset_password_token( cls, @@ -435,14 +543,64 @@ class AccountService: ) return code, token + @classmethod + def generate_change_email_token( + cls, + email: str, + account: Optional[Account] = None, + code: Optional[str] = None, + old_email: Optional[str] = None, + additional_data: dict[str, Any] = {}, + ): + if not code: + code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)]) + additional_data["code"] = code + additional_data["old_email"] = old_email + token = TokenManager.generate_token( + account=account, email=email, token_type="change_email", additional_data=additional_data + ) + return code, token + + @classmethod + def generate_owner_transfer_token( + cls, + email: str, + account: Optional[Account] = None, + code: Optional[str] = None, + additional_data: dict[str, Any] = {}, + ): + if not code: + code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)]) + additional_data["code"] = code + token = TokenManager.generate_token( + account=account, email=email, token_type="owner_transfer", additional_data=additional_data + ) + return code, token + @classmethod def revoke_reset_password_token(cls, token: str): TokenManager.revoke_token(token, "reset_password") + @classmethod + def revoke_change_email_token(cls, token: str): + TokenManager.revoke_token(token, "change_email") + + @classmethod + def revoke_owner_transfer_token(cls, token: str): + TokenManager.revoke_token(token, "owner_transfer") + @classmethod def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]: return TokenManager.get_token_data(token, "reset_password") + @classmethod + def get_change_email_data(cls, token: str) -> Optional[dict[str, Any]]: + return TokenManager.get_token_data(token, "change_email") + + @classmethod + def get_owner_transfer_data(cls, token: str) -> Optional[dict[str, Any]]: + return TokenManager.get_token_data(token, "owner_transfer") + @classmethod def send_email_code_login_email( cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US" @@ -485,7 +643,7 @@ class AccountService: ) ) - account = db.session.query(Account).filter(Account.email == email).first() + account = db.session.query(Account).where(Account.email == email).first() if not account: return None @@ -495,6 +653,7 @@ class AccountService: return account @staticmethod + @redis_fallback(default_return=None) def add_login_error_rate_limit(email: str) -> None: key = f"login_error_rate_limit:{email}" count = redis_client.get(key) @@ -504,6 +663,7 @@ class AccountService: redis_client.setex(key, dify_config.LOGIN_LOCKOUT_DURATION, count) @staticmethod + @redis_fallback(default_return=False) def is_login_error_rate_limit(email: str) -> bool: key = f"login_error_rate_limit:{email}" count = redis_client.get(key) @@ -516,11 +676,13 @@ class AccountService: return False @staticmethod + @redis_fallback(default_return=None) def reset_login_error_rate_limit(email: str): key = f"login_error_rate_limit:{email}" redis_client.delete(key) @staticmethod + @redis_fallback(default_return=None) def add_forgot_password_error_rate_limit(email: str) -> None: key = f"forgot_password_error_rate_limit:{email}" count = redis_client.get(key) @@ -530,6 +692,7 @@ class AccountService: redis_client.setex(key, dify_config.FORGOT_PASSWORD_LOCKOUT_DURATION, count) @staticmethod + @redis_fallback(default_return=False) def is_forgot_password_error_rate_limit(email: str) -> bool: key = f"forgot_password_error_rate_limit:{email}" count = redis_client.get(key) @@ -542,11 +705,69 @@ class AccountService: return False @staticmethod + @redis_fallback(default_return=None) def reset_forgot_password_error_rate_limit(email: str): key = f"forgot_password_error_rate_limit:{email}" redis_client.delete(key) @staticmethod + @redis_fallback(default_return=None) + def add_change_email_error_rate_limit(email: str) -> None: + key = f"change_email_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + count = 0 + count = int(count) + 1 + redis_client.setex(key, dify_config.CHANGE_EMAIL_LOCKOUT_DURATION, count) + + @staticmethod + @redis_fallback(default_return=False) + def is_change_email_error_rate_limit(email: str) -> bool: + key = f"change_email_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + return False + count = int(count) + if count > AccountService.CHANGE_EMAIL_MAX_ERROR_LIMITS: + return True + return False + + @staticmethod + @redis_fallback(default_return=None) + def reset_change_email_error_rate_limit(email: str): + key = f"change_email_error_rate_limit:{email}" + redis_client.delete(key) + + @staticmethod + @redis_fallback(default_return=None) + def add_owner_transfer_error_rate_limit(email: str) -> None: + key = f"owner_transfer_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + count = 0 + count = int(count) + 1 + redis_client.setex(key, dify_config.OWNER_TRANSFER_LOCKOUT_DURATION, count) + + @staticmethod + @redis_fallback(default_return=False) + def is_owner_transfer_error_rate_limit(email: str) -> bool: + key = f"owner_transfer_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + return False + count = int(count) + if count > AccountService.OWNER_TRANSFER_MAX_ERROR_LIMITS: + return True + return False + + @staticmethod + @redis_fallback(default_return=None) + def reset_owner_transfer_error_rate_limit(email: str): + key = f"owner_transfer_error_rate_limit:{email}" + redis_client.delete(key) + + @staticmethod + @redis_fallback(default_return=False) def is_email_send_ip_limit(ip_address: str): minute_key = f"email_send_ip_limit_minute:{ip_address}" freeze_key = f"email_send_ip_limit_freeze:{ip_address}" @@ -586,6 +807,10 @@ class AccountService: return False + @staticmethod + def check_email_unique(email: str) -> bool: + return db.session.query(Account).filter_by(email=email).first() is None + class TenantService: @staticmethod @@ -604,6 +829,17 @@ class TenantService: db.session.add(tenant) db.session.commit() + plugin_upgrade_strategy = TenantPluginAutoUpgradeStrategy( + tenant_id=tenant.id, + strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY, + upgrade_time_of_day=0, + upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE, + exclude_plugins=[], + include_plugins=[], + ) + db.session.add(plugin_upgrade_strategy) + db.session.commit() + tenant.encrypt_public_key = generate_key_pair(tenant.id) db.session.commit() return tenant @@ -664,7 +900,7 @@ class TenantService: return ( db.session.query(Tenant) .join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id) - .filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL) + .where(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL) .all() ) @@ -693,7 +929,7 @@ class TenantService: tenant_account_join = ( db.session.query(TenantAccountJoin) .join(Tenant, TenantAccountJoin.tenant_id == Tenant.id) - .filter( + .where( TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id == tenant_id, Tenant.status == TenantStatus.NORMAL, @@ -704,7 +940,7 @@ class TenantService: if not tenant_account_join: raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") else: - db.session.query(TenantAccountJoin).filter( + db.session.query(TenantAccountJoin).where( TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id ).update({"current": False}) tenant_account_join.current = True @@ -719,7 +955,7 @@ class TenantService: db.session.query(Account, TenantAccountJoin.role) .select_from(Account) .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) - .filter(TenantAccountJoin.tenant_id == tenant.id) + .where(TenantAccountJoin.tenant_id == tenant.id) ) # Initialize an empty list to store the updated accounts @@ -738,8 +974,8 @@ class TenantService: db.session.query(Account, TenantAccountJoin.role) .select_from(Account) .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) - .filter(TenantAccountJoin.tenant_id == tenant.id) - .filter(TenantAccountJoin.role == "dataset_operator") + .where(TenantAccountJoin.tenant_id == tenant.id) + .where(TenantAccountJoin.role == "dataset_operator") ) # Initialize an empty list to store the updated accounts @@ -759,9 +995,7 @@ class TenantService: return ( db.session.query(TenantAccountJoin) - .filter( - TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles]) - ) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles])) .first() is not None ) @@ -771,10 +1005,10 @@ class TenantService: """Get the role of the current account for a given tenant""" join = ( db.session.query(TenantAccountJoin) - .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) .first() ) - return join.role if join else None + return TenantAccountRole(join.role) if join else None @staticmethod def get_tenant_count() -> int: @@ -843,21 +1077,21 @@ class TenantService: target_member_join.role = new_role db.session.commit() - @staticmethod - def dissolve_tenant(tenant: Tenant, operator: Account) -> None: - """Dissolve tenant""" - if not TenantService.check_member_permission(tenant, operator, operator, "remove"): - raise NoPermissionError("No permission to dissolve tenant.") - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete() - db.session.delete(tenant) - db.session.commit() - @staticmethod def get_custom_config(tenant_id: str) -> dict: tenant = db.get_or_404(Tenant, tenant_id) return cast(dict, tenant.custom_config_dict) + @staticmethod + def is_owner(account: Account, tenant: Tenant) -> bool: + return TenantService.get_user_role(account, tenant) == TenantAccountRole.OWNER + + @staticmethod + def is_member(account: Account, tenant: Tenant) -> bool: + """Check if the account is a member of the tenant""" + return TenantService.get_user_role(account, tenant) is not None + class RegisterService: @classmethod @@ -885,7 +1119,7 @@ class RegisterService: ) account.last_login_ip = ip_address - account.initialized_at = datetime.now(UTC).replace(tzinfo=None) + account.initialized_at = naive_utc_now() TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True) @@ -926,7 +1160,7 @@ class RegisterService: is_setup=is_setup, ) account.status = AccountStatus.ACTIVE.value if not status else status.value - account.initialized_at = datetime.now(UTC).replace(tzinfo=None) + account.initialized_at = naive_utc_now() if open_id is not None and provider is not None: AccountService.link_account_integrate(provider, open_id, account) @@ -1038,7 +1272,7 @@ class RegisterService: tenant = ( db.session.query(Tenant) - .filter(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal") + .where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal") .first() ) @@ -1048,7 +1282,7 @@ class RegisterService: tenant_account = ( db.session.query(Account, TenantAccountJoin.role) .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) - .filter(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id) + .where(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id) .first() ) diff --git a/api/services/agent_service.py b/api/services/agent_service.py index 503b31ede2..7c6df2428f 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -25,7 +25,7 @@ class AgentService: conversation: Conversation | None = ( db.session.query(Conversation) - .filter( + .where( Conversation.id == conversation_id, Conversation.app_id == app_model.id, ) @@ -37,7 +37,7 @@ class AgentService: message: Optional[Message] = ( db.session.query(Message) - .filter( + .where( Message.id == message_id, Message.conversation_id == conversation_id, ) @@ -52,12 +52,10 @@ class AgentService: if conversation.from_end_user_id: # only select name field executor = ( - db.session.query(EndUser, EndUser.name).filter(EndUser.id == conversation.from_end_user_id).first() + db.session.query(EndUser, EndUser.name).where(EndUser.id == conversation.from_end_user_id).first() ) else: - executor = ( - db.session.query(Account, Account.name).filter(Account.id == conversation.from_account_id).first() - ) + executor = db.session.query(Account, Account.name).where(Account.id == conversation.from_account_id).first() if executor: executor = executor.name diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 8c950abc24..7cb0b46517 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -26,7 +26,7 @@ class AppAnnotationService: # get app info app = ( db.session.query(App) - .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) @@ -35,7 +35,7 @@ class AppAnnotationService: if args.get("message_id"): message_id = str(args["message_id"]) # get message info - message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app.id).first() + message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app.id).first() if not message: raise NotFound("Message Not Exists.") @@ -61,9 +61,7 @@ class AppAnnotationService: db.session.add(annotation) db.session.commit() # if annotation reply is enabled , add annotation to index - annotation_setting = ( - db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() - ) + annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() if annotation_setting: add_annotation_to_index_task.delay( annotation.id, @@ -117,7 +115,7 @@ class AppAnnotationService: # get app info app = ( db.session.query(App) - .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) @@ -126,8 +124,8 @@ class AppAnnotationService: if keyword: stmt = ( select(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) - .filter( + .where(MessageAnnotation.app_id == app_id) + .where( or_( MessageAnnotation.question.ilike("%{}%".format(keyword)), MessageAnnotation.content.ilike("%{}%".format(keyword)), @@ -138,7 +136,7 @@ class AppAnnotationService: else: stmt = ( select(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) + .where(MessageAnnotation.app_id == app_id) .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) ) annotations = db.paginate(select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False) @@ -149,7 +147,7 @@ class AppAnnotationService: # get app info app = ( db.session.query(App) - .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) @@ -157,7 +155,7 @@ class AppAnnotationService: raise NotFound("App not found") annotations = ( db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) + .where(MessageAnnotation.app_id == app_id) .order_by(MessageAnnotation.created_at.desc()) .all() ) @@ -168,7 +166,7 @@ class AppAnnotationService: # get app info app = ( db.session.query(App) - .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) @@ -181,9 +179,7 @@ class AppAnnotationService: db.session.add(annotation) db.session.commit() # if annotation reply is enabled , add annotation to index - annotation_setting = ( - db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() - ) + annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() if annotation_setting: add_annotation_to_index_task.delay( annotation.id, @@ -199,14 +195,14 @@ class AppAnnotationService: # get app info app = ( db.session.query(App) - .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) if not app: raise NotFound("App not found") - annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() + annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() if not annotation: raise NotFound("Annotation not found") @@ -217,7 +213,7 @@ class AppAnnotationService: db.session.commit() # if annotation reply is enabled , add annotation to index app_annotation_setting = ( - db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() ) if app_annotation_setting: @@ -236,14 +232,14 @@ class AppAnnotationService: # get app info app = ( db.session.query(App) - .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) if not app: raise NotFound("App not found") - annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() + annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() if not annotation: raise NotFound("Annotation not found") @@ -252,7 +248,7 @@ class AppAnnotationService: annotation_hit_histories = ( db.session.query(AppAnnotationHitHistory) - .filter(AppAnnotationHitHistory.annotation_id == annotation_id) + .where(AppAnnotationHitHistory.annotation_id == annotation_id) .all() ) if annotation_hit_histories: @@ -262,7 +258,7 @@ class AppAnnotationService: db.session.commit() # if annotation reply is enabled , delete annotation index app_annotation_setting = ( - db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() ) if app_annotation_setting: @@ -275,7 +271,7 @@ class AppAnnotationService: # get app info app = ( db.session.query(App) - .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) @@ -314,21 +310,21 @@ class AppAnnotationService: # get app info app = ( db.session.query(App) - .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) if not app: raise NotFound("App not found") - annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() + annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() if not annotation: raise NotFound("Annotation not found") stmt = ( select(AppAnnotationHitHistory) - .filter( + .where( AppAnnotationHitHistory.app_id == app_id, AppAnnotationHitHistory.annotation_id == annotation_id, ) @@ -341,7 +337,7 @@ class AppAnnotationService: @classmethod def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None: - annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() + annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() if not annotation: return None @@ -361,7 +357,7 @@ class AppAnnotationService: score: float, ): # add hit count to annotation - db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).update( + db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).update( {MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, synchronize_session=False ) @@ -384,16 +380,14 @@ class AppAnnotationService: # get app info app = ( db.session.query(App) - .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) if not app: raise NotFound("App not found") - annotation_setting = ( - db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() - ) + annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() if annotation_setting: collection_binding_detail = annotation_setting.collection_binding_detail return { @@ -412,7 +406,7 @@ class AppAnnotationService: # get app info app = ( db.session.query(App) - .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) @@ -421,7 +415,7 @@ class AppAnnotationService: annotation_setting = ( db.session.query(AppAnnotationSetting) - .filter( + .where( AppAnnotationSetting.app_id == app_id, AppAnnotationSetting.id == annotation_setting_id, ) diff --git a/api/services/api_based_extension_service.py b/api/services/api_based_extension_service.py index 601d67d2fb..457c91e5c0 100644 --- a/api/services/api_based_extension_service.py +++ b/api/services/api_based_extension_service.py @@ -73,7 +73,7 @@ class APIBasedExtensionService: db.session.query(APIBasedExtension) .filter_by(tenant_id=extension_data.tenant_id) .filter_by(name=extension_data.name) - .filter(APIBasedExtension.id != extension_data.id) + .where(APIBasedExtension.id != extension_data.id) .first() ) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 20257fa345..fe0efd061d 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -41,7 +41,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:" CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:" IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB -CURRENT_DSL_VERSION = "0.3.0" +CURRENT_DSL_VERSION = "0.3.1" class ImportMode(StrEnum): @@ -575,13 +575,26 @@ class AppDslService: raise ValueError("Missing draft workflow configuration, please check.") workflow_dict = workflow.to_dict(include_secret=include_secret) + # TODO: refactor: we need a better way to filter workspace related data from nodes for node in workflow_dict.get("graph", {}).get("nodes", []): - if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: - dataset_ids = node["data"].get("dataset_ids", []) - node["data"]["dataset_ids"] = [ + node_data = node.get("data", {}) + if not node_data: + continue + data_type = node_data.get("type", "") + if data_type == NodeType.KNOWLEDGE_RETRIEVAL.value: + dataset_ids = node_data.get("dataset_ids", []) + node_data["dataset_ids"] = [ cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id) for dataset_id in dataset_ids ] + # filter credential id from tool node + if not include_secret and data_type == NodeType.TOOL.value: + node_data.pop("credential_id", None) + # filter credential id from agent node + if not include_secret and data_type == NodeType.AGENT.value: + for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []): + tool.pop("credential_id", None) + export_data["workflow"] = workflow_dict dependencies = cls._extract_dependencies_from_workflow(workflow) export_data["dependencies"] = [ @@ -602,7 +615,15 @@ class AppDslService: if not app_model_config: raise ValueError("Missing app configuration, please check.") - export_data["model_config"] = app_model_config.to_dict() + model_config = app_model_config.to_dict() + + # TODO: refactor: we need a better way to filter workspace related data from model config + # filter credential id from model config + for tool in model_config.get("agent_mode", {}).get("tools", []): + tool.pop("credential_id", None) + + export_data["model_config"] = model_config + dependencies = cls._extract_dependencies_from_model_config(app_model_config.to_dict()) export_data["dependencies"] = [ jsonable_encoder(d.model_dump()) diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 245c123a04..6f7e705b52 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -129,11 +129,25 @@ class AppGenerateService: rate_limit.exit(request_id) @staticmethod - def _get_max_active_requests(app_model: App) -> int: - max_active_requests = app_model.max_active_requests - if max_active_requests is None: - max_active_requests = int(dify_config.APP_MAX_ACTIVE_REQUESTS) - return max_active_requests + def _get_max_active_requests(app: App) -> int: + """ + Get the maximum number of active requests allowed for an app. + + Returns the smaller value between app's custom limit and global config limit. + A value of 0 means infinite (no limit). + + Args: + app: The App model instance + + Returns: + The maximum number of active requests allowed + """ + app_limit = app.max_active_requests or 0 + config_limit = dify_config.APP_MAX_ACTIVE_REQUESTS + + # Filter out infinite (0) values and return the minimum, or 0 if both are infinite + limits = [limit for limit in [app_limit, config_limit] if limit > 0] + return min(limits) if limits else 0 @classmethod def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True): diff --git a/api/services/app_service.py b/api/services/app_service.py index d08462d001..0b6b85bcb2 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -1,7 +1,6 @@ import json import logging -from datetime import UTC, datetime -from typing import Optional, cast +from typing import Optional, TypedDict, cast from flask_login import current_user from flask_sqlalchemy.pagination import Pagination @@ -17,6 +16,7 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_was_created from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.account import Account from models.model import App, AppMode, AppModelConfig, Site from models.tools import ApiToolProvider @@ -47,8 +47,6 @@ class AppService: filters.append(App.mode == AppMode.ADVANCED_CHAT.value) elif args["mode"] == "agent-chat": filters.append(App.mode == AppMode.AGENT_CHAT.value) - elif args["mode"] == "channel": - filters.append(App.mode == AppMode.CHANNEL.value) if args.get("is_created_by_me", False): filters.append(App.created_by == user_id) @@ -222,21 +220,31 @@ class AppService: return app - def update_app(self, app: App, args: dict) -> App: + class ArgsDict(TypedDict): + name: str + description: str + icon_type: str + icon: str + icon_background: str + use_icon_as_answer_icon: bool + max_active_requests: int + + def update_app(self, app: App, args: ArgsDict) -> App: """ Update app :param app: App instance :param args: request args :return: App instance """ - app.name = args.get("name") - app.description = args.get("description", "") - app.icon_type = args.get("icon_type", "emoji") - app.icon = args.get("icon") - app.icon_background = args.get("icon_background") + app.name = args["name"] + app.description = args["description"] + app.icon_type = args["icon_type"] + app.icon = args["icon"] + app.icon_background = args["icon_background"] app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False) + app.max_active_requests = args.get("max_active_requests") app.updated_by = current_user.id - app.updated_at = datetime.now(UTC).replace(tzinfo=None) + app.updated_at = naive_utc_now() db.session.commit() return app @@ -250,7 +258,7 @@ class AppService: """ app.name = name app.updated_by = current_user.id - app.updated_at = datetime.now(UTC).replace(tzinfo=None) + app.updated_at = naive_utc_now() db.session.commit() return app @@ -266,7 +274,7 @@ class AppService: app.icon = icon app.icon_background = icon_background app.updated_by = current_user.id - app.updated_at = datetime.now(UTC).replace(tzinfo=None) + app.updated_at = naive_utc_now() db.session.commit() return app @@ -283,7 +291,7 @@ class AppService: app.enable_site = enable_site app.updated_by = current_user.id - app.updated_at = datetime.now(UTC).replace(tzinfo=None) + app.updated_at = naive_utc_now() db.session.commit() return app @@ -300,7 +308,7 @@ class AppService: app.enable_api = enable_api app.updated_by = current_user.id - app.updated_at = datetime.now(UTC).replace(tzinfo=None) + app.updated_at = naive_utc_now() db.session.commit() return app @@ -374,7 +382,7 @@ class AppService: elif provider_type == "api": try: provider: Optional[ApiToolProvider] = ( - db.session.query(ApiToolProvider).filter(ApiToolProvider.id == provider_id).first() + db.session.query(ApiToolProvider).where(ApiToolProvider.id == provider_id).first() ) if provider is None: raise ValueError(f"provider not found for tool {tool_name}") @@ -391,7 +399,7 @@ class AppService: :param app_id: app id :return: app code """ - site = db.session.query(Site).filter(Site.app_id == app_id).first() + site = db.session.query(Site).where(Site.app_id == app_id).first() if not site: raise ValueError(f"App with id {app_id} not found") return str(site.code) @@ -403,7 +411,7 @@ class AppService: :param app_code: app code :return: app id """ - site = db.session.query(Site).filter(Site.code == app_code).first() + site = db.session.query(Site).where(Site.code == app_code).first() if not site: raise ValueError(f"App with code {app_code} not found") return str(site.app_id) diff --git a/api/services/audio_service.py b/api/services/audio_service.py index e8923eb51b..0084eebb32 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -135,7 +135,7 @@ class AudioService: uuid.UUID(message_id) except ValueError: return None - message = db.session.query(Message).filter(Message.id == message_id).first() + message = db.session.query(Message).where(Message.id == message_id).first() if message is None: return None if message.answer == "" and message.status == MessageStatus.NORMAL: diff --git a/api/services/auth/api_key_auth_service.py b/api/services/auth/api_key_auth_service.py index e5f4a3ef6e..996e9187f3 100644 --- a/api/services/auth/api_key_auth_service.py +++ b/api/services/auth/api_key_auth_service.py @@ -11,7 +11,7 @@ class ApiKeyAuthService: def get_provider_auth_list(tenant_id: str) -> list: data_source_api_key_bindings = ( db.session.query(DataSourceApiKeyAuthBinding) - .filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)) + .where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)) .all() ) return data_source_api_key_bindings @@ -36,7 +36,7 @@ class ApiKeyAuthService: def get_auth_credentials(tenant_id: str, category: str, provider: str): data_source_api_key_bindings = ( db.session.query(DataSourceApiKeyAuthBinding) - .filter( + .where( DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.category == category, DataSourceApiKeyAuthBinding.provider == provider, @@ -53,7 +53,7 @@ class ApiKeyAuthService: def delete_provider_auth(tenant_id: str, binding_id: str): data_source_api_key_binding = ( db.session.query(DataSourceApiKeyAuthBinding) - .filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id) + .where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id) .first() ) if data_source_api_key_binding: diff --git a/api/services/billing_service.py b/api/services/billing_service.py index d44483ad89..5a12aa2e54 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -75,14 +75,14 @@ class BillingService: join: Optional[TenantAccountJoin] = ( db.session.query(TenantAccountJoin) - .filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id) + .where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id) .first() ) if not join: raise ValueError("Tenant account join not found") - if not TenantAccountRole.is_privileged_role(join.role): + if not TenantAccountRole.is_privileged_role(TenantAccountRole(join.role)): raise ValueError("Only team owner or team admin can perform this action") @classmethod diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index 1fd560d581..ad9b750d40 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -6,7 +6,7 @@ from concurrent.futures import ThreadPoolExecutor import click from flask import Flask, current_app -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.model_runtime.utils.encoders import jsonable_encoder @@ -14,7 +14,7 @@ from extensions.ext_database import db from extensions.ext_storage import storage from models.account import Tenant from models.model import App, Conversation, Message -from models.workflow import WorkflowNodeExecutionModel, WorkflowRun +from repositories.factory import DifyAPIRepositoryFactory from services.billing_service import BillingService logger = logging.getLogger(__name__) @@ -24,13 +24,13 @@ class ClearFreePlanTenantExpiredLogs: @classmethod def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int): with flask_app.app_context(): - apps = db.session.query(App).filter(App.tenant_id == tenant_id).all() + apps = db.session.query(App).where(App.tenant_id == tenant_id).all() app_ids = [app.id for app in apps] while True: with Session(db.engine).no_autoflush as session: messages = ( session.query(Message) - .filter( + .where( Message.app_id.in_(app_ids), Message.created_at < datetime.datetime.now() - datetime.timedelta(days=days), ) @@ -54,7 +54,7 @@ class ClearFreePlanTenantExpiredLogs: message_ids = [message.id for message in messages] # delete messages - session.query(Message).filter( + session.query(Message).where( Message.id.in_(message_ids), ).delete(synchronize_session=False) @@ -70,7 +70,7 @@ class ClearFreePlanTenantExpiredLogs: with Session(db.engine).no_autoflush as session: conversations = ( session.query(Conversation) - .filter( + .where( Conversation.app_id.in_(app_ids), Conversation.updated_at < datetime.datetime.now() - datetime.timedelta(days=days), ) @@ -93,7 +93,7 @@ class ClearFreePlanTenantExpiredLogs: ) conversation_ids = [conversation.id for conversation in conversations] - session.query(Conversation).filter( + session.query(Conversation).where( Conversation.id.in_(conversation_ids), ).delete(synchronize_session=False) session.commit() @@ -105,84 +105,99 @@ class ClearFreePlanTenantExpiredLogs: ) ) - while True: - with Session(db.engine).no_autoflush as session: - workflow_node_executions = ( - session.query(WorkflowNodeExecutionModel) - .filter( - WorkflowNodeExecutionModel.tenant_id == tenant_id, - WorkflowNodeExecutionModel.created_at - < datetime.datetime.now() - datetime.timedelta(days=days), - ) - .limit(batch) - .all() - ) - - if len(workflow_node_executions) == 0: - break - - # save workflow node executions - storage.save( - f"free_plan_tenant_expired_logs/" - f"{tenant_id}/workflow_node_executions/{datetime.datetime.now().strftime('%Y-%m-%d')}" - f"-{time.time()}.json", - json.dumps( - jsonable_encoder(workflow_node_executions), - ).encode("utf-8"), - ) - - workflow_node_execution_ids = [ - workflow_node_execution.id for workflow_node_execution in workflow_node_executions - ] - - # delete workflow node executions - session.query(WorkflowNodeExecutionModel).filter( - WorkflowNodeExecutionModel.id.in_(workflow_node_execution_ids), - ).delete(synchronize_session=False) - session.commit() - - click.echo( - click.style( - f"[{datetime.datetime.now()}] Processed {len(workflow_node_execution_ids)}" - f" workflow node executions for tenant {tenant_id}" - ) - ) + # Process expired workflow node executions with backup + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker) + before_date = datetime.datetime.now() - datetime.timedelta(days=days) + total_deleted = 0 while True: - with Session(db.engine).no_autoflush as session: - workflow_runs = ( - session.query(WorkflowRun) - .filter( - WorkflowRun.tenant_id == tenant_id, - WorkflowRun.created_at < datetime.datetime.now() - datetime.timedelta(days=days), - ) - .limit(batch) - .all() + # Get a batch of expired executions for backup + workflow_node_executions = node_execution_repo.get_expired_executions_batch( + tenant_id=tenant_id, + before_date=before_date, + batch_size=batch, + ) + + if len(workflow_node_executions) == 0: + break + + # Save workflow node executions to storage + storage.save( + f"free_plan_tenant_expired_logs/" + f"{tenant_id}/workflow_node_executions/{datetime.datetime.now().strftime('%Y-%m-%d')}" + f"-{time.time()}.json", + json.dumps( + jsonable_encoder(workflow_node_executions), + ).encode("utf-8"), + ) + + # Extract IDs for deletion + workflow_node_execution_ids = [ + workflow_node_execution.id for workflow_node_execution in workflow_node_executions + ] + + # Delete the backed up executions + deleted_count = node_execution_repo.delete_executions_by_ids(workflow_node_execution_ids) + total_deleted += deleted_count + + click.echo( + click.style( + f"[{datetime.datetime.now()}] Processed {len(workflow_node_execution_ids)}" + f" workflow node executions for tenant {tenant_id}" ) + ) - if len(workflow_runs) == 0: - break + # If we got fewer than the batch size, we're done + if len(workflow_node_executions) < batch: + break - # save workflow runs + # Process expired workflow runs with backup + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + before_date = datetime.datetime.now() - datetime.timedelta(days=days) + total_deleted = 0 - storage.save( - f"free_plan_tenant_expired_logs/" - f"{tenant_id}/workflow_runs/{datetime.datetime.now().strftime('%Y-%m-%d')}" - f"-{time.time()}.json", - json.dumps( - jsonable_encoder( - [workflow_run.to_dict() for workflow_run in workflow_runs], - ), - ).encode("utf-8"), + while True: + # Get a batch of expired workflow runs for backup + workflow_runs = workflow_run_repo.get_expired_runs_batch( + tenant_id=tenant_id, + before_date=before_date, + batch_size=batch, + ) + + if len(workflow_runs) == 0: + break + + # Save workflow runs to storage + storage.save( + f"free_plan_tenant_expired_logs/" + f"{tenant_id}/workflow_runs/{datetime.datetime.now().strftime('%Y-%m-%d')}" + f"-{time.time()}.json", + json.dumps( + jsonable_encoder( + [workflow_run.to_dict() for workflow_run in workflow_runs], + ), + ).encode("utf-8"), + ) + + # Extract IDs for deletion + workflow_run_ids = [workflow_run.id for workflow_run in workflow_runs] + + # Delete the backed up workflow runs + deleted_count = workflow_run_repo.delete_runs_by_ids(workflow_run_ids) + total_deleted += deleted_count + + click.echo( + click.style( + f"[{datetime.datetime.now()}] Processed {len(workflow_run_ids)}" + f" workflow runs for tenant {tenant_id}" ) + ) - workflow_run_ids = [workflow_run.id for workflow_run in workflow_runs] - - # delete workflow runs - session.query(WorkflowRun).filter( - WorkflowRun.id.in_(workflow_run_ids), - ).delete(synchronize_session=False) - session.commit() + # If we got fewer than the batch size, we're done + if len(workflow_runs) < batch: + break @classmethod def process(cls, days: int, batch: int, tenant_ids: list[str]): @@ -261,7 +276,7 @@ class ClearFreePlanTenantExpiredLogs: for test_interval in test_intervals: tenant_count = ( session.query(Tenant.id) - .filter(Tenant.created_at.between(current_time, current_time + test_interval)) + .where(Tenant.created_at.between(current_time, current_time + test_interval)) .count() ) if tenant_count <= 100: @@ -286,7 +301,7 @@ class ClearFreePlanTenantExpiredLogs: rs = ( session.query(Tenant.id) - .filter(Tenant.created_at.between(current_time, batch_end)) + .where(Tenant.created_at.between(current_time, batch_end)) .order_by(Tenant.created_at) ) diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index afdaa49465..525c87fe4a 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,5 +1,4 @@ from collections.abc import Callable, Sequence -from datetime import UTC, datetime from typing import Optional, Union from sqlalchemy import asc, desc, func, or_, select @@ -8,6 +7,7 @@ from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from core.llm_generator.llm_generator import LLMGenerator from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import ConversationVariable from models.account import Account @@ -113,7 +113,7 @@ class ConversationService: return cls.auto_generate_name(app_model, conversation) else: conversation.name = name - conversation.updated_at = datetime.now(UTC).replace(tzinfo=None) + conversation.updated_at = naive_utc_now() db.session.commit() return conversation @@ -123,7 +123,7 @@ class ConversationService: # get conversation first message message = ( db.session.query(Message) - .filter(Message.app_id == app_model.id, Message.conversation_id == conversation.id) + .where(Message.app_id == app_model.id, Message.conversation_id == conversation.id) .order_by(Message.created_at.asc()) .first() ) @@ -148,7 +148,7 @@ class ConversationService: def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): conversation = ( db.session.query(Conversation) - .filter( + .where( Conversation.id == conversation_id, Conversation.app_id == app_model.id, Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"), @@ -169,7 +169,7 @@ class ConversationService: conversation = cls.get_conversation(app_model, conversation_id, user) conversation.is_deleted = True - conversation.updated_at = datetime.now(UTC).replace(tzinfo=None) + conversation.updated_at = naive_utc_now() db.session.commit() @classmethod diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index e42b5ace75..4872702a76 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -26,6 +26,7 @@ from events.document_event import document_was_deleted from extensions.ext_database import db from extensions.ext_redis import redis_client from libs import helper +from libs.datetime_utils import naive_utc_now from models.account import Account, TenantAccountRole from models.dataset import ( AppDatasetJoin, @@ -79,7 +80,7 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde class DatasetService: @staticmethod def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False): - query = select(Dataset).filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc()) + query = select(Dataset).where(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc()) if user: # get permitted dataset ids @@ -91,14 +92,14 @@ class DatasetService: if user.current_role == TenantAccountRole.DATASET_OPERATOR: # only show datasets that the user has permission to access if permitted_dataset_ids: - query = query.filter(Dataset.id.in_(permitted_dataset_ids)) + query = query.where(Dataset.id.in_(permitted_dataset_ids)) else: return [], 0 else: if user.current_role != TenantAccountRole.OWNER or not include_all: # show all datasets that the user has permission to access if permitted_dataset_ids: - query = query.filter( + query = query.where( db.or_( Dataset.permission == DatasetPermissionEnum.ALL_TEAM, db.and_( @@ -111,7 +112,7 @@ class DatasetService: ) ) else: - query = query.filter( + query = query.where( db.or_( Dataset.permission == DatasetPermissionEnum.ALL_TEAM, db.and_( @@ -121,15 +122,15 @@ class DatasetService: ) else: # if no user, only show datasets that are shared with all team members - query = query.filter(Dataset.permission == DatasetPermissionEnum.ALL_TEAM) + query = query.where(Dataset.permission == DatasetPermissionEnum.ALL_TEAM) if search: - query = query.filter(Dataset.name.ilike(f"%{search}%")) + query = query.where(Dataset.name.ilike(f"%{search}%")) if tag_ids: target_ids = TagService.get_target_ids_by_tag_ids("knowledge", tenant_id, tag_ids) if target_ids: - query = query.filter(Dataset.id.in_(target_ids)) + query = query.where(Dataset.id.in_(target_ids)) else: return [], 0 @@ -142,7 +143,7 @@ class DatasetService: # get the latest process rule dataset_process_rule = ( db.session.query(DatasetProcessRule) - .filter(DatasetProcessRule.dataset_id == dataset_id) + .where(DatasetProcessRule.dataset_id == dataset_id) .order_by(DatasetProcessRule.created_at.desc()) .limit(1) .one_or_none() @@ -157,7 +158,7 @@ class DatasetService: @staticmethod def get_datasets_by_ids(ids, tenant_id): - stmt = select(Dataset).filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id) + stmt = select(Dataset).where(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id) datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False) @@ -214,9 +215,9 @@ class DatasetService: dataset.created_by = account.id dataset.updated_by = account.id dataset.tenant_id = tenant_id - dataset.embedding_model_provider = embedding_model.provider if embedding_model else None - dataset.embedding_model = embedding_model.model if embedding_model else None - dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None + dataset.embedding_model_provider = embedding_model.provider if embedding_model else None # type: ignore + dataset.embedding_model = embedding_model.model if embedding_model else None # type: ignore + dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None # type: ignore dataset.permission = permission or DatasetPermissionEnum.ONLY_ME dataset.provider = provider db.session.add(dataset) @@ -428,7 +429,7 @@ class DatasetService: # Add metadata fields filtered_data["updated_by"] = user.id - filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + filtered_data["updated_at"] = naive_utc_now() # update Retrieval model filtered_data["retrieval_model"] = data["retrieval_model"] @@ -696,7 +697,7 @@ class DatasetService: def get_related_apps(dataset_id: str): return ( db.session.query(AppDatasetJoin) - .filter(AppDatasetJoin.dataset_id == dataset_id) + .where(AppDatasetJoin.dataset_id == dataset_id) .order_by(db.desc(AppDatasetJoin.created_at)) .all() ) @@ -713,7 +714,7 @@ class DatasetService: start_date = datetime.datetime.now() - datetime.timedelta(days=30) dataset_auto_disable_logs = ( db.session.query(DatasetAutoDisableLog) - .filter( + .where( DatasetAutoDisableLog.dataset_id == dataset_id, DatasetAutoDisableLog.created_at >= start_date, ) @@ -842,7 +843,7 @@ class DocumentService: def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]: if document_id: document = ( - db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) return document else: @@ -850,7 +851,7 @@ class DocumentService: @staticmethod def get_document_by_id(document_id: str) -> Optional[Document]: - document = db.session.query(Document).filter(Document.id == document_id).first() + document = db.session.query(Document).where(Document.id == document_id).first() return document @@ -858,7 +859,7 @@ class DocumentService: def get_document_by_ids(document_ids: list[str]) -> list[Document]: documents = ( db.session.query(Document) - .filter( + .where( Document.id.in_(document_ids), Document.enabled == True, Document.indexing_status == "completed", @@ -872,7 +873,7 @@ class DocumentService: def get_document_by_dataset_id(dataset_id: str) -> list[Document]: documents = ( db.session.query(Document) - .filter( + .where( Document.dataset_id == dataset_id, Document.enabled == True, ) @@ -885,7 +886,7 @@ class DocumentService: def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]: documents = ( db.session.query(Document) - .filter( + .where( Document.dataset_id == dataset_id, Document.enabled == True, Document.indexing_status == "completed", @@ -900,7 +901,7 @@ class DocumentService: def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]: documents = ( db.session.query(Document) - .filter(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"])) + .where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"])) .all() ) return documents @@ -909,7 +910,7 @@ class DocumentService: def get_batch_documents(dataset_id: str, batch: str) -> list[Document]: documents = ( db.session.query(Document) - .filter( + .where( Document.batch == batch, Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id, @@ -921,7 +922,7 @@ class DocumentService: @staticmethod def get_document_file_detail(file_id: str): - file_detail = db.session.query(UploadFile).filter(UploadFile.id == file_id).one_or_none() + file_detail = db.session.query(UploadFile).where(UploadFile.id == file_id).one_or_none() return file_detail @staticmethod @@ -949,7 +950,7 @@ class DocumentService: @staticmethod def delete_documents(dataset: Dataset, document_ids: list[str]): - documents = db.session.query(Document).filter(Document.id.in_(document_ids)).all() + documents = db.session.query(Document).where(Document.id.in_(document_ids)).all() file_ids = [ document.data_source_info_dict["upload_file_id"] for document in documents @@ -994,7 +995,7 @@ class DocumentService: # update document to be paused document.is_paused = True document.paused_by = current_user.id - document.paused_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.paused_at = naive_utc_now() db.session.add(document) db.session.commit() @@ -1188,7 +1189,7 @@ class DocumentService: for file_id in upload_file_list: file = ( db.session.query(UploadFile) - .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + .where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) .first() ) @@ -1269,7 +1270,7 @@ class DocumentService: workspace_id = notion_info.workspace_id data_source_binding = ( db.session.query(DataSourceOauthBinding) - .filter( + .where( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.provider == "notion", @@ -1412,7 +1413,7 @@ class DocumentService: def get_tenant_documents_count(): documents_count = ( db.session.query(Document) - .filter( + .where( Document.completed_at.isnot(None), Document.enabled == True, Document.archived == False, @@ -1468,7 +1469,7 @@ class DocumentService: for file_id in upload_file_list: file = ( db.session.query(UploadFile) - .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + .where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) .first() ) @@ -1488,7 +1489,7 @@ class DocumentService: workspace_id = notion_info.workspace_id data_source_binding = ( db.session.query(DataSourceOauthBinding) - .filter( + .where( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.provider == "notion", @@ -1539,8 +1540,10 @@ class DocumentService: db.session.add(document) db.session.commit() # update document segment - update_params = {DocumentSegment.status: "re_segment"} - db.session.query(DocumentSegment).filter_by(document_id=document.id).update(update_params) + + db.session.query(DocumentSegment).filter_by(document_id=document.id).update( + {DocumentSegment.status: "re_segment"} + ) # type: ignore db.session.commit() # trigger async task document_indexing_update_task.delay(document.dataset_id, document.id) @@ -2002,7 +2005,7 @@ class SegmentService: with redis_client.lock(lock_name, timeout=600): max_position = ( db.session.query(func.max(DocumentSegment.position)) - .filter(DocumentSegment.document_id == document.id) + .where(DocumentSegment.document_id == document.id) .scalar() ) segment_document = DocumentSegment( @@ -2040,7 +2043,7 @@ class SegmentService: segment_document.status = "error" segment_document.error = str(e) db.session.commit() - segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first() + segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first() return segment @classmethod @@ -2059,7 +2062,7 @@ class SegmentService: ) max_position = ( db.session.query(func.max(DocumentSegment.position)) - .filter(DocumentSegment.document_id == document.id) + .where(DocumentSegment.document_id == document.id) .scalar() ) pre_segment_data_list = [] @@ -2198,7 +2201,7 @@ class SegmentService: # get the process rule processing_rule = ( db.session.query(DatasetProcessRule) - .filter(DatasetProcessRule.id == document.dataset_process_rule_id) + .where(DatasetProcessRule.id == document.dataset_process_rule_id) .first() ) if not processing_rule: @@ -2225,7 +2228,7 @@ class SegmentService: # calc embedding use tokens if document.doc_form == "qa_model": segment.answer = args.answer - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0] + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0] # type: ignore else: tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0] segment.content = content @@ -2273,7 +2276,7 @@ class SegmentService: # get the process rule processing_rule = ( db.session.query(DatasetProcessRule) - .filter(DatasetProcessRule.id == document.dataset_process_rule_id) + .where(DatasetProcessRule.id == document.dataset_process_rule_id) .first() ) if not processing_rule: @@ -2292,7 +2295,7 @@ class SegmentService: segment.status = "error" segment.error = str(e) db.session.commit() - new_segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first() + new_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first() return new_segment @classmethod @@ -2318,7 +2321,7 @@ class SegmentService: index_node_ids = ( db.session.query(DocumentSegment) .with_entities(DocumentSegment.index_node_id) - .filter( + .where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, DocumentSegment.document_id == document.id, @@ -2329,7 +2332,7 @@ class SegmentService: index_node_ids = [index_node_id[0] for index_node_id in index_node_ids] delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id) - db.session.query(DocumentSegment).filter(DocumentSegment.id.in_(segment_ids)).delete() + db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).delete() db.session.commit() @classmethod @@ -2337,7 +2340,7 @@ class SegmentService: if action == "enable": segments = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, DocumentSegment.document_id == document.id, @@ -2364,7 +2367,7 @@ class SegmentService: elif action == "disable": segments = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, DocumentSegment.document_id == document.id, @@ -2401,7 +2404,7 @@ class SegmentService: index_node_hash = helper.generate_text_hash(content) child_chunk_count = ( db.session.query(ChildChunk) - .filter( + .where( ChildChunk.tenant_id == current_user.current_tenant_id, ChildChunk.dataset_id == dataset.id, ChildChunk.document_id == document.id, @@ -2411,7 +2414,7 @@ class SegmentService: ) max_position = ( db.session.query(func.max(ChildChunk.position)) - .filter( + .where( ChildChunk.tenant_id == current_user.current_tenant_id, ChildChunk.dataset_id == dataset.id, ChildChunk.document_id == document.id, @@ -2454,7 +2457,7 @@ class SegmentService: ) -> list[ChildChunk]: child_chunks = ( db.session.query(ChildChunk) - .filter( + .where( ChildChunk.dataset_id == dataset.id, ChildChunk.document_id == document.id, ChildChunk.segment_id == segment.id, @@ -2575,7 +2578,7 @@ class SegmentService: """Get a child chunk by its ID.""" result = ( db.session.query(ChildChunk) - .filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id) + .where(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id) .first() ) return result if isinstance(result, ChildChunk) else None @@ -2591,15 +2594,15 @@ class SegmentService: limit: int = 20, ): """Get segments for a document with optional filtering.""" - query = select(DocumentSegment).filter( + query = select(DocumentSegment).where( DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id ) if status_list: - query = query.filter(DocumentSegment.status.in_(status_list)) + query = query.where(DocumentSegment.status.in_(status_list)) if keyword: - query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%")) + query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) query = query.order_by(DocumentSegment.position.asc()) paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) @@ -2612,7 +2615,7 @@ class SegmentService: ) -> tuple[DocumentSegment, Document]: """Update a segment by its ID with validation and checks.""" # check dataset - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") @@ -2644,7 +2647,7 @@ class SegmentService: # check segment segment = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id) + .where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id) .first() ) if not segment: @@ -2661,7 +2664,7 @@ class SegmentService: """Get a segment by its ID.""" result = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id) + .where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id) .first() ) return result if isinstance(result, DocumentSegment) else None @@ -2674,7 +2677,7 @@ class DatasetCollectionBindingService: ) -> DatasetCollectionBinding: dataset_collection_binding = ( db.session.query(DatasetCollectionBinding) - .filter( + .where( DatasetCollectionBinding.provider_name == provider_name, DatasetCollectionBinding.model_name == model_name, DatasetCollectionBinding.type == collection_type, @@ -2700,7 +2703,7 @@ class DatasetCollectionBindingService: ) -> DatasetCollectionBinding: dataset_collection_binding = ( db.session.query(DatasetCollectionBinding) - .filter( + .where( DatasetCollectionBinding.id == collection_binding_id, DatasetCollectionBinding.type == collection_type ) .order_by(DatasetCollectionBinding.created_at) @@ -2719,7 +2722,7 @@ class DatasetPermissionService: db.session.query( DatasetPermission.account_id, ) - .filter(DatasetPermission.dataset_id == dataset_id) + .where(DatasetPermission.dataset_id == dataset_id) .all() ) @@ -2732,7 +2735,7 @@ class DatasetPermissionService: @classmethod def update_partial_member_list(cls, tenant_id, dataset_id, user_list): try: - db.session.query(DatasetPermission).filter(DatasetPermission.dataset_id == dataset_id).delete() + db.session.query(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id).delete() permissions = [] for user in user_list: permission = DatasetPermission( @@ -2768,7 +2771,7 @@ class DatasetPermissionService: @classmethod def clear_partial_member_list(cls, dataset_id): try: - db.session.query(DatasetPermission).filter(DatasetPermission.dataset_id == dataset_id).delete() + db.session.query(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id).delete() db.session.commit() except Exception as e: db.session.rollback() diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 8c06ee9386..54d45f45ea 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -29,7 +29,7 @@ class EnterpriseService: raise ValueError("No data found.") try: # parse the UTC timestamp from the response - return datetime.fromisoformat(data.replace("Z", "+00:00")) + return datetime.fromisoformat(data) except ValueError as e: raise ValueError(f"Invalid date format: {data}") from e @@ -40,7 +40,7 @@ class EnterpriseService: raise ValueError("No data found.") try: # parse the UTC timestamp from the response - return datetime.fromisoformat(data.replace("Z", "+00:00")) + return datetime.fromisoformat(data) except ValueError as e: raise ValueError(f"Invalid date format: {data}") from e diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 603064ca07..344c67885e 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -4,13 +4,6 @@ from typing import Literal, Optional from pydantic import BaseModel -class SegmentUpdateEntity(BaseModel): - content: str - answer: Optional[str] = None - keywords: Optional[list[str]] = None - enabled: Optional[bool] = None - - class ParentMode(StrEnum): FULL_DOC = "full-doc" PARAGRAPH = "paragraph" @@ -95,7 +88,7 @@ class WeightKeywordSetting(BaseModel): class WeightModel(BaseModel): - weight_type: Optional[str] = None + weight_type: Optional[Literal["semantic_first", "keyword_first", "customized"]] = None vector_setting: Optional[WeightVectorSetting] = None keyword_setting: Optional[WeightKeywordSetting] = None @@ -153,10 +146,6 @@ class MetadataUpdateArgs(BaseModel): value: Optional[str | int | float] = None -class MetadataValueUpdateArgs(BaseModel): - fields: list[MetadataUpdateArgs] - - class MetadataDetail(BaseModel): id: str name: str diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index eb50d79494..b7af03e91f 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -1,6 +1,5 @@ import json from copy import deepcopy -from datetime import UTC, datetime from typing import Any, Optional, Union, cast from urllib.parse import urlparse @@ -11,6 +10,7 @@ from constants import HIDDEN_VALUE from core.helper import ssrf_proxy from core.rag.entities.metadata_entities import MetadataCondition from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.dataset import ( Dataset, ExternalKnowledgeApis, @@ -30,11 +30,11 @@ class ExternalDatasetService: ) -> tuple[list[ExternalKnowledgeApis], int | None]: query = ( select(ExternalKnowledgeApis) - .filter(ExternalKnowledgeApis.tenant_id == tenant_id) + .where(ExternalKnowledgeApis.tenant_id == tenant_id) .order_by(ExternalKnowledgeApis.created_at.desc()) ) if search: - query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%")) + query = query.where(ExternalKnowledgeApis.name.ilike(f"%{search}%")) external_knowledge_apis = db.paginate( select=query, page=page, per_page=per_page, max_per_page=100, error_out=False @@ -120,7 +120,7 @@ class ExternalDatasetService: external_knowledge_api.description = args.get("description", "") external_knowledge_api.settings = json.dumps(args.get("settings"), ensure_ascii=False) external_knowledge_api.updated_by = user_id - external_knowledge_api.updated_at = datetime.now(UTC).replace(tzinfo=None) + external_knowledge_api.updated_at = naive_utc_now() db.session.commit() return external_knowledge_api diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 188caf3505..1441e6ce16 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -123,7 +123,7 @@ class FeatureModel(BaseModel): dataset_operator_enabled: bool = False webapp_copyright_enabled: bool = False workspace_members: LicenseLimitationModel = LicenseLimitationModel(enabled=False, size=0, limit=0) - + is_allow_transfer_workspace: bool = True # pydantic configs model_config = ConfigDict(protected_namespaces=()) @@ -149,6 +149,7 @@ class SystemFeatureModel(BaseModel): branding: BrandingModel = BrandingModel() webapp_auth: WebAppAuthModel = WebAppAuthModel() plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel() + enable_change_email: bool = True class FeatureService: @@ -186,6 +187,7 @@ class FeatureService: if dify_config.ENTERPRISE_ENABLED: system_features.branding.enabled = True system_features.webapp_auth.enabled = True + system_features.enable_change_email = False cls._fulfill_params_from_enterprise(system_features) if dify_config.MARKETPLACE_ENABLED: @@ -228,6 +230,8 @@ class FeatureService: if features.billing.subscription.plan != "sandbox": features.webapp_copyright_enabled = True + else: + features.is_allow_transfer_workspace = False if "members" in billing_info: features.members.size = billing_info["members"]["size"] diff --git a/api/services/file_service.py b/api/services/file_service.py index 2d68f30c5a..e234c2f325 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -18,6 +18,7 @@ from core.file import helpers as file_helpers from core.rag.extractor.extract_processor import ExtractProcessor from extensions.ext_database import db from extensions.ext_storage import storage +from libs.helper import extract_tenant_id from models.account import Account from models.enums import CreatorUserRole from models.model import EndUser, UploadFile @@ -61,11 +62,7 @@ class FileService: # generate file key file_uuid = str(uuid.uuid4()) - if isinstance(user, Account): - current_tenant_id = user.current_tenant_id - else: - # end_user - current_tenant_id = user.tenant_id + current_tenant_id = extract_tenant_id(user) file_key = "upload_files/" + (current_tenant_id or "") + "/" + file_uuid + "." + extension @@ -147,7 +144,7 @@ class FileService: @staticmethod def get_file_preview(file_id: str): - upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() + upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found") @@ -170,7 +167,7 @@ class FileService: if not result: raise NotFound("File not found or signature is invalid") - upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() + upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found or signature is invalid") @@ -190,7 +187,7 @@ class FileService: if not result: raise NotFound("File not found or signature is invalid") - upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() + upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found or signature is invalid") @@ -201,7 +198,7 @@ class FileService: @staticmethod def get_public_image_preview(file_id: str): - upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() + upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found or signature is invalid") diff --git a/api/services/message_service.py b/api/services/message_service.py index 51b070ece7..283b7b9b4b 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -50,7 +50,7 @@ class MessageService: if first_id: first_message = ( db.session.query(Message) - .filter(Message.conversation_id == conversation.id, Message.id == first_id) + .where(Message.conversation_id == conversation.id, Message.id == first_id) .first() ) @@ -59,7 +59,7 @@ class MessageService: history_messages = ( db.session.query(Message) - .filter( + .where( Message.conversation_id == conversation.id, Message.created_at < first_message.created_at, Message.id != first_message.id, @@ -71,7 +71,7 @@ class MessageService: else: history_messages = ( db.session.query(Message) - .filter(Message.conversation_id == conversation.id) + .where(Message.conversation_id == conversation.id) .order_by(Message.created_at.desc()) .limit(fetch_limit) .all() @@ -109,19 +109,19 @@ class MessageService: app_model=app_model, user=user, conversation_id=conversation_id ) - base_query = base_query.filter(Message.conversation_id == conversation.id) + base_query = base_query.where(Message.conversation_id == conversation.id) if include_ids is not None: - base_query = base_query.filter(Message.id.in_(include_ids)) + base_query = base_query.where(Message.id.in_(include_ids)) if last_id: - last_message = base_query.filter(Message.id == last_id).first() + last_message = base_query.where(Message.id == last_id).first() if not last_message: raise LastMessageNotExistsError() history_messages = ( - base_query.filter(Message.created_at < last_message.created_at, Message.id != last_message.id) + base_query.where(Message.created_at < last_message.created_at, Message.id != last_message.id) .order_by(Message.created_at.desc()) .limit(fetch_limit) .all() @@ -183,7 +183,7 @@ class MessageService: offset = (page - 1) * limit feedbacks = ( db.session.query(MessageFeedback) - .filter(MessageFeedback.app_id == app_model.id) + .where(MessageFeedback.app_id == app_model.id) .order_by(MessageFeedback.created_at.desc(), MessageFeedback.id.desc()) .limit(limit) .offset(offset) @@ -196,7 +196,7 @@ class MessageService: def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): message = ( db.session.query(Message) - .filter( + .where( Message.id == message_id, Message.app_id == app_model.id, Message.from_source == ("api" if isinstance(user, EndUser) else "console"), @@ -248,9 +248,7 @@ class MessageService: if not conversation.override_model_configs: app_model_config = ( db.session.query(AppModelConfig) - .filter( - AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id - ) + .where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id) .first() ) else: diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 26311a6377..a200cfa146 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -103,7 +103,7 @@ class ModelLoadBalancingService: # Get load balancing configurations load_balancing_configs = ( db.session.query(LoadBalancingModelConfig) - .filter( + .where( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), @@ -219,7 +219,7 @@ class ModelLoadBalancingService: # Get load balancing configurations load_balancing_model_config = ( db.session.query(LoadBalancingModelConfig) - .filter( + .where( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), @@ -307,7 +307,7 @@ class ModelLoadBalancingService: current_load_balancing_configs = ( db.session.query(LoadBalancingModelConfig) - .filter( + .where( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), @@ -457,7 +457,7 @@ class ModelLoadBalancingService: # Get load balancing config load_balancing_model_config = ( db.session.query(LoadBalancingModelConfig) - .filter( + .where( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider, LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), diff --git a/api/services/ops_service.py b/api/services/ops_service.py index c88accb9a5..62f37c1588 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -17,7 +17,7 @@ class OpsService: """ trace_config_data: Optional[TraceAppConfig] = ( db.session.query(TraceAppConfig) - .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() ) @@ -25,7 +25,7 @@ class OpsService: return None # decrypt_token and obfuscated_token - app = db.session.query(App).filter(App.id == app_id).first() + app = db.session.query(App).where(App.id == app_id).first() if not app: return None tenant_id = app.tenant_id @@ -94,6 +94,16 @@ class OpsService: new_decrypt_tracing_config.update({"project_url": project_url}) except Exception: new_decrypt_tracing_config.update({"project_url": "https://wandb.ai/"}) + + if tracing_provider == "aliyun" and ( + "project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url") + ): + try: + project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider) + new_decrypt_tracing_config.update({"project_url": project_url}) + except Exception: + new_decrypt_tracing_config.update({"project_url": "https://arms.console.aliyun.com/"}) + trace_config_data.tracing_config = new_decrypt_tracing_config return trace_config_data.to_dict() @@ -138,7 +148,7 @@ class OpsService: # check if trace config already exists trace_config_data: Optional[TraceAppConfig] = ( db.session.query(TraceAppConfig) - .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() ) @@ -146,7 +156,7 @@ class OpsService: return None # get tenant id - app = db.session.query(App).filter(App.id == app_id).first() + app = db.session.query(App).where(App.id == app_id).first() if not app: return None tenant_id = app.tenant_id @@ -180,7 +190,7 @@ class OpsService: # check if trace config already exists current_trace_config = ( db.session.query(TraceAppConfig) - .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() ) @@ -188,7 +198,7 @@ class OpsService: return None # get tenant id - app = db.session.query(App).filter(App.id == app_id).first() + app = db.session.query(App).where(App.id == app_id).first() if not app: return None tenant_id = app.tenant_id @@ -217,7 +227,7 @@ class OpsService: """ trace_config = ( db.session.query(TraceAppConfig) - .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() ) diff --git a/api/services/plugin/plugin_auto_upgrade_service.py b/api/services/plugin/plugin_auto_upgrade_service.py new file mode 100644 index 0000000000..3774050445 --- /dev/null +++ b/api/services/plugin/plugin_auto_upgrade_service.py @@ -0,0 +1,87 @@ +from sqlalchemy.orm import Session + +from extensions.ext_database import db +from models.account import TenantPluginAutoUpgradeStrategy + + +class PluginAutoUpgradeService: + @staticmethod + def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None: + with Session(db.engine) as session: + return ( + session.query(TenantPluginAutoUpgradeStrategy) + .filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) + .first() + ) + + @staticmethod + def change_strategy( + tenant_id: str, + strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting, + upgrade_time_of_day: int, + upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode, + exclude_plugins: list[str], + include_plugins: list[str], + ) -> bool: + with Session(db.engine) as session: + exist_strategy = ( + session.query(TenantPluginAutoUpgradeStrategy) + .filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) + .first() + ) + if not exist_strategy: + strategy = TenantPluginAutoUpgradeStrategy( + tenant_id=tenant_id, + strategy_setting=strategy_setting, + upgrade_time_of_day=upgrade_time_of_day, + upgrade_mode=upgrade_mode, + exclude_plugins=exclude_plugins, + include_plugins=include_plugins, + ) + session.add(strategy) + else: + exist_strategy.strategy_setting = strategy_setting + exist_strategy.upgrade_time_of_day = upgrade_time_of_day + exist_strategy.upgrade_mode = upgrade_mode + exist_strategy.exclude_plugins = exclude_plugins + exist_strategy.include_plugins = include_plugins + + session.commit() + return True + + @staticmethod + def exclude_plugin(tenant_id: str, plugin_id: str) -> bool: + with Session(db.engine) as session: + exist_strategy = ( + session.query(TenantPluginAutoUpgradeStrategy) + .filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) + .first() + ) + if not exist_strategy: + # create for this tenant + PluginAutoUpgradeService.change_strategy( + tenant_id, + TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY, + 0, + TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE, + [plugin_id], + [], + ) + return True + else: + if exist_strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE: + if plugin_id not in exist_strategy.exclude_plugins: + new_exclude_plugins = exist_strategy.exclude_plugins.copy() + new_exclude_plugins.append(plugin_id) + exist_strategy.exclude_plugins = new_exclude_plugins + elif exist_strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL: + if plugin_id in exist_strategy.include_plugins: + new_include_plugins = exist_strategy.include_plugins.copy() + new_include_plugins.remove(plugin_id) + exist_strategy.include_plugins = new_include_plugins + elif exist_strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL: + exist_strategy.upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE + exist_strategy.exclude_plugins = [plugin_id] + + session.commit() + return True diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index dbaaa7160e..1806fbcfd6 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -101,7 +101,7 @@ class PluginMigration: for test_interval in test_intervals: tenant_count = ( session.query(Tenant.id) - .filter(Tenant.created_at.between(current_time, current_time + test_interval)) + .where(Tenant.created_at.between(current_time, current_time + test_interval)) .count() ) if tenant_count <= 100: @@ -126,7 +126,7 @@ class PluginMigration: rs = ( session.query(Tenant.id) - .filter(Tenant.created_at.between(current_time, batch_end)) + .where(Tenant.created_at.between(current_time, batch_end)) .order_by(Tenant.created_at) ) @@ -212,7 +212,7 @@ class PluginMigration: Extract tool tables. """ with Session(db.engine) as session: - rs = session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() + rs = session.query(BuiltinToolProvider).where(BuiltinToolProvider.tenant_id == tenant_id).all() result = [] for row in rs: result.append(ToolProviderID(row.provider).plugin_id) @@ -226,7 +226,7 @@ class PluginMigration: """ with Session(db.engine) as session: - rs = session.query(Workflow).filter(Workflow.tenant_id == tenant_id).all() + rs = session.query(Workflow).where(Workflow.tenant_id == tenant_id).all() result = [] for row in rs: graph = row.graph_dict @@ -249,7 +249,7 @@ class PluginMigration: Extract app tables. """ with Session(db.engine) as session: - apps = session.query(App).filter(App.tenant_id == tenant_id).all() + apps = session.query(App).where(App.tenant_id == tenant_id).all() if not apps: return [] @@ -257,7 +257,7 @@ class PluginMigration: app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value ] - rs = session.query(AppModelConfig).filter(AppModelConfig.id.in_(agent_app_model_config_ids)).all() + rs = session.query(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids)).all() result = [] for row in rs: agent_config = row.agent_mode_dict diff --git a/api/services/plugin/plugin_parameter_service.py b/api/services/plugin/plugin_parameter_service.py index 393213c0e2..00b59dacb3 100644 --- a/api/services/plugin/plugin_parameter_service.py +++ b/api/services/plugin/plugin_parameter_service.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session from core.plugin.entities.parameters import PluginParameterOption from core.plugin.impl.dynamic_select import DynamicSelectClient from core.tools.tool_manager import ToolManager -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.encryption import create_tool_provider_encrypter from extensions.ext_database import db from models.tools import BuiltinToolProvider @@ -38,11 +38,9 @@ class PluginParameterService: case "tool": provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) # init tool configuration - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + controller=provider_controller, ) # check if credentials are required @@ -53,7 +51,7 @@ class PluginParameterService: with Session(db.engine) as session: db_record = ( session.query(BuiltinToolProvider) - .filter( + .where( BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider, ) @@ -63,7 +61,7 @@ class PluginParameterService: if db_record is None: raise ValueError(f"Builtin provider {provider} not found when fetching credentials") - credentials = tool_configuration.decrypt(db_record.credentials) + credentials = encrypter.decrypt(db_record.credentials) case _: raise ValueError(f"Invalid provider type: {provider_type}") diff --git a/api/services/plugin/plugin_permission_service.py b/api/services/plugin/plugin_permission_service.py index 275e496037..60fa269640 100644 --- a/api/services/plugin/plugin_permission_service.py +++ b/api/services/plugin/plugin_permission_service.py @@ -8,7 +8,7 @@ class PluginPermissionService: @staticmethod def get_permission(tenant_id: str) -> TenantPluginPermission | None: with Session(db.engine) as session: - return session.query(TenantPluginPermission).filter(TenantPluginPermission.tenant_id == tenant_id).first() + return session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first() @staticmethod def change_permission( @@ -18,7 +18,7 @@ class PluginPermissionService: ): with Session(db.engine) as session: permission = ( - session.query(TenantPluginPermission).filter(TenantPluginPermission.tenant_id == tenant_id).first() + session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first() ) if not permission: permission = TenantPluginPermission( diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py index d7fb4a7c1b..9005f0669b 100644 --- a/api/services/plugin/plugin_service.py +++ b/api/services/plugin/plugin_service.py @@ -38,6 +38,9 @@ class PluginService: plugin_id: str version: str unique_identifier: str + status: str + deprecated_reason: str + alternative_plugin_id: str REDIS_KEY_PREFIX = "plugin_service:latest_plugin:" REDIS_TTL = 60 * 5 # 5 minutes @@ -71,6 +74,9 @@ class PluginService: plugin_id=plugin_id, version=manifest.latest_version, unique_identifier=manifest.latest_package_identifier, + status=manifest.status, + deprecated_reason=manifest.deprecated_reason, + alternative_plugin_id=manifest.alternative_plugin_id, ) # Store in Redis @@ -196,6 +202,17 @@ class PluginService: manager = PluginInstaller() return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier) + @staticmethod + def is_plugin_verified(tenant_id: str, plugin_unique_identifier: str) -> bool: + """ + Check if the plugin is verified + """ + manager = PluginInstaller() + try: + return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier).verified + except Exception: + return False + @staticmethod def fetch_install_tasks(tenant_id: str, page: int, page_size: int) -> Sequence[PluginInstallTask]: """ @@ -427,6 +444,9 @@ class PluginService: manager = PluginInstaller() + # collect actual plugin_unique_identifiers + actual_plugin_unique_identifiers = [] + metas = [] features = FeatureService.get_system_features() # check if already downloaded @@ -437,6 +457,8 @@ class PluginService: # check if the plugin is available to install PluginService._check_plugin_installation_scope(plugin_decode_response.verification) # already downloaded, skip + actual_plugin_unique_identifiers.append(plugin_unique_identifier) + metas.append({"plugin_unique_identifier": plugin_unique_identifier}) except Exception: # plugin not installed, download and upload pkg pkg = download_plugin_pkg(plugin_unique_identifier) @@ -447,17 +469,15 @@ class PluginService: ) # check if the plugin is available to install PluginService._check_plugin_installation_scope(response.verification) + # use response plugin_unique_identifier + actual_plugin_unique_identifiers.append(response.unique_identifier) + metas.append({"plugin_unique_identifier": response.unique_identifier}) return manager.install_from_identifiers( tenant_id, - plugin_unique_identifiers, + actual_plugin_unique_identifiers, PluginInstallationSource.Marketplace, - [ - { - "plugin_unique_identifier": plugin_unique_identifier, - } - for plugin_unique_identifier in plugin_unique_identifiers - ], + metas, ) @staticmethod diff --git a/api/services/recommend_app/database/database_retrieval.py b/api/services/recommend_app/database/database_retrieval.py index 3295516cce..b97d13d012 100644 --- a/api/services/recommend_app/database/database_retrieval.py +++ b/api/services/recommend_app/database/database_retrieval.py @@ -33,14 +33,14 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): """ recommended_apps = ( db.session.query(RecommendedApp) - .filter(RecommendedApp.is_listed == True, RecommendedApp.language == language) + .where(RecommendedApp.is_listed == True, RecommendedApp.language == language) .all() ) if len(recommended_apps) == 0: recommended_apps = ( db.session.query(RecommendedApp) - .filter(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) + .where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) .all() ) @@ -83,7 +83,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): # is in public recommended list recommended_app = ( db.session.query(RecommendedApp) - .filter(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id) + .where(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id) .first() ) @@ -91,7 +91,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): return None # get app detail - app_model = db.session.query(App).filter(App.id == app_id).first() + app_model = db.session.query(App).where(App.id == app_id).first() if not app_model or not app_model.is_public: return None diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index 4cb8700117..641e03c3cf 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -17,7 +17,7 @@ class SavedMessageService: raise ValueError("User is required") saved_messages = ( db.session.query(SavedMessage) - .filter( + .where( SavedMessage.app_id == app_model.id, SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), SavedMessage.created_by == user.id, @@ -37,7 +37,7 @@ class SavedMessageService: return saved_message = ( db.session.query(SavedMessage) - .filter( + .where( SavedMessage.app_id == app_model.id, SavedMessage.message_id == message_id, SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), @@ -67,7 +67,7 @@ class SavedMessageService: return saved_message = ( db.session.query(SavedMessage) - .filter( + .where( SavedMessage.app_id == app_model.id, SavedMessage.message_id == message_id, SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), diff --git a/api/services/tag_service.py b/api/services/tag_service.py index 74c6150b44..75fa52a75c 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -16,10 +16,10 @@ class TagService: query = ( db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count")) .outerjoin(TagBinding, Tag.id == TagBinding.tag_id) - .filter(Tag.type == tag_type, Tag.tenant_id == current_tenant_id) + .where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id) ) if keyword: - query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%"))) + query = query.where(db.and_(Tag.name.ilike(f"%{keyword}%"))) query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at) results: list = query.order_by(Tag.created_at.desc()).all() return results @@ -28,7 +28,7 @@ class TagService: def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list: tags = ( db.session.query(Tag) - .filter(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type) + .where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type) .all() ) if not tags: @@ -36,7 +36,7 @@ class TagService: tag_ids = [tag.id for tag in tags] tag_bindings = ( db.session.query(TagBinding.target_id) - .filter(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id) + .where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id) .all() ) if not tag_bindings: @@ -50,7 +50,7 @@ class TagService: return [] tags = ( db.session.query(Tag) - .filter(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type) + .where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type) .all() ) if not tags: @@ -62,7 +62,7 @@ class TagService: tags = ( db.session.query(Tag) .join(TagBinding, Tag.id == TagBinding.tag_id) - .filter( + .where( TagBinding.target_id == target_id, TagBinding.tenant_id == current_tenant_id, Tag.tenant_id == current_tenant_id, @@ -92,7 +92,7 @@ class TagService: def update_tags(args: dict, tag_id: str) -> Tag: if TagService.get_tag_by_tag_name(args.get("type", ""), current_user.current_tenant_id, args.get("name", "")): raise ValueError("Tag name already exists") - tag = db.session.query(Tag).filter(Tag.id == tag_id).first() + tag = db.session.query(Tag).where(Tag.id == tag_id).first() if not tag: raise NotFound("Tag not found") tag.name = args["name"] @@ -101,17 +101,17 @@ class TagService: @staticmethod def get_tag_binding_count(tag_id: str) -> int: - count = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).count() + count = db.session.query(TagBinding).where(TagBinding.tag_id == tag_id).count() return count @staticmethod def delete_tag(tag_id: str): - tag = db.session.query(Tag).filter(Tag.id == tag_id).first() + tag = db.session.query(Tag).where(Tag.id == tag_id).first() if not tag: raise NotFound("Tag not found") db.session.delete(tag) # delete tag binding - tag_bindings = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).all() + tag_bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag_id).all() if tag_bindings: for tag_binding in tag_bindings: db.session.delete(tag_binding) @@ -125,7 +125,7 @@ class TagService: for tag_id in args["tag_ids"]: tag_binding = ( db.session.query(TagBinding) - .filter(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"]) + .where(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"]) .first() ) if tag_binding: @@ -146,7 +146,7 @@ class TagService: # delete tag binding tag_bindings = ( db.session.query(TagBinding) - .filter(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"])) + .where(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"])) .first() ) if tag_bindings: @@ -158,7 +158,7 @@ class TagService: if type == "knowledge": dataset = ( db.session.query(Dataset) - .filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id) + .where(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id) .first() ) if not dataset: @@ -166,7 +166,7 @@ class TagService: elif type == "app": app = ( db.session.query(App) - .filter(App.tenant_id == current_user.current_tenant_id, App.id == target_id) + .where(App.tenant_id == current_user.current_tenant_id, App.id == target_id) .first() ) if not app: diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 6f848d49c4..78e587abee 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -18,7 +18,7 @@ from core.tools.entities.tool_entities import ( ) from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.encryption import create_tool_provider_encrypter from core.tools.utils.parser import ApiBasedToolSchemaParser from extensions.ext_database import db from models.tools import ApiToolProvider @@ -119,7 +119,7 @@ class ApiToolManageService: # check if the provider exists provider = ( db.session.query(ApiToolProvider) - .filter( + .where( ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == provider_name, ) @@ -164,15 +164,11 @@ class ApiToolManageService: provider_controller.load_bundled_tools(tool_bundles) # encrypt credentials - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, - config=list(provider_controller.get_credentials_schema()), - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + controller=provider_controller, ) - - encrypted_credentials = tool_configuration.encrypt(credentials) - db_provider.credentials_str = json.dumps(encrypted_credentials) + db_provider.credentials_str = json.dumps(encrypter.encrypt(credentials)) db.session.add(db_provider) db.session.commit() @@ -214,7 +210,7 @@ class ApiToolManageService: """ provider: ApiToolProvider | None = ( db.session.query(ApiToolProvider) - .filter( + .where( ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == provider_name, ) @@ -261,7 +257,7 @@ class ApiToolManageService: # check if the provider exists provider = ( db.session.query(ApiToolProvider) - .filter( + .where( ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == original_provider, ) @@ -297,28 +293,26 @@ class ApiToolManageService: provider_controller.load_bundled_tools(tool_bundles) # get original credentials if exists - tool_configuration = ProviderConfigEncrypter( + encrypter, cache = create_tool_provider_encrypter( tenant_id=tenant_id, - config=list(provider_controller.get_credentials_schema()), - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + controller=provider_controller, ) - original_credentials = tool_configuration.decrypt(provider.credentials) - masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) + original_credentials = encrypter.decrypt(provider.credentials) + masked_credentials = encrypter.mask_tool_credentials(original_credentials) # check if the credential has changed, save the original credential for name, value in credentials.items(): if name in masked_credentials and value == masked_credentials[name]: credentials[name] = original_credentials[name] - credentials = tool_configuration.encrypt(credentials) + credentials = encrypter.encrypt(credentials) provider.credentials_str = json.dumps(credentials) db.session.add(provider) db.session.commit() # delete cache - tool_configuration.delete_tool_credentials_cache() + cache.delete() # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) @@ -332,7 +326,7 @@ class ApiToolManageService: """ provider = ( db.session.query(ApiToolProvider) - .filter( + .where( ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == provider_name, ) @@ -382,7 +376,7 @@ class ApiToolManageService: db_provider = ( db.session.query(ApiToolProvider) - .filter( + .where( ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == provider_name, ) @@ -416,15 +410,13 @@ class ApiToolManageService: # decrypt credentials if db_provider.id: - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, - config=list(provider_controller.get_credentials_schema()), - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + controller=provider_controller, ) - decrypted_credentials = tool_configuration.decrypt(credentials) + decrypted_credentials = encrypter.decrypt(credentials) # check if the credential has changed, save the original credential - masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) + masked_credentials = encrypter.mask_tool_credentials(decrypted_credentials) for name, value in credentials.items(): if name in masked_credentials and value == masked_credentials[name]: credentials[name] = decrypted_credentials[name] @@ -446,13 +438,13 @@ class ApiToolManageService: return {"result": result or "empty response"} @staticmethod - def list_api_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]: + def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]: """ list api tools """ # get all api providers db_providers: list[ApiToolProvider] = ( - db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or [] + db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() or [] ) result: list[ToolProviderApiEntity] = [] @@ -474,7 +466,7 @@ class ApiToolManageService: for tool in tools or []: user_provider.tools.append( ToolTransformService.convert_tool_entity_to_api_entity( - tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels + tenant_id=tenant_id, tool=tool, labels=labels ) ) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 58a4b2f179..65f05d2986 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -1,28 +1,85 @@ import json import logging +import re +from collections.abc import Mapping from pathlib import Path +from typing import Any, Optional from sqlalchemy.orm import Session from configs import dify_config +from constants import HIDDEN_VALUE, UNKNOWN_VALUE from core.helper.position_helper import is_filtered -from core.model_runtime.utils.encoders import jsonable_encoder +from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache from core.plugin.entities.plugin import ToolProviderID -from core.plugin.impl.exc import PluginDaemonClientSideError +from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort -from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity -from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError +from core.tools.entities.api_entities import ( + ToolApiEntity, + ToolProviderApiEntity, + ToolProviderCredentialApiEntity, + ToolProviderCredentialInfoApiEntity, +) +from core.tools.entities.tool_entities import CredentialType +from core.tools.errors import ToolProviderNotFoundError +from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.encryption import create_provider_encrypter +from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params from extensions.ext_database import db -from models.tools import BuiltinToolProvider +from extensions.ext_redis import redis_client +from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient +from services.plugin.plugin_service import PluginService from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) class BuiltinToolManageService: + __MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100 + __DEFAULT_EXPIRES_AT__ = 2147483647 + + @staticmethod + def delete_custom_oauth_client_params(tenant_id: str, provider: str): + """ + delete custom oauth client params + """ + tool_provider = ToolProviderID(provider) + with Session(db.engine) as session: + session.query(ToolOAuthTenantClient).filter_by( + tenant_id=tenant_id, + provider=tool_provider.provider_name, + plugin_id=tool_provider.plugin_id, + ).delete() + session.commit() + return {"result": "success"} + + @staticmethod + def get_builtin_tool_provider_oauth_client_schema(tenant_id: str, provider_name: str): + """ + get builtin tool provider oauth client schema + """ + provider = ToolManager.get_builtin_provider(provider_name, tenant_id) + verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified( + tenant_id, provider.plugin_unique_identifier + ) + + is_oauth_custom_client_enabled = BuiltinToolManageService.is_oauth_custom_client_enabled( + tenant_id, provider_name + ) + is_system_oauth_params_exists = verified and BuiltinToolManageService.is_oauth_system_client_exists( + provider_name + ) + result = { + "schema": provider.get_oauth_client_schema(), + "is_oauth_custom_client_enabled": is_oauth_custom_client_enabled, + "is_system_oauth_params_exists": is_system_oauth_params_exists, + "client_params": BuiltinToolManageService.get_custom_oauth_client_params(tenant_id, provider_name), + "redirect_uri": f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_name}/tool/callback", + } + return result + @staticmethod def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]: """ @@ -36,27 +93,11 @@ class BuiltinToolManageService: provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) tools = provider_controller.get_tools() - tool_provider_configurations = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) - # check if user has added the provider - builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id) - - credentials = {} - if builtin_provider is not None: - # get credentials - credentials = builtin_provider.credentials - credentials = tool_provider_configurations.decrypt(credentials) - result: list[ToolApiEntity] = [] for tool in tools or []: result.append( ToolTransformService.convert_tool_entity_to_api_entity( tool=tool, - credentials=credentials, tenant_id=tenant_id, labels=ToolLabelManager.get_tool_labels(provider_controller), ) @@ -65,25 +106,15 @@ class BuiltinToolManageService: return result @staticmethod - def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str): + def get_builtin_tool_provider_info(tenant_id: str, provider: str): """ get builtin tool provider info """ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) - tool_provider_configurations = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) # check if user has added the provider - builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id) - - credentials = {} - if builtin_provider is not None: - # get credentials - credentials = builtin_provider.credentials - credentials = tool_provider_configurations.decrypt(credentials) + builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id) + if builtin_provider is None: + raise ValueError(f"you have not added provider {provider}") entity = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider_controller, @@ -92,128 +123,411 @@ class BuiltinToolManageService: ) entity.original_credentials = {} - return entity @staticmethod - def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str): + def list_builtin_provider_credentials_schema(provider_name: str, credential_type: CredentialType, tenant_id: str): """ list builtin provider credentials schema + :param credential_type: credential type :param provider_name: the name of the provider :param tenant_id: the id of the tenant :return: the list of tool providers """ provider = ToolManager.get_builtin_provider(provider_name, tenant_id) - return jsonable_encoder(provider.get_credentials_schema()) + return provider.get_credentials_schema_by_type(credential_type) @staticmethod def update_builtin_tool_provider( - session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict + user_id: str, + tenant_id: str, + provider: str, + credential_id: str, + credentials: dict | None = None, + name: str | None = None, ): """ update builtin tool provider """ - # get if the provider exists - provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id) - - try: - # get provider - provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) - if not provider_controller.need_credentials: - raise ValueError(f"provider {provider_name} does not need credentials") - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + with Session(db.engine) as session: + # get if the provider exists + db_provider = ( + session.query(BuiltinToolProvider) + .where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.id == credential_id, + ) + .first() ) + if db_provider is None: + raise ValueError(f"you have not added provider {provider}") - # get original credentials if exists - if provider is not None: - original_credentials = tool_configuration.decrypt(provider.credentials) - masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) - # check if the credential has changed, save the original credential - for name, value in credentials.items(): - if name in masked_credentials and value == masked_credentials[name]: - credentials[name] = original_credentials[name] - # validate credentials - provider_controller.validate_credentials(user_id, credentials) - # encrypt credentials - credentials = tool_configuration.encrypt(credentials) - except ( - PluginDaemonClientSideError, - ToolProviderNotFoundError, - ToolNotFoundError, - ToolProviderCredentialValidationError, - ) as e: - raise ValueError(str(e)) + try: + if CredentialType.of(db_provider.credential_type).is_editable() and credentials: + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + if not provider_controller.need_credentials: + raise ValueError(f"provider {provider} does not need credentials") - if provider is None: - # create provider - provider = BuiltinToolProvider( - tenant_id=tenant_id, - user_id=user_id, - provider=provider_name, - encrypted_credentials=json.dumps(credentials), - ) + encrypter, cache = BuiltinToolManageService.create_tool_encrypter( + tenant_id, db_provider, provider, provider_controller + ) - db.session.add(provider) - else: - provider.encrypted_credentials = json.dumps(credentials) + original_credentials = encrypter.decrypt(db_provider.credentials) + new_credentials: dict = { + key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE) + for key, value in credentials.items() + } - # delete cache - tool_configuration.delete_tool_credentials_cache() + if CredentialType.of(db_provider.credential_type).is_validate_allowed(): + provider_controller.validate_credentials(user_id, new_credentials) - db.session.commit() + # encrypt credentials + db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials)) + + cache.delete() + + # update name if provided + if name and name != db_provider.name: + # check if the name is already used + if ( + session.query(BuiltinToolProvider) + .filter_by(tenant_id=tenant_id, provider=provider, name=name) + .count() + > 0 + ): + raise ValueError(f"the credential name '{name}' is already used") + + db_provider.name = name + + session.commit() + except Exception as e: + session.rollback() + raise ValueError(str(e)) return {"result": "success"} @staticmethod - def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str): + def add_builtin_tool_provider( + user_id: str, + api_type: CredentialType, + tenant_id: str, + provider: str, + credentials: dict, + expires_at: int = -1, + name: str | None = None, + ): + """ + add builtin tool provider + """ + try: + with Session(db.engine) as session: + lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}" + with redis_client.lock(lock, timeout=20): + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + if not provider_controller.need_credentials: + raise ValueError(f"provider {provider} does not need credentials") + + provider_count = ( + session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count() + ) + + # check if the provider count is reached the limit + if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__: + raise ValueError(f"you have reached the maximum number of providers for {provider}") + + # validate credentials if allowed + if CredentialType.of(api_type).is_validate_allowed(): + provider_controller.validate_credentials(user_id, credentials) + + # generate name if not provided + if name is None or name == "": + name = BuiltinToolManageService.generate_builtin_tool_provider_name( + session=session, tenant_id=tenant_id, provider=provider, credential_type=api_type + ) + else: + # check if the name is already used + if ( + session.query(BuiltinToolProvider) + .filter_by(tenant_id=tenant_id, provider=provider, name=name) + .count() + > 0 + ): + raise ValueError(f"the credential name '{name}' is already used") + + # create encrypter + encrypter, _ = create_provider_encrypter( + tenant_id=tenant_id, + config=[ + x.to_basic_provider_config() + for x in provider_controller.get_credentials_schema_by_type(api_type) + ], + cache=NoOpProviderCredentialCache(), + ) + + db_provider = BuiltinToolProvider( + tenant_id=tenant_id, + user_id=user_id, + provider=provider, + encrypted_credentials=json.dumps(encrypter.encrypt(credentials)), + credential_type=api_type.value, + name=name, + expires_at=expires_at + if expires_at is not None + else BuiltinToolManageService.__DEFAULT_EXPIRES_AT__, + ) + + session.add(db_provider) + session.commit() + except Exception as e: + session.rollback() + raise ValueError(str(e)) + return {"result": "success"} + + @staticmethod + def create_tool_encrypter( + tenant_id: str, + db_provider: BuiltinToolProvider, + provider: str, + provider_controller: BuiltinToolProviderController, + ): + encrypter, cache = create_provider_encrypter( + tenant_id=tenant_id, + config=[ + x.to_basic_provider_config() + for x in provider_controller.get_credentials_schema_by_type(db_provider.credential_type) + ], + cache=ToolProviderCredentialsCache(tenant_id=tenant_id, provider=provider, credential_id=db_provider.id), + ) + return encrypter, cache + + @staticmethod + def generate_builtin_tool_provider_name( + session: Session, tenant_id: str, provider: str, credential_type: CredentialType + ) -> str: + try: + db_providers = ( + session.query(BuiltinToolProvider) + .filter_by( + tenant_id=tenant_id, + provider=provider, + credential_type=credential_type.value, + ) + .order_by(BuiltinToolProvider.created_at.desc()) + .all() + ) + + # Get the default name pattern + default_pattern = f"{credential_type.get_name()}" + + # Find all names that match the default pattern: "{default_pattern} {number}" + pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$" + numbers = [] + + for db_provider in db_providers: + if db_provider.name: + match = re.match(pattern, db_provider.name.strip()) + if match: + numbers.append(int(match.group(1))) + + # If no default pattern names found, start with 1 + if not numbers: + return f"{default_pattern} 1" + + # Find the next number + max_number = max(numbers) + return f"{default_pattern} {max_number + 1}" + except Exception as e: + logger.warning(f"Error generating next provider name for {provider}: {str(e)}") + # fallback + return f"{credential_type.get_name()} 1" + + @staticmethod + def get_builtin_tool_provider_credentials( + tenant_id: str, provider_name: str + ) -> list[ToolProviderCredentialApiEntity]: """ get builtin tool provider credentials """ - provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id) + with db.session.no_autoflush: + providers = ( + db.session.query(BuiltinToolProvider) + .filter_by(tenant_id=tenant_id, provider=provider_name) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) + .all() + ) - if provider_obj is None: - return {} + if len(providers) == 0: + return [] - provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id) - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) - credentials = tool_configuration.decrypt(provider_obj.credentials) - credentials = tool_configuration.mask_tool_credentials(credentials) - return credentials + default_provider = providers[0] + default_provider.is_default = True + provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id) + + credentials: list[ToolProviderCredentialApiEntity] = [] + encrypters = {} + for provider in providers: + credential_type = provider.credential_type + if credential_type not in encrypters: + encrypters[credential_type] = BuiltinToolManageService.create_tool_encrypter( + tenant_id, provider, provider.provider, provider_controller + )[0] + encrypter = encrypters[credential_type] + decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials)) + credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity( + provider=provider, + credentials=decrypt_credential, + ) + credentials.append(credential_entity) + return credentials @staticmethod - def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str): + def get_builtin_tool_provider_credential_info(tenant_id: str, provider: str) -> ToolProviderCredentialInfoApiEntity: + """ + get builtin tool provider credential info + """ + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + supported_credential_types = provider_controller.get_supported_credential_types() + credentials = BuiltinToolManageService.get_builtin_tool_provider_credentials(tenant_id, provider) + credential_info = ToolProviderCredentialInfoApiEntity( + supported_credential_types=supported_credential_types, + is_oauth_custom_client_enabled=BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider), + credentials=credentials, + ) + + return credential_info + + @staticmethod + def delete_builtin_tool_provider(tenant_id: str, provider: str, credential_id: str): """ delete tool provider """ - provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id) + with Session(db.engine) as session: + db_provider = ( + session.query(BuiltinToolProvider) + .where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.id == credential_id, + ) + .first() + ) - if provider_obj is None: - raise ValueError(f"you have not added provider {provider_name}") + if db_provider is None: + raise ValueError(f"you have not added provider {provider}") - db.session.delete(provider_obj) - db.session.commit() + session.delete(db_provider) + session.commit() - # delete cache - provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) - tool_configuration.delete_tool_credentials_cache() + # delete cache + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + _, cache = BuiltinToolManageService.create_tool_encrypter( + tenant_id, db_provider, provider, provider_controller + ) + cache.delete() return {"result": "success"} + @staticmethod + def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str): + """ + set default provider + """ + with Session(db.engine) as session: + # get provider + target_provider = session.query(BuiltinToolProvider).filter_by(id=id).first() + if target_provider is None: + raise ValueError("provider not found") + + # clear default provider + session.query(BuiltinToolProvider).filter_by( + tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True + ).update({"is_default": False}) + + # set new default provider + target_provider.is_default = True + session.commit() + return {"result": "success"} + + @staticmethod + def is_oauth_system_client_exists(provider_name: str) -> bool: + """ + check if oauth system client exists + """ + tool_provider = ToolProviderID(provider_name) + with Session(db.engine).no_autoflush as session: + system_client: ToolOAuthSystemClient | None = ( + session.query(ToolOAuthSystemClient) + .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name) + .first() + ) + return system_client is not None + + @staticmethod + def is_oauth_custom_client_enabled(tenant_id: str, provider: str) -> bool: + """ + check if oauth custom client is enabled + """ + tool_provider = ToolProviderID(provider) + with Session(db.engine).no_autoflush as session: + user_client: ToolOAuthTenantClient | None = ( + session.query(ToolOAuthTenantClient) + .filter_by( + tenant_id=tenant_id, + provider=tool_provider.provider_name, + plugin_id=tool_provider.plugin_id, + enabled=True, + ) + .first() + ) + return user_client is not None and user_client.enabled + + @staticmethod + def get_oauth_client(tenant_id: str, provider: str) -> Mapping[str, Any] | None: + """ + get builtin tool provider + """ + tool_provider = ToolProviderID(provider) + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + encrypter, _ = create_provider_encrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], + cache=NoOpProviderCredentialCache(), + ) + with Session(db.engine).no_autoflush as session: + user_client: ToolOAuthTenantClient | None = ( + session.query(ToolOAuthTenantClient) + .filter_by( + tenant_id=tenant_id, + provider=tool_provider.provider_name, + plugin_id=tool_provider.plugin_id, + enabled=True, + ) + .first() + ) + oauth_params: Mapping[str, Any] | None = None + if user_client: + oauth_params = encrypter.decrypt(user_client.oauth_params) + return oauth_params + + # only verified provider can use custom oauth client + is_verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified( + tenant_id, provider.plugin_unique_identifier + ) + if not is_verified: + return oauth_params + + system_client: ToolOAuthSystemClient | None = ( + session.query(ToolOAuthSystemClient) + .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name) + .first() + ) + if system_client: + try: + oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params) + except Exception as e: + raise ValueError(f"Error decrypting system oauth params: {e}") + + return oauth_params + @staticmethod def get_builtin_tool_provider_icon(provider: str): """ @@ -234,9 +548,7 @@ class BuiltinToolManageService: with db.session.no_autoflush: # get all user added providers - db_providers: list[BuiltinToolProvider] = ( - db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or [] - ) + db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id) # rewrite db_providers for db_provider in db_providers: @@ -275,7 +587,6 @@ class BuiltinToolManageService: ToolTransformService.convert_tool_entity_to_api_entity( tenant_id=tenant_id, tool=tool, - credentials=user_builtin_provider.original_credentials, labels=ToolLabelManager.get_tool_labels(provider_controller), ) ) @@ -287,43 +598,153 @@ class BuiltinToolManageService: return BuiltinToolProviderSort.sort(result) @staticmethod - def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None: - try: - full_provider_name = provider_name - provider_id_entity = ToolProviderID(provider_name) - provider_name = provider_id_entity.provider_name - if provider_id_entity.organization != "langgenius": - provider_obj = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == full_provider_name, + def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]: + """ + This method is used to fetch the builtin provider from the database + 1.if the default provider exists, return the default provider + 2.if the default provider does not exist, return the oldest provider + """ + with Session(db.engine) as session: + try: + full_provider_name = provider_name + provider_id_entity = ToolProviderID(provider_name) + provider_name = provider_id_entity.provider_name + + if provider_id_entity.organization != "langgenius": + provider = ( + session.query(BuiltinToolProvider) + .where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == full_provider_name, + ) + .order_by( + BuiltinToolProvider.is_default.desc(), # default=True first + BuiltinToolProvider.created_at.asc(), # oldest first + ) + .first() ) - .first() - ) - else: - provider_obj = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - (BuiltinToolProvider.provider == provider_name) - | (BuiltinToolProvider.provider == full_provider_name), + else: + provider = ( + session.query(BuiltinToolProvider) + .where( + BuiltinToolProvider.tenant_id == tenant_id, + (BuiltinToolProvider.provider == provider_name) + | (BuiltinToolProvider.provider == full_provider_name), + ) + .order_by( + BuiltinToolProvider.is_default.desc(), # default=True first + BuiltinToolProvider.created_at.asc(), # oldest first + ) + .first() + ) + + if provider is None: + return None + + provider.provider = ToolProviderID(provider.provider).to_string() + return provider + except Exception: + # it's an old provider without organization + return ( + session.query(BuiltinToolProvider) + .where(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name) + .order_by( + BuiltinToolProvider.is_default.desc(), # default=True first + BuiltinToolProvider.created_at.asc(), # oldest first ) .first() ) - if provider_obj is None: - return None + @staticmethod + def save_custom_oauth_client_params( + tenant_id: str, + provider: str, + client_params: Optional[dict] = None, + enable_oauth_custom_client: Optional[bool] = None, + ): + """ + setup oauth custom client + """ + if client_params is None and enable_oauth_custom_client is None: + return {"result": "success"} - provider_obj.provider = ToolProviderID(provider_obj.provider).to_string() - return provider_obj - except Exception: - # it's an old provider without organization - return ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - (BuiltinToolProvider.provider == provider_name), + tool_provider = ToolProviderID(provider) + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + if not provider_controller: + raise ToolProviderNotFoundError(f"Provider {provider} not found") + + if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)): + raise ValueError(f"Provider {provider} is not a builtin or plugin provider") + + with Session(db.engine) as session: + custom_client_params = ( + session.query(ToolOAuthTenantClient) + .filter_by( + tenant_id=tenant_id, + plugin_id=tool_provider.plugin_id, + provider=tool_provider.provider_name, ) .first() ) + + # if the record does not exist, create a basic record + if custom_client_params is None: + custom_client_params = ToolOAuthTenantClient( + tenant_id=tenant_id, + plugin_id=tool_provider.plugin_id, + provider=tool_provider.provider_name, + ) + session.add(custom_client_params) + + if client_params is not None: + encrypter, _ = create_provider_encrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], + cache=NoOpProviderCredentialCache(), + ) + original_params = encrypter.decrypt(custom_client_params.oauth_params) + new_params: dict = { + key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE) + for key, value in client_params.items() + } + custom_client_params.encrypted_oauth_params = json.dumps(encrypter.encrypt(new_params)) + + if enable_oauth_custom_client is not None: + custom_client_params.enabled = enable_oauth_custom_client + + session.commit() + return {"result": "success"} + + @staticmethod + def get_custom_oauth_client_params(tenant_id: str, provider: str): + """ + get custom oauth client params + """ + with Session(db.engine) as session: + tool_provider = ToolProviderID(provider) + custom_oauth_client_params: ToolOAuthTenantClient | None = ( + session.query(ToolOAuthTenantClient) + .filter_by( + tenant_id=tenant_id, + plugin_id=tool_provider.plugin_id, + provider=tool_provider.provider_name, + ) + .first() + ) + if custom_oauth_client_params is None: + return {} + + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + if not provider_controller: + raise ToolProviderNotFoundError(f"Provider {provider} not found") + + if not isinstance(provider_controller, BuiltinToolProviderController): + raise ValueError(f"Provider {provider} is not a builtin or plugin provider") + + encrypter, _ = create_provider_encrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], + cache=NoOpProviderCredentialCache(), + ) + + return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params)) diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py new file mode 100644 index 0000000000..23be449a5a --- /dev/null +++ b/api/services/tools/mcp_tools_manage_service.py @@ -0,0 +1,252 @@ +import hashlib +import json +from datetime import datetime +from typing import Any + +from sqlalchemy import or_ +from sqlalchemy.exc import IntegrityError + +from core.helper import encrypter +from core.helper.provider_cache import NoOpProviderCredentialCache +from core.mcp.error import MCPAuthError, MCPError +from core.mcp.mcp_client import MCPClient +from core.tools.entities.api_entities import ToolProviderApiEntity +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolProviderType +from core.tools.mcp_tool.provider import MCPToolProviderController +from core.tools.utils.encryption import ProviderConfigEncrypter +from extensions.ext_database import db +from models.tools import MCPToolProvider +from services.tools.tools_transform_service import ToolTransformService + +UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]" + + +class MCPToolManageService: + """ + Service class for managing mcp tools. + """ + + @staticmethod + def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider: + res = ( + db.session.query(MCPToolProvider) + .where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id) + .first() + ) + if not res: + raise ValueError("MCP tool not found") + return res + + @staticmethod + def get_mcp_provider_by_server_identifier(server_identifier: str, tenant_id: str) -> MCPToolProvider: + res = ( + db.session.query(MCPToolProvider) + .where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier) + .first() + ) + if not res: + raise ValueError("MCP tool not found") + return res + + @staticmethod + def create_mcp_provider( + tenant_id: str, + name: str, + server_url: str, + user_id: str, + icon: str, + icon_type: str, + icon_background: str, + server_identifier: str, + ) -> ToolProviderApiEntity: + server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() + existing_provider = ( + db.session.query(MCPToolProvider) + .where( + MCPToolProvider.tenant_id == tenant_id, + or_( + MCPToolProvider.name == name, + MCPToolProvider.server_url_hash == server_url_hash, + MCPToolProvider.server_identifier == server_identifier, + ), + ) + .first() + ) + if existing_provider: + if existing_provider.name == name: + raise ValueError(f"MCP tool {name} already exists") + if existing_provider.server_url_hash == server_url_hash: + raise ValueError(f"MCP tool {server_url} already exists") + if existing_provider.server_identifier == server_identifier: + raise ValueError(f"MCP tool {server_identifier} already exists") + encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url) + mcp_tool = MCPToolProvider( + tenant_id=tenant_id, + name=name, + server_url=encrypted_server_url, + server_url_hash=server_url_hash, + user_id=user_id, + authed=False, + tools="[]", + icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon, + server_identifier=server_identifier, + ) + db.session.add(mcp_tool) + db.session.commit() + return ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True) + + @staticmethod + def retrieve_mcp_tools(tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]: + mcp_providers = ( + db.session.query(MCPToolProvider) + .where(MCPToolProvider.tenant_id == tenant_id) + .order_by(MCPToolProvider.name) + .all() + ) + return [ + ToolTransformService.mcp_provider_to_user_provider(mcp_provider, for_list=for_list) + for mcp_provider in mcp_providers + ] + + @classmethod + def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str) -> ToolProviderApiEntity: + mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) + server_url = mcp_provider.decrypted_server_url + authed = mcp_provider.authed + + try: + with MCPClient(server_url, provider_id, tenant_id, authed=authed, for_list=True) as mcp_client: + tools = mcp_client.list_tools() + except MCPAuthError: + raise ValueError("Please auth the tool first") + except MCPError as e: + raise ValueError(f"Failed to connect to MCP server: {e}") + + try: + mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) + mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools]) + mcp_provider.authed = True + mcp_provider.updated_at = datetime.now() + db.session.commit() + except Exception: + db.session.rollback() + raise + + user = mcp_provider.load_user() + return ToolProviderApiEntity( + id=mcp_provider.id, + name=mcp_provider.name, + tools=ToolTransformService.mcp_tool_to_user_tool(mcp_provider, tools), + type=ToolProviderType.MCP, + icon=mcp_provider.icon, + author=user.name if user else "Anonymous", + server_url=mcp_provider.masked_server_url, + updated_at=int(mcp_provider.updated_at.timestamp()), + description=I18nObject(en_US="", zh_Hans=""), + label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name), + plugin_unique_identifier=mcp_provider.server_identifier, + ) + + @classmethod + def delete_mcp_tool(cls, tenant_id: str, provider_id: str): + mcp_tool = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) + + db.session.delete(mcp_tool) + db.session.commit() + + @classmethod + def update_mcp_provider( + cls, + tenant_id: str, + provider_id: str, + name: str, + server_url: str, + icon: str, + icon_type: str, + icon_background: str, + server_identifier: str, + ): + mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) + + reconnect_result = None + encrypted_server_url = None + server_url_hash = None + + if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url: + encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url) + server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() + + if server_url_hash != mcp_provider.server_url_hash: + reconnect_result = cls._re_connect_mcp_provider(server_url, provider_id, tenant_id) + + try: + mcp_provider.updated_at = datetime.now() + mcp_provider.name = name + mcp_provider.icon = ( + json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon + ) + mcp_provider.server_identifier = server_identifier + + if encrypted_server_url is not None and server_url_hash is not None: + mcp_provider.server_url = encrypted_server_url + mcp_provider.server_url_hash = server_url_hash + + if reconnect_result: + mcp_provider.authed = reconnect_result["authed"] + mcp_provider.tools = reconnect_result["tools"] + mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"] + + db.session.commit() + except IntegrityError as e: + db.session.rollback() + error_msg = str(e.orig) + if "unique_mcp_provider_name" in error_msg: + raise ValueError(f"MCP tool {name} already exists") + if "unique_mcp_provider_server_url" in error_msg: + raise ValueError(f"MCP tool {server_url} already exists") + if "unique_mcp_provider_server_identifier" in error_msg: + raise ValueError(f"MCP tool {server_identifier} already exists") + raise + except Exception: + db.session.rollback() + raise + + @classmethod + def update_mcp_provider_credentials( + cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False + ): + provider_controller = MCPToolProviderController._from_db(mcp_provider) + tool_configuration = ProviderConfigEncrypter( + tenant_id=mcp_provider.tenant_id, + config=list(provider_controller.get_credentials_schema()), + provider_config_cache=NoOpProviderCredentialCache(), + ) + credentials = tool_configuration.encrypt(credentials) + mcp_provider.updated_at = datetime.now() + mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **credentials}) + mcp_provider.authed = authed + if not authed: + mcp_provider.tools = "[]" + db.session.commit() + + @classmethod + def _re_connect_mcp_provider(cls, server_url: str, provider_id: str, tenant_id: str): + try: + with MCPClient( + server_url, + provider_id, + tenant_id, + authed=False, + for_list=True, + ) as mcp_client: + tools = mcp_client.list_tools() + return { + "authed": True, + "tools": json.dumps([tool.model_dump() for tool in tools]), + "encrypted_credentials": "{}", + } + except MCPAuthError: + return {"authed": False, "tools": "[]", "encrypted_credentials": "{}"} + except MCPError as e: + raise ValueError(f"Failed to re-connect MCP server: {e}") from e diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 367121125b..2d192e6f7f 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -1,27 +1,30 @@ import json import logging -from typing import Optional, Union, cast +from typing import Any, Optional, Union, cast from yarl import URL from configs import dify_config +from core.helper.provider_cache import ToolProviderCredentialsCache +from core.mcp.types import Tool as MCPTool from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.custom_tool.provider import ApiToolProviderController -from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity +from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( ApiProviderAuthType, + CredentialType, ToolParameter, ToolProviderType, ) from core.tools.plugin_tool.provider import PluginToolProviderController -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool -from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider +from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider logger = logging.getLogger(__name__) @@ -52,7 +55,8 @@ class ToolTransformService: return icon except Exception: return {"background": "#252525", "content": "\ud83d\ude01"} - + elif provider_type == ToolProviderType.MCP.value: + return icon return "" @staticmethod @@ -73,10 +77,18 @@ class ToolTransformService: provider.icon = ToolTransformService.get_plugin_icon_url( tenant_id=tenant_id, filename=provider.icon ) + if isinstance(provider.icon_dark, str) and provider.icon_dark: + provider.icon_dark = ToolTransformService.get_plugin_icon_url( + tenant_id=tenant_id, filename=provider.icon_dark + ) else: provider.icon = ToolTransformService.get_tool_provider_icon_url( provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon ) + if provider.icon_dark: + provider.icon_dark = ToolTransformService.get_tool_provider_icon_url( + provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon_dark + ) @classmethod def builtin_provider_to_user_provider( @@ -94,6 +106,7 @@ class ToolTransformService: name=provider_controller.entity.identity.name, description=provider_controller.entity.identity.description, icon=provider_controller.entity.identity.icon, + icon_dark=provider_controller.entity.identity.icon_dark, label=provider_controller.entity.identity.label, type=ToolProviderType.BUILT_IN, masked_credentials={}, @@ -108,7 +121,12 @@ class ToolTransformService: result.plugin_unique_identifier = provider_controller.plugin_unique_identifier # get credentials schema - schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()} + schema = { + x.to_basic_provider_config().name: x + for x in provider_controller.get_credentials_schema_by_type( + CredentialType.of(db_provider.credential_type) if db_provider else CredentialType.API_KEY + ) + } for name, value in schema.items(): if result.masked_credentials: @@ -125,15 +143,23 @@ class ToolTransformService: credentials = db_provider.credentials # init tool configuration - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_provider_encrypter( tenant_id=db_provider.tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + config=[ + x.to_basic_provider_config() + for x in provider_controller.get_credentials_schema_by_type( + CredentialType.of(db_provider.credential_type) + ) + ], + cache=ToolProviderCredentialsCache( + tenant_id=db_provider.tenant_id, + provider=db_provider.provider, + credential_id=db_provider.id, + ), ) # decrypt the credentials and mask the credentials - decrypted_credentials = tool_configuration.decrypt(data=credentials) - masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials) + decrypted_credentials = encrypter.decrypt(data=credentials) + masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials) result.masked_credentials = masked_credentials result.original_credentials = decrypted_credentials @@ -148,11 +174,16 @@ class ToolTransformService: convert provider controller to user provider """ # package tool provider controller + auth_type = ApiProviderAuthType.NONE + credentials_auth_type = db_provider.credentials.get("auth_type") + if credentials_auth_type in ("api_key_header", "api_key"): # backward compatibility + auth_type = ApiProviderAuthType.API_KEY_HEADER + elif credentials_auth_type == "api_key_query": + auth_type = ApiProviderAuthType.API_KEY_QUERY + controller = ApiToolProviderController.from_db( db_provider=db_provider, - auth_type=ApiProviderAuthType.API_KEY - if db_provider.credentials["auth_type"] == "api_key" - else ApiProviderAuthType.NONE, + auth_type=auth_type, ) return controller @@ -177,6 +208,7 @@ class ToolTransformService: name=provider_controller.entity.identity.name, description=provider_controller.entity.identity.description, icon=provider_controller.entity.identity.icon, + icon_dark=provider_controller.entity.identity.icon_dark, label=provider_controller.entity.identity.label, type=ToolProviderType.WORKFLOW, masked_credentials={}, @@ -187,6 +219,41 @@ class ToolTransformService: labels=labels or [], ) + @staticmethod + def mcp_provider_to_user_provider(db_provider: MCPToolProvider, for_list: bool = False) -> ToolProviderApiEntity: + user = db_provider.load_user() + return ToolProviderApiEntity( + id=db_provider.server_identifier if not for_list else db_provider.id, + author=user.name if user else "Anonymous", + name=db_provider.name, + icon=db_provider.provider_icon, + type=ToolProviderType.MCP, + is_team_authorization=db_provider.authed, + server_url=db_provider.masked_server_url, + tools=ToolTransformService.mcp_tool_to_user_tool( + db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)] + ), + updated_at=int(db_provider.updated_at.timestamp()), + label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name), + description=I18nObject(en_US="", zh_Hans=""), + server_identifier=db_provider.server_identifier, + ) + + @staticmethod + def mcp_tool_to_user_tool(mcp_provider: MCPToolProvider, tools: list[MCPTool]) -> list[ToolApiEntity]: + user = mcp_provider.load_user() + return [ + ToolApiEntity( + author=user.name if user else "Anonymous", + name=tool.name, + label=I18nObject(en_US=tool.name, zh_Hans=tool.name), + description=I18nObject(en_US=tool.description, zh_Hans=tool.description), + parameters=ToolTransformService.convert_mcp_schema_to_parameter(tool.inputSchema), + labels=[], + ) + for tool in tools + ] + @classmethod def api_provider_to_user_provider( cls, @@ -235,16 +302,14 @@ class ToolTransformService: if decrypt_credentials: # init tool configuration - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=db_provider.tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + controller=provider_controller, ) # decrypt the credentials and mask the credentials - decrypted_credentials = tool_configuration.decrypt(data=credentials) - masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials) + decrypted_credentials = encrypter.decrypt(data=credentials) + masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials) result.masked_credentials = masked_credentials @@ -254,7 +319,6 @@ class ToolTransformService: def convert_tool_entity_to_api_entity( tool: Union[ApiToolBundle, WorkflowTool, Tool], tenant_id: str, - credentials: dict | None = None, labels: list[str] | None = None, ) -> ToolApiEntity: """ @@ -264,27 +328,39 @@ class ToolTransformService: # fork tool runtime tool = tool.fork_tool_runtime( runtime=ToolRuntime( - credentials=credentials or {}, + credentials={}, tenant_id=tenant_id, ) ) # get tool parameters - parameters = tool.entity.parameters or [] + base_parameters = tool.entity.parameters or [] # get tool runtime parameters runtime_parameters = tool.get_runtime_parameters() - # override parameters - current_parameters = parameters.copy() - for runtime_parameter in runtime_parameters: - found = False - for index, parameter in enumerate(current_parameters): - if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form: - current_parameters[index] = runtime_parameter - found = True - break - if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: - current_parameters.append(runtime_parameter) + # merge parameters using a functional approach to avoid type issues + merged_parameters: list[ToolParameter] = [] + + # create a mapping of runtime parameters for quick lookup + runtime_param_map = {(rp.name, rp.form): rp for rp in runtime_parameters} + + # process base parameters, replacing with runtime versions if they exist + for base_param in base_parameters: + key = (base_param.name, base_param.form) + if key in runtime_param_map: + merged_parameters.append(runtime_param_map[key]) + else: + merged_parameters.append(base_param) + + # add any runtime parameters that weren't in base parameters + for runtime_parameter in runtime_parameters: + if runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: + # check if this parameter is already in merged_parameters + already_exists = any( + p.name == runtime_parameter.name and p.form == runtime_parameter.form for p in merged_parameters + ) + if not already_exists: + merged_parameters.append(runtime_parameter) return ToolApiEntity( author=tool.entity.identity.author, @@ -292,10 +368,10 @@ class ToolTransformService: label=tool.entity.identity.label, description=tool.entity.description.human if tool.entity.description else I18nObject(en_US=""), output_schema=tool.entity.output_schema, - parameters=current_parameters, + parameters=merged_parameters, labels=labels or [], ) - if isinstance(tool, ApiToolBundle): + elif isinstance(tool, ApiToolBundle): return ToolApiEntity( author=tool.author, name=tool.operation_id or "", @@ -304,3 +380,69 @@ class ToolTransformService: parameters=tool.parameters, labels=labels or [], ) + else: + # Handle WorkflowTool case + raise ValueError(f"Unsupported tool type: {type(tool)}") + + @staticmethod + def convert_builtin_provider_to_credential_entity( + provider: BuiltinToolProvider, credentials: dict + ) -> ToolProviderCredentialApiEntity: + return ToolProviderCredentialApiEntity( + id=provider.id, + name=provider.name, + provider=provider.provider, + credential_type=CredentialType.of(provider.credential_type), + is_default=provider.is_default, + credentials=credentials, + ) + + @staticmethod + def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]: + """ + Convert MCP JSON schema to tool parameters + + :param schema: JSON schema dictionary + :return: list of ToolParameter instances + """ + + def create_parameter( + name: str, description: str, param_type: str, required: bool, input_schema: dict | None = None + ) -> ToolParameter: + """Create a ToolParameter instance with given attributes""" + input_schema_dict: dict[str, Any] = {"input_schema": input_schema} if input_schema else {} + return ToolParameter( + name=name, + llm_description=description, + label=I18nObject(en_US=name), + form=ToolParameter.ToolParameterForm.LLM, + required=required, + type=ToolParameter.ToolParameterType(param_type), + human_description=I18nObject(en_US=description), + **input_schema_dict, + ) + + def process_properties(props: dict, required: list, prefix: str = "") -> list[ToolParameter]: + """Process properties recursively""" + TYPE_MAPPING = {"integer": "number", "float": "number"} + COMPLEX_TYPES = ["array", "object"] + + parameters = [] + for name, prop in props.items(): + current_description = prop.get("description", "") + prop_type = prop.get("type", "string") + + if isinstance(prop_type, list): + prop_type = prop_type[0] + if prop_type in TYPE_MAPPING: + prop_type = TYPE_MAPPING[prop_type] + input_schema = prop if prop_type in COMPLEX_TYPES else None + parameters.append( + create_parameter(name, current_description, prop_type, name in required, input_schema) + ) + + return parameters + + if schema.get("type") == "object" and "properties" in schema: + return process_properties(schema["properties"], schema.get("required", [])) + return [] diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index c6b205557a..75da5e5eaa 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -43,7 +43,7 @@ class WorkflowToolManageService: # check if the name is unique existing_workflow_tool_provider = ( db.session.query(WorkflowToolProvider) - .filter( + .where( WorkflowToolProvider.tenant_id == tenant_id, # name or app_id or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id), @@ -54,7 +54,7 @@ class WorkflowToolManageService: if existing_workflow_tool_provider is not None: raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists") - app: App | None = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first() + app: App | None = db.session.query(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).first() if app is None: raise ValueError(f"App {workflow_app_id} not found") @@ -123,7 +123,7 @@ class WorkflowToolManageService: # check if the name is unique existing_workflow_tool_provider = ( db.session.query(WorkflowToolProvider) - .filter( + .where( WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.name == name, WorkflowToolProvider.id != workflow_tool_id, @@ -136,7 +136,7 @@ class WorkflowToolManageService: workflow_tool_provider: WorkflowToolProvider | None = ( db.session.query(WorkflowToolProvider) - .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .first() ) @@ -144,7 +144,7 @@ class WorkflowToolManageService: raise ValueError(f"Tool {workflow_tool_id} not found") app: App | None = ( - db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first() + db.session.query(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first() ) if app is None: @@ -186,7 +186,7 @@ class WorkflowToolManageService: :param tenant_id: the tenant id :return: the list of tools """ - db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() + db_tools = db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all() tools: list[WorkflowToolProviderController] = [] for provider in db_tools: @@ -224,7 +224,7 @@ class WorkflowToolManageService: :param tenant_id: the tenant id :param workflow_tool_id: the workflow tool id """ - db.session.query(WorkflowToolProvider).filter( + db.session.query(WorkflowToolProvider).where( WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id ).delete() @@ -243,7 +243,7 @@ class WorkflowToolManageService: """ db_tool: WorkflowToolProvider | None = ( db.session.query(WorkflowToolProvider) - .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .first() ) return cls._get_workflow_tool(tenant_id, db_tool) @@ -259,7 +259,7 @@ class WorkflowToolManageService: """ db_tool: WorkflowToolProvider | None = ( db.session.query(WorkflowToolProvider) - .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) .first() ) return cls._get_workflow_tool(tenant_id, db_tool) @@ -275,7 +275,7 @@ class WorkflowToolManageService: raise ValueError("Tool not found") workflow_app: App | None = ( - db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).first() + db.session.query(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).first() ) if workflow_app is None: @@ -318,7 +318,7 @@ class WorkflowToolManageService: """ db_tool: WorkflowToolProvider | None = ( db.session.query(WorkflowToolProvider) - .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .first() ) diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 9165139193..f9ec054593 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -36,7 +36,7 @@ class VectorService: # get the process rule processing_rule = ( db.session.query(DatasetProcessRule) - .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .first() ) if not processing_rule: diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index f698ed3084..c48e24f244 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -65,7 +65,7 @@ class WebConversationService: return pinned_conversation = ( db.session.query(PinnedConversation) - .filter( + .where( PinnedConversation.app_id == app_model.id, PinnedConversation.conversation_id == conversation_id, PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), @@ -97,7 +97,7 @@ class WebConversationService: return pinned_conversation = ( db.session.query(PinnedConversation) - .filter( + .where( PinnedConversation.app_id == app_model.id, PinnedConversation.conversation_id == conversation_id, PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index 8f92b3f070..a9df8d0d73 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -52,7 +52,7 @@ class WebAppAuthService: @classmethod def get_user_through_email(cls, email: str): - account = db.session.query(Account).filter(Account.email == email).first() + account = db.session.query(Account).where(Account.email == email).first() if not account: return None @@ -91,10 +91,10 @@ class WebAppAuthService: @classmethod def create_end_user(cls, app_code, email) -> EndUser: - site = db.session.query(Site).filter(Site.code == app_code).first() + site = db.session.query(Site).where(Site.code == app_code).first() if not site: raise NotFound("Site not found.") - app_model = db.session.query(App).filter(App.id == site.app_id).first() + app_model = db.session.query(App).where(App.id == site.app_id).first() if not app_model: raise NotFound("App not found.") end_user = EndUser( diff --git a/api/services/website_service.py b/api/services/website_service.py index 6720932a3a..991b669737 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -1,6 +1,7 @@ import datetime import json -from typing import Any +from dataclasses import dataclass +from typing import Any, Optional import requests from flask_login import current_user @@ -13,241 +14,392 @@ from extensions.ext_storage import storage from services.auth.api_key_auth_service import ApiKeyAuthService -class WebsiteService: - @classmethod - def document_create_args_validate(cls, args: dict): - if "url" not in args or not args["url"]: - raise ValueError("url is required") - if "options" not in args or not args["options"]: - raise ValueError("options is required") - if "limit" not in args["options"] or not args["options"]["limit"]: - raise ValueError("limit is required") +@dataclass +class CrawlOptions: + """Options for crawling operations.""" + + limit: int = 1 + crawl_sub_pages: bool = False + only_main_content: bool = False + includes: Optional[str] = None + excludes: Optional[str] = None + max_depth: Optional[int] = None + use_sitemap: bool = True + + def get_include_paths(self) -> list[str]: + """Get list of include paths from comma-separated string.""" + return self.includes.split(",") if self.includes else [] + + def get_exclude_paths(self) -> list[str]: + """Get list of exclude paths from comma-separated string.""" + return self.excludes.split(",") if self.excludes else [] + + +@dataclass +class CrawlRequest: + """Request container for crawling operations.""" + + url: str + provider: str + options: CrawlOptions + + +@dataclass +class ScrapeRequest: + """Request container for scraping operations.""" + + provider: str + url: str + tenant_id: str + only_main_content: bool + + +@dataclass +class WebsiteCrawlApiRequest: + """Request container for website crawl API arguments.""" + + provider: str + url: str + options: dict[str, Any] + + def to_crawl_request(self) -> CrawlRequest: + """Convert API request to internal CrawlRequest.""" + options = CrawlOptions( + limit=self.options.get("limit", 1), + crawl_sub_pages=self.options.get("crawl_sub_pages", False), + only_main_content=self.options.get("only_main_content", False), + includes=self.options.get("includes"), + excludes=self.options.get("excludes"), + max_depth=self.options.get("max_depth"), + use_sitemap=self.options.get("use_sitemap", True), + ) + return CrawlRequest(url=self.url, provider=self.provider, options=options) @classmethod - def crawl_url(cls, args: dict) -> dict: - provider = args.get("provider", "") + def from_args(cls, args: dict) -> "WebsiteCrawlApiRequest": + """Create from Flask-RESTful parsed arguments.""" + provider = args.get("provider") url = args.get("url") - options = args.get("options", "") - credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider) - if provider == "firecrawl": - # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") - ) - firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) - crawl_sub_pages = options.get("crawl_sub_pages", False) - only_main_content = options.get("only_main_content", False) - if not crawl_sub_pages: - params = { - "includePaths": [], - "excludePaths": [], - "limit": 1, - "scrapeOptions": {"onlyMainContent": only_main_content}, - } - else: - includes = options.get("includes").split(",") if options.get("includes") else [] - excludes = options.get("excludes").split(",") if options.get("excludes") else [] - params = { - "includePaths": includes, - "excludePaths": excludes, - "limit": options.get("limit", 1), - "scrapeOptions": {"onlyMainContent": only_main_content}, - } - if options.get("max_depth"): - params["maxDepth"] = options.get("max_depth") - job_id = firecrawl_app.crawl_url(url, params) - website_crawl_time_cache_key = f"website_crawl_{job_id}" - time = str(datetime.datetime.now().timestamp()) - redis_client.setex(website_crawl_time_cache_key, 3600, time) - return {"status": "active", "job_id": job_id} - elif provider == "watercrawl": - # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") - ) - return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).crawl_url(url, options) + options = args.get("options", {}) - elif provider == "jinareader": - api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") - ) - crawl_sub_pages = options.get("crawl_sub_pages", False) - if not crawl_sub_pages: - response = requests.get( - f"https://r.jina.ai/{url}", - headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, - ) - if response.json().get("code") != 200: - raise ValueError("Failed to crawl") - return {"status": "active", "data": response.json().get("data")} - else: - response = requests.post( - "https://adaptivecrawl-kir3wx7b3a-uc.a.run.app", - json={ - "url": url, - "maxPages": options.get("limit", 1), - "useSitemap": options.get("use_sitemap", True), - }, - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}", - }, - ) - if response.json().get("code") != 200: - raise ValueError("Failed to crawl") - return {"status": "active", "job_id": response.json().get("data", {}).get("taskId")} + if not provider: + raise ValueError("Provider is required") + if not url: + raise ValueError("URL is required") + if not options: + raise ValueError("Options are required") + + return cls(provider=provider, url=url, options=options) + + +@dataclass +class WebsiteCrawlStatusApiRequest: + """Request container for website crawl status API arguments.""" + + provider: str + job_id: str + + @classmethod + def from_args(cls, args: dict, job_id: str) -> "WebsiteCrawlStatusApiRequest": + """Create from Flask-RESTful parsed arguments.""" + provider = args.get("provider") + + if not provider: + raise ValueError("Provider is required") + if not job_id: + raise ValueError("Job ID is required") + + return cls(provider=provider, job_id=job_id) + + +class WebsiteService: + """Service class for website crawling operations using different providers.""" + + @classmethod + def _get_credentials_and_config(cls, tenant_id: str, provider: str) -> tuple[dict, dict]: + """Get and validate credentials for a provider.""" + credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) + if not credentials or "config" not in credentials: + raise ValueError("No valid credentials found for the provider") + return credentials, credentials["config"] + + @classmethod + def _get_decrypted_api_key(cls, tenant_id: str, config: dict) -> str: + """Decrypt and return the API key from config.""" + api_key = config.get("api_key") + if not api_key: + raise ValueError("API key not found in configuration") + return encrypter.decrypt_token(tenant_id=tenant_id, token=api_key) + + @classmethod + def document_create_args_validate(cls, args: dict) -> None: + """Validate arguments for document creation.""" + try: + WebsiteCrawlApiRequest.from_args(args) + except ValueError as e: + raise ValueError(f"Invalid arguments: {e}") + + @classmethod + def crawl_url(cls, api_request: WebsiteCrawlApiRequest) -> dict[str, Any]: + """Crawl a URL using the specified provider with typed request.""" + request = api_request.to_crawl_request() + + _, config = cls._get_credentials_and_config(current_user.current_tenant_id, request.provider) + api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config) + + if request.provider == "firecrawl": + return cls._crawl_with_firecrawl(request=request, api_key=api_key, config=config) + elif request.provider == "watercrawl": + return cls._crawl_with_watercrawl(request=request, api_key=api_key, config=config) + elif request.provider == "jinareader": + return cls._crawl_with_jinareader(request=request, api_key=api_key) else: raise ValueError("Invalid provider") @classmethod - def get_crawl_status(cls, job_id: str, provider: str) -> dict: - credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider) - if provider == "firecrawl": - # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") - ) - firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) - result = firecrawl_app.check_crawl_status(job_id) - crawl_status_data = { - "status": result.get("status", "active"), - "job_id": job_id, - "total": result.get("total", 0), - "current": result.get("current", 0), - "data": result.get("data", []), + def _crawl_with_firecrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]: + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url")) + + if not request.options.crawl_sub_pages: + params = { + "includePaths": [], + "excludePaths": [], + "limit": 1, + "scrapeOptions": {"onlyMainContent": request.options.only_main_content}, } - if crawl_status_data["status"] == "completed": - website_crawl_time_cache_key = f"website_crawl_{job_id}" - start_time = redis_client.get(website_crawl_time_cache_key) - if start_time: - end_time = datetime.datetime.now().timestamp() - time_consuming = abs(end_time - float(start_time)) - crawl_status_data["time_consuming"] = f"{time_consuming:.2f}" - redis_client.delete(website_crawl_time_cache_key) - elif provider == "watercrawl": - # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") + else: + params = { + "includePaths": request.options.get_include_paths(), + "excludePaths": request.options.get_exclude_paths(), + "limit": request.options.limit, + "scrapeOptions": {"onlyMainContent": request.options.only_main_content}, + } + if request.options.max_depth: + params["maxDepth"] = request.options.max_depth + + job_id = firecrawl_app.crawl_url(request.url, params) + website_crawl_time_cache_key = f"website_crawl_{job_id}" + time = str(datetime.datetime.now().timestamp()) + redis_client.setex(website_crawl_time_cache_key, 3600, time) + return {"status": "active", "job_id": job_id} + + @classmethod + def _crawl_with_watercrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]: + # Convert CrawlOptions back to dict format for WaterCrawlProvider + options = { + "limit": request.options.limit, + "crawl_sub_pages": request.options.crawl_sub_pages, + "only_main_content": request.options.only_main_content, + "includes": request.options.includes, + "excludes": request.options.excludes, + "max_depth": request.options.max_depth, + "use_sitemap": request.options.use_sitemap, + } + return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).crawl_url( + url=request.url, options=options + ) + + @classmethod + def _crawl_with_jinareader(cls, request: CrawlRequest, api_key: str) -> dict[str, Any]: + if not request.options.crawl_sub_pages: + response = requests.get( + f"https://r.jina.ai/{request.url}", + headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, ) - crawl_status_data = WaterCrawlProvider( - api_key, credentials.get("config").get("base_url", None) - ).get_crawl_status(job_id) - elif provider == "jinareader": - api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") + if response.json().get("code") != 200: + raise ValueError("Failed to crawl") + return {"status": "active", "data": response.json().get("data")} + else: + response = requests.post( + "https://adaptivecrawl-kir3wx7b3a-uc.a.run.app", + json={ + "url": request.url, + "maxPages": request.options.limit, + "useSitemap": request.options.use_sitemap, + }, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + }, ) + if response.json().get("code") != 200: + raise ValueError("Failed to crawl") + return {"status": "active", "job_id": response.json().get("data", {}).get("taskId")} + + @classmethod + def get_crawl_status(cls, job_id: str, provider: str) -> dict[str, Any]: + """Get crawl status using string parameters.""" + api_request = WebsiteCrawlStatusApiRequest(provider=provider, job_id=job_id) + return cls.get_crawl_status_typed(api_request) + + @classmethod + def get_crawl_status_typed(cls, api_request: WebsiteCrawlStatusApiRequest) -> dict[str, Any]: + """Get crawl status using typed request.""" + _, config = cls._get_credentials_and_config(current_user.current_tenant_id, api_request.provider) + api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config) + + if api_request.provider == "firecrawl": + return cls._get_firecrawl_status(api_request.job_id, api_key, config) + elif api_request.provider == "watercrawl": + return cls._get_watercrawl_status(api_request.job_id, api_key, config) + elif api_request.provider == "jinareader": + return cls._get_jinareader_status(api_request.job_id, api_key) + else: + raise ValueError("Invalid provider") + + @classmethod + def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]: + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url")) + result = firecrawl_app.check_crawl_status(job_id) + crawl_status_data = { + "status": result.get("status", "active"), + "job_id": job_id, + "total": result.get("total", 0), + "current": result.get("current", 0), + "data": result.get("data", []), + } + if crawl_status_data["status"] == "completed": + website_crawl_time_cache_key = f"website_crawl_{job_id}" + start_time = redis_client.get(website_crawl_time_cache_key) + if start_time: + end_time = datetime.datetime.now().timestamp() + time_consuming = abs(end_time - float(start_time)) + crawl_status_data["time_consuming"] = f"{time_consuming:.2f}" + redis_client.delete(website_crawl_time_cache_key) + return crawl_status_data + + @classmethod + def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]: + return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id) + + @classmethod + def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]: + response = requests.post( + "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", + headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, + json={"taskId": job_id}, + ) + data = response.json().get("data", {}) + crawl_status_data = { + "status": data.get("status", "active"), + "job_id": job_id, + "total": len(data.get("urls", [])), + "current": len(data.get("processed", [])) + len(data.get("failed", [])), + "data": [], + "time_consuming": data.get("duration", 0) / 1000, + } + + if crawl_status_data["status"] == "completed": response = requests.post( "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, - json={"taskId": job_id}, + json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())}, ) data = response.json().get("data", {}) - crawl_status_data = { - "status": data.get("status", "active"), - "job_id": job_id, - "total": len(data.get("urls", [])), - "current": len(data.get("processed", [])) + len(data.get("failed", [])), - "data": [], - "time_consuming": data.get("duration", 0) / 1000, - } - - if crawl_status_data["status"] == "completed": - response = requests.post( - "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", - headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, - json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())}, - ) - data = response.json().get("data", {}) - formatted_data = [ - { - "title": item.get("data", {}).get("title"), - "source_url": item.get("data", {}).get("url"), - "description": item.get("data", {}).get("description"), - "markdown": item.get("data", {}).get("content"), - } - for item in data.get("processed", {}).values() - ] - crawl_status_data["data"] = formatted_data - else: - raise ValueError("Invalid provider") + formatted_data = [ + { + "title": item.get("data", {}).get("title"), + "source_url": item.get("data", {}).get("url"), + "description": item.get("data", {}).get("description"), + "markdown": item.get("data", {}).get("content"), + } + for item in data.get("processed", {}).values() + ] + crawl_status_data["data"] = formatted_data return crawl_status_data @classmethod def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[str, Any] | None: - credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) - # decrypt api_key - api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) + _, config = cls._get_credentials_and_config(tenant_id, provider) + api_key = cls._get_decrypted_api_key(tenant_id, config) if provider == "firecrawl": - crawl_data: list[dict[str, Any]] | None = None - file_key = "website_files/" + job_id + ".txt" - if storage.exists(file_key): - stored_data = storage.load_once(file_key) - if stored_data: - crawl_data = json.loads(stored_data.decode("utf-8")) - else: - firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) - result = firecrawl_app.check_crawl_status(job_id) - if result.get("status") != "completed": - raise ValueError("Crawl job is not completed") - crawl_data = result.get("data") - - if crawl_data: - for item in crawl_data: - if item.get("source_url") == url: - return dict(item) - return None + return cls._get_firecrawl_url_data(job_id, url, api_key, config) elif provider == "watercrawl": - api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) - return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).get_crawl_url_data( - job_id, url - ) + return cls._get_watercrawl_url_data(job_id, url, api_key, config) elif provider == "jinareader": - if not job_id: - response = requests.get( - f"https://r.jina.ai/{url}", - headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, - ) - if response.json().get("code") != 200: - raise ValueError("Failed to crawl") - return dict(response.json().get("data", {})) - else: - # Get crawl status first - status_response = requests.post( - "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", - headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, - json={"taskId": job_id}, - ) - status_data = status_response.json().get("data", {}) - if status_data.get("status") != "completed": - raise ValueError("Crawl job is not completed") - - # Get processed data - data_response = requests.post( - "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", - headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, - json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())}, - ) - processed_data = data_response.json().get("data", {}) - for item in processed_data.get("processed", {}).values(): - if item.get("data", {}).get("url") == url: - return dict(item.get("data", {})) - return None + return cls._get_jinareader_url_data(job_id, url, api_key) else: raise ValueError("Invalid provider") @classmethod - def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict: - credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) - if provider == "firecrawl": - # decrypt api_key - api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) - firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) - params = {"onlyMainContent": only_main_content} - result = firecrawl_app.scrape_url(url, params) - return result - elif provider == "watercrawl": - api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) - return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).scrape_url(url) + def _get_firecrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None: + crawl_data: list[dict[str, Any]] | None = None + file_key = "website_files/" + job_id + ".txt" + if storage.exists(file_key): + stored_data = storage.load_once(file_key) + if stored_data: + crawl_data = json.loads(stored_data.decode("utf-8")) + else: + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url")) + result = firecrawl_app.check_crawl_status(job_id) + if result.get("status") != "completed": + raise ValueError("Crawl job is not completed") + crawl_data = result.get("data") + + if crawl_data: + for item in crawl_data: + if item.get("source_url") == url: + return dict(item) + return None + + @classmethod + def _get_watercrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None: + return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_url_data(job_id, url) + + @classmethod + def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None: + if not job_id: + response = requests.get( + f"https://r.jina.ai/{url}", + headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, + ) + if response.json().get("code") != 200: + raise ValueError("Failed to crawl") + return dict(response.json().get("data", {})) + else: + # Get crawl status first + status_response = requests.post( + "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", + headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, + json={"taskId": job_id}, + ) + status_data = status_response.json().get("data", {}) + if status_data.get("status") != "completed": + raise ValueError("Crawl job is not completed") + + # Get processed data + data_response = requests.post( + "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", + headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, + json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())}, + ) + processed_data = data_response.json().get("data", {}) + for item in processed_data.get("processed", {}).values(): + if item.get("data", {}).get("url") == url: + return dict(item.get("data", {})) + return None + + @classmethod + def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict[str, Any]: + request = ScrapeRequest(provider=provider, url=url, tenant_id=tenant_id, only_main_content=only_main_content) + + _, config = cls._get_credentials_and_config(tenant_id=request.tenant_id, provider=request.provider) + api_key = cls._get_decrypted_api_key(tenant_id=request.tenant_id, config=config) + + if request.provider == "firecrawl": + return cls._scrape_with_firecrawl(request=request, api_key=api_key, config=config) + elif request.provider == "watercrawl": + return cls._scrape_with_watercrawl(request=request, api_key=api_key, config=config) else: raise ValueError("Invalid provider") + + @classmethod + def _scrape_with_firecrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]: + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url")) + params = {"onlyMainContent": request.only_main_content} + return firecrawl_app.scrape_url(url=request.url, params=params) + + @classmethod + def _scrape_with_watercrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]: + return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).scrape_url(request.url) diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 2b0d57bdfd..abf6824d73 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -620,7 +620,7 @@ class WorkflowConverter: """ api_based_extension = ( db.session.query(APIBasedExtension) - .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) + .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) .first() ) diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 44fd72b5e4..3164e010b4 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -5,9 +5,9 @@ from collections.abc import Mapping, Sequence from enum import StrEnum from typing import Any, ClassVar -from sqlalchemy import Engine, orm, select +from sqlalchemy import Engine, orm from sqlalchemy.dialects.postgresql import insert -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.sql.expression import and_, or_ from core.app.entities.app_invoke_entities import InvokeFrom @@ -25,7 +25,8 @@ from factories.file_factory import StorageKeyLoader from factories.variable_factory import build_segment, segment_to_variable from models import App, Conversation from models.enums import DraftVariableType -from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable +from models.workflow import Workflow, WorkflowDraftVariable, is_system_variable_editable +from repositories.factory import DifyAPIRepositoryFactory _logger = logging.getLogger(__name__) @@ -117,10 +118,27 @@ class WorkflowDraftVariableService: _session: Session def __init__(self, session: Session) -> None: + """ + Initialize the WorkflowDraftVariableService with a SQLAlchemy session. + + Args: + session (Session): The SQLAlchemy session used to execute database queries. + The provided session must be bound to an `Engine` object, not a specific `Connection`. + + Raises: + AssertionError: If the provided session is not bound to an `Engine` object. + """ self._session = session + engine = session.get_bind() + # Ensure the session is bound to a engine. + assert isinstance(engine, Engine) + session_maker = sessionmaker(bind=engine, expire_on_commit=False) + self._api_node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + session_maker + ) def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None: - return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first() + return self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable_id).first() def get_draft_variables_by_selectors( self, @@ -148,7 +166,7 @@ class WorkflowDraftVariableService: def list_variables_without_values(self, app_id: str, page: int, limit: int) -> WorkflowDraftVariableList: criteria = WorkflowDraftVariable.app_id == app_id total = None - query = self._session.query(WorkflowDraftVariable).filter(criteria) + query = self._session.query(WorkflowDraftVariable).where(criteria) if page == 1: total = query.count() variables = ( @@ -167,7 +185,7 @@ class WorkflowDraftVariableService: WorkflowDraftVariable.app_id == app_id, WorkflowDraftVariable.node_id == node_id, ) - query = self._session.query(WorkflowDraftVariable).filter(*criteria) + query = self._session.query(WorkflowDraftVariable).where(*criteria) variables = query.order_by(WorkflowDraftVariable.created_at.desc()).all() return WorkflowDraftVariableList(variables=variables) @@ -248,8 +266,7 @@ class WorkflowDraftVariableService: _logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name) return None - query = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == variable.node_execution_id) - node_exec = self._session.scalars(query).first() + node_exec = self._api_node_execution_repo.get_execution_by_id(variable.node_execution_id) if node_exec is None: _logger.warning( "Node exectution not found for draft variable, id=%s, name=%s, node_execution_id=%s", @@ -298,6 +315,8 @@ class WorkflowDraftVariableService: def reset_variable(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None: variable_type = variable.get_variable_type() + if variable_type == DraftVariableType.SYS and not is_system_variable_editable(variable.name): + raise VariableResetError(f"cannot reset system variable, variable_id={variable.id}") if variable_type == DraftVariableType.CONVERSATION: return self._reset_conv_var(workflow, variable) else: @@ -309,7 +328,7 @@ class WorkflowDraftVariableService: def delete_workflow_variables(self, app_id: str): ( self._session.query(WorkflowDraftVariable) - .filter(WorkflowDraftVariable.app_id == app_id) + .where(WorkflowDraftVariable.app_id == app_id) .delete(synchronize_session=False) ) @@ -360,7 +379,7 @@ class WorkflowDraftVariableService: if conv_id is not None: conversation = ( self._session.query(Conversation) - .filter( + .where( Conversation.id == conv_id, Conversation.app_id == workflow.app_id, ) diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 483c0d3086..e43999a8c9 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -2,9 +2,9 @@ import threading from collections.abc import Sequence from typing import Optional +from sqlalchemy.orm import sessionmaker + import contexts -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import OrderConfig from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import ( @@ -15,10 +15,18 @@ from models import ( WorkflowRun, WorkflowRunTriggeredFrom, ) -from models.workflow import WorkflowNodeExecutionTriggeredFrom +from repositories.factory import DifyAPIRepositoryFactory class WorkflowRunService: + def __init__(self): + """Initialize WorkflowRunService with repository dependencies.""" + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + session_maker + ) + self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + def get_paginate_advanced_chat_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination: """ Get advanced chat app workflow run list @@ -62,45 +70,16 @@ class WorkflowRunService: :param args: request args """ limit = int(args.get("limit", 20)) + last_id = args.get("last_id") - base_query = db.session.query(WorkflowRun).filter( - WorkflowRun.tenant_id == app_model.tenant_id, - WorkflowRun.app_id == app_model.id, - WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value, + return self._workflow_run_repo.get_paginated_workflow_runs( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING.value, + limit=limit, + last_id=last_id, ) - if args.get("last_id"): - last_workflow_run = base_query.filter( - WorkflowRun.id == args.get("last_id"), - ).first() - - if not last_workflow_run: - raise ValueError("Last workflow run not exists") - - workflow_runs = ( - base_query.filter( - WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id - ) - .order_by(WorkflowRun.created_at.desc()) - .limit(limit) - .all() - ) - else: - workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all() - - has_more = False - if len(workflow_runs) == limit: - current_page_first_workflow_run = workflow_runs[-1] - rest_count = base_query.filter( - WorkflowRun.created_at < current_page_first_workflow_run.created_at, - WorkflowRun.id != current_page_first_workflow_run.id, - ).count() - - if rest_count > 0: - has_more = True - - return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) - def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]: """ Get workflow run detail @@ -108,18 +87,12 @@ class WorkflowRunService: :param app_model: app model :param run_id: workflow run id """ - workflow_run = ( - db.session.query(WorkflowRun) - .filter( - WorkflowRun.tenant_id == app_model.tenant_id, - WorkflowRun.app_id == app_model.id, - WorkflowRun.id == run_id, - ) - .first() + return self._workflow_run_repo.get_workflow_run_by_id( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + run_id=run_id, ) - return workflow_run - def get_workflow_run_node_executions( self, app_model: App, @@ -137,17 +110,13 @@ class WorkflowRunService: if not workflow_run: return [] - repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=db.engine, - user=user, + # Get tenant_id from user + tenant_id = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id + if tenant_id is None: + raise ValueError("User tenant_id cannot be None") + + return self._node_execution_service_repo.get_executions_by_workflow_run( + tenant_id=tenant_id, app_id=app_model.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + workflow_run_id=run_id, ) - - # Use the repository to get the database models directly - order_config = OrderConfig(order_by=["index"], order_direction="desc") - workflow_node_executions = repository.get_db_models_by_workflow_run( - workflow_run_id=run_id, order_config=order_config - ) - - return workflow_node_executions diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 2be57fd51c..e9f21fc5f1 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -2,23 +2,22 @@ import json import time import uuid from collections.abc import Callable, Generator, Mapping, Sequence -from datetime import UTC, datetime -from typing import Any, Optional +from typing import Any, Optional, cast from uuid import uuid4 from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from core.app.app_config.entities import VariableEntityType from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.file import File -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.variables import Variable +from core.variables.variables import VariableUnion from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes import NodeType @@ -28,10 +27,12 @@ from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.event.types import NodeEvent from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings +from libs.datetime_utils import naive_utc_now from models.account import Account from models.model import App, AppMode from models.tools import WorkflowToolProvider @@ -41,6 +42,7 @@ from models.workflow import ( WorkflowNodeExecutionTriggeredFrom, WorkflowType, ) +from repositories.factory import DifyAPIRepositoryFactory from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError from services.workflow.workflow_converter import WorkflowConverter @@ -57,26 +59,37 @@ class WorkflowService: Workflow Service """ + def __init__(self, session_maker: sessionmaker | None = None): + """Initialize WorkflowService with repository dependencies.""" + if session_maker is None: + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + session_maker + ) + def get_node_last_run(self, app_model: App, workflow: Workflow, node_id: str) -> WorkflowNodeExecutionModel | None: - # TODO(QuantumGhost): This query is not fully covered by index. - criteria = ( - WorkflowNodeExecutionModel.tenant_id == app_model.tenant_id, - WorkflowNodeExecutionModel.app_id == app_model.id, - WorkflowNodeExecutionModel.workflow_id == workflow.id, - WorkflowNodeExecutionModel.node_id == node_id, + """ + Get the most recent execution for a specific node. + + Args: + app_model: The application model + workflow: The workflow model + node_id: The node identifier + + Returns: + The most recent WorkflowNodeExecutionModel for the node, or None if not found + """ + return self._node_execution_service_repo.get_node_last_execution( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + workflow_id=workflow.id, + node_id=node_id, ) - node_exec = ( - db.session.query(WorkflowNodeExecutionModel) - .filter(*criteria) - .order_by(WorkflowNodeExecutionModel.created_at.desc()) - .first() - ) - return node_exec def is_workflow_exist(self, app_model: App) -> bool: return ( db.session.query(Workflow) - .filter( + .where( Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == Workflow.VERSION_DRAFT, @@ -91,7 +104,7 @@ class WorkflowService: # fetch draft workflow by app_model workflow = ( db.session.query(Workflow) - .filter( + .where( Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == "draft" ) .first() @@ -104,7 +117,7 @@ class WorkflowService: # fetch published workflow by workflow_id workflow = ( db.session.query(Workflow) - .filter( + .where( Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id, @@ -128,7 +141,7 @@ class WorkflowService: # fetch published workflow by workflow_id workflow = ( db.session.query(Workflow) - .filter( + .where( Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == app_model.workflow_id, @@ -219,7 +232,7 @@ class WorkflowService: workflow.graph = json.dumps(graph) workflow.features = json.dumps(features) workflow.updated_by = account.id - workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) + workflow.updated_at = naive_utc_now() workflow.environment_variables = environment_variables workflow.conversation_variables = conversation_variables @@ -255,7 +268,7 @@ class WorkflowService: tenant_id=app_model.tenant_id, app_id=app_model.id, type=draft_workflow.type, - version=Workflow.version_from_datetime(datetime.now(UTC).replace(tzinfo=None)), + version=Workflow.version_from_datetime(naive_utc_now()), graph=draft_workflow.graph, features=draft_workflow.features, created_by=account.id, @@ -357,7 +370,7 @@ class WorkflowService: else: variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs=user_inputs, environment_variables=draft_workflow.environment_variables, conversation_variables=[], @@ -396,7 +409,7 @@ class WorkflowService: node_execution.workflow_id = draft_workflow.id # Create repository and save the node execution - repository = SQLAlchemyWorkflowNodeExecutionRepository( + repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=db.engine, user=account, app_id=app_model.id, @@ -404,8 +417,9 @@ class WorkflowService: ) repository.save(node_execution) - # Convert node_execution to WorkflowNodeExecution after save - workflow_node_execution = repository.to_db_model(node_execution) + workflow_node_execution = self._node_execution_service_repo.get_execution_by_id(node_execution.id) + if workflow_node_execution is None: + raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving") with Session(bind=db.engine) as session, session.begin(): draft_var_saver = DraftVariableSaver( @@ -418,6 +432,7 @@ class WorkflowService: ) draft_var_saver.save(process_data=node_execution.process_data, outputs=node_execution.outputs) session.commit() + return workflow_node_execution def run_free_workflow_node( @@ -429,7 +444,7 @@ class WorkflowService: # run draft workflow node start_at = time.perf_counter() - workflow_node_execution = self._handle_node_run_result( + node_execution = self._handle_node_run_result( invoke_node_fn=lambda: WorkflowEntry.run_free_node( node_id=node_id, node_data=node_data, @@ -441,7 +456,7 @@ class WorkflowService: node_id=node_id, ) - return workflow_node_execution + return node_execution def _handle_node_run_result( self, @@ -450,10 +465,10 @@ class WorkflowService: node_id: str, ) -> WorkflowNodeExecution: try: - node_instance, generator = invoke_node_fn() + node, node_events = invoke_node_fn() node_run_result: NodeRunResult | None = None - for event in generator: + for event in node_events: if isinstance(event, RunCompletedEvent): node_run_result = event.run_result @@ -464,18 +479,18 @@ class WorkflowService: if not node_run_result: raise ValueError("Node run failed with no run result") # single step debug mode error handling return - if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error: + if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node.continue_on_error: node_error_args: dict[str, Any] = { "status": WorkflowNodeExecutionStatus.EXCEPTION, "error": node_run_result.error, "inputs": node_run_result.inputs, - "metadata": {"error_strategy": node_instance.node_data.error_strategy}, + "metadata": {"error_strategy": node.error_strategy}, } - if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: + if node.error_strategy is ErrorStrategy.DEFAULT_VALUE: node_run_result = NodeRunResult( **node_error_args, outputs={ - **node_instance.node_data.default_value_dict, + **node.default_value_dict, "error_message": node_run_result.error, "error_type": node_run_result.error_type, }, @@ -494,10 +509,10 @@ class WorkflowService: ) error = node_run_result.error if not run_succeeded else None except WorkflowNodeRunFailedError as e: - node_instance = e.node_instance + node = e._node run_succeeded = False node_run_result = None - error = e.error + error = e._error # Create a NodeExecution domain model node_execution = WorkflowNodeExecution( @@ -505,11 +520,11 @@ class WorkflowService: workflow_id="", # This is a single-step execution, so no workflow ID index=1, node_id=node_id, - node_type=node_instance.node_type, - title=node_instance.node_data.title, + node_type=node.type_, + title=node.title, elapsed_time=time.perf_counter() - start_at, - created_at=datetime.now(UTC).replace(tzinfo=None), - finished_at=datetime.now(UTC).replace(tzinfo=None), + created_at=naive_utc_now(), + finished_at=naive_utc_now(), ) if run_succeeded and node_run_result: @@ -606,7 +621,7 @@ class WorkflowService: setattr(workflow, field, value) workflow.updated_by = account_id - workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) + workflow.updated_at = naive_utc_now() return workflow @@ -643,7 +658,7 @@ class WorkflowService: # Check if there's a tool provider using this specific workflow version tool_provider = ( session.query(WorkflowToolProvider) - .filter( + .where( WorkflowToolProvider.tenant_id == workflow.tenant_id, WorkflowToolProvider.app_id == workflow.app_id, WorkflowToolProvider.version == workflow.version, @@ -671,36 +686,30 @@ def _setup_variable_pool( ): # Only inject system variables for START node type. if node_type == NodeType.START: - # Create a variable pool. - system_inputs: dict[SystemVariableKey, Any] = { - # From inputs: - SystemVariableKey.FILES: files, - SystemVariableKey.USER_ID: user_id, - # From workflow model - SystemVariableKey.APP_ID: workflow.app_id, - SystemVariableKey.WORKFLOW_ID: workflow.id, - # Randomly generated. - SystemVariableKey.WORKFLOW_EXECUTION_ID: str(uuid.uuid4()), - } + system_variable = SystemVariable( + user_id=user_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + files=files or [], + workflow_execution_id=str(uuid.uuid4()), + ) # Only add chatflow-specific variables for non-workflow types if workflow.type != WorkflowType.WORKFLOW.value: - system_inputs.update( - { - SystemVariableKey.QUERY: query, - SystemVariableKey.CONVERSATION_ID: conversation_id, - SystemVariableKey.DIALOGUE_COUNT: 0, - } - ) + system_variable.query = query + system_variable.conversation_id = conversation_id + system_variable.dialogue_count = 0 else: - system_inputs = {} + system_variable = SystemVariable.empty() # init variable pool variable_pool = VariablePool( - system_variables=system_inputs, + system_variables=system_variable, user_inputs=user_inputs, environment_variables=workflow.environment_variables, - conversation_variables=conversation_variables, + # Based on the definition of `VariableUnion`, + # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. + conversation_variables=cast(list[VariableUnion], conversation_variables), # ) return variable_pool diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 125e0c1b1e..d4fc68a084 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -25,13 +25,13 @@ class WorkspaceService: # Get role of user tenant_account_join = ( db.session.query(TenantAccountJoin) - .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id) .first() ) assert tenant_account_join is not None, "TenantAccountJoin not found" tenant_info["role"] = tenant_account_join.role - can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo + can_replace_logo = FeatureService.get_features(tenant.id).can_replace_logo if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN]): base_url = dify_config.FILES_URL diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index 75d648e1b7..204c1a4f5b 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -25,7 +25,7 @@ def add_document_to_index_task(dataset_document_id: str): logging.info(click.style("Start add document to index: {}".format(dataset_document_id), fg="green")) start_at = time.perf_counter() - dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document_id).first() + dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first() if not dataset_document: logging.info(click.style("Document not found: {}".format(dataset_document_id), fg="red")) db.session.close() @@ -43,7 +43,7 @@ def add_document_to_index_task(dataset_document_id: str): segments = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == False, DocumentSegment.status == "completed", @@ -86,12 +86,10 @@ def add_document_to_index_task(dataset_document_id: str): index_processor.load(dataset, documents) # delete auto disable log - db.session.query(DatasetAutoDisableLog).filter( - DatasetAutoDisableLog.document_id == dataset_document.id - ).delete() + db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id).delete() # update segment to enable - db.session.query(DocumentSegment).filter(DocumentSegment.document_id == dataset_document.id).update( + db.session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update( { DocumentSegment.enabled: True, DocumentSegment.disabled_at: None, diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index 6144a4fe3e..6d48f5df89 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -29,7 +29,7 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: start_at = time.perf_counter() indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) # get app info - app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() + app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() if app: try: @@ -48,7 +48,7 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: documents.append(document) # if annotation reply is enabled , batch add annotations' index app_annotation_setting = ( - db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() ) if app_annotation_setting: diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index 747fce5784..5d5d1d3ad8 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -19,16 +19,14 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): logging.info(click.style("Start delete app annotations index: {}".format(app_id), fg="green")) start_at = time.perf_counter() # get app info - app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() - annotations_count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).count() + app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() + annotations_count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).count() if not app: logging.info(click.style("App not found: {}".format(app_id), fg="red")) db.session.close() return - app_annotation_setting = ( - db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() - ) + app_annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() if not app_annotation_setting: logging.info(click.style("App annotation setting not found: {}".format(app_id), fg="red")) diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index c04f1be845..12d10df442 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -30,14 +30,14 @@ def enable_annotation_reply_task( logging.info(click.style("Start add app annotation to index: {}".format(app_id), fg="green")) start_at = time.perf_counter() # get app info - app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() + app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() if not app: logging.info(click.style("App not found: {}".format(app_id), fg="red")) db.session.close() return - annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).all() + annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).all() enable_app_annotation_key = "enable_app_annotation_{}".format(str(app_id)) enable_app_annotation_job_key = "enable_app_annotation_job_{}".format(str(job_id)) @@ -46,9 +46,7 @@ def enable_annotation_reply_task( dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( embedding_provider_name, embedding_model_name, "annotation" ) - annotation_setting = ( - db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() - ) + annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() if annotation_setting: if dataset_collection_binding.id != annotation_setting.collection_binding_id: old_dataset_collection_binding = ( diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py index 97efc47b33..49bff72a96 100644 --- a/api/tasks/batch_clean_document_task.py +++ b/api/tasks/batch_clean_document_task.py @@ -27,12 +27,12 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form start_at = time.perf_counter() try: - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: raise Exception("Document has no dataset") - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids)).all() + segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)).all() # check segment is exist if segments: index_node_ids = [segment.index_node_id for segment in segments] @@ -42,7 +42,7 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) for upload_file_id in image_upload_file_ids: - image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() try: if image_file and image_file.key: storage.delete(image_file.key) @@ -56,7 +56,7 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form db.session.commit() if file_ids: - files = db.session.query(UploadFile).filter(UploadFile.id.in_(file_ids)).all() + files = db.session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all() for file in files: try: storage.delete(file.key) diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 51b6343fdc..64df3175e1 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -81,7 +81,7 @@ def batch_create_segment_to_index_task( segment_hash = helper.generate_text_hash(content) # type: ignore max_position = ( db.session.query(func.max(DocumentSegment.position)) - .filter(DocumentSegment.document_id == dataset_document.id) + .where(DocumentSegment.document_id == dataset_document.id) .scalar() ) segment_document = DocumentSegment( diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 6bac718395..fad090141a 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -53,8 +53,8 @@ def clean_dataset_task( index_struct=index_struct, collection_binding_id=collection_binding_id, ) - documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all() - segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() + documents = db.session.query(Document).where(Document.dataset_id == dataset_id).all() + segments = db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id).all() if documents is None or len(documents) == 0: logging.info(click.style("No documents found for dataset: {}".format(dataset_id), fg="green")) @@ -72,7 +72,7 @@ def clean_dataset_task( for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) for upload_file_id in image_upload_file_ids: - image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() if image_file is None: continue try: @@ -85,12 +85,12 @@ def clean_dataset_task( db.session.delete(image_file) db.session.delete(segment) - db.session.query(DatasetProcessRule).filter(DatasetProcessRule.dataset_id == dataset_id).delete() - db.session.query(DatasetQuery).filter(DatasetQuery.dataset_id == dataset_id).delete() - db.session.query(AppDatasetJoin).filter(AppDatasetJoin.dataset_id == dataset_id).delete() + db.session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete() + db.session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete() + db.session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete() # delete dataset metadata - db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id == dataset_id).delete() - db.session.query(DatasetMetadataBinding).filter(DatasetMetadataBinding.dataset_id == dataset_id).delete() + db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete() + db.session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete() # delete files if documents: for document in documents: @@ -102,7 +102,7 @@ def clean_dataset_task( file_id = data_source_info["upload_file_id"] file = ( db.session.query(UploadFile) - .filter(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) + .where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) .first() ) if not file: diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 5824121e8f..dd7a544ff5 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -28,12 +28,12 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i start_at = time.perf_counter() try: - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: raise Exception("Document has no dataset") - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() # check segment is exist if segments: index_node_ids = [segment.index_node_id for segment in segments] @@ -43,7 +43,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) for upload_file_id in image_upload_file_ids: - image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() if image_file is None: continue try: @@ -58,7 +58,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i db.session.commit() if file_id: - file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() + file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() if file: try: storage.delete(file.key) @@ -68,10 +68,11 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i db.session.commit() # delete dataset metadata binding - db.session.query(DatasetMetadataBinding).filter( + db.session.query(DatasetMetadataBinding).where( DatasetMetadataBinding.dataset_id == dataset_id, DatasetMetadataBinding.document_id == document_id, ).delete() + db.session.commit() end_at = time.perf_counter() logging.info( diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index 1087a37761..0f72f87f15 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -24,17 +24,17 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): start_at = time.perf_counter() try: - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: raise Exception("Document has no dataset") index_type = dataset.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() for document_id in document_ids: - document = db.session.query(Document).filter(Document.id == document_id).first() + document = db.session.query(Document).where(Document.id == document_id).first() db.session.delete(document) - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() index_node_ids = [segment.index_node_id for segment in segments] index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index a3f811faa1..5eda24674a 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -24,7 +24,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] logging.info(click.style("Start create segment to index: {}".format(segment_id), fg="green")) start_at = time.perf_counter() - segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() + segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first() if not segment: logging.info(click.style("Segment not found: {}".format(segment_id), fg="red")) db.session.close() @@ -37,11 +37,12 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] try: # update segment status to indexing - update_params = { - DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), - } - db.session.query(DocumentSegment).filter_by(id=segment.id).update(update_params) + db.session.query(DocumentSegment).filter_by(id=segment.id).update( + { + DocumentSegment.status: "indexing", + DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + } + ) db.session.commit() document = Document( page_content=segment.content, @@ -74,11 +75,12 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] index_processor.load(dataset, [document]) # update segment to completed - update_params = { - DocumentSegment.status: "completed", - DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), - } - db.session.query(DocumentSegment).filter_by(id=segment.id).update(update_params) + db.session.query(DocumentSegment).filter_by(id=segment.id).update( + { + DocumentSegment.status: "completed", + DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + } + ) db.session.commit() end_at = time.perf_counter() diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index a27207f2f1..7478bf5a90 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -35,7 +35,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): elif action == "add": dataset_documents = ( db.session.query(DatasetDocument) - .filter( + .where( DatasetDocument.dataset_id == dataset_id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, @@ -46,7 +46,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): if dataset_documents: dataset_documents_ids = [doc.id for doc in dataset_documents] - db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update( + db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( {"indexing_status": "indexing"}, synchronize_session=False ) db.session.commit() @@ -56,7 +56,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): # add from vector index segments = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) .order_by(DocumentSegment.position.asc()) .all() ) @@ -76,19 +76,19 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): documents.append(document) # save vector index index_processor.load(dataset, documents, with_keywords=False) - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "completed"}, synchronize_session=False ) db.session.commit() except Exception as e: - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "error", "error": str(e)}, synchronize_session=False ) db.session.commit() elif action == "update": dataset_documents = ( db.session.query(DatasetDocument) - .filter( + .where( DatasetDocument.dataset_id == dataset_id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, @@ -100,7 +100,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): if dataset_documents: # update document status dataset_documents_ids = [doc.id for doc in dataset_documents] - db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update( + db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( {"indexing_status": "indexing"}, synchronize_session=False ) db.session.commit() @@ -113,7 +113,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): try: segments = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) .order_by(DocumentSegment.position.asc()) .all() ) @@ -148,12 +148,12 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): documents.append(document) # save vector index index_processor.load(dataset, documents, with_keywords=False) - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "completed"}, synchronize_session=False ) db.session.commit() except Exception as e: - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "error", "error": str(e)}, synchronize_session=False ) db.session.commit() diff --git a/api/tasks/delete_account_task.py b/api/tasks/delete_account_task.py index 52c884ca29..d3b33e3052 100644 --- a/api/tasks/delete_account_task.py +++ b/api/tasks/delete_account_task.py @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="dataset") def delete_account_task(account_id): - account = db.session.query(Account).filter(Account.id == account_id).first() + account = db.session.query(Account).where(Account.id == account_id).first() try: BillingService.delete_account(account_id) except Exception as e: diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index a93babc310..66ff0f9a0a 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -22,11 +22,11 @@ def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, docume logging.info(click.style("Start delete segment from index", fg="green")) start_at = time.perf_counter() try: - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: return - dataset_document = db.session.query(Document).filter(Document.id == document_id).first() + dataset_document = db.session.query(Document).where(Document.id == document_id).first() if not dataset_document: return diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py index 327eed4721..e67ba5c76e 100644 --- a/api/tasks/disable_segment_from_index_task.py +++ b/api/tasks/disable_segment_from_index_task.py @@ -21,7 +21,7 @@ def disable_segment_from_index_task(segment_id: str): logging.info(click.style("Start disable segment from index: {}".format(segment_id), fg="green")) start_at = time.perf_counter() - segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() + segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first() if not segment: logging.info(click.style("Segment not found: {}".format(segment_id), fg="red")) db.session.close() diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py index 8b77b290c8..0c8b1aabc7 100644 --- a/api/tasks/disable_segments_from_index_task.py +++ b/api/tasks/disable_segments_from_index_task.py @@ -23,13 +23,13 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen """ start_at = time.perf_counter() - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan")) db.session.close() return - dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() + dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() if not dataset_document: logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan")) @@ -44,7 +44,7 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen segments = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset_id, DocumentSegment.document_id == document_id, @@ -64,7 +64,7 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen logging.info(click.style("Segments removed from index latency: {}".format(end_at - start_at), fg="green")) except Exception: # update segment error msg - db.session.query(DocumentSegment).filter( + db.session.query(DocumentSegment).where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset_id, DocumentSegment.document_id == document_id, diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index b4848be192..dcc748ef18 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -25,7 +25,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): logging.info(click.style("Start sync document: {}".format(document_id), fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: logging.info(click.style("Document not found: {}".format(document_id), fg="red")) @@ -46,7 +46,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): page_edited_time = data_source_info["last_edited_time"] data_source_binding = ( db.session.query(DataSourceOauthBinding) - .filter( + .where( db.and_( DataSourceOauthBinding.tenant_id == document.tenant_id, DataSourceOauthBinding.provider == "notion", @@ -77,13 +77,13 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): # delete all document segment and index try: - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: raise Exception("Dataset not found") index_type = document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 55cac6a9af..ec6d10d93b 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -1,4 +1,3 @@ -import datetime import logging import time @@ -8,6 +7,7 @@ from celery import shared_task # type: ignore from configs import dify_config from core.indexing_runner import DocumentIsPausedError, IndexingRunner from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document from services.feature_service import FeatureService @@ -24,7 +24,7 @@ def document_indexing_task(dataset_id: str, document_ids: list): documents = [] start_at = time.perf_counter() - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: logging.info(click.style("Dataset is not found: {}".format(dataset_id), fg="yellow")) db.session.close() @@ -48,12 +48,12 @@ def document_indexing_task(dataset_id: str, document_ids: list): except Exception as e: for document_id in document_ids: document = ( - db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) if document: document.indexing_status = "error" document.error = str(e) - document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.stopped_at = naive_utc_now() db.session.add(document) db.session.commit() db.session.close() @@ -63,12 +63,12 @@ def document_indexing_task(dataset_id: str, document_ids: list): logging.info(click.style("Start process document: {}".format(document_id), fg="green")) document = ( - db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) if document: document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.processing_started_at = naive_utc_now() documents.append(document) db.session.add(document) db.session.commit() diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 167b928f5d..e53c38ddc3 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -23,7 +23,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): logging.info(click.style("Start update document: {}".format(document_id), fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: logging.info(click.style("Document not found: {}".format(document_id), fg="red")) @@ -36,14 +36,14 @@ def document_indexing_update_task(dataset_id: str, document_id: str): # delete all document segment and index try: - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: raise Exception("Dataset not found") index_type = document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() if segments: index_node_ids = [segment.index_node_id for segment in segments] diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index a6c93e110e..b3ddface59 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -25,7 +25,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): documents = [] start_at = time.perf_counter() - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if dataset is None: logging.info(click.style("Dataset not found: {}".format(dataset_id), fg="red")) db.session.close() @@ -50,7 +50,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): except Exception as e: for document_id in document_ids: document = ( - db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) if document: document.indexing_status = "error" @@ -66,7 +66,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): logging.info(click.style("Start process document: {}".format(document_id), fg="green")) document = ( - db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) if document: @@ -74,7 +74,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): index_type = document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() if segments: index_node_ids = [segment.index_node_id for segment in segments] diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 21f08f40a7..13822f078e 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -24,7 +24,7 @@ def enable_segment_to_index_task(segment_id: str): logging.info(click.style("Start enable segment to index: {}".format(segment_id), fg="green")) start_at = time.perf_counter() - segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() + segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first() if not segment: logging.info(click.style("Segment not found: {}".format(segment_id), fg="red")) db.session.close() diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py index 625a3b582e..e3fdf04d8c 100644 --- a/api/tasks/enable_segments_to_index_task.py +++ b/api/tasks/enable_segments_to_index_task.py @@ -25,12 +25,12 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i Usage: enable_segments_to_index_task.delay(segment_ids, dataset_id, document_id) """ start_at = time.perf_counter() - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan")) return - dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() + dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() if not dataset_document: logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan")) @@ -45,7 +45,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i segments = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset_id, DocumentSegment.document_id == document_id, @@ -95,7 +95,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i except Exception as e: logging.exception("enable segments to index failed") # update segment error msg - db.session.query(DocumentSegment).filter( + db.session.query(DocumentSegment).where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset_id, DocumentSegment.document_id == document_id, diff --git a/api/tasks/mail_account_deletion_task.py b/api/tasks/mail_account_deletion_task.py index 0c60ae53d5..a6f8ce2f0b 100644 --- a/api/tasks/mail_account_deletion_task.py +++ b/api/tasks/mail_account_deletion_task.py @@ -3,14 +3,20 @@ import time import click from celery import shared_task # type: ignore -from flask import render_template from extensions.ext_mail import mail +from libs.email_i18n import EmailType, get_email_i18n_service @shared_task(queue="mail") -def send_deletion_success_task(to): - """Send email to user regarding account deletion.""" +def send_deletion_success_task(to: str, language: str = "en-US") -> None: + """ + Send account deletion success email with internationalization support. + + Args: + to: Recipient email address + language: Language code for email localization + """ if not mail.is_inited(): return @@ -18,12 +24,16 @@ def send_deletion_success_task(to): start_at = time.perf_counter() try: - html_content = render_template( - "delete_account_success_template_en-US.html", + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.ACCOUNT_DELETION_SUCCESS, + language_code=language, to=to, - email=to, + template_context={ + "to": to, + "email": to, + }, ) - mail.send(to=to, subject="Your Dify.AI Account Has Been Successfully Deleted", html=html_content) end_at = time.perf_counter() logging.info( @@ -36,12 +46,14 @@ def send_deletion_success_task(to): @shared_task(queue="mail") -def send_account_deletion_verification_code(to, code): - """Send email to user regarding account deletion verification code. +def send_account_deletion_verification_code(to: str, code: str, language: str = "en-US") -> None: + """ + Send account deletion verification code email with internationalization support. Args: - to (str): Recipient email address - code (str): Verification code + to: Recipient email address + code: Verification code + language: Language code for email localization """ if not mail.is_inited(): return @@ -50,8 +62,16 @@ def send_account_deletion_verification_code(to, code): start_at = time.perf_counter() try: - html_content = render_template("delete_account_code_email_template_en-US.html", to=to, code=code) - mail.send(to=to, subject="Dify.AI Account Deletion and Verification", html=html_content) + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.ACCOUNT_DELETION_VERIFICATION, + language_code=language, + to=to, + template_context={ + "to": to, + "code": code, + }, + ) end_at = time.perf_counter() logging.info( diff --git a/api/tasks/mail_change_mail_task.py b/api/tasks/mail_change_mail_task.py new file mode 100644 index 0000000000..ea1875901c --- /dev/null +++ b/api/tasks/mail_change_mail_task.py @@ -0,0 +1,42 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore + +from extensions.ext_mail import mail +from libs.email_i18n import get_email_i18n_service + + +@shared_task(queue="mail") +def send_change_mail_task(language: str, to: str, code: str, phase: str) -> None: + """ + Send change email notification with internationalization support. + + Args: + language: Language code for email localization + to: Recipient email address + code: Email verification code + phase: Change email phase ('old_email' or 'new_email') + """ + if not mail.is_inited(): + return + + logging.info(click.style("Start change email mail to {}".format(to), fg="green")) + start_at = time.perf_counter() + + try: + email_service = get_email_i18n_service() + email_service.send_change_email( + language_code=language, + to=to, + code=code, + phase=phase, + ) + + end_at = time.perf_counter() + logging.info( + click.style("Send change email mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green") + ) + except Exception: + logging.exception("Send change email mail to {} failed".format(to)) diff --git a/api/tasks/mail_email_code_login.py b/api/tasks/mail_email_code_login.py index ddad331725..34220784e9 100644 --- a/api/tasks/mail_email_code_login.py +++ b/api/tasks/mail_email_code_login.py @@ -3,19 +3,20 @@ import time import click from celery import shared_task # type: ignore -from flask import render_template from extensions.ext_mail import mail -from services.feature_service import FeatureService +from libs.email_i18n import EmailType, get_email_i18n_service @shared_task(queue="mail") -def send_email_code_login_mail_task(language: str, to: str, code: str): +def send_email_code_login_mail_task(language: str, to: str, code: str) -> None: """ - Async Send email code login mail - :param language: Language in which the email should be sent (e.g., 'en', 'zh') - :param to: Recipient email address - :param code: Email code to be included in the email + Send email code login email with internationalization support. + + Args: + language: Language code for email localization + to: Recipient email address + code: Email verification code """ if not mail.is_inited(): return @@ -23,28 +24,17 @@ def send_email_code_login_mail_task(language: str, to: str, code: str): logging.info(click.style("Start email code login mail to {}".format(to), fg="green")) start_at = time.perf_counter() - # send email code login mail using different languages try: - if language == "zh-Hans": - template = "email_code_login_mail_template_zh-CN.html" - system_features = FeatureService.get_system_features() - if system_features.branding.enabled: - application_title = system_features.branding.application_title - template = "without-brand/email_code_login_mail_template_zh-CN.html" - html_content = render_template(template, to=to, code=code, application_title=application_title) - else: - html_content = render_template(template, to=to, code=code) - mail.send(to=to, subject="邮箱验证码", html=html_content) - else: - template = "email_code_login_mail_template_en-US.html" - system_features = FeatureService.get_system_features() - if system_features.branding.enabled: - application_title = system_features.branding.application_title - template = "without-brand/email_code_login_mail_template_en-US.html" - html_content = render_template(template, to=to, code=code, application_title=application_title) - else: - html_content = render_template(template, to=to, code=code) - mail.send(to=to, subject="Email Code", html=html_content) + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.EMAIL_CODE_LOGIN, + language_code=language, + to=to, + template_context={ + "to": to, + "code": code, + }, + ) end_at = time.perf_counter() logging.info( diff --git a/api/tasks/mail_enterprise_task.py b/api/tasks/mail_enterprise_task.py index b9d8fd55df..a1c2908624 100644 --- a/api/tasks/mail_enterprise_task.py +++ b/api/tasks/mail_enterprise_task.py @@ -1,15 +1,17 @@ import logging import time +from collections.abc import Mapping import click from celery import shared_task # type: ignore from flask import render_template_string from extensions.ext_mail import mail +from libs.email_i18n import get_email_i18n_service @shared_task(queue="mail") -def send_enterprise_email_task(to, subject, body, substitutions): +def send_enterprise_email_task(to: list[str], subject: str, body: str, substitutions: Mapping[str, str]): if not mail.is_inited(): return @@ -19,11 +21,8 @@ def send_enterprise_email_task(to, subject, body, substitutions): try: html_content = render_template_string(body, **substitutions) - if isinstance(to, list): - for t in to: - mail.send(to=t, subject=subject, html=html_content) - else: - mail.send(to=to, subject=subject, html=html_content) + email_service = get_email_i18n_service() + email_service.send_raw_email(to=to, subject=subject, html_content=html_content) end_at = time.perf_counter() logging.info( diff --git a/api/tasks/mail_invite_member_task.py b/api/tasks/mail_invite_member_task.py index 7ca85c7f2d..8c73de0111 100644 --- a/api/tasks/mail_invite_member_task.py +++ b/api/tasks/mail_invite_member_task.py @@ -3,24 +3,23 @@ import time import click from celery import shared_task # type: ignore -from flask import render_template from configs import dify_config from extensions.ext_mail import mail -from services.feature_service import FeatureService +from libs.email_i18n import EmailType, get_email_i18n_service @shared_task(queue="mail") -def send_invite_member_mail_task(language: str, to: str, token: str, inviter_name: str, workspace_name: str): +def send_invite_member_mail_task(language: str, to: str, token: str, inviter_name: str, workspace_name: str) -> None: """ - Async Send invite member mail - :param language - :param to - :param token - :param inviter_name - :param workspace_name + Send invite member email with internationalization support. - Usage: send_invite_member_mail_task.delay(language, to, token, inviter_name, workspace_name) + Args: + language: Language code for email localization + to: Recipient email address + token: Invitation token + inviter_name: Name of the person sending the invitation + workspace_name: Name of the workspace """ if not mail.is_inited(): return @@ -30,49 +29,20 @@ def send_invite_member_mail_task(language: str, to: str, token: str, inviter_nam ) start_at = time.perf_counter() - # send invite member mail using different languages try: url = f"{dify_config.CONSOLE_WEB_URL}/activate?token={token}" - if language == "zh-Hans": - template = "invite_member_mail_template_zh-CN.html" - system_features = FeatureService.get_system_features() - if system_features.branding.enabled: - application_title = system_features.branding.application_title - template = "without-brand/invite_member_mail_template_zh-CN.html" - html_content = render_template( - template, - to=to, - inviter_name=inviter_name, - workspace_name=workspace_name, - url=url, - application_title=application_title, - ) - mail.send(to=to, subject=f"立即加入 {application_title} 工作空间", html=html_content) - else: - html_content = render_template( - template, to=to, inviter_name=inviter_name, workspace_name=workspace_name, url=url - ) - mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content) - else: - template = "invite_member_mail_template_en-US.html" - system_features = FeatureService.get_system_features() - if system_features.branding.enabled: - application_title = system_features.branding.application_title - template = "without-brand/invite_member_mail_template_en-US.html" - html_content = render_template( - template, - to=to, - inviter_name=inviter_name, - workspace_name=workspace_name, - url=url, - application_title=application_title, - ) - mail.send(to=to, subject=f"Join {application_title} Workspace Now", html=html_content) - else: - html_content = render_template( - template, to=to, inviter_name=inviter_name, workspace_name=workspace_name, url=url - ) - mail.send(to=to, subject="Join Dify Workspace Now", html=html_content) + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.INVITE_MEMBER, + language_code=language, + to=to, + template_context={ + "to": to, + "inviter_name": inviter_name, + "workspace_name": workspace_name, + "url": url, + }, + ) end_at = time.perf_counter() logging.info( diff --git a/api/tasks/mail_owner_transfer_task.py b/api/tasks/mail_owner_transfer_task.py new file mode 100644 index 0000000000..e566a6bc56 --- /dev/null +++ b/api/tasks/mail_owner_transfer_task.py @@ -0,0 +1,129 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore + +from extensions.ext_mail import mail +from libs.email_i18n import EmailType, get_email_i18n_service + + +@shared_task(queue="mail") +def send_owner_transfer_confirm_task(language: str, to: str, code: str, workspace: str) -> None: + """ + Send owner transfer confirmation email with internationalization support. + + Args: + language: Language code for email localization + to: Recipient email address + code: Verification code + workspace: Workspace name + """ + if not mail.is_inited(): + return + + logging.info(click.style("Start owner transfer confirm mail to {}".format(to), fg="green")) + start_at = time.perf_counter() + + try: + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.OWNER_TRANSFER_CONFIRM, + language_code=language, + to=to, + template_context={ + "to": to, + "code": code, + "WorkspaceName": workspace, + }, + ) + + end_at = time.perf_counter() + logging.info( + click.style( + "Send owner transfer confirm mail to {} succeeded: latency: {}".format(to, end_at - start_at), + fg="green", + ) + ) + except Exception: + logging.exception("owner transfer confirm email mail to {} failed".format(to)) + + +@shared_task(queue="mail") +def send_old_owner_transfer_notify_email_task(language: str, to: str, workspace: str, new_owner_email: str) -> None: + """ + Send old owner transfer notification email with internationalization support. + + Args: + language: Language code for email localization + to: Recipient email address + workspace: Workspace name + new_owner_email: New owner email address + """ + if not mail.is_inited(): + return + + logging.info(click.style("Start old owner transfer notify mail to {}".format(to), fg="green")) + start_at = time.perf_counter() + + try: + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.OWNER_TRANSFER_OLD_NOTIFY, + language_code=language, + to=to, + template_context={ + "to": to, + "WorkspaceName": workspace, + "NewOwnerEmail": new_owner_email, + }, + ) + + end_at = time.perf_counter() + logging.info( + click.style( + "Send old owner transfer notify mail to {} succeeded: latency: {}".format(to, end_at - start_at), + fg="green", + ) + ) + except Exception: + logging.exception("old owner transfer notify email mail to {} failed".format(to)) + + +@shared_task(queue="mail") +def send_new_owner_transfer_notify_email_task(language: str, to: str, workspace: str) -> None: + """ + Send new owner transfer notification email with internationalization support. + + Args: + language: Language code for email localization + to: Recipient email address + workspace: Workspace name + """ + if not mail.is_inited(): + return + + logging.info(click.style("Start new owner transfer notify mail to {}".format(to), fg="green")) + start_at = time.perf_counter() + + try: + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.OWNER_TRANSFER_NEW_NOTIFY, + language_code=language, + to=to, + template_context={ + "to": to, + "WorkspaceName": workspace, + }, + ) + + end_at = time.perf_counter() + logging.info( + click.style( + "Send new owner transfer notify mail to {} succeeded: latency: {}".format(to, end_at - start_at), + fg="green", + ) + ) + except Exception: + logging.exception("new owner transfer notify email mail to {} failed".format(to)) diff --git a/api/tasks/mail_reset_password_task.py b/api/tasks/mail_reset_password_task.py index d4f4482a48..e2482f2101 100644 --- a/api/tasks/mail_reset_password_task.py +++ b/api/tasks/mail_reset_password_task.py @@ -3,19 +3,20 @@ import time import click from celery import shared_task # type: ignore -from flask import render_template from extensions.ext_mail import mail -from services.feature_service import FeatureService +from libs.email_i18n import EmailType, get_email_i18n_service @shared_task(queue="mail") -def send_reset_password_mail_task(language: str, to: str, code: str): +def send_reset_password_mail_task(language: str, to: str, code: str) -> None: """ - Async Send reset password mail - :param language: Language in which the email should be sent (e.g., 'en', 'zh') - :param to: Recipient email address - :param code: Reset password code + Send reset password email with internationalization support. + + Args: + language: Language code for email localization + to: Recipient email address + code: Reset password code """ if not mail.is_inited(): return @@ -23,30 +24,17 @@ def send_reset_password_mail_task(language: str, to: str, code: str): logging.info(click.style("Start password reset mail to {}".format(to), fg="green")) start_at = time.perf_counter() - # send reset password mail using different languages try: - if language == "zh-Hans": - template = "reset_password_mail_template_zh-CN.html" - system_features = FeatureService.get_system_features() - if system_features.branding.enabled: - application_title = system_features.branding.application_title - template = "without-brand/reset_password_mail_template_zh-CN.html" - html_content = render_template(template, to=to, code=code, application_title=application_title) - mail.send(to=to, subject=f"设置您的 {application_title} 密码", html=html_content) - else: - html_content = render_template(template, to=to, code=code) - mail.send(to=to, subject="设置您的 Dify 密码", html=html_content) - else: - template = "reset_password_mail_template_en-US.html" - system_features = FeatureService.get_system_features() - if system_features.branding.enabled: - application_title = system_features.branding.application_title - template = "without-brand/reset_password_mail_template_en-US.html" - html_content = render_template(template, to=to, code=code, application_title=application_title) - mail.send(to=to, subject=f"Set Your {application_title} Password", html=html_content) - else: - html_content = render_template(template, to=to, code=code) - mail.send(to=to, subject="Set Your Dify Password", html=html_content) + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.RESET_PASSWORD, + language_code=language, + to=to, + template_context={ + "to": to, + "code": code, + }, + ) end_at = time.perf_counter() logging.info( diff --git a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py new file mode 100644 index 0000000000..6fcdad0525 --- /dev/null +++ b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @@ -0,0 +1,166 @@ +import traceback +import typing + +import click +from celery import shared_task # type: ignore + +from core.helper import marketplace +from core.helper.marketplace import MarketplacePluginDeclaration +from core.plugin.entities.plugin import PluginInstallationSource +from core.plugin.impl.plugin import PluginInstaller +from models.account import TenantPluginAutoUpgradeStrategy + +RETRY_TIMES_OF_ONE_PLUGIN_IN_ONE_TENANT = 3 + + +cached_plugin_manifests: dict[str, typing.Union[MarketplacePluginDeclaration, None]] = {} + + +def marketplace_batch_fetch_plugin_manifests( + plugin_ids_plain_list: list[str], +) -> list[MarketplacePluginDeclaration]: + global cached_plugin_manifests + # return marketplace.batch_fetch_plugin_manifests(plugin_ids_plain_list) + not_included_plugin_ids = [ + plugin_id for plugin_id in plugin_ids_plain_list if plugin_id not in cached_plugin_manifests + ] + if not_included_plugin_ids: + manifests = marketplace.batch_fetch_plugin_manifests_ignore_deserialization_error(not_included_plugin_ids) + for manifest in manifests: + cached_plugin_manifests[manifest.plugin_id] = manifest + + if ( + len(manifests) == 0 + ): # this indicates that the plugin not found in marketplace, should set None in cache to prevent future check + for plugin_id in not_included_plugin_ids: + cached_plugin_manifests[plugin_id] = None + + result: list[MarketplacePluginDeclaration] = [] + for plugin_id in plugin_ids_plain_list: + final_manifest = cached_plugin_manifests.get(plugin_id) + if final_manifest is not None: + result.append(final_manifest) + + return result + + +@shared_task(queue="plugin") +def process_tenant_plugin_autoupgrade_check_task( + tenant_id: str, + strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting, + upgrade_time_of_day: int, + upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode, + exclude_plugins: list[str], + include_plugins: list[str], +): + try: + manager = PluginInstaller() + + click.echo( + click.style( + "Checking upgradable plugin for tenant: {}".format(tenant_id), + fg="green", + ) + ) + + if strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED: + return + + # get plugin_ids to check + plugin_ids: list[tuple[str, str, str]] = [] # plugin_id, version, unique_identifier + click.echo(click.style("Upgrade mode: {}".format(upgrade_mode), fg="green")) + + if upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL and include_plugins: + all_plugins = manager.list_plugins(tenant_id) + + for plugin in all_plugins: + if plugin.source == PluginInstallationSource.Marketplace and plugin.plugin_id in include_plugins: + plugin_ids.append( + ( + plugin.plugin_id, + plugin.version, + plugin.plugin_unique_identifier, + ) + ) + + elif upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE: + # get all plugins and remove excluded plugins + all_plugins = manager.list_plugins(tenant_id) + plugin_ids = [ + (plugin.plugin_id, plugin.version, plugin.plugin_unique_identifier) + for plugin in all_plugins + if plugin.source == PluginInstallationSource.Marketplace and plugin.plugin_id not in exclude_plugins + ] + elif upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL: + all_plugins = manager.list_plugins(tenant_id) + plugin_ids = [ + (plugin.plugin_id, plugin.version, plugin.plugin_unique_identifier) + for plugin in all_plugins + if plugin.source == PluginInstallationSource.Marketplace + ] + + if not plugin_ids: + return + + plugin_ids_plain_list = [plugin_id for plugin_id, _, _ in plugin_ids] + + manifests = marketplace_batch_fetch_plugin_manifests(plugin_ids_plain_list) + + if not manifests: + return + + for manifest in manifests: + for plugin_id, version, original_unique_identifier in plugin_ids: + if manifest.plugin_id != plugin_id: + continue + + try: + current_version = version + latest_version = manifest.latest_version + + def fix_only_checker(latest_version, current_version): + latest_version_tuple = tuple(int(val) for val in latest_version.split(".")) + current_version_tuple = tuple(int(val) for val in current_version.split(".")) + + if ( + latest_version_tuple[0] == current_version_tuple[0] + and latest_version_tuple[1] == current_version_tuple[1] + ): + return latest_version_tuple[2] != current_version_tuple[2] + return False + + version_checker = { + TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: lambda latest_version, + current_version: latest_version != current_version, + TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY: fix_only_checker, + } + + if version_checker[strategy_setting](latest_version, current_version): + # execute upgrade + new_unique_identifier = manifest.latest_package_identifier + + marketplace.record_install_plugin_event(new_unique_identifier) + click.echo( + click.style( + "Upgrade plugin: {} -> {}".format(original_unique_identifier, new_unique_identifier), + fg="green", + ) + ) + task_start_resp = manager.upgrade_plugin( + tenant_id, + original_unique_identifier, + new_unique_identifier, + PluginInstallationSource.Marketplace, + { + "plugin_unique_identifier": new_unique_identifier, + }, + ) + except Exception as e: + click.echo(click.style("Error when upgrading plugin: {}".format(e), fg="red")) + traceback.print_exc() + break + + except Exception as e: + click.echo(click.style("Error when checking upgradable plugin: {}".format(e), fg="red")) + traceback.print_exc() + return diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index e7d49c78dc..dfb2389579 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -21,7 +21,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): logging.info(click.style("Recover document: {}".format(document_id), fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: logging.info(click.style("Document not found: {}".format(document_id), fg="red")) diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index d366efd6f2..1619f8c546 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -6,6 +6,7 @@ import click from celery import shared_task # type: ignore from sqlalchemy import delete from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import sessionmaker from extensions.ext_database import db from models import ( @@ -13,6 +14,7 @@ from models import ( AppAnnotationHitHistory, AppAnnotationSetting, AppDatasetJoin, + AppMCPServer, AppModelConfig, Conversation, EndUser, @@ -30,7 +32,8 @@ from models import ( ) from models.tools import WorkflowToolProvider from models.web import PinnedConversation, SavedMessage -from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecutionModel, WorkflowRun +from models.workflow import ConversationVariable, Workflow, WorkflowAppLog +from repositories.factory import DifyAPIRepositoryFactory @shared_task(queue="app_deletion", bind=True, max_retries=3) @@ -41,6 +44,7 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): # Delete related data _delete_app_model_configs(tenant_id, app_id) _delete_app_site(tenant_id, app_id) + _delete_app_mcp_servers(tenant_id, app_id) _delete_app_api_tokens(tenant_id, app_id) _delete_installed_apps(tenant_id, app_id) _delete_recommended_apps(tenant_id, app_id) @@ -72,7 +76,7 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): def _delete_app_model_configs(tenant_id: str, app_id: str): def del_model_config(model_config_id: str): - db.session.query(AppModelConfig).filter(AppModelConfig.id == model_config_id).delete(synchronize_session=False) + db.session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False) _delete_records( """select id from app_model_configs where app_id=:app_id limit 1000""", @@ -84,14 +88,26 @@ def _delete_app_model_configs(tenant_id: str, app_id: str): def _delete_app_site(tenant_id: str, app_id: str): def del_site(site_id: str): - db.session.query(Site).filter(Site.id == site_id).delete(synchronize_session=False) + db.session.query(Site).where(Site.id == site_id).delete(synchronize_session=False) _delete_records("""select id from sites where app_id=:app_id limit 1000""", {"app_id": app_id}, del_site, "site") +def _delete_app_mcp_servers(tenant_id: str, app_id: str): + def del_mcp_server(mcp_server_id: str): + db.session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False) + + _delete_records( + """select id from app_mcp_servers where app_id=:app_id limit 1000""", + {"app_id": app_id}, + del_mcp_server, + "app mcp server", + ) + + def _delete_app_api_tokens(tenant_id: str, app_id: str): def del_api_token(api_token_id: str): - db.session.query(ApiToken).filter(ApiToken.id == api_token_id).delete(synchronize_session=False) + db.session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False) _delete_records( """select id from api_tokens where app_id=:app_id limit 1000""", {"app_id": app_id}, del_api_token, "api token" @@ -100,7 +116,7 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str): def _delete_installed_apps(tenant_id: str, app_id: str): def del_installed_app(installed_app_id: str): - db.session.query(InstalledApp).filter(InstalledApp.id == installed_app_id).delete(synchronize_session=False) + db.session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False) _delete_records( """select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -112,7 +128,7 @@ def _delete_installed_apps(tenant_id: str, app_id: str): def _delete_recommended_apps(tenant_id: str, app_id: str): def del_recommended_app(recommended_app_id: str): - db.session.query(RecommendedApp).filter(RecommendedApp.id == recommended_app_id).delete( + db.session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete( synchronize_session=False ) @@ -126,9 +142,9 @@ def _delete_recommended_apps(tenant_id: str, app_id: str): def _delete_app_annotation_data(tenant_id: str, app_id: str): def del_annotation_hit_history(annotation_hit_history_id: str): - db.session.query(AppAnnotationHitHistory).filter( - AppAnnotationHitHistory.id == annotation_hit_history_id - ).delete(synchronize_session=False) + db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete( + synchronize_session=False + ) _delete_records( """select id from app_annotation_hit_histories where app_id=:app_id limit 1000""", @@ -138,7 +154,7 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str): ) def del_annotation_setting(annotation_setting_id: str): - db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.id == annotation_setting_id).delete( + db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete( synchronize_session=False ) @@ -152,7 +168,7 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str): def _delete_app_dataset_joins(tenant_id: str, app_id: str): def del_dataset_join(dataset_join_id: str): - db.session.query(AppDatasetJoin).filter(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False) + db.session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False) _delete_records( """select id from app_dataset_joins where app_id=:app_id limit 1000""", @@ -164,7 +180,7 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str): def _delete_app_workflows(tenant_id: str, app_id: str): def del_workflow(workflow_id: str): - db.session.query(Workflow).filter(Workflow.id == workflow_id).delete(synchronize_session=False) + db.session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False) _delete_records( """select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -175,34 +191,36 @@ def _delete_app_workflows(tenant_id: str, app_id: str): def _delete_app_workflow_runs(tenant_id: str, app_id: str): - def del_workflow_run(workflow_run_id: str): - db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).delete(synchronize_session=False) + """Delete all workflow runs for an app using the service repository.""" + session_maker = sessionmaker(bind=db.engine) + workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) - _delete_records( - """select id from workflow_runs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", - {"tenant_id": tenant_id, "app_id": app_id}, - del_workflow_run, - "workflow run", + deleted_count = workflow_run_repo.delete_runs_by_app( + tenant_id=tenant_id, + app_id=app_id, + batch_size=1000, ) + logging.info(f"Deleted {deleted_count} workflow runs for app {app_id}") + def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): - def del_workflow_node_execution(workflow_node_execution_id: str): - db.session.query(WorkflowNodeExecutionModel).filter( - WorkflowNodeExecutionModel.id == workflow_node_execution_id - ).delete(synchronize_session=False) + """Delete all workflow node executions for an app using the service repository.""" + session_maker = sessionmaker(bind=db.engine) + node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker) - _delete_records( - """select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""", - {"tenant_id": tenant_id, "app_id": app_id}, - del_workflow_node_execution, - "workflow node execution", + deleted_count = node_execution_repo.delete_executions_by_app( + tenant_id=tenant_id, + app_id=app_id, + batch_size=1000, ) + logging.info(f"Deleted {deleted_count} workflow node executions for app {app_id}") + def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def del_workflow_app_log(workflow_app_log_id: str): - db.session.query(WorkflowAppLog).filter(WorkflowAppLog.id == workflow_app_log_id).delete( + db.session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete( synchronize_session=False ) @@ -216,10 +234,10 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def _delete_app_conversations(tenant_id: str, app_id: str): def del_conversation(conversation_id: str): - db.session.query(PinnedConversation).filter(PinnedConversation.conversation_id == conversation_id).delete( + db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete( synchronize_session=False ) - db.session.query(Conversation).filter(Conversation.id == conversation_id).delete(synchronize_session=False) + db.session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False) _delete_records( """select id from conversations where app_id=:app_id limit 1000""", @@ -239,19 +257,19 @@ def _delete_conversation_variables(*, app_id: str): def _delete_app_messages(tenant_id: str, app_id: str): def del_message(message_id: str): - db.session.query(MessageFeedback).filter(MessageFeedback.message_id == message_id).delete( + db.session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete( synchronize_session=False ) - db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == message_id).delete( + db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete( synchronize_session=False ) - db.session.query(MessageChain).filter(MessageChain.message_id == message_id).delete(synchronize_session=False) - db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message_id).delete( + db.session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False) + db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete( synchronize_session=False ) - db.session.query(MessageFile).filter(MessageFile.message_id == message_id).delete(synchronize_session=False) - db.session.query(SavedMessage).filter(SavedMessage.message_id == message_id).delete(synchronize_session=False) - db.session.query(Message).filter(Message.id == message_id).delete() + db.session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False) + db.session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False) + db.session.query(Message).where(Message.id == message_id).delete() _delete_records( """select id from messages where app_id=:app_id limit 1000""", {"app_id": app_id}, del_message, "message" @@ -260,7 +278,7 @@ def _delete_app_messages(tenant_id: str, app_id: str): def _delete_workflow_tool_providers(tenant_id: str, app_id: str): def del_tool_provider(tool_provider_id: str): - db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.id == tool_provider_id).delete( + db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete( synchronize_session=False ) @@ -274,7 +292,7 @@ def _delete_workflow_tool_providers(tenant_id: str, app_id: str): def _delete_app_tag_bindings(tenant_id: str, app_id: str): def del_tag_binding(tag_binding_id: str): - db.session.query(TagBinding).filter(TagBinding.id == tag_binding_id).delete(synchronize_session=False) + db.session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False) _delete_records( """select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""", @@ -286,7 +304,7 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str): def _delete_end_users(tenant_id: str, app_id: str): def del_end_user(end_user_id: str): - db.session.query(EndUser).filter(EndUser.id == end_user_id).delete(synchronize_session=False) + db.session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False) _delete_records( """select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -298,7 +316,7 @@ def _delete_end_users(tenant_id: str, app_id: str): def _delete_trace_app_configs(tenant_id: str, app_id: str): def del_trace_app_config(trace_app_config_id: str): - db.session.query(TraceAppConfig).filter(TraceAppConfig.id == trace_app_config_id).delete( + db.session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete( synchronize_session=False ) diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index 0e2960788d..3f73cc7b40 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -22,7 +22,7 @@ def remove_document_from_index_task(document_id: str): logging.info(click.style("Start remove document segments from index: {}".format(document_id), fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).filter(Document.id == document_id).first() + document = db.session.query(Document).where(Document.id == document_id).first() if not document: logging.info(click.style("Document not found: {}".format(document_id), fg="red")) db.session.close() @@ -43,7 +43,7 @@ def remove_document_from_index_task(document_id: str): index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).all() + segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).all() index_node_ids = [segment.index_node_id for segment in segments] if index_node_ids: try: @@ -51,7 +51,7 @@ def remove_document_from_index_task(document_id: str): except Exception: logging.exception(f"clean dataset {dataset.id} from index failed") # update segment to disable - db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).update( + db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).update( { DocumentSegment.enabled: False, DocumentSegment.disabled_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index 8f8c3f9d81..58f0156afb 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -25,7 +25,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): documents: list[Document] = [] start_at = time.perf_counter() - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: logging.info(click.style("Dataset not found: {}".format(dataset_id), fg="red")) db.session.close() @@ -45,7 +45,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): ) except Exception as e: document = ( - db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) if document: document.indexing_status = "error" @@ -59,7 +59,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): logging.info(click.style("Start retry document: {}".format(document_id), fg="green")) document = ( - db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) if not document: logging.info(click.style("Document not found: {}".format(document_id), fg="yellow")) @@ -69,7 +69,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): # clean old data index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() if segments: index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index dba0a39c2d..539c2db80f 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -24,7 +24,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): """ start_at = time.perf_counter() - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if dataset is None: raise ValueError("Dataset not found") @@ -41,7 +41,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): ) except Exception as e: document = ( - db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) if document: document.indexing_status = "error" @@ -53,7 +53,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): return logging.info(click.style("Start sync website document: {}".format(document_id), fg="green")) - document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: logging.info(click.style("Document not found: {}".format(document_id), fg="yellow")) return @@ -61,7 +61,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): # clean old data index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() if segments: index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index diff --git a/api/templates/change_mail_confirm_new_template_en-US.html b/api/templates/change_mail_confirm_new_template_en-US.html new file mode 100644 index 0000000000..88721e787c --- /dev/null +++ b/api/templates/change_mail_confirm_new_template_en-US.html @@ -0,0 +1,125 @@ + + + + + + + + +
+
+ + Dify Logo +
+

Confirm Your New Email Address

+
+

You’re updating the email address linked to your Dify account.

+

To confirm this action, please use the verification code below.

+

This code will only be valid for the next 5 minutes:

+
+
+ {{code}} +
+

If you didn’t make this request, please ignore this email or contact support immediately.

+
+ + + + diff --git a/api/templates/change_mail_confirm_new_template_zh-CN.html b/api/templates/change_mail_confirm_new_template_zh-CN.html new file mode 100644 index 0000000000..25336ea1a1 --- /dev/null +++ b/api/templates/change_mail_confirm_new_template_zh-CN.html @@ -0,0 +1,125 @@ + + + + + + + + +
+
+ + Dify Logo +
+

确认您的邮箱地址变更

+
+

您正在更新与您的 Dify 账户关联的邮箱地址。

+

为了确认此操作,请使用以下验证码。

+

此验证码仅在接下来的5分钟内有效:

+
+
+ {{code}} +
+

如果您没有请求变更邮箱地址,请忽略此邮件或立即联系支持。

+
+ + + + diff --git a/api/templates/change_mail_confirm_old_template_en-US.html b/api/templates/change_mail_confirm_old_template_en-US.html new file mode 100644 index 0000000000..b20306aa87 --- /dev/null +++ b/api/templates/change_mail_confirm_old_template_en-US.html @@ -0,0 +1,125 @@ + + + + + + + + +
+
+ + Dify Logo +
+

Verify Your Request to Change Email

+
+

We received a request to change the email address associated with your Dify account.

+

To confirm this action, please use the verification code below.

+

This code will only be valid for the next 5 minutes:

+
+
+ {{code}} +
+

If you didn’t make this request, please ignore this email or contact support immediately.

+
+ + + + diff --git a/api/templates/change_mail_confirm_old_template_zh-CN.html b/api/templates/change_mail_confirm_old_template_zh-CN.html new file mode 100644 index 0000000000..23c9e46652 --- /dev/null +++ b/api/templates/change_mail_confirm_old_template_zh-CN.html @@ -0,0 +1,124 @@ + + + + + + + + +
+
+ + Dify Logo +
+

验证您的邮箱变更请求

+
+

我们收到了一个变更您 Dify 账户关联邮箱地址的请求。

+

此验证码仅在接下来的5分钟内有效:

+
+
+ {{code}} +
+

如果您没有请求变更邮箱地址,请忽略此邮件或立即联系支持。

+
+ + + + diff --git a/api/templates/clean_document_job_mail_template-US.html b/api/templates/clean_document_job_mail_template-US.html index 2d8f78b46a..97f3997c93 100644 --- a/api/templates/clean_document_job_mail_template-US.html +++ b/api/templates/clean_document_job_mail_template-US.html @@ -6,94 +6,135 @@ Documents Disabled Notification -