diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index a9580a3ba3..d684fe9144 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -8,13 +8,15 @@ body: label: Self Checks description: "To make sure we get to you in time, please check the following :)" options: + - label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542). + required: true - label: This is only for bug report, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general). required: true - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. required: true - - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). + - label: I confirm that I am using English to submit this report, otherwise it will be closed. required: true - - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" + - label: 【中文用户 & Non English User】请使用英语提交,否则会被关闭 :) required: true - label: "Please do not modify this template :) and fill in all the required fields." required: true @@ -42,20 +44,22 @@ body: attributes: label: Steps to reproduce description: We highly suggest including screenshots and a bug report log. Please use the right markdown syntax for code blocks. - placeholder: Having detailed steps helps us reproduce the bug. + placeholder: Having detailed steps helps us reproduce the bug. If you have logs, please use fenced code blocks (triple backticks ```) to format them. validations: required: true - type: textarea attributes: label: ✔️ Expected Behavior - placeholder: What were you expecting? + description: Describe what you expected to happen. + placeholder: What were you expecting? Please do not copy and paste the steps to reproduce here. validations: - required: false + required: true - type: textarea attributes: label: ❌ Actual Behavior - placeholder: What happened instead? + description: Describe what actually happened. + placeholder: What happened instead? Please do not copy and paste the steps to reproduce here. validations: required: false diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index 6877c382c4..c1666d24cf 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,5 +1,11 @@ blank_issues_enabled: false contact_links: + - name: "\U0001F4A1 Model Providers & Plugins" + url: "https://github.com/langgenius/dify-official-plugins/issues/new/choose" + about: Report issues with official plugins or model providers, you will need to provide the plugin version and other relevant details. + - name: "\U0001F4AC Documentation Issues" + url: "https://github.com/langgenius/dify-docs/issues/new" + about: Report issues with the documentation, such as typos, outdated information, or missing content. Please provide the specific section and details of the issue. - name: "\U0001F4E7 Discussions" url: https://github.com/langgenius/dify/discussions/categories/general - about: General discussions and request help from the community + about: General discussions and seek help from the community diff --git a/.github/ISSUE_TEMPLATE/document_issue.yml b/.github/ISSUE_TEMPLATE/document_issue.yml deleted file mode 100644 index 8fdbc0fb9a..0000000000 --- a/.github/ISSUE_TEMPLATE/document_issue.yml +++ /dev/null @@ -1,24 +0,0 @@ -name: "📚 Documentation Issue" -description: Report issues in our documentation -labels: - - documentation -body: - - type: checkboxes - attributes: - label: Self Checks - description: "To make sure we get to you in time, please check the following :)" - options: - - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. - required: true - - label: I confirm that I am using English to submit report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). - required: true - - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" - required: true - - label: "Please do not modify this template :) and fill in all the required fields." - required: true - - type: textarea - attributes: - label: Provide a description of requested docs changes - placeholder: Briefly describe which document needs to be corrected and why. - validations: - required: true diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index b1952c63a9..bd293e2442 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -8,11 +8,11 @@ body: label: Self Checks description: "To make sure we get to you in time, please check the following :)" options: + - label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542). + required: true - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. required: true - - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). - required: true - - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" + - label: I confirm that I am using English to submit this report, otherwise it will be closed. required: true - label: "Please do not modify this template :) and fill in all the required fields." required: true diff --git a/.github/ISSUE_TEMPLATE/translation_issue.yml b/.github/ISSUE_TEMPLATE/translation_issue.yml deleted file mode 100644 index f9c2dfb7d2..0000000000 --- a/.github/ISSUE_TEMPLATE/translation_issue.yml +++ /dev/null @@ -1,55 +0,0 @@ -name: "🌐 Localization/Translation issue" -description: Report incorrect translations. [please use English :)] -labels: - - translation -body: - - type: checkboxes - attributes: - label: Self Checks - description: "To make sure we get to you in time, please check the following :)" - options: - - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. - required: true - - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). - required: true - - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" - required: true - - label: "Please do not modify this template :) and fill in all the required fields." - required: true - - type: input - attributes: - label: Dify version - description: Hover over system tray icon or look at Settings - validations: - required: true - - type: input - attributes: - label: Utility with translation issue - placeholder: Some area - description: Please input here the utility with the translation issue - validations: - required: true - - type: input - attributes: - label: 🌐 Language affected - placeholder: "German" - validations: - required: true - - type: textarea - attributes: - label: ❌ Actual phrase(s) - placeholder: What is there? Please include a screenshot as that is extremely helpful. - validations: - required: true - - type: textarea - attributes: - label: ✔️ Expected phrase(s) - placeholder: What was expected? - validations: - required: true - - type: textarea - attributes: - label: ℹ Why is the current translation wrong - placeholder: Why do you feel this is incorrect? - validations: - required: true diff --git a/README.md b/README.md index b8bd6d0725..2909e0e6cf 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ Dify is an open-source platform for developing LLM applications. Its intuitive i
-The easiest way to start the Dify server is through [docker compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine: +The easiest way to start the Dify server is through [Docker Compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine: ```bash cd dify @@ -205,6 +205,7 @@ If you'd like to configure a highly-available setup, there are community-contrib - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Using Terraform for Deployment @@ -261,8 +262,8 @@ At the same time, please consider supporting Dify by sharing it on social media ## Security disclosure -To protect your privacy, please avoid posting security issues on GitHub. Instead, send your questions to security@dify.ai and we will provide you with a more detailed answer. +To protect your privacy, please avoid posting security issues on GitHub. Instead, report issues to security@dify.ai, and our team will respond with detailed answer. ## License -This repository is available under the [Dify Open Source License](LICENSE), which is essentially Apache 2.0 with a few additional restrictions. +This repository is licensed under the [Dify Open Source License](LICENSE), based on Apache 2.0 with additional conditions. diff --git a/README_AR.md b/README_AR.md index d93bca8646..e959ca0f78 100644 --- a/README_AR.md +++ b/README_AR.md @@ -188,6 +188,7 @@ docker compose up -d - [رسم بياني Helm من قبل @magicsong](https://github.com/magicsong/ai-charts) - [ملف YAML من قبل @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [ملف YAML من قبل @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 جديد! ملفات YAML (تدعم Dify v1.6.0) بواسطة @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### استخدام Terraform للتوزيع diff --git a/README_BN.md b/README_BN.md index 3efee3684d..29d7374ea5 100644 --- a/README_BN.md +++ b/README_BN.md @@ -204,6 +204,8 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 নতুন! YAML ফাইলসমূহ (Dify v1.6.0 সমর্থিত) তৈরি করেছেন @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) + #### টেরাফর্ম ব্যবহার করে ডিপ্লয় diff --git a/README_CN.md b/README_CN.md index 21e27429ec..486a368c09 100644 --- a/README_CN.md +++ b/README_CN.md @@ -194,9 +194,9 @@ docker compose up -d 如果您需要自定义配置,请参考 [.env.example](docker/.env.example) 文件中的注释,并更新 `.env` 文件中对应的值。此外,您可能需要根据您的具体部署环境和需求对 `docker-compose.yaml` 文件本身进行调整,例如更改镜像版本、端口映射或卷挂载。完成任何更改后,请重新运行 `docker-compose up -d`。您可以在[此处](https://docs.dify.ai/getting-started/install-self-hosted/environments)找到可用环境变量的完整列表。 -#### 使用 Helm Chart 部署 +#### 使用 Helm Chart 或 Kubernetes 资源清单(YAML)部署 -使用 [Helm Chart](https://helm.sh/) 版本或者 YAML 文件,可以在 Kubernetes 上部署 Dify。 +使用 [Helm Chart](https://helm.sh/) 版本或者 Kubernetes 资源清单(YAML),可以在 Kubernetes 上部署 Dify。 - [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) - [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) @@ -204,6 +204,10 @@ docker compose up -d - [YAML 文件 by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML 文件 (支持 Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) + + + #### 使用 Terraform 部署 使用 [terraform](https://www.terraform.io/) 一键将 Dify 部署到云平台 diff --git a/README_DE.md b/README_DE.md index 20c313035e..fce52c34c2 100644 --- a/README_DE.md +++ b/README_DE.md @@ -203,6 +203,7 @@ Falls Sie eine hochverfügbare Konfiguration einrichten möchten, gibt es von de - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Terraform für die Bereitstellung verwenden diff --git a/README_ES.md b/README_ES.md index e4b7df6686..6fd6dfcee8 100644 --- a/README_ES.md +++ b/README_ES.md @@ -203,6 +203,7 @@ Si desea configurar una configuración de alta disponibilidad, la comunidad prop - [Gráfico Helm por @magicsong](https://github.com/magicsong/ai-charts) - [Ficheros YAML por @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [Ficheros YAML por @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 ¡NUEVO! Archivos YAML (compatible con Dify v1.6.0) por @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Uso de Terraform para el despliegue diff --git a/README_FR.md b/README_FR.md index 8fd17fb7c3..b2209fb495 100644 --- a/README_FR.md +++ b/README_FR.md @@ -201,6 +201,7 @@ Si vous souhaitez configurer une configuration haute disponibilité, la communau - [Helm Chart par @magicsong](https://github.com/magicsong/ai-charts) - [Fichier YAML par @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [Fichier YAML par @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NOUVEAU ! Fichiers YAML (compatible avec Dify v1.6.0) par @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Utilisation de Terraform pour le déploiement diff --git a/README_JA.md b/README_JA.md index a3ee81e1f2..c658225f90 100644 --- a/README_JA.md +++ b/README_JA.md @@ -202,6 +202,7 @@ docker compose up -d - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 新着!YAML ファイル(Dify v1.6.0 対応)by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Terraformを使用したデプロイ diff --git a/README_KL.md b/README_KL.md index 3e5ab1a74f..bfafcc7407 100644 --- a/README_KL.md +++ b/README_KL.md @@ -201,6 +201,7 @@ If you'd like to configure a highly-available setup, there are community-contrib - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Terraform atorlugu pilersitsineq diff --git a/README_KR.md b/README_KR.md index 3c504900e1..282117e776 100644 --- a/README_KR.md +++ b/README_KR.md @@ -195,6 +195,7 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Terraform을 사용한 배포 diff --git a/README_PT.md b/README_PT.md index fb5f3662ae..576f6b48f7 100644 --- a/README_PT.md +++ b/README_PT.md @@ -200,6 +200,7 @@ Se deseja configurar uma instalação de alta disponibilidade, há [Helm Charts] - [Helm Chart de @magicsong](https://github.com/magicsong/ai-charts) - [Arquivo YAML por @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [Arquivo YAML por @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NOVO! Arquivos YAML (Compatível com Dify v1.6.0) por @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Usando o Terraform para Implantação diff --git a/README_SI.md b/README_SI.md index 647069a220..7ded001d86 100644 --- a/README_SI.md +++ b/README_SI.md @@ -201,6 +201,7 @@ Star Dify on GitHub and be instantly notified of new releases. - [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Uporaba Terraform za uvajanje diff --git a/README_TR.md b/README_TR.md index f52335646a..6e94e54fa0 100644 --- a/README_TR.md +++ b/README_TR.md @@ -194,6 +194,7 @@ Yüksek kullanılabilirliğe sahip bir kurulum yapılandırmak isterseniz, Dify' - [@BorisPolonsky tarafından Helm Chart](https://github.com/BorisPolonsky/dify-helm) - [@Winson-030 tarafından YAML dosyası](https://github.com/Winson-030/dify-kubernetes) - [@wyy-holding tarafından YAML dosyası](https://github.com/wyy-holding/dify-k8s) +- [🚀 YENİ! YAML dosyaları (Dify v1.6.0 destekli) @Zhoneym tarafından](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Dağıtım için Terraform Kullanımı diff --git a/README_TW.md b/README_TW.md index 71082ff893..6e3e22b5c1 100644 --- a/README_TW.md +++ b/README_TW.md @@ -197,12 +197,13 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify 如果您需要自定義配置,請參考我們的 [.env.example](docker/.env.example) 文件中的註釋,並在您的 `.env` 文件中更新相應的值。此外,根據您特定的部署環境和需求,您可能需要調整 `docker-compose.yaml` 文件本身,例如更改映像版本、端口映射或卷掛載。進行任何更改後,請重新運行 `docker-compose up -d`。您可以在[這裡](https://docs.dify.ai/getting-started/install-self-hosted/environments)找到可用環境變數的完整列表。 -如果您想配置高可用性設置,社區貢獻的 [Helm Charts](https://helm.sh/) 和 YAML 文件允許在 Kubernetes 上部署 Dify。 +如果您想配置高可用性設置,社區貢獻的 [Helm Charts](https://helm.sh/) 和 Kubernetes 資源清單(YAML)允許在 Kubernetes 上部署 Dify。 - [由 @LeoQuote 提供的 Helm Chart](https://github.com/douban/charts/tree/master/charts/dify) - [由 @BorisPolonsky 提供的 Helm Chart](https://github.com/BorisPolonsky/dify-helm) - [由 @Winson-030 提供的 YAML 文件](https://github.com/Winson-030/dify-kubernetes) - [由 @wyy-holding 提供的 YAML 文件](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML 檔案(支援 Dify v1.6.0)by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) ### 使用 Terraform 進行部署 diff --git a/README_VI.md b/README_VI.md index 58d8434fff..51314e6de5 100644 --- a/README_VI.md +++ b/README_VI.md @@ -196,6 +196,7 @@ Nếu bạn muốn cấu hình một cài đặt có độ sẵn sàng cao, có - [Helm Chart bởi @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) - [Tệp YAML bởi @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [Tệp YAML bởi @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 MỚI! Tệp YAML (Hỗ trợ Dify v1.6.0) bởi @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Sử dụng Terraform để Triển khai diff --git a/api/.env.example b/api/.env.example index a7ea6cf937..3fe95c44b5 100644 --- a/api/.env.example +++ b/api/.env.example @@ -449,6 +449,19 @@ MAX_VARIABLE_SIZE=204800 # hybrid: Save new data to object storage, read from both object storage and RDBMS WORKFLOW_NODE_EXECUTION_STORAGE=rdbms +# Repository configuration +# Core workflow execution repository implementation +CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository + +# Core workflow node execution repository implementation +CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository + +# API workflow node execution repository implementation +API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository + +# API workflow run repository implementation +API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository + # App configuration APP_MAX_EXECUTION_TIME=1200 APP_MAX_ACTIVE_REQUESTS=0 @@ -482,6 +495,8 @@ ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id} # Reset password token expiry minutes RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 +CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES=5 +OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5 CREATE_TIDB_SERVICE_JOB_ENABLED=false @@ -492,6 +507,8 @@ LOGIN_LOCKOUT_DURATION=86400 # Enable OpenTelemetry ENABLE_OTEL=false +OTLP_TRACE_ENDPOINT= +OTLP_METRIC_ENDPOINT= OTLP_BASE_ENDPOINT=http://localhost:4318 OTLP_API_KEY= OTEL_EXPORTER_OTLP_PROTOCOL= diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 963fcbedf9..f1d529355d 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -31,6 +31,15 @@ class SecurityConfig(BaseSettings): description="Duration in minutes for which a password reset token remains valid", default=5, ) + CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES: PositiveInt = Field( + description="Duration in minutes for which a change email token remains valid", + default=5, + ) + + OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES: PositiveInt = Field( + description="Duration in minutes for which a owner transfer token remains valid", + default=5, + ) LOGIN_DISABLED: bool = Field( description="Whether to disable login checks", @@ -537,6 +546,33 @@ class WorkflowNodeExecutionConfig(BaseSettings): ) +class RepositoryConfig(BaseSettings): + """ + Configuration for repository implementations + """ + + CORE_WORKFLOW_EXECUTION_REPOSITORY: str = Field( + description="Repository implementation for WorkflowExecution. Specify as a module path", + default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository", + ) + + CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field( + description="Repository implementation for WorkflowNodeExecution. Specify as a module path", + default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository", + ) + + API_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field( + description="Service-layer repository implementation for WorkflowNodeExecutionModel operations. " + "Specify as a module path", + default="repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository", + ) + + API_WORKFLOW_RUN_REPOSITORY: str = Field( + description="Service-layer repository implementation for WorkflowRun operations. Specify as a module path", + default="repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository", + ) + + class AuthConfig(BaseSettings): """ Configuration for authentication and OAuth @@ -587,6 +623,16 @@ class AuthConfig(BaseSettings): default=86400, ) + CHANGE_EMAIL_LOCKOUT_DURATION: PositiveInt = Field( + description="Time (in seconds) a user must wait before retrying change email after exceeding the rate limit.", + default=86400, + ) + + OWNER_TRANSFER_LOCKOUT_DURATION: PositiveInt = Field( + description="Time (in seconds) a user must wait before retrying owner transfer after exceeding the rate limit.", + default=86400, + ) + class ModerationConfig(BaseSettings): """ @@ -903,6 +949,7 @@ class FeatureConfig( MultiModalTransferConfig, PositionConfig, RagEtlConfig, + RepositoryConfig, SecurityConfig, ToolConfig, UpdateConfig, diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 427602676f..0c0c06dd46 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -162,6 +162,11 @@ class DatabaseConfig(BaseSettings): default=3600, ) + SQLALCHEMY_POOL_USE_LIFO: bool = Field( + description="If True, SQLAlchemy will use last-in-first-out way to retrieve connections from pool.", + default=False, + ) + SQLALCHEMY_POOL_PRE_PING: bool = Field( description="If True, enables connection pool pre-ping feature to check connections.", default=False, @@ -199,6 +204,7 @@ class DatabaseConfig(BaseSettings): "pool_recycle": self.SQLALCHEMY_POOL_RECYCLE, "pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING, "connect_args": connect_args, + "pool_use_lifo": self.SQLALCHEMY_POOL_USE_LIFO, } diff --git a/api/configs/observability/otel/otel_config.py b/api/configs/observability/otel/otel_config.py index 1b88ddcfe6..7572a696ce 100644 --- a/api/configs/observability/otel/otel_config.py +++ b/api/configs/observability/otel/otel_config.py @@ -12,6 +12,16 @@ class OTelConfig(BaseSettings): default=False, ) + OTLP_TRACE_ENDPOINT: str = Field( + description="OTLP trace endpoint", + default="", + ) + + OTLP_METRIC_ENDPOINT: str = Field( + description="OTLP metric endpoint", + default="", + ) + OTLP_BASE_ENDPOINT: str = Field( description="OTLP base endpoint", default="http://localhost:4318", diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 860166a61a..9fe32dde6d 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -151,6 +151,7 @@ class AppApi(Resource): parser.add_argument("icon", type=str, location="json") parser.add_argument("icon_background", type=str, location="json") parser.add_argument("use_icon_as_answer_icon", type=bool, location="json") + parser.add_argument("max_active_requests", type=int, location="json") args = parser.parse_args() app_service = AppService() diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index 0f53860f56..503393f264 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -35,16 +35,20 @@ class AppMCPServerController(Resource): @get_app_model @marshal_with(app_server_fields) def post(self, app_model): - # The role of the current user in the ta table must be editor, admin, or owner if not current_user.is_editor: raise NotFound() parser = reqparse.RequestParser() - parser.add_argument("description", type=str, required=True, location="json") + parser.add_argument("description", type=str, required=False, location="json") parser.add_argument("parameters", type=dict, required=True, location="json") args = parser.parse_args() + + description = args.get("description") + if not description: + description = app_model.description or "" + server = AppMCPServer( name=app_model.name, - description=args["description"], + description=description, parameters=json.dumps(args["parameters"], ensure_ascii=False), status=AppMCPServerStatus.ACTIVE, app_id=app_model.id, @@ -65,14 +69,22 @@ class AppMCPServerController(Resource): raise NotFound() parser = reqparse.RequestParser() parser.add_argument("id", type=str, required=True, location="json") - parser.add_argument("description", type=str, required=True, location="json") + parser.add_argument("description", type=str, required=False, location="json") parser.add_argument("parameters", type=dict, required=True, location="json") parser.add_argument("status", type=str, required=False, location="json") args = parser.parse_args() server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first() if not server: raise NotFound() - server.description = args["description"] + + description = args.get("description") + if description is None: + pass + elif not description: + server.description = app_model.description or "" + else: + server.description = description + server.parameters = json.dumps(args["parameters"], ensure_ascii=False) if args["status"]: if args["status"] not in [status.value for status in AppMCPServerStatus]: diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 86aed77412..32b64d10c5 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -2,6 +2,7 @@ from datetime import datetime from decimal import Decimal import pytz +import sqlalchemy as sa from flask import jsonify from flask_login import current_user from flask_restful import Resource, reqparse @@ -9,10 +10,11 @@ from flask_restful import Resource, reqparse from controllers.console import api from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required +from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from libs.helper import DatetimeString from libs.login import login_required -from models.model import AppMode +from models import AppMode, Message class DailyMessageStatistic(Resource): @@ -85,46 +87,41 @@ class DailyConversationStatistic(Resource): parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """SELECT - DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, - COUNT(DISTINCT messages.conversation_id) AS conversation_count -FROM - messages -WHERE - app_id = :app_id""" - arg_dict = {"tz": account.timezone, "app_id": app_model.id} - timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc + stmt = ( + sa.select( + sa.func.date( + sa.func.date_trunc("day", sa.text("created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz")) + ).label("date"), + sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"), + ) + .select_from(Message) + .where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER.value) + ) + if args["start"]: start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) - start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - - sql_query += " AND created_at >= :start" - arg_dict["start"] = start_datetime_utc + stmt = stmt.where(Message.created_at >= start_datetime_utc) if args["end"]: end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) - end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) + stmt = stmt.where(Message.created_at < end_datetime_utc) - sql_query += " AND created_at < :end" - arg_dict["end"] = end_datetime_utc - - sql_query += " GROUP BY date ORDER BY date" + stmt = stmt.group_by("date").order_by("date") response_data = [] - with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) - for i in rs: - response_data.append({"date": str(i.date), "conversation_count": i.conversation_count}) + rs = conn.execute(stmt, {"tz": account.timezone}) + for row in rs: + response_data.append({"date": str(row.date), "conversation_count": row.conversation_count}) return jsonify({"data": response_data}) diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 00d6fa3cbf..ba93f82756 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -68,13 +68,18 @@ def _create_pagination_parser(): return parser +def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str: + value_type = workflow_draft_var.value_type + return value_type.exposed_type().value + + _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = { "id": fields.String, "type": fields.String(attribute=lambda model: model.get_variable_type()), "name": fields.String, "description": fields.String, "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()), - "value_type": fields.String, + "value_type": fields.String(attribute=_serialize_variable_type), "edited": fields.Boolean(attribute=lambda model: model.edited), "visible": fields.Boolean, } @@ -90,7 +95,7 @@ _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = { "name": fields.String, "description": fields.String, "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()), - "value_type": fields.String, + "value_type": fields.String(attribute=_serialize_variable_type), "edited": fields.Boolean(attribute=lambda model: model.edited), "visible": fields.Boolean, } @@ -396,7 +401,7 @@ class EnvironmentVariableCollectionApi(Resource): "name": v.name, "description": v.description, "selector": v.selector, - "value_type": v.value_type.value, + "value_type": v.value_type.exposed_type().value, "value": v.value, # Do not track edited for env vars. "edited": False, diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index 03b60610aa..3322350e25 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -35,8 +35,6 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[ raise AppNotFoundError() app_mode = AppMode.value_of(app_model.mode) - if app_mode == AppMode.CHANNEL: - raise AppNotFoundError() if mode is not None: if isinstance(mode, list): diff --git a/api/controllers/console/auth/error.py b/api/controllers/console/auth/error.py index b40934dbf5..8c5e23de58 100644 --- a/api/controllers/console/auth/error.py +++ b/api/controllers/console/auth/error.py @@ -27,7 +27,19 @@ class InvalidTokenError(BaseHTTPException): class PasswordResetRateLimitExceededError(BaseHTTPException): error_code = "password_reset_rate_limit_exceeded" - description = "Too many password reset emails have been sent. Please try again in 1 minutes." + description = "Too many password reset emails have been sent. Please try again in 1 minute." + code = 429 + + +class EmailChangeRateLimitExceededError(BaseHTTPException): + error_code = "email_change_rate_limit_exceeded" + description = "Too many email change emails have been sent. Please try again in 1 minute." + code = 429 + + +class OwnerTransferRateLimitExceededError(BaseHTTPException): + error_code = "owner_transfer_rate_limit_exceeded" + description = "Too many owner transfer emails have been sent. Please try again in 1 minute." code = 429 @@ -65,3 +77,39 @@ class EmailPasswordResetLimitError(BaseHTTPException): error_code = "email_password_reset_limit" description = "Too many failed password reset attempts. Please try again in 24 hours." code = 429 + + +class EmailChangeLimitError(BaseHTTPException): + error_code = "email_change_limit" + description = "Too many failed email change attempts. Please try again in 24 hours." + code = 429 + + +class EmailAlreadyInUseError(BaseHTTPException): + error_code = "email_already_in_use" + description = "A user with this email already exists." + code = 400 + + +class OwnerTransferLimitError(BaseHTTPException): + error_code = "owner_transfer_limit" + description = "Too many failed owner transfer attempts. Please try again in 24 hours." + code = 429 + + +class NotOwnerError(BaseHTTPException): + error_code = "not_owner" + description = "You are not the owner of the workspace." + code = 400 + + +class CannotTransferOwnerToSelfError(BaseHTTPException): + error_code = "cannot_transfer_owner_to_self" + description = "You cannot transfer ownership to yourself." + code = 400 + + +class MemberNotInTenantError(BaseHTTPException): + error_code = "member_not_in_tenant" + description = "The member is not in the workspace." + code = 400 diff --git a/api/controllers/console/datasets/error.py b/api/controllers/console/datasets/error.py index b80c4402cd..1fa6057a39 100644 --- a/api/controllers/console/datasets/error.py +++ b/api/controllers/console/datasets/error.py @@ -25,12 +25,6 @@ class UnsupportedFileTypeError(BaseHTTPException): code = 415 -class HighQualityDatasetOnlyError(BaseHTTPException): - error_code = "high_quality_dataset_only" - description = "Current operation only supports 'high-quality' datasets." - code = 400 - - class DatasetNotInitializedError(BaseHTTPException): error_code = "dataset_not_initialized" description = "The dataset is still being initialized or indexing. Please wait a moment." diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index a9dbf44456..1f22e3fd01 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -4,10 +4,20 @@ import pytz from flask import request from flask_login import current_user from flask_restful import Resource, fields, marshal_with, reqparse +from sqlalchemy import select +from sqlalchemy.orm import Session from configs import dify_config from constants.languages import supported_language from controllers.console import api +from controllers.console.auth.error import ( + EmailAlreadyInUseError, + EmailChangeLimitError, + EmailCodeError, + InvalidEmailError, + InvalidTokenError, +) +from controllers.console.error import AccountNotFound, EmailSendIpLimitError from controllers.console.workspace.error import ( AccountAlreadyInitedError, CurrentPasswordIncorrectError, @@ -18,15 +28,17 @@ from controllers.console.workspace.error import ( from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_enabled, + enable_change_email, enterprise_license_required, only_edition_cloud, setup_required, ) from extensions.ext_database import db from fields.member_fields import account_fields -from libs.helper import TimestampField, timezone +from libs.helper import TimestampField, email, extract_remote_ip, timezone from libs.login import login_required from models import AccountIntegrate, InvitationCode +from models.account import Account from services.account_service import AccountService from services.billing_service import BillingService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError @@ -369,6 +381,134 @@ class EducationAutoCompleteApi(Resource): return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"]) +class ChangeEmailSendEmailApi(Resource): + @enable_change_email + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("language", type=str, required=False, location="json") + parser.add_argument("phase", type=str, required=False, location="json") + parser.add_argument("token", type=str, required=False, location="json") + args = parser.parse_args() + + ip_address = extract_remote_ip(request) + if AccountService.is_email_send_ip_limit(ip_address): + raise EmailSendIpLimitError() + + if args["language"] is not None and args["language"] == "zh-Hans": + language = "zh-Hans" + else: + language = "en-US" + account = None + user_email = args["email"] + if args["phase"] is not None and args["phase"] == "new_email": + if args["token"] is None: + raise InvalidTokenError() + + reset_data = AccountService.get_change_email_data(args["token"]) + if reset_data is None: + raise InvalidTokenError() + user_email = reset_data.get("email", "") + + if user_email != current_user.email: + raise InvalidEmailError() + else: + with Session(db.engine) as session: + account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() + if account is None: + raise AccountNotFound() + + token = AccountService.send_change_email_email( + account=account, email=args["email"], old_email=user_email, language=language, phase=args["phase"] + ) + return {"result": "success", "data": token} + + +class ChangeEmailCheckApi(Resource): + @enable_change_email + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("code", type=str, required=True, location="json") + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + + user_email = args["email"] + + is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args["email"]) + if is_change_email_error_rate_limit: + raise EmailChangeLimitError() + + token_data = AccountService.get_change_email_data(args["token"]) + if token_data is None: + raise InvalidTokenError() + + if user_email != token_data.get("email"): + raise InvalidEmailError() + + if args["code"] != token_data.get("code"): + AccountService.add_change_email_error_rate_limit(args["email"]) + raise EmailCodeError() + + # Verified, revoke the first token + AccountService.revoke_change_email_token(args["token"]) + + # Refresh token data by generating a new token + _, new_token = AccountService.generate_change_email_token( + user_email, code=args["code"], old_email=token_data.get("old_email"), additional_data={} + ) + + AccountService.reset_change_email_error_rate_limit(args["email"]) + return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + + +class ChangeEmailResetApi(Resource): + @enable_change_email + @setup_required + @login_required + @account_initialization_required + @marshal_with(account_fields) + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("new_email", type=email, required=True, location="json") + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + + reset_data = AccountService.get_change_email_data(args["token"]) + if not reset_data: + raise InvalidTokenError() + + AccountService.revoke_change_email_token(args["token"]) + + if not AccountService.check_email_unique(args["new_email"]): + raise EmailAlreadyInUseError() + + old_email = reset_data.get("old_email", "") + if current_user.email != old_email: + raise AccountNotFound() + + updated_account = AccountService.update_account(current_user, email=args["new_email"]) + + return updated_account + + +class CheckEmailUnique(Resource): + @setup_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + args = parser.parse_args() + if not AccountService.check_email_unique(args["email"]): + raise EmailAlreadyInUseError() + return {"result": "success"} + + # Register API resources api.add_resource(AccountInitApi, "/account/init") api.add_resource(AccountProfileApi, "/account/profile") @@ -385,5 +525,10 @@ api.add_resource(AccountDeleteUpdateFeedbackApi, "/account/delete/feedback") api.add_resource(EducationVerifyApi, "/account/education/verify") api.add_resource(EducationApi, "/account/education") api.add_resource(EducationAutoCompleteApi, "/account/education/autocomplete") +# Change email +api.add_resource(ChangeEmailSendEmailApi, "/account/change-email") +api.add_resource(ChangeEmailCheckApi, "/account/change-email/validity") +api.add_resource(ChangeEmailResetApi, "/account/change-email/reset") +api.add_resource(CheckEmailUnique, "/account/change-email/check-email-unique") # api.add_resource(AccountEmailApi, '/account/email') # api.add_resource(AccountEmailVerifyApi, '/account/email-verify') diff --git a/api/controllers/console/workspace/error.py b/api/controllers/console/workspace/error.py index 8b70ca62b9..4427d1ff72 100644 --- a/api/controllers/console/workspace/error.py +++ b/api/controllers/console/workspace/error.py @@ -13,12 +13,6 @@ class CurrentPasswordIncorrectError(BaseHTTPException): code = 400 -class ProviderRequestFailedError(BaseHTTPException): - error_code = "provider_request_failed" - description = None - code = 400 - - class InvalidInvitationCodeError(BaseHTTPException): error_code = "invalid_invitation_code" description = "Invalid invitation code." diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 48225ac90d..b1f79ffdec 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,22 +1,34 @@ from urllib import parse +from flask import request from flask_login import current_user from flask_restful import Resource, abort, marshal_with, reqparse import services from configs import dify_config from controllers.console import api -from controllers.console.error import WorkspaceMembersLimitExceeded +from controllers.console.auth.error import ( + CannotTransferOwnerToSelfError, + EmailCodeError, + InvalidEmailError, + InvalidTokenError, + MemberNotInTenantError, + NotOwnerError, + OwnerTransferLimitError, +) +from controllers.console.error import EmailSendIpLimitError, WorkspaceMembersLimitExceeded from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_resource_check, + is_allow_transfer_owner, setup_required, ) from extensions.ext_database import db from fields.member_fields import account_with_role_list_fields +from libs.helper import extract_remote_ip from libs.login import login_required from models.account import Account, TenantAccountRole -from services.account_service import RegisterService, TenantService +from services.account_service import AccountService, RegisterService, TenantService from services.errors.account import AccountAlreadyInTenantError from services.feature_service import FeatureService @@ -156,8 +168,146 @@ class DatasetOperatorMemberListApi(Resource): return {"result": "success", "accounts": members}, 200 +class SendOwnerTransferEmailApi(Resource): + """Send owner transfer email.""" + + @setup_required + @login_required + @account_initialization_required + @is_allow_transfer_owner + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("language", type=str, required=False, location="json") + args = parser.parse_args() + ip_address = extract_remote_ip(request) + if AccountService.is_email_send_ip_limit(ip_address): + raise EmailSendIpLimitError() + + # check if the current user is the owner of the workspace + if not TenantService.is_owner(current_user, current_user.current_tenant): + raise NotOwnerError() + + if args["language"] is not None and args["language"] == "zh-Hans": + language = "zh-Hans" + else: + language = "en-US" + + email = current_user.email + + token = AccountService.send_owner_transfer_email( + account=current_user, + email=email, + language=language, + workspace_name=current_user.current_tenant.name, + ) + + return {"result": "success", "data": token} + + +class OwnerTransferCheckApi(Resource): + @setup_required + @login_required + @account_initialization_required + @is_allow_transfer_owner + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("code", type=str, required=True, location="json") + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + # check if the current user is the owner of the workspace + if not TenantService.is_owner(current_user, current_user.current_tenant): + raise NotOwnerError() + + user_email = current_user.email + + is_owner_transfer_error_rate_limit = AccountService.is_owner_transfer_error_rate_limit(user_email) + if is_owner_transfer_error_rate_limit: + raise OwnerTransferLimitError() + + token_data = AccountService.get_owner_transfer_data(args["token"]) + if token_data is None: + raise InvalidTokenError() + + if user_email != token_data.get("email"): + raise InvalidEmailError() + + if args["code"] != token_data.get("code"): + AccountService.add_owner_transfer_error_rate_limit(user_email) + raise EmailCodeError() + + # Verified, revoke the first token + AccountService.revoke_owner_transfer_token(args["token"]) + + # Refresh token data by generating a new token + _, new_token = AccountService.generate_owner_transfer_token(user_email, code=args["code"], additional_data={}) + + AccountService.reset_owner_transfer_error_rate_limit(user_email) + return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + + +class OwnerTransfer(Resource): + @setup_required + @login_required + @account_initialization_required + @is_allow_transfer_owner + def post(self, member_id): + parser = reqparse.RequestParser() + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + + # check if the current user is the owner of the workspace + if not TenantService.is_owner(current_user, current_user.current_tenant): + raise NotOwnerError() + + if current_user.id == str(member_id): + raise CannotTransferOwnerToSelfError() + + transfer_token_data = AccountService.get_owner_transfer_data(args["token"]) + if not transfer_token_data: + raise InvalidTokenError() + + if transfer_token_data.get("email") != current_user.email: + raise InvalidEmailError() + + AccountService.revoke_owner_transfer_token(args["token"]) + + member = db.session.get(Account, str(member_id)) + if not member: + abort(404) + else: + member_account = member + if not TenantService.is_member(member_account, current_user.current_tenant): + raise MemberNotInTenantError() + + try: + assert member is not None, "Member not found" + TenantService.update_member_role(current_user.current_tenant, member, "owner", current_user) + + AccountService.send_new_owner_transfer_notify_email( + account=member, + email=member.email, + workspace_name=current_user.current_tenant.name, + ) + + AccountService.send_old_owner_transfer_notify_email( + account=current_user, + email=current_user.email, + workspace_name=current_user.current_tenant.name, + new_owner_email=member.email, + ) + + except Exception as e: + raise ValueError(str(e)) + + return {"result": "success"} + + api.add_resource(MemberListApi, "/workspaces/current/members") api.add_resource(MemberInviteEmailApi, "/workspaces/current/members/invite-email") api.add_resource(MemberCancelInviteApi, "/workspaces/current/members/") api.add_resource(MemberUpdateRoleApi, "/workspaces/current/members//update-role") api.add_resource(DatasetOperatorMemberListApi, "/workspaces/current/dataset-operators") +# owner transfer +api.add_resource(SendOwnerTransferEmailApi, "/workspaces/current/members/send-owner-transfer-confirm-email") +api.add_resource(OwnerTransferCheckApi, "/workspaces/current/members/owner-transfer-check") +api.add_resource(OwnerTransfer, "/workspaces/current/members//owner-transfer") diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index ca122772de..d862dac373 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -235,3 +235,29 @@ def email_password_login_enabled(view): abort(403) return decorated + + +def enable_change_email(view): + @wraps(view) + def decorated(*args, **kwargs): + features = FeatureService.get_system_features() + if features.enable_change_email: + return view(*args, **kwargs) + + # otherwise, return 403 + abort(403) + + return decorated + + +def is_allow_transfer_owner(view): + @wraps(view) + def decorated(*args, **kwargs): + features = FeatureService.get_features(current_user.current_tenant_id) + if features.is_allow_transfer_workspace: + return view(*args, **kwargs) + + # otherwise, return 403 + abort(403) + + return decorated diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index efb4acc5fb..ac2ebf2b09 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -3,7 +3,7 @@ import logging from dateutil.parser import isoparse from flask_restful import Resource, fields, marshal_with, reqparse from flask_restful.inputs import int_range -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import InternalServerError from controllers.service_api import api @@ -30,7 +30,7 @@ from fields.workflow_app_log_fields import workflow_app_log_pagination_fields from libs import helper from libs.helper import TimestampField from models.model import App, AppMode, EndUser -from models.workflow import WorkflowRun +from repositories.factory import DifyAPIRepositoryFactory from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError from services.workflow_app_service import WorkflowAppService @@ -63,7 +63,15 @@ class WorkflowRunDetailApi(Resource): if app_mode not in [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]: raise NotWorkflowAppError() - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() + # Use repository to get workflow run + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + + workflow_run = workflow_run_repo.get_workflow_run_by_id( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + run_id=workflow_run_id, + ) return workflow_run diff --git a/api/controllers/service_api/dataset/error.py b/api/controllers/service_api/dataset/error.py index 5ff5e08c72..ecc47b40a1 100644 --- a/api/controllers/service_api/dataset/error.py +++ b/api/controllers/service_api/dataset/error.py @@ -25,12 +25,6 @@ class UnsupportedFileTypeError(BaseHTTPException): code = 415 -class HighQualityDatasetOnlyError(BaseHTTPException): - error_code = "high_quality_dataset_only" - description = "Current operation only supports 'high-quality' datasets." - code = 400 - - class DatasetNotInitializedError(BaseHTTPException): error_code = "dataset_not_initialized" description = "The dataset is still being initialized or indexing. Please wait a moment." diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 0d304de97a..28bf4a9a23 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -3,6 +3,8 @@ import logging import uuid from typing import Optional, Union, cast +from sqlalchemy import select + from core.agent.entities import AgentEntity, AgentToolEntity from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig @@ -417,12 +419,15 @@ class BaseAgentRunner(AppRunner): if isinstance(prompt_message, SystemPromptMessage): result.append(prompt_message) - messages: list[Message] = ( - db.session.query(Message) - .filter( - Message.conversation_id == self.message.conversation_id, + messages = ( + ( + db.session.execute( + select(Message) + .where(Message.conversation_id == self.message.conversation_id) + .order_by(Message.created_at.desc()) + ) ) - .order_by(Message.created_at.desc()) + .scalars() .all() ) diff --git a/api/core/agent/plugin_entities.py b/api/core/agent/plugin_entities.py index 3b48288710..a3438fc2c7 100644 --- a/api/core/agent/plugin_entities.py +++ b/api/core/agent/plugin_entities.py @@ -41,6 +41,7 @@ class AgentStrategyParameter(PluginParameter): APP_SELECTOR = CommonParameterType.APP_SELECTOR.value MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value + ANY = CommonParameterType.ANY.value # deprecated, should not use. SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 7877408cef..4b8f5ebe27 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -25,8 +25,7 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.get_thread_messages_length import get_thread_messages_length -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.repositories.draft_variable_repository import ( DraftVariableSaverFactory, ) @@ -183,14 +182,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING else: workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=workflow_triggered_from, ) # Create workflow node execution repository - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -260,14 +259,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -343,14 +342,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 840a3c9d3b..af15324f46 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -16,9 +16,10 @@ from core.app.entities.queue_entities import ( QueueTextChunkEvent, ) from core.moderation.base import ModerationError +from core.variables.variables import VariableUnion from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey +from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db @@ -64,7 +65,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): if not workflow: raise ValueError("Workflow not initialized") - user_id = None + user_id: str | None = None if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() if end_user: @@ -136,23 +137,25 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): session.commit() # Create a variable pool. - system_inputs = { - SystemVariableKey.QUERY: query, - SystemVariableKey.FILES: files, - SystemVariableKey.CONVERSATION_ID: self.conversation.id, - SystemVariableKey.USER_ID: user_id, - SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count, - SystemVariableKey.APP_ID: app_config.app_id, - SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, - SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_run_id, - } + system_inputs = SystemVariable( + query=query, + files=files, + conversation_id=self.conversation.id, + user_id=user_id, + dialogue_count=self._dialogue_count, + app_id=app_config.app_id, + workflow_id=app_config.workflow_id, + workflow_execution_id=self.application_generate_entity.workflow_run_id, + ) # init variable pool variable_pool = VariablePool( system_variables=system_inputs, user_inputs=inputs, environment_variables=workflow.environment_variables, - conversation_variables=conversation_variables, + # Based on the definition of `VariableUnion`, + # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. + conversation_variables=cast(list[VariableUnion], conversation_variables), ) # init graph diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 4c52fc3e83..1dc9796d5b 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -61,12 +61,12 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.model_runtime.entities.llm_entities import LLMUsage from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes import NodeType from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.system_variable import SystemVariable from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager from events.message_event import message_was_created from extensions.ext_database import db @@ -116,16 +116,16 @@ class AdvancedChatAppGenerateTaskPipeline: self._workflow_cycle_manager = WorkflowCycleManager( application_generate_entity=application_generate_entity, - workflow_system_variables={ - SystemVariableKey.QUERY: message.query, - SystemVariableKey.FILES: application_generate_entity.files, - SystemVariableKey.CONVERSATION_ID: conversation.id, - SystemVariableKey.USER_ID: user_session_id, - SystemVariableKey.DIALOGUE_COUNT: dialogue_count, - SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, - SystemVariableKey.WORKFLOW_ID: workflow.id, - SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_run_id, - }, + workflow_system_variables=SystemVariable( + query=message.query, + files=application_generate_entity.files, + conversation_id=conversation.id, + user_id=user_session_id, + dialogue_count=dialogue_count, + app_id=application_generate_entity.app_config.app_id, + workflow_id=workflow.id, + workflow_execution_id=application_generate_entity.workflow_run_id, + ), workflow_info=CycleManagerWorkflowInfo( workflow_id=workflow.id, workflow_type=WorkflowType(workflow.type), diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index a3f0cf7f9f..428db607fa 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -38,69 +38,6 @@ _logger = logging.getLogger(__name__) class AppRunner: - def get_pre_calculate_rest_tokens( - self, - app_record: App, - model_config: ModelConfigWithCredentialsEntity, - prompt_template_entity: PromptTemplateEntity, - inputs: Mapping[str, str], - files: Sequence["File"], - query: Optional[str] = None, - ) -> int: - """ - Get pre calculate rest tokens - :param app_record: app record - :param model_config: model config entity - :param prompt_template_entity: prompt template entity - :param inputs: inputs - :param files: files - :param query: query - :return: - """ - # Invoke model - model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, model=model_config.model - ) - - model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - - max_tokens = 0 - for parameter_rule in model_config.model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template or "") - ) or 0 - - if model_context_tokens is None: - return -1 - - if max_tokens is None: - max_tokens = 0 - - # get prompt messages without memory and context - prompt_messages, stop = self.organize_prompt_messages( - app_record=app_record, - model_config=model_config, - prompt_template_entity=prompt_template_entity, - inputs=inputs, - files=files, - query=query, - ) - - prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) - - rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens - if rest_tokens < 0: - raise InvokeBadRequestError( - "Query or prefix prompt is too long, you can reduce the prefix prompt, " - "or shrink the max token, or switch to a llm with a larger token limit size." - ) - - return rest_tokens - def recalc_llm_max_tokens( self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage] ): diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 40a1e272a7..2f9632e97d 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -23,8 +23,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository @@ -156,14 +155,14 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING else: workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=workflow_triggered_from, ) # Create workflow node execution repository - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -306,16 +305,14 @@ class WorkflowAppGenerator(BaseAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -390,16 +387,14 @@ class WorkflowAppGenerator(BaseAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 07aeb57fa3..3a66ffa578 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -11,7 +11,7 @@ from core.app.entities.app_invoke_entities import ( ) from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey +from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db @@ -95,13 +95,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): files = self.application_generate_entity.files # Create a variable pool. - system_inputs = { - SystemVariableKey.FILES: files, - SystemVariableKey.USER_ID: user_id, - SystemVariableKey.APP_ID: app_config.app_id, - SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, - SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_execution_id, - } + + system_inputs = SystemVariable( + files=files, + user_id=user_id, + app_id=app_config.app_id, + workflow_id=app_config.workflow_id, + workflow_execution_id=self.application_generate_entity.workflow_execution_id, + ) variable_pool = VariablePool( system_variables=system_inputs, diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 2a85cd5e3d..7adc03e9c3 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -3,7 +3,6 @@ import time from collections.abc import Generator from typing import Optional, Union -from sqlalchemy import select from sqlalchemy.orm import Session from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME @@ -55,10 +54,10 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType -from core.workflow.enums import SystemVariableKey from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.system_variable import SystemVariable from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager from extensions.ext_database import db from models.account import Account @@ -68,7 +67,6 @@ from models.workflow import ( Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, - WorkflowRun, ) logger = logging.getLogger(__name__) @@ -109,13 +107,13 @@ class WorkflowAppGenerateTaskPipeline: self._workflow_cycle_manager = WorkflowCycleManager( application_generate_entity=application_generate_entity, - workflow_system_variables={ - SystemVariableKey.FILES: application_generate_entity.files, - SystemVariableKey.USER_ID: user_session_id, - SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, - SystemVariableKey.WORKFLOW_ID: workflow.id, - SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_execution_id, - }, + workflow_system_variables=SystemVariable( + files=application_generate_entity.files, + user_id=user_session_id, + app_id=application_generate_entity.app_config.app_id, + workflow_id=workflow.id, + workflow_execution_id=application_generate_entity.workflow_execution_id, + ), workflow_info=CycleManagerWorkflowInfo( workflow_id=workflow.id, workflow_type=WorkflowType(workflow.type), @@ -562,8 +560,6 @@ class WorkflowAppGenerateTaskPipeline: tts_publisher.publish(None) def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None: - workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_)) - assert workflow_run is not None invoke_from = self._application_generate_entity.invoke_from if invoke_from == InvokeFrom.SERVICE_API: created_from = WorkflowAppLogCreatedFrom.SERVICE_API @@ -576,10 +572,10 @@ class WorkflowAppGenerateTaskPipeline: return workflow_app_log = WorkflowAppLog() - workflow_app_log.tenant_id = workflow_run.tenant_id - workflow_app_log.app_id = workflow_run.app_id - workflow_app_log.workflow_id = workflow_run.workflow_id - workflow_app_log.workflow_run_id = workflow_run.id + workflow_app_log.tenant_id = self._application_generate_entity.app_config.tenant_id + workflow_app_log.app_id = self._application_generate_entity.app_config.app_id + workflow_app_log.workflow_id = workflow_execution.workflow_id + workflow_app_log.workflow_run_id = workflow_execution.id_ workflow_app_log.created_from = created_from.value workflow_app_log.created_by_role = self._created_by_role workflow_app_log.created_by = self._user_id diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 17b9ac5827..2f4d234ecd 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -62,6 +62,7 @@ from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.graph import Graph from core.workflow.nodes import NodeType from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db @@ -166,7 +167,7 @@ class WorkflowBasedAppRunner(AppRunner): # init variable pool variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, environment_variables=workflow.environment_variables, ) @@ -263,7 +264,7 @@ class WorkflowBasedAppRunner(AppRunner): # init variable pool variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, environment_variables=workflow.environment_variables, ) diff --git a/api/core/app/task_pipeline/exc.py b/api/core/app/task_pipeline/exc.py index e4b4168d08..df62776977 100644 --- a/api/core/app/task_pipeline/exc.py +++ b/api/core/app/task_pipeline/exc.py @@ -10,8 +10,3 @@ class RecordNotFoundError(TaskPipilineError): class WorkflowRunNotFoundError(RecordNotFoundError): def __init__(self, workflow_run_id: str): super().__init__("WorkflowRun", workflow_run_id) - - -class WorkflowNodeExecutionNotFoundError(RecordNotFoundError): - def __init__(self, workflow_node_execution_id: str): - super().__init__("WorkflowNodeExecution", workflow_node_execution_id) diff --git a/api/core/entities/parameter_entities.py b/api/core/entities/parameter_entities.py index 2fa347c204..fbd62437e6 100644 --- a/api/core/entities/parameter_entities.py +++ b/api/core/entities/parameter_entities.py @@ -14,6 +14,7 @@ class CommonParameterType(StrEnum): APP_SELECTOR = "app-selector" MODEL_SELECTOR = "model-selector" TOOLS_SELECTOR = "array[tools]" + ANY = "any" # Dynamic select parameter # Once you are not sure about the available options until authorization is done diff --git a/api/core/file/tool_file_parser.py b/api/core/file/tool_file_parser.py index 656c9d48ed..fac68beb0f 100644 --- a/api/core/file/tool_file_parser.py +++ b/api/core/file/tool_file_parser.py @@ -7,13 +7,6 @@ if TYPE_CHECKING: _tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None -class ToolFileParser: - @staticmethod - def get_tool_file_manager() -> "ToolFileManager": - assert _tool_file_manager_factory is not None - return _tool_file_manager_factory() - - def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]) -> None: global _tool_file_manager_factory _tool_file_manager_factory = factory diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index 84f212a9c1..b416e48ce4 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -5,6 +5,8 @@ from base64 import b64encode from collections.abc import Mapping from typing import Any +from core.variables.utils import SegmentJSONEncoder + class TemplateTransformer(ABC): _code_placeholder: str = "{{code}}" @@ -43,17 +45,13 @@ class TemplateTransformer(ABC): result_str = cls.extract_result_str_from_response(response) result = json.loads(result_str) except json.JSONDecodeError as e: - raise ValueError(f"Failed to parse JSON response: {str(e)}. Response content: {result_str[:200]}...") + raise ValueError(f"Failed to parse JSON response: {str(e)}.") except ValueError as e: # Re-raise ValueError from extract_result_str_from_response raise e except Exception as e: raise ValueError(f"Unexpected error during response transformation: {str(e)}") - # Check if the result contains an error - if isinstance(result, dict) and "error" in result: - raise ValueError(f"JavaScript execution error: {result['error']}") - if not isinstance(result, dict): raise ValueError(f"Result must be a dict, got {type(result).__name__}") if not all(isinstance(k, str) for k in result): @@ -95,7 +93,7 @@ class TemplateTransformer(ABC): @classmethod def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str: - inputs_json_str = json.dumps(inputs, ensure_ascii=False).encode() + inputs_json_str = json.dumps(inputs, ensure_ascii=False, cls=SegmentJSONEncoder).encode() input_base64_encoded = b64encode(inputs_json_str).decode("utf-8") return input_base64_encoded diff --git a/api/core/helper/url_signer.py b/api/core/helper/url_signer.py deleted file mode 100644 index dfb143f4c4..0000000000 --- a/api/core/helper/url_signer.py +++ /dev/null @@ -1,52 +0,0 @@ -import base64 -import hashlib -import hmac -import os -import time - -from pydantic import BaseModel, Field - -from configs import dify_config - - -class SignedUrlParams(BaseModel): - sign_key: str = Field(..., description="The sign key") - timestamp: str = Field(..., description="Timestamp") - nonce: str = Field(..., description="Nonce") - sign: str = Field(..., description="Signature") - - -class UrlSigner: - @classmethod - def get_signed_url(cls, url: str, sign_key: str, prefix: str) -> str: - signed_url_params = cls.get_signed_url_params(sign_key, prefix) - return ( - f"{url}?timestamp={signed_url_params.timestamp}" - f"&nonce={signed_url_params.nonce}&sign={signed_url_params.sign}" - ) - - @classmethod - def get_signed_url_params(cls, sign_key: str, prefix: str) -> SignedUrlParams: - timestamp = str(int(time.time())) - nonce = os.urandom(16).hex() - sign = cls._sign(sign_key, timestamp, nonce, prefix) - - return SignedUrlParams(sign_key=sign_key, timestamp=timestamp, nonce=nonce, sign=sign) - - @classmethod - def verify(cls, sign_key: str, timestamp: str, nonce: str, sign: str, prefix: str) -> bool: - recalculated_sign = cls._sign(sign_key, timestamp, nonce, prefix) - - return sign == recalculated_sign - - @classmethod - def _sign(cls, sign_key: str, timestamp: str, nonce: str, prefix: str) -> str: - if not dify_config.SECRET_KEY: - raise Exception("SECRET_KEY is not set") - - data_to_sign = f"{prefix}|{sign_key}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() - sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - encoded_sign = base64.urlsafe_b64encode(sign).decode() - - return encoded_sign diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index e01896a491..f7fd93be4a 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -148,9 +148,11 @@ class LLMGenerator: model_manager = ModelManager() - model_instance = model_manager.get_default_model_instance( + model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), ) try: diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index b63478e822..bcb31a816f 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -240,7 +240,7 @@ def refresh_authorization( response = requests.post(token_url, data=params) if not response.ok: raise ValueError(f"Token refresh failed: HTTP {response.status_code}") - return OAuthTokens.parse_obj(response.json()) + return OAuthTokens.model_validate(response.json()) def register_client( diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index 1c2cf570e2..20ff7e7524 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -148,9 +148,7 @@ class MCPServerStreamableHTTPRequestHandler: if not self.end_user: raise ValueError("User not found") request = cast(types.CallToolRequest, self.request.root) - args = request.params.arguments - if not args: - raise ValueError("No arguments provided") + args = request.params.arguments or {} if self.app.mode in {AppMode.WORKFLOW.value}: args = {"inputs": args} elif self.app.mode in {AppMode.COMPLETION.value}: diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index 1c0f582501..7734b8fdd9 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -1,7 +1,7 @@ import logging import queue from collections.abc import Callable -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError from contextlib import ExitStack from datetime import timedelta from types import TracebackType @@ -171,23 +171,41 @@ class BaseSession( self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} self._exit_stack = ExitStack() + # Initialize executor and future to None for proper cleanup checks + self._executor: ThreadPoolExecutor | None = None + self._receiver_future: Future | None = None def __enter__(self) -> Self: - self._executor = ThreadPoolExecutor() + # The thread pool is dedicated to running `_receive_loop`. Setting `max_workers` to 1 + # ensures no unnecessary threads are created. + self._executor = ThreadPoolExecutor(max_workers=1) self._receiver_future = self._executor.submit(self._receive_loop) return self def check_receiver_status(self) -> None: - if self._receiver_future.done(): + """`check_receiver_status` ensures that any exceptions raised during the + execution of `_receive_loop` are retrieved and propagated.""" + if self._receiver_future and self._receiver_future.done(): self._receiver_future.result() def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None ) -> None: - self._exit_stack.close() self._read_stream.put(None) self._write_stream.put(None) + # Wait for the receiver loop to finish + if self._receiver_future: + try: + self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds + except TimeoutError: + # If the receiver loop is still running after timeout, we'll force shutdown + pass + + # Shutdown the executor + if self._executor: + self._executor.shutdown(wait=True) + def send_request( self, request: SendRequestT, diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 2254b3d4d5..a9f0a92e5d 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,6 +1,8 @@ from collections.abc import Sequence from typing import Optional +from sqlalchemy import select + from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.file import file_manager from core.model_manager import ModelInstance @@ -17,11 +19,15 @@ from core.prompt.utils.extract_thread_messages import extract_thread_messages from extensions.ext_database import db from factories import file_factory from models.model import AppMode, Conversation, Message, MessageFile -from models.workflow import WorkflowRun +from models.workflow import Workflow, WorkflowRun class TokenBufferMemory: - def __init__(self, conversation: Conversation, model_instance: ModelInstance) -> None: + def __init__( + self, + conversation: Conversation, + model_instance: ModelInstance, + ) -> None: self.conversation = conversation self.model_instance = model_instance @@ -36,20 +42,8 @@ class TokenBufferMemory: app_record = self.conversation.app # fetch limited messages, and return reversed - query = ( - db.session.query( - Message.id, - Message.query, - Message.answer, - Message.created_at, - Message.workflow_run_id, - Message.parent_message_id, - Message.answer_tokens, - ) - .filter( - Message.conversation_id == self.conversation.id, - ) - .order_by(Message.created_at.desc()) + stmt = ( + select(Message).where(Message.conversation_id == self.conversation.id).order_by(Message.created_at.desc()) ) if message_limit and message_limit > 0: @@ -57,7 +51,9 @@ class TokenBufferMemory: else: message_limit = 500 - messages = query.limit(message_limit).all() + stmt = stmt.limit(message_limit) + + messages = db.session.scalars(stmt).all() # instead of all messages from the conversation, we only need to extract messages # that belong to the thread of last message @@ -74,18 +70,20 @@ class TokenBufferMemory: files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() if files: file_extra_config = None - if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}: file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) + elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + workflow_run = db.session.scalar( + select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id) + ) + if not workflow_run: + raise ValueError(f"Workflow run not found: {message.workflow_run_id}") + workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id)) + if not workflow: + raise ValueError(f"Workflow not found: {workflow_run.workflow_id}") + file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) else: - if message.workflow_run_id: - workflow_run = ( - db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first() - ) - - if workflow_run and workflow_run.workflow: - file_extra_config = FileUploadConfigManager.convert( - workflow_run.workflow.features_dict, is_vision=False - ) + raise AssertionError(f"Invalid app mode: {self.conversation.mode}") detail = ImagePromptMessageContent.DETAIL.LOW if file_extra_config and app_record: diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index b18a6905fe..db8fec4ee9 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -284,7 +284,8 @@ class AliyunDataTrace(BaseTraceInstance): else: node_span = self.build_workflow_task_span(trace_id, workflow_span_id, trace_info, node_execution) return node_span - except Exception: + except Exception as e: + logging.debug(f"Error occurred in build_workflow_node_span: {e}", exc_info=True) return None def get_workflow_node_status(self, node_execution: WorkflowNodeExecution) -> Status: @@ -306,7 +307,7 @@ class AliyunDataTrace(BaseTraceInstance): start_time=convert_datetime_to_nanoseconds(node_execution.created_at), end_time=convert_datetime_to_nanoseconds(node_execution.finished_at), attributes={ - GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""), + GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "", GEN_AI_SPAN_KIND: GenAISpanKind.TASK.value, GEN_AI_FRAMEWORK: "dify", INPUT_VALUE: json.dumps(node_execution.inputs, ensure_ascii=False), @@ -381,7 +382,7 @@ class AliyunDataTrace(BaseTraceInstance): start_time=convert_datetime_to_nanoseconds(node_execution.created_at), end_time=convert_datetime_to_nanoseconds(node_execution.finished_at), attributes={ - GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""), + GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "", GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value, GEN_AI_FRAMEWORK: "dify", GEN_AI_MODEL_NAME: process_data.get("model_name", ""), @@ -415,7 +416,7 @@ class AliyunDataTrace(BaseTraceInstance): start_time=convert_datetime_to_nanoseconds(trace_info.start_time), end_time=convert_datetime_to_nanoseconds(trace_info.end_time), attributes={ - GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""), + GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "", GEN_AI_USER_ID: str(user_id), GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value, GEN_AI_FRAMEWORK: "dify", diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index a3dbce0e59..4a7e66d27c 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -28,7 +28,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( UnitEnum, ) from core.ops.utils import filter_none_values -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.nodes.enums import NodeType from extensions.ext_database import db from models import EndUser, WorkflowNodeExecutionTriggeredFrom @@ -123,10 +123,10 @@ class LangFuseDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, - app_id=trace_info.metadata.get("app_id"), + app_id=app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index f94e5e49d7..8a559c4929 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -27,7 +27,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( LangSmithRunUpdateModel, ) from core.ops.utils import filter_none_values, generate_dotted_order -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db @@ -145,10 +145,10 @@ class LangSmithDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, - app_id=trace_info.metadata.get("app_id"), + app_id=app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index 8bedea20fb..be4997a5bf 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -21,7 +21,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db @@ -160,10 +160,10 @@ class OpikDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, - app_id=trace_info.metadata.get("app_id"), + app_id=app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) @@ -241,7 +241,7 @@ class OpikDataTrace(BaseTraceInstance): "trace_id": opik_trace_id, "id": prepare_opik_uuid(created_at, node_execution_id), "parent_span_id": prepare_opik_uuid(trace_info.start_time, parent_span_id), - "name": node_type, + "name": node_name, "type": run_type, "start_time": created_at, "end_time": finished_at, diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 3917348a91..445c6a8741 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -22,7 +22,7 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db @@ -144,10 +144,10 @@ class WeaveDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, - app_id=trace_info.metadata.get("app_id"), + app_id=app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) diff --git a/api/core/plugin/entities/parameters.py b/api/core/plugin/entities/parameters.py index 2be65d67a0..47290ee613 100644 --- a/api/core/plugin/entities/parameters.py +++ b/api/core/plugin/entities/parameters.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, Field, field_validator from core.entities.parameter_entities import CommonParameterType from core.tools.entities.common_entities import I18nObject +from core.workflow.nodes.base.entities import NumberType class PluginParameterOption(BaseModel): @@ -38,6 +39,7 @@ class PluginParameterType(enum.StrEnum): APP_SELECTOR = CommonParameterType.APP_SELECTOR.value MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value + ANY = CommonParameterType.ANY.value DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT.value # deprecated, should not use. @@ -151,6 +153,10 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /): if value and not isinstance(value, list): raise ValueError("The tools selector must be a list.") return value + case PluginParameterType.ANY: + if value and not isinstance(value, str | dict | list | NumberType): + raise ValueError("The var selector must be a string, dictionary, list or number.") + return value case PluginParameterType.ARRAY: if not isinstance(value, list): # Try to parse JSON string for arrays diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index 871fbb9624..de368f7ccd 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -141,17 +141,6 @@ class PluginEntity(PluginInstallation): return self -class GithubPackage(BaseModel): - repo: str - version: str - package: str - - -class GithubVersion(BaseModel): - repo: str - version: str - - class GenericProviderID: organization: str plugin_name: str diff --git a/api/core/plugin/impl/plugin.py b/api/core/plugin/impl/plugin.py index b7f7b31655..04ac8c9649 100644 --- a/api/core/plugin/impl/plugin.py +++ b/api/core/plugin/impl/plugin.py @@ -36,7 +36,7 @@ class PluginInstaller(BasePluginClient): "GET", f"plugin/{tenant_id}/management/list", PluginListResponse, - params={"page": 1, "page_size": 256}, + params={"page": 1, "page_size": 256, "response_type": "paged"}, ) return result.list @@ -45,7 +45,7 @@ class PluginInstaller(BasePluginClient): "GET", f"plugin/{tenant_id}/management/list", PluginListResponse, - params={"page": page, "page_size": page_size}, + params={"page": page, "page_size": page_size, "response_type": "paged"}, ) def upload_pkg( diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 25964ae063..0f0fe65f27 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -158,7 +158,7 @@ class AdvancedPromptTransform(PromptTransform): if prompt_item.edition_type == "basic" or not prompt_item.edition_type: if self.with_variable_tmpl: - vp = VariablePool() + vp = VariablePool.empty() for k, v in inputs.items(): if k.startswith("#"): vp.add(k[1:-1].split("."), v) diff --git a/api/core/prompt/utils/extract_thread_messages.py b/api/core/prompt/utils/extract_thread_messages.py index f7aef76c87..4b883622a7 100644 --- a/api/core/prompt/utils/extract_thread_messages.py +++ b/api/core/prompt/utils/extract_thread_messages.py @@ -1,10 +1,11 @@ -from typing import Any +from collections.abc import Sequence from constants import UUID_NIL +from models import Message -def extract_thread_messages(messages: list[Any]): - thread_messages = [] +def extract_thread_messages(messages: Sequence[Message]): + thread_messages: list[Message] = [] next_message = None for message in messages: diff --git a/api/core/prompt/utils/get_thread_messages_length.py b/api/core/prompt/utils/get_thread_messages_length.py index f49466db6d..de64c27a73 100644 --- a/api/core/prompt/utils/get_thread_messages_length.py +++ b/api/core/prompt/utils/get_thread_messages_length.py @@ -1,3 +1,5 @@ +from sqlalchemy import select + from core.prompt.utils.extract_thread_messages import extract_thread_messages from extensions.ext_database import db from models.model import Message @@ -8,19 +10,9 @@ def get_thread_messages_length(conversation_id: str) -> int: Get the number of thread messages based on the parent message id. """ # Fetch all messages related to the conversation - query = ( - db.session.query( - Message.id, - Message.parent_message_id, - Message.answer, - ) - .filter( - Message.conversation_id == conversation_id, - ) - .order_by(Message.created_at.desc()) - ) + stmt = select(Message).where(Message.conversation_id == conversation_id).order_by(Message.created_at.desc()) - messages = query.all() + messages = db.session.scalars(stmt).all() # Extract thread messages thread_messages = extract_thread_messages(messages) diff --git a/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py deleted file mode 100644 index 167a919e69..0000000000 --- a/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Abstract interface for document clean implementations.""" - -from core.rag.cleaner.cleaner_base import BaseCleaner - - -class UnstructuredNonAsciiCharsCleaner(BaseCleaner): - def clean(self, content) -> str: - """clean document content.""" - from unstructured.cleaners.core import clean_extra_whitespace - - # Returns "ITEM 1A: RISK FACTORS" - return clean_extra_whitespace(content) diff --git a/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py deleted file mode 100644 index 9c682d29db..0000000000 --- a/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Abstract interface for document clean implementations.""" - -from core.rag.cleaner.cleaner_base import BaseCleaner - - -class UnstructuredGroupBrokenParagraphsCleaner(BaseCleaner): - def clean(self, content) -> str: - """clean document content.""" - import re - - from unstructured.cleaners.core import group_broken_paragraphs - - para_split_re = re.compile(r"(\s*\n\s*){3}") - - return group_broken_paragraphs(content, paragraph_split=para_split_re) diff --git a/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py deleted file mode 100644 index 0cdbb171e1..0000000000 --- a/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Abstract interface for document clean implementations.""" - -from core.rag.cleaner.cleaner_base import BaseCleaner - - -class UnstructuredNonAsciiCharsCleaner(BaseCleaner): - def clean(self, content) -> str: - """clean document content.""" - from unstructured.cleaners.core import clean_non_ascii_chars - - # Returns "This text contains non-ascii characters!" - return clean_non_ascii_chars(content) diff --git a/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py deleted file mode 100644 index 9f42044a2d..0000000000 --- a/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Abstract interface for document clean implementations.""" - -from core.rag.cleaner.cleaner_base import BaseCleaner - - -class UnstructuredNonAsciiCharsCleaner(BaseCleaner): - def clean(self, content) -> str: - """Replaces unicode quote characters, such as the \x91 character in a string.""" - - from unstructured.cleaners.core import replace_unicode_quotes - - return replace_unicode_quotes(content) diff --git a/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py deleted file mode 100644 index 32ae7217e8..0000000000 --- a/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Abstract interface for document clean implementations.""" - -from core.rag.cleaner.cleaner_base import BaseCleaner - - -class UnstructuredTranslateTextCleaner(BaseCleaner): - def clean(self, content) -> str: - """clean document content.""" - from unstructured.cleaners.translate import translate_text - - return translate_text(content) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 2c5178241c..5a6903d3d5 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -3,7 +3,7 @@ from concurrent.futures import ThreadPoolExecutor from typing import Optional from flask import Flask, current_app -from sqlalchemy.orm import load_only +from sqlalchemy.orm import Session, load_only from configs import dify_config from core.rag.data_post_processor.data_post_processor import DataPostProcessor @@ -144,7 +144,8 @@ class RetrievalService: @classmethod def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]: - return db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + with Session(db.engine) as session: + return session.query(Dataset).filter(Dataset.id == dataset_id).first() @classmethod def keyword_search( diff --git a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py index a124faa503..552068c99e 100644 --- a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py +++ b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py @@ -4,6 +4,7 @@ from typing import Any, Optional import tablestore # type: ignore from pydantic import BaseModel, model_validator +from tablestore import BatchGetRowRequest, TableInBatchGetRowItem from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -50,6 +51,29 @@ class TableStoreVector(BaseVector): self._index_name = f"{collection_name}_idx" self._tags_field = f"{Field.METADATA_KEY.value}_tags" + def create_collection(self, embeddings: list[list[float]], **kwargs): + dimension = len(embeddings[0]) + self._create_collection(dimension) + + def get_by_ids(self, ids: list[str]) -> list[Document]: + docs = [] + request = BatchGetRowRequest() + columns_to_get = [Field.METADATA_KEY.value, Field.CONTENT_KEY.value] + rows_to_get = [[("id", _id)] for _id in ids] + request.add(TableInBatchGetRowItem(self._table_name, rows_to_get, columns_to_get, None, 1)) + + result = self._tablestore_client.batch_get_row(request) + table_result = result.get_result_by_table(self._table_name) + for item in table_result: + if item.is_ok and item.row: + kv = {k: v for k, v, t in item.row.attribute_columns} + docs.append( + Document( + page_content=kv[Field.CONTENT_KEY.value], metadata=json.loads(kv[Field.METADATA_KEY.value]) + ) + ) + return docs + def get_type(self) -> str: return VectorType.TABLESTORE diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py deleted file mode 100644 index 1e62b3c589..0000000000 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel - - -class ClusterEntity(BaseModel): - """ - Model Config Entity. - """ - - name: str - cluster_id: str - displayName: str - region: str - spendingLimit: Optional[int] = 1000 - version: str - createdBy: str diff --git a/api/core/rag/extractor/blob/blob.py b/api/core/rag/extractor/blob/blob.py index e46ab8b7fd..01003a13b6 100644 --- a/api/core/rag/extractor/blob/blob.py +++ b/api/core/rag/extractor/blob/blob.py @@ -9,8 +9,7 @@ from __future__ import annotations import contextlib import mimetypes -from abc import ABC, abstractmethod -from collections.abc import Generator, Iterable, Mapping +from collections.abc import Generator, Mapping from io import BufferedReader, BytesIO from pathlib import Path, PurePath from typing import Any, Optional, Union @@ -143,21 +142,3 @@ class Blob(BaseModel): if self.source: str_repr += f" {self.source}" return str_repr - - -class BlobLoader(ABC): - """Abstract interface for blob loaders implementation. - - Implementer should be able to load raw content from a datasource system according - to some criteria and return the raw content lazily as a stream of blobs. - """ - - @abstractmethod - def yield_blobs( - self, - ) -> Iterable[Blob]: - """A lazy loader for raw data represented by Blob object. - - Returns: - A generator over blobs - """ diff --git a/api/core/rag/extractor/unstructured/unstructured_pdf_extractor.py b/api/core/rag/extractor/unstructured/unstructured_pdf_extractor.py deleted file mode 100644 index dd8a979e70..0000000000 --- a/api/core/rag/extractor/unstructured/unstructured_pdf_extractor.py +++ /dev/null @@ -1,47 +0,0 @@ -import logging - -from core.rag.extractor.extractor_base import BaseExtractor -from core.rag.models.document import Document - -logger = logging.getLogger(__name__) - - -class UnstructuredPDFExtractor(BaseExtractor): - """Load pdf files. - - - Args: - file_path: Path to the file to load. - - api_url: Unstructured API URL - - api_key: Unstructured API Key - """ - - def __init__(self, file_path: str, api_url: str, api_key: str): - """Initialize with file path.""" - self._file_path = file_path - self._api_url = api_url - self._api_key = api_key - - def extract(self) -> list[Document]: - if self._api_url: - from unstructured.partition.api import partition_via_api - - elements = partition_via_api( - filename=self._file_path, api_url=self._api_url, api_key=self._api_key, strategy="auto" - ) - else: - from unstructured.partition.pdf import partition_pdf - - elements = partition_pdf(filename=self._file_path, strategy="auto") - - from unstructured.chunking.title import chunk_by_title - - chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) - documents = [] - for chunk in chunks: - text = chunk.text.strip() - documents.append(Document(page_content=text)) - - return documents diff --git a/api/core/rag/extractor/unstructured/unstructured_text_extractor.py b/api/core/rag/extractor/unstructured/unstructured_text_extractor.py deleted file mode 100644 index 22dfdd2075..0000000000 --- a/api/core/rag/extractor/unstructured/unstructured_text_extractor.py +++ /dev/null @@ -1,34 +0,0 @@ -import logging - -from core.rag.extractor.extractor_base import BaseExtractor -from core.rag.models.document import Document - -logger = logging.getLogger(__name__) - - -class UnstructuredTextExtractor(BaseExtractor): - """Load msg files. - - - Args: - file_path: Path to the file to load. - """ - - def __init__(self, file_path: str, api_url: str): - """Initialize with file path.""" - self._file_path = file_path - self._api_url = api_url - - def extract(self) -> list[Document]: - from unstructured.partition.text import partition_text - - elements = partition_text(filename=self._file_path) - from unstructured.chunking.title import chunk_by_title - - chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) - documents = [] - for chunk in chunks: - text = chunk.text.strip() - documents.append(Document(page_content=text)) - - return documents diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 3fca48be22..5c0360b064 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -9,6 +9,7 @@ from typing import Any, Optional, Union, cast from flask import Flask, current_app from sqlalchemy import Float, and_, or_, text from sqlalchemy import cast as sqlalchemy_cast +from sqlalchemy.orm import Session from core.app.app_config.entities import ( DatasetEntity, @@ -598,7 +599,8 @@ class DatasetRetrieval: metadata_condition: Optional[MetadataCondition] = None, ): with flask_app.app_context(): - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + with Session(db.engine) as session: + dataset = session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: return [] diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index b711e8434a..529d8ccd27 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -10,7 +10,6 @@ from typing import ( Any, Literal, Optional, - TypedDict, TypeVar, Union, ) @@ -168,167 +167,6 @@ class TextSplitter(BaseDocumentTransformer, ABC): raise NotImplementedError -class CharacterTextSplitter(TextSplitter): - """Splitting text that looks at characters.""" - - def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None: - """Create a new TextSplitter.""" - super().__init__(**kwargs) - self._separator = separator - - def split_text(self, text: str) -> list[str]: - """Split incoming text and return chunks.""" - # First we naively split the large input into a bunch of smaller ones. - splits = _split_text_with_regex(text, self._separator, self._keep_separator) - _separator = "" if self._keep_separator else self._separator - _good_splits_lengths = [] # cache the lengths of the splits - if splits: - _good_splits_lengths.extend(self._length_function(splits)) - return self._merge_splits(splits, _separator, _good_splits_lengths) - - -class LineType(TypedDict): - """Line type as typed dict.""" - - metadata: dict[str, str] - content: str - - -class HeaderType(TypedDict): - """Header type as typed dict.""" - - level: int - name: str - data: str - - -class MarkdownHeaderTextSplitter: - """Splitting markdown files based on specified headers.""" - - def __init__(self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False): - """Create a new MarkdownHeaderTextSplitter. - - Args: - headers_to_split_on: Headers we want to track - return_each_line: Return each line w/ associated headers - """ - # Output line-by-line or aggregated into chunks w/ common headers - self.return_each_line = return_each_line - # Given the headers we want to split on, - # (e.g., "#, ##, etc") order by length - self.headers_to_split_on = sorted(headers_to_split_on, key=lambda split: len(split[0]), reverse=True) - - def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]: - """Combine lines with common metadata into chunks - Args: - lines: Line of text / associated header metadata - """ - aggregated_chunks: list[LineType] = [] - - for line in lines: - if aggregated_chunks and aggregated_chunks[-1]["metadata"] == line["metadata"]: - # If the last line in the aggregated list - # has the same metadata as the current line, - # append the current content to the last lines's content - aggregated_chunks[-1]["content"] += " \n" + line["content"] - else: - # Otherwise, append the current line to the aggregated list - aggregated_chunks.append(line) - - return [Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in aggregated_chunks] - - def split_text(self, text: str) -> list[Document]: - """Split markdown file - Args: - text: Markdown file""" - - # Split the input text by newline character ("\n"). - lines = text.split("\n") - # Final output - lines_with_metadata: list[LineType] = [] - # Content and metadata of the chunk currently being processed - current_content: list[str] = [] - current_metadata: dict[str, str] = {} - # Keep track of the nested header structure - # header_stack: List[Dict[str, Union[int, str]]] = [] - header_stack: list[HeaderType] = [] - initial_metadata: dict[str, str] = {} - - for line in lines: - stripped_line = line.strip() - # Check each line against each of the header types (e.g., #, ##) - for sep, name in self.headers_to_split_on: - # Check if line starts with a header that we intend to split on - if stripped_line.startswith(sep) and ( - # Header with no text OR header is followed by space - # Both are valid conditions that sep is being used a header - len(stripped_line) == len(sep) or stripped_line[len(sep)] == " " - ): - # Ensure we are tracking the header as metadata - if name is not None: - # Get the current header level - current_header_level = sep.count("#") - - # Pop out headers of lower or same level from the stack - while header_stack and header_stack[-1]["level"] >= current_header_level: - # We have encountered a new header - # at the same or higher level - popped_header = header_stack.pop() - # Clear the metadata for the - # popped header in initial_metadata - if popped_header["name"] in initial_metadata: - initial_metadata.pop(popped_header["name"]) - - # Push the current header to the stack - header: HeaderType = { - "level": current_header_level, - "name": name, - "data": stripped_line[len(sep) :].strip(), - } - header_stack.append(header) - # Update initial_metadata with the current header - initial_metadata[name] = header["data"] - - # Add the previous line to the lines_with_metadata - # only if current_content is not empty - if current_content: - lines_with_metadata.append( - { - "content": "\n".join(current_content), - "metadata": current_metadata.copy(), - } - ) - current_content.clear() - - break - else: - if stripped_line: - current_content.append(stripped_line) - elif current_content: - lines_with_metadata.append( - { - "content": "\n".join(current_content), - "metadata": current_metadata.copy(), - } - ) - current_content.clear() - - current_metadata = initial_metadata.copy() - - if current_content: - lines_with_metadata.append({"content": "\n".join(current_content), "metadata": current_metadata}) - - # lines_with_metadata has each line with associated header metadata - # aggregate these into chunks based on common metadata - if not self.return_each_line: - return self.aggregate_lines_to_chunks(lines_with_metadata) - else: - return [ - Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in lines_with_metadata - ] - - -# should be in newer Python versions (3.10+) # @dataclass(frozen=True, kw_only=True, slots=True) @dataclass(frozen=True) class Tokenizer: diff --git a/api/core/repositories/__init__.py b/api/core/repositories/__init__.py index 6452317120..052ba1c2cb 100644 --- a/api/core/repositories/__init__.py +++ b/api/core/repositories/__init__.py @@ -5,8 +5,11 @@ This package contains concrete implementations of the repository interfaces defined in the core.workflow.repository package. """ +from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository __all__ = [ + "DifyCoreRepositoryFactory", + "RepositoryImportError", "SQLAlchemyWorkflowNodeExecutionRepository", ] diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py new file mode 100644 index 0000000000..4118aa61c7 --- /dev/null +++ b/api/core/repositories/factory.py @@ -0,0 +1,224 @@ +""" +Repository factory for dynamically creating repository instances based on configuration. + +This module provides a Django-like settings system for repository implementations, +allowing users to configure different repository backends through string paths. +""" + +import importlib +import inspect +import logging +from typing import Protocol, Union + +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker + +from configs import dify_config +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from models import Account, EndUser +from models.enums import WorkflowRunTriggeredFrom +from models.workflow import WorkflowNodeExecutionTriggeredFrom + +logger = logging.getLogger(__name__) + + +class RepositoryImportError(Exception): + """Raised when a repository implementation cannot be imported or instantiated.""" + + pass + + +class DifyCoreRepositoryFactory: + """ + Factory for creating repository instances based on configuration. + + This factory supports Django-like settings where repository implementations + are specified as module paths (e.g., 'module.submodule.ClassName'). + """ + + @staticmethod + def _import_class(class_path: str) -> type: + """ + Import a class from a module path string. + + Args: + class_path: Full module path to the class (e.g., 'module.submodule.ClassName') + + Returns: + The imported class + + Raises: + RepositoryImportError: If the class cannot be imported + """ + try: + module_path, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_path) + repo_class = getattr(module, class_name) + assert isinstance(repo_class, type) + return repo_class + except (ValueError, ImportError, AttributeError) as e: + raise RepositoryImportError(f"Cannot import repository class '{class_path}': {e}") from e + + @staticmethod + def _validate_repository_interface(repository_class: type, expected_interface: type[Protocol]) -> None: # type: ignore + """ + Validate that a class implements the expected repository interface. + + Args: + repository_class: The class to validate + expected_interface: The expected interface/protocol + + Raises: + RepositoryImportError: If the class doesn't implement the interface + """ + # Check if the class has all required methods from the protocol + required_methods = [ + method + for method in dir(expected_interface) + if not method.startswith("_") and callable(getattr(expected_interface, method, None)) + ] + + missing_methods = [] + for method_name in required_methods: + if not hasattr(repository_class, method_name): + missing_methods.append(method_name) + + if missing_methods: + raise RepositoryImportError( + f"Repository class '{repository_class.__name__}' does not implement required methods " + f"{missing_methods} from interface '{expected_interface.__name__}'" + ) + + @staticmethod + def _validate_constructor_signature(repository_class: type, required_params: list[str]) -> None: + """ + Validate that a repository class constructor accepts required parameters. + + Args: + repository_class: The class to validate + required_params: List of required parameter names + + Raises: + RepositoryImportError: If the constructor doesn't accept required parameters + """ + + try: + # MyPy may flag the line below with the following error: + # + # > Accessing "__init__" on an instance is unsound, since + # > instance.__init__ could be from an incompatible subclass. + # + # Despite this, we need to ensure that the constructor of `repository_class` + # has a compatible signature. + signature = inspect.signature(repository_class.__init__) # type: ignore[misc] + param_names = list(signature.parameters.keys()) + + # Remove 'self' parameter + if "self" in param_names: + param_names.remove("self") + + missing_params = [param for param in required_params if param not in param_names] + if missing_params: + raise RepositoryImportError( + f"Repository class '{repository_class.__name__}' constructor does not accept required parameters: " + f"{missing_params}. Expected parameters: {required_params}" + ) + except Exception as e: + raise RepositoryImportError( + f"Failed to validate constructor signature for '{repository_class.__name__}': {e}" + ) from e + + @classmethod + def create_workflow_execution_repository( + cls, + session_factory: Union[sessionmaker, Engine], + user: Union[Account, EndUser], + app_id: str, + triggered_from: WorkflowRunTriggeredFrom, + ) -> WorkflowExecutionRepository: + """ + Create a WorkflowExecutionRepository instance based on configuration. + + Args: + session_factory: SQLAlchemy sessionmaker or engine + user: Account or EndUser object + app_id: Application ID + triggered_from: Source of the execution trigger + + Returns: + Configured WorkflowExecutionRepository instance + + Raises: + RepositoryImportError: If the configured repository cannot be created + """ + class_path = dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY + logger.debug(f"Creating WorkflowExecutionRepository from: {class_path}") + + try: + repository_class = cls._import_class(class_path) + cls._validate_repository_interface(repository_class, WorkflowExecutionRepository) + cls._validate_constructor_signature( + repository_class, ["session_factory", "user", "app_id", "triggered_from"] + ) + + return repository_class( # type: ignore[no-any-return] + session_factory=session_factory, + user=user, + app_id=app_id, + triggered_from=triggered_from, + ) + except RepositoryImportError: + # Re-raise our custom errors as-is + raise + except Exception as e: + logger.exception("Failed to create WorkflowExecutionRepository") + raise RepositoryImportError(f"Failed to create WorkflowExecutionRepository from '{class_path}': {e}") from e + + @classmethod + def create_workflow_node_execution_repository( + cls, + session_factory: Union[sessionmaker, Engine], + user: Union[Account, EndUser], + app_id: str, + triggered_from: WorkflowNodeExecutionTriggeredFrom, + ) -> WorkflowNodeExecutionRepository: + """ + Create a WorkflowNodeExecutionRepository instance based on configuration. + + Args: + session_factory: SQLAlchemy sessionmaker or engine + user: Account or EndUser object + app_id: Application ID + triggered_from: Source of the execution trigger + + Returns: + Configured WorkflowNodeExecutionRepository instance + + Raises: + RepositoryImportError: If the configured repository cannot be created + """ + class_path = dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY + logger.debug(f"Creating WorkflowNodeExecutionRepository from: {class_path}") + + try: + repository_class = cls._import_class(class_path) + cls._validate_repository_interface(repository_class, WorkflowNodeExecutionRepository) + cls._validate_constructor_signature( + repository_class, ["session_factory", "user", "app_id", "triggered_from"] + ) + + return repository_class( # type: ignore[no-any-return] + session_factory=session_factory, + user=user, + app_id=app_id, + triggered_from=triggered_from, + ) + except RepositoryImportError: + # Re-raise our custom errors as-is + raise + except Exception as e: + logger.exception("Failed to create WorkflowNodeExecutionRepository") + raise RepositoryImportError( + f"Failed to create WorkflowNodeExecutionRepository from '{class_path}': {e}" + ) from e diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index b5148e245f..64568a8eda 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -16,6 +16,7 @@ from core.plugin.entities.parameters import ( cast_parameter_value, init_frontend_parameter, ) +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.tools.entities.common_entities import I18nObject from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY @@ -179,6 +180,10 @@ class ToolInvokeMessage(BaseModel): data: Mapping[str, Any] = Field(..., description="Detailed log data") metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log") + class RetrieverResourceMessage(BaseModel): + retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources") + context: str = Field(..., description="context") + class MessageType(Enum): TEXT = "text" IMAGE = "image" @@ -191,13 +196,22 @@ class ToolInvokeMessage(BaseModel): FILE = "file" LOG = "log" BLOB_CHUNK = "blob_chunk" + RETRIEVER_RESOURCES = "retriever_resources" type: MessageType = MessageType.TEXT """ plain text, image url or link url """ message: ( - JsonMessage | TextMessage | BlobChunkMessage | BlobMessage | LogMessage | FileMessage | None | VariableMessage + JsonMessage + | TextMessage + | BlobChunkMessage + | BlobMessage + | LogMessage + | FileMessage + | None + | VariableMessage + | RetrieverResourceMessage ) meta: dict[str, Any] | None = None @@ -243,6 +257,7 @@ class ToolParameter(PluginParameter): FILES = PluginParameterType.FILES.value APP_SELECTOR = PluginParameterType.APP_SELECTOR.value MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value + ANY = PluginParameterType.ANY.value DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value # MCP object and array type parameters diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 3f844e8234..a3c84615ca 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -1,5 +1,4 @@ import re -import uuid from json import dumps as json_dumps from json import loads as json_loads from json.decoder import JSONDecodeError @@ -154,7 +153,7 @@ class ApiBasedToolSchemaParser: # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ path = re.sub(r"[^a-zA-Z0-9_-]", "", path) if not path: - path = str(uuid.uuid4()) + path = "" interface["operation"]["operationId"] = f"{path}_{interface['method']}" diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 6cf09e0372..13274f4e0e 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -1,9 +1,9 @@ import json import sys from collections.abc import Mapping, Sequence -from typing import Any +from typing import Annotated, Any, TypeAlias -from pydantic import BaseModel, ConfigDict, field_validator +from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator from core.file import File @@ -11,6 +11,11 @@ from .types import SegmentType class Segment(BaseModel): + """Segment is runtime type used during the execution of workflow. + + Note: this class is abstract, you should use subclasses of this class instead. + """ + model_config = ConfigDict(frozen=True) value_type: SegmentType @@ -73,7 +78,7 @@ class StringSegment(Segment): class FloatSegment(Segment): - value_type: SegmentType = SegmentType.NUMBER + value_type: SegmentType = SegmentType.FLOAT value: float # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. # The following tests cannot pass. @@ -92,7 +97,7 @@ class FloatSegment(Segment): class IntegerSegment(Segment): - value_type: SegmentType = SegmentType.NUMBER + value_type: SegmentType = SegmentType.INTEGER value: int @@ -181,3 +186,46 @@ class ArrayFileSegment(ArraySegment): @property def text(self) -> str: return "" + + +def get_segment_discriminator(v: Any) -> SegmentType | None: + if isinstance(v, Segment): + return v.value_type + elif isinstance(v, dict): + value_type = v.get("value_type") + if value_type is None: + return None + try: + seg_type = SegmentType(value_type) + except ValueError: + return None + return seg_type + else: + # return None if the discriminator value isn't found + return None + + +# The `SegmentUnion`` type is used to enable serialization and deserialization with Pydantic. +# Use `Segment` for type hinting when serialization is not required. +# +# Note: +# - All variants in `SegmentUnion` must inherit from the `Segment` class. +# - The union must include all non-abstract subclasses of `Segment`, except: +# - `SegmentGroup`, which is not added to the variable pool. +# - `Variable` and its subclasses, which are handled by `VariableUnion`. +SegmentUnion: TypeAlias = Annotated[ + ( + Annotated[NoneSegment, Tag(SegmentType.NONE)] + | Annotated[StringSegment, Tag(SegmentType.STRING)] + | Annotated[FloatSegment, Tag(SegmentType.FLOAT)] + | Annotated[IntegerSegment, Tag(SegmentType.INTEGER)] + | Annotated[ObjectSegment, Tag(SegmentType.OBJECT)] + | Annotated[FileSegment, Tag(SegmentType.FILE)] + | Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)] + | Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)] + | Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)] + | Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)] + | Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)] + ), + Discriminator(get_segment_discriminator), +] diff --git a/api/core/variables/types.py b/api/core/variables/types.py index 68d3d82883..e79b2410bf 100644 --- a/api/core/variables/types.py +++ b/api/core/variables/types.py @@ -1,8 +1,27 @@ +from collections.abc import Mapping from enum import StrEnum +from typing import Any, Optional + +from core.file.models import File + + +class ArrayValidation(StrEnum): + """Strategy for validating array elements""" + + # Skip element validation (only check array container) + NONE = "none" + + # Validate the first element (if array is non-empty) + FIRST = "first" + + # Validate all elements in the array. + ALL = "all" class SegmentType(StrEnum): NUMBER = "number" + INTEGER = "integer" + FLOAT = "float" STRING = "string" OBJECT = "object" SECRET = "secret" @@ -19,16 +38,139 @@ class SegmentType(StrEnum): GROUP = "group" - def is_array_type(self): + def is_array_type(self) -> bool: return self in _ARRAY_TYPES + @classmethod + def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]: + """ + Attempt to infer the `SegmentType` based on the Python type of the `value` parameter. + + Returns `None` if no appropriate `SegmentType` can be determined for the given `value`. + For example, this may occur if the input is a generic Python object of type `object`. + """ + + if isinstance(value, list): + elem_types: set[SegmentType] = set() + for i in value: + segment_type = cls.infer_segment_type(i) + if segment_type is None: + return None + + elem_types.add(segment_type) + + if len(elem_types) != 1: + if elem_types.issubset(_NUMERICAL_TYPES): + return SegmentType.ARRAY_NUMBER + return SegmentType.ARRAY_ANY + elif all(i.is_array_type() for i in elem_types): + return SegmentType.ARRAY_ANY + match elem_types.pop(): + case SegmentType.STRING: + return SegmentType.ARRAY_STRING + case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: + return SegmentType.ARRAY_NUMBER + case SegmentType.OBJECT: + return SegmentType.ARRAY_OBJECT + case SegmentType.FILE: + return SegmentType.ARRAY_FILE + case SegmentType.NONE: + return SegmentType.ARRAY_ANY + case _: + # This should be unreachable. + raise ValueError(f"not supported value {value}") + if value is None: + return SegmentType.NONE + elif isinstance(value, int) and not isinstance(value, bool): + return SegmentType.INTEGER + elif isinstance(value, float): + return SegmentType.FLOAT + elif isinstance(value, str): + return SegmentType.STRING + elif isinstance(value, dict): + return SegmentType.OBJECT + elif isinstance(value, File): + return SegmentType.FILE + else: + return None + + def _validate_array(self, value: Any, array_validation: ArrayValidation) -> bool: + if not isinstance(value, list): + return False + # Skip element validation if array is empty + if len(value) == 0: + return True + if self == SegmentType.ARRAY_ANY: + return True + element_type = _ARRAY_ELEMENT_TYPES_MAPPING[self] + + if array_validation == ArrayValidation.NONE: + return True + elif array_validation == ArrayValidation.FIRST: + return element_type.is_valid(value[0]) + else: + return all([element_type.is_valid(i, array_validation=ArrayValidation.NONE)] for i in value) + + def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool: + """ + Check if a value matches the segment type. + Users of `SegmentType` should call this method, instead of using + `isinstance` manually. + + Args: + value: The value to validate + array_validation: Validation strategy for array types (ignored for non-array types) + + Returns: + True if the value matches the type under the given validation strategy + """ + if self.is_array_type(): + return self._validate_array(value, array_validation) + elif self == SegmentType.NUMBER: + return isinstance(value, (int, float)) + elif self == SegmentType.STRING: + return isinstance(value, str) + elif self == SegmentType.OBJECT: + return isinstance(value, dict) + elif self == SegmentType.SECRET: + return isinstance(value, str) + elif self == SegmentType.FILE: + return isinstance(value, File) + elif self == SegmentType.NONE: + return value is None + else: + raise AssertionError("this statement should be unreachable.") + + def exposed_type(self) -> "SegmentType": + """Returns the type exposed to the frontend. + + The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here. + """ + if self in (SegmentType.INTEGER, SegmentType.FLOAT): + return SegmentType.NUMBER + return self + + +_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = { + # ARRAY_ANY does not have correpond element type. + SegmentType.ARRAY_STRING: SegmentType.STRING, + SegmentType.ARRAY_NUMBER: SegmentType.NUMBER, + SegmentType.ARRAY_OBJECT: SegmentType.OBJECT, + SegmentType.ARRAY_FILE: SegmentType.FILE, +} _ARRAY_TYPES = frozenset( - [ + list(_ARRAY_ELEMENT_TYPES_MAPPING.keys()) + + [ SegmentType.ARRAY_ANY, - SegmentType.ARRAY_STRING, - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_FILE, + ] +) + + +_NUMERICAL_TYPES = frozenset( + [ + SegmentType.NUMBER, + SegmentType.INTEGER, + SegmentType.FLOAT, ] ) diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index e5dc226571..d9ce91990b 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -3,6 +3,10 @@ from typing import Any, cast from uuid import uuid4 from pydantic import BaseModel, Field +from typing import Annotated, TypeAlias, cast +from uuid import uuid4 + +from pydantic import Discriminator, Field, Tag from core.helper import encrypter @@ -20,6 +24,7 @@ from .segments import ( ObjectSegment, Segment, StringSegment, + get_segment_discriminator, ) from .types import SegmentType @@ -27,6 +32,10 @@ from .types import SegmentType class Variable(Segment): """ A variable is a segment that has a name. + + It is mainly used to store segments and their selector in VariablePool. + + Note: this class is abstract, you should use subclasses of this class instead. """ id: str = Field( @@ -122,3 +131,26 @@ class RAGPipelineVariable(BaseModel): class RAGPipelineVariableInput(BaseModel): variable: RAGPipelineVariable value: Any +# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic. +# Use `Variable` for type hinting when serialization is not required. +# +# Note: +# - All variants in `VariableUnion` must inherit from the `Variable` class. +# - The union must include all non-abstract subclasses of `Segment`, except: +VariableUnion: TypeAlias = Annotated[ + ( + Annotated[NoneVariable, Tag(SegmentType.NONE)] + | Annotated[StringVariable, Tag(SegmentType.STRING)] + | Annotated[FloatVariable, Tag(SegmentType.FLOAT)] + | Annotated[IntegerVariable, Tag(SegmentType.INTEGER)] + | Annotated[ObjectVariable, Tag(SegmentType.OBJECT)] + | Annotated[FileVariable, Tag(SegmentType.FILE)] + | Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)] + | Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)] + | Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)] + | Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)] + | Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)] + | Annotated[SecretVariable, Tag(SegmentType.SECRET)] + ), + Discriminator(get_segment_discriminator), +] diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 3a68f45f61..742a50c01f 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -1,7 +1,7 @@ import re from collections import defaultdict from collections.abc import Mapping, Sequence -from typing import Any, Union +from typing import Annotated, Any, Union, cast from pydantic import BaseModel, Field @@ -17,6 +17,9 @@ from core.workflow.constants import ( SYSTEM_VARIABLE_NODE_ID, ) from core.workflow.enums import SystemVariableKey +from core.variables.variables import VariableUnion +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.workflow.system_variable import SystemVariable from factories import variable_factory VariableValue = Union[str, int, float, dict, list, File] @@ -29,24 +32,24 @@ class VariablePool(BaseModel): # The first element of the selector is the node id, it's the first-level key in the dictionary. # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # elements of the selector except the first one. - variable_dictionary: dict[str, dict[int, Segment]] = Field( + variable_dictionary: defaultdict[str, Annotated[dict[int, VariableUnion], Field(default_factory=dict)]] = Field( description="Variables mapping", default=defaultdict(dict), ) - # TODO: This user inputs is not used for pool. + + # The `user_inputs` is used only when constructing the inputs for the `StartNode`. It's not used elsewhere. user_inputs: Mapping[str, Any] = Field( description="User inputs", default_factory=dict, ) - system_variables: Mapping[SystemVariableKey, Any] = Field( + system_variables: SystemVariable = Field( description="System variables", - default_factory=dict, ) - environment_variables: Sequence[Variable] = Field( + environment_variables: Sequence[VariableUnion] = Field( description="Environment variables.", default_factory=list, ) - conversation_variables: Sequence[Variable] = Field( + conversation_variables: Sequence[VariableUnion] = Field( description="Conversation variables.", default_factory=list, ) @@ -56,8 +59,8 @@ class VariablePool(BaseModel): ) def model_post_init(self, context: Any, /) -> None: - for key, value in self.system_variables.items(): - self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value) + # Create a mapping from field names to SystemVariableKey enum values + self._add_system_variables(self.system_variables) # Add environment variables to the variable pool for var in self.environment_variables: self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) @@ -96,8 +99,22 @@ class VariablePool(BaseModel): segment = variable_factory.build_segment(value) variable = variable_factory.segment_to_variable(segment=segment, selector=selector) - hash_key = hash(tuple(selector[1:])) - self.variable_dictionary[selector[0]][hash_key] = variable + key, hash_key = self._selector_to_keys(selector) + # Based on the definition of `VariableUnion`, + # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. + self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable) + + @classmethod + def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, int]: + return selector[0], hash(tuple(selector[1:])) + + def _has(self, selector: Sequence[str]) -> bool: + key, hash_key = self._selector_to_keys(selector) + if key not in self.variable_dictionary: + return False + if hash_key not in self.variable_dictionary[key]: + return False + return True def get(self, selector: Sequence[str], /) -> Segment | None: """ @@ -115,8 +132,8 @@ class VariablePool(BaseModel): if len(selector) < MIN_SELECTORS_LENGTH: return None - hash_key = hash(tuple(selector[1:])) - value = self.variable_dictionary[selector[0]].get(hash_key) + key, hash_key = self._selector_to_keys(selector) + value: Segment | None = self.variable_dictionary[key].get(hash_key) if value is None: selector, attr = selector[:-1], selector[-1] @@ -149,8 +166,8 @@ class VariablePool(BaseModel): if len(selector) == 1: self.variable_dictionary[selector[0]] = {} return - hash_key = hash(tuple(selector[1:])) - self.variable_dictionary[selector[0]].pop(hash_key, None) + key, hash_key = self._selector_to_keys(selector) + self.variable_dictionary[key].pop(hash_key, None) def convert_template(self, template: str, /): parts = VARIABLE_PATTERN.split(template) @@ -167,3 +184,20 @@ class VariablePool(BaseModel): if isinstance(segment, FileSegment): return segment return None + + def _add_system_variables(self, system_variable: SystemVariable): + sys_var_mapping = system_variable.to_dict() + for key, value in sys_var_mapping.items(): + if value is None: + continue + selector = (SYSTEM_VARIABLE_NODE_ID, key) + # If the system variable already exists, do not add it again. + # This ensures that we can keep the id of the system variables intact. + if self._has(selector): + continue + self.add(selector, value) # type: ignore + + @classmethod + def empty(cls) -> "VariablePool": + """Create an empty variable pool.""" + return cls(system_variables=SystemVariable.empty()) diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py deleted file mode 100644 index 8896416f12..0000000000 --- a/api/core/workflow/entities/workflow_entities.py +++ /dev/null @@ -1,79 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel - -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.nodes.base import BaseIterationState, BaseLoopState, BaseNode -from models.enums import UserFrom -from models.workflow import Workflow, WorkflowType - -from .node_entities import NodeRunResult -from .variable_pool import VariablePool - - -class WorkflowNodeAndResult: - node: BaseNode - result: Optional[NodeRunResult] = None - - def __init__(self, node: BaseNode, result: Optional[NodeRunResult] = None): - self.node = node - self.result = result - - -class WorkflowRunState: - tenant_id: str - app_id: str - workflow_id: str - workflow_type: WorkflowType - user_id: str - user_from: UserFrom - invoke_from: InvokeFrom - - workflow_call_depth: int - - start_at: float - variable_pool: VariablePool - - total_tokens: int = 0 - - workflow_nodes_and_results: list[WorkflowNodeAndResult] - - class NodeRun(BaseModel): - node_id: str - iteration_node_id: str - loop_node_id: str - - workflow_node_runs: list[NodeRun] - workflow_node_steps: int - - current_iteration_state: Optional[BaseIterationState] - current_loop_state: Optional[BaseLoopState] - - def __init__( - self, - workflow: Workflow, - start_at: float, - variable_pool: VariablePool, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, - workflow_call_depth: int, - ): - self.workflow_id = workflow.id - self.tenant_id = workflow.tenant_id - self.app_id = workflow.app_id - self.workflow_type = WorkflowType.value_of(workflow.type) - self.user_id = user_id - self.user_from = user_from - self.invoke_from = invoke_from - self.workflow_call_depth = workflow_call_depth - - self.start_at = start_at - self.variable_pool = variable_pool - - self.total_tokens = 0 - - self.workflow_node_steps = 1 - self.workflow_node_runs = [] - self.current_iteration_state = None - self.current_loop_state = None diff --git a/api/core/workflow/graph_engine/entities/graph_runtime_state.py b/api/core/workflow/graph_engine/entities/graph_runtime_state.py index afc09bfac5..a62ffe46c9 100644 --- a/api/core/workflow/graph_engine/entities/graph_runtime_state.py +++ b/api/core/workflow/graph_engine/entities/graph_runtime_state.py @@ -17,8 +17,12 @@ class GraphRuntimeState(BaseModel): """total tokens""" llm_usage: LLMUsage = LLMUsage.empty_usage() """llm usage info""" + + # The `outputs` field stores the final output values generated by executing workflows or chatflows. + # + # Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent + # after a serialization and deserialization round trip. outputs: dict[str, Any] = {} - """outputs""" node_run_steps: int = 0 """node run steps""" diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 22ed9e2651..1adabf7247 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,4 +1,5 @@ from collections.abc import Mapping, Sequence +from decimal import Decimal from typing import Any, Optional from configs import dify_config @@ -114,8 +115,10 @@ class CodeNode(BaseNode[CodeNodeData]): ) if isinstance(value, float): + decimal_value = Decimal(str(value)).normalize() + precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator] # raise error if precision is too high - if len(str(value).split(".")[1]) > dify_config.CODE_MAX_PRECISION: + if precision > dify_config.CODE_MAX_PRECISION: raise OutputValidationError( f"Output variable `{variable}` has too high precision," f" it must be less than {dify_config.CODE_MAX_PRECISION} digits." diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index c447f433aa..8b566c83cd 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -521,18 +521,52 @@ class IterationNode(BaseNode[IterationNodeData]): ) return elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED: - yield IterationRunFailedEvent( - iteration_id=self.id, - iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, - start_at=start_at, - inputs=inputs, - outputs={"output": None}, - steps=len(iterator_list_value), - metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, - error=event.error, + yield NodeInIterationFailedEvent( + **metadata_event.model_dump(), ) + outputs[current_index] = None + + # clean nodes resources + for node_id in iteration_graph.node_ids: + variable_pool.remove([node_id]) + + # iteration run failed + if self.node_data.is_parallel: + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + parallel_mode_run_id=parallel_mode_run_id, + start_at=start_at, + inputs=inputs, + outputs={"output": outputs}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=event.error, + ) + else: + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": outputs}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=event.error, + ) + + # stop the iterator + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=event.error, + ) + ) + return yield metadata_event current_output_segment = variable_pool.get(self.node_data.output_selector) diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index b34d62d669..f05d93d83e 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -144,6 +144,8 @@ class KnowledgeRetrievalNode(LLMNode): error=str(e), error_type=type(e).__name__, ) + finally: + db.session.close() def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]: available_datasets = [] @@ -171,6 +173,9 @@ class KnowledgeRetrievalNode(LLMNode): .all() ) + # avoid blocking at retrieval + db.session.close() + for dataset in results: # pass if dataset is not available if not dataset: diff --git a/api/core/workflow/nodes/loop/entities.py b/api/core/workflow/nodes/loop/entities.py index 3f4a5edab9..d04e0bfae1 100644 --- a/api/core/workflow/nodes/loop/entities.py +++ b/api/core/workflow/nodes/loop/entities.py @@ -1,11 +1,29 @@ from collections.abc import Mapping -from typing import Any, Literal, Optional +from typing import Annotated, Any, Literal, Optional -from pydantic import BaseModel, Field +from pydantic import AfterValidator, BaseModel, Field +from core.variables.types import SegmentType from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData from core.workflow.utils.condition.entities import Condition +_VALID_VAR_TYPE = frozenset( + [ + SegmentType.STRING, + SegmentType.NUMBER, + SegmentType.OBJECT, + SegmentType.ARRAY_STRING, + SegmentType.ARRAY_NUMBER, + SegmentType.ARRAY_OBJECT, + ] +) + + +def _is_valid_var_type(seg_type: SegmentType) -> SegmentType: + if seg_type not in _VALID_VAR_TYPE: + raise ValueError(...) + return seg_type + class LoopVariableData(BaseModel): """ @@ -13,7 +31,7 @@ class LoopVariableData(BaseModel): """ label: str - var_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"] + var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)] value_type: Literal["variable", "constant"] value: Optional[Any | list[str]] = None diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 11fd7b6c2d..20501d0317 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -7,14 +7,9 @@ from typing import TYPE_CHECKING, Any, Literal, cast from configs import dify_config from core.variables import ( - ArrayNumberSegment, - ArrayObjectSegment, - ArrayStringSegment, IntegerSegment, - ObjectSegment, Segment, SegmentType, - StringSegment, ) from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -39,6 +34,7 @@ from core.workflow.nodes.enums import NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.loop.entities import LoopNodeData from core.workflow.utils.condition.processor import ConditionProcessor +from factories.variable_factory import TypeMismatchError, build_segment_with_type if TYPE_CHECKING: from core.workflow.entities.variable_pool import VariablePool @@ -505,23 +501,21 @@ class LoopNode(BaseNode[LoopNodeData]): return variable_mapping @staticmethod - def _get_segment_for_constant(var_type: str, value: Any) -> Segment: + def _get_segment_for_constant(var_type: SegmentType, value: Any) -> Segment: """Get the appropriate segment type for a constant value.""" - segment_mapping: dict[str, tuple[type[Segment], SegmentType]] = { - "string": (StringSegment, SegmentType.STRING), - "number": (IntegerSegment, SegmentType.NUMBER), - "object": (ObjectSegment, SegmentType.OBJECT), - "array[string]": (ArrayStringSegment, SegmentType.ARRAY_STRING), - "array[number]": (ArrayNumberSegment, SegmentType.ARRAY_NUMBER), - "array[object]": (ArrayObjectSegment, SegmentType.ARRAY_OBJECT), - } if var_type in ["array[string]", "array[number]", "array[object]"]: - if value: + if value and isinstance(value, str): value = json.loads(value) else: value = [] - segment_info = segment_mapping.get(var_type) - if not segment_info: - raise ValueError(f"Invalid variable type: {var_type}") - segment_class, value_type = segment_info - return segment_class(value=value, value_type=value_type) + try: + return build_segment_with_type(var_type, value) + except TypeMismatchError as type_exc: + # Attempt to parse the value as a JSON-encoded string, if applicable. + if not isinstance(value, str): + raise + try: + value = json.loads(value) + except ValueError: + raise type_exc + return build_segment_with_type(var_type, value) diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 5ee9bc331f..e215591888 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -16,7 +16,7 @@ class StartNode(BaseNode[StartNodeData]): def _run(self) -> NodeRunResult: node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables + system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() # TODO: System variables should be directly accessible, no need for special handling # Set system variables as node outputs. diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 48627a229d..3853a5d920 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -22,7 +22,7 @@ from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import AgentLogEvent from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType -from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent +from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db from factories import file_factory @@ -373,6 +373,12 @@ class ToolNode(BaseNode[ToolNodeData]): agent_logs.append(agent_log) yield agent_log + elif message.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES: + assert isinstance(message.message, ToolInvokeMessage.RetrieverResourceMessage) + yield RunRetrieverResourceEvent( + retriever_resources=message.message.retriever_resources, + context=message.message.context, + ) # Add agent_logs to outputs['json'] to ensure frontend can access thinking process json_output: list[dict[str, Any]] = [] diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index be5083c9c1..1864b13784 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -130,6 +130,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]): def get_zero_value(t: SegmentType): + # TODO(QuantumGhost): this should be a method of `SegmentType`. match t: case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER: return variable_factory.build_segment([]) @@ -137,6 +138,10 @@ def get_zero_value(t: SegmentType): return variable_factory.build_segment({}) case SegmentType.STRING: return variable_factory.build_segment("") + case SegmentType.INTEGER: + return variable_factory.build_segment(0) + case SegmentType.FLOAT: + return variable_factory.build_segment(0.0) case SegmentType.NUMBER: return variable_factory.build_segment(0) case _: diff --git a/api/core/workflow/nodes/variable_assigner/v2/constants.py b/api/core/workflow/nodes/variable_assigner/v2/constants.py index 3797bfa77a..7f760e5baa 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/constants.py +++ b/api/core/workflow/nodes/variable_assigner/v2/constants.py @@ -1,5 +1,6 @@ from core.variables import SegmentType +# Note: This mapping is duplicated with `get_zero_value`. Consider refactoring to avoid redundancy. EMPTY_VALUE_MAPPING = { SegmentType.STRING: "", SegmentType.NUMBER: 0, diff --git a/api/core/workflow/nodes/variable_assigner/v2/helpers.py b/api/core/workflow/nodes/variable_assigner/v2/helpers.py index 8fb2a27388..7a20975b15 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/helpers.py +++ b/api/core/workflow/nodes/variable_assigner/v2/helpers.py @@ -10,10 +10,16 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation): case Operation.OVER_WRITE | Operation.CLEAR: return True case Operation.SET: - return variable_type in {SegmentType.OBJECT, SegmentType.STRING, SegmentType.NUMBER} + return variable_type in { + SegmentType.OBJECT, + SegmentType.STRING, + SegmentType.NUMBER, + SegmentType.INTEGER, + SegmentType.FLOAT, + } case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE: # Only number variable can be added, subtracted, multiplied or divided - return variable_type == SegmentType.NUMBER + return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT} case Operation.APPEND | Operation.EXTEND: # Only array variable can be appended or extended return variable_type in { @@ -46,7 +52,7 @@ def is_constant_input_supported(*, variable_type: SegmentType, operation: Operat match variable_type: case SegmentType.STRING | SegmentType.OBJECT: return operation in {Operation.OVER_WRITE, Operation.SET} - case SegmentType.NUMBER: + case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: return operation in { Operation.OVER_WRITE, Operation.SET, @@ -66,7 +72,7 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va case SegmentType.STRING: return isinstance(value, str) - case SegmentType.NUMBER: + case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: if not isinstance(value, int | float): return False if operation == Operation.DIVIDE and value == 0: diff --git a/api/core/workflow/system_variable.py b/api/core/workflow/system_variable.py new file mode 100644 index 0000000000..df90c16596 --- /dev/null +++ b/api/core/workflow/system_variable.py @@ -0,0 +1,89 @@ +from collections.abc import Sequence +from typing import Any + +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator + +from core.file.models import File +from core.workflow.enums import SystemVariableKey + + +class SystemVariable(BaseModel): + """A model for managing system variables. + + Fields with a value of `None` are treated as absent and will not be included + in the variable pool. + """ + + model_config = ConfigDict( + extra="forbid", + serialize_by_alias=True, + validate_by_alias=True, + ) + + user_id: str | None = None + + # Ideally, `app_id` and `workflow_id` should be required and not `None`. + # However, there are scenarios in the codebase where these fields are not set. + # To maintain compatibility, they are marked as optional here. + app_id: str | None = None + workflow_id: str | None = None + + files: Sequence[File] = Field(default_factory=list) + + # NOTE: The `workflow_execution_id` field was previously named `workflow_run_id`. + # To maintain compatibility with existing workflows, it must be serialized + # as `workflow_run_id` in dictionaries or JSON objects, and also referenced + # as `workflow_run_id` in the variable pool. + workflow_execution_id: str | None = Field( + validation_alias=AliasChoices("workflow_execution_id", "workflow_run_id"), + serialization_alias="workflow_run_id", + default=None, + ) + # Chatflow related fields. + query: str | None = None + conversation_id: str | None = None + dialogue_count: int | None = None + + @model_validator(mode="before") + @classmethod + def validate_json_fields(cls, data): + if isinstance(data, dict): + # For JSON validation, only allow workflow_run_id + if "workflow_execution_id" in data and "workflow_run_id" not in data: + # This is likely from direct instantiation, allow it + return data + elif "workflow_execution_id" in data and "workflow_run_id" in data: + # Both present, remove workflow_execution_id + data = data.copy() + data.pop("workflow_execution_id") + return data + return data + + @classmethod + def empty(cls) -> "SystemVariable": + return cls() + + def to_dict(self) -> dict[SystemVariableKey, Any]: + # NOTE: This method is provided for compatibility with legacy code. + # New code should use the `SystemVariable` object directly instead of converting + # it to a dictionary, as this conversion results in the loss of type information + # for each key, making static analysis more difficult. + + d: dict[SystemVariableKey, Any] = { + SystemVariableKey.FILES: self.files, + } + if self.user_id is not None: + d[SystemVariableKey.USER_ID] = self.user_id + if self.app_id is not None: + d[SystemVariableKey.APP_ID] = self.app_id + if self.workflow_id is not None: + d[SystemVariableKey.WORKFLOW_ID] = self.workflow_id + if self.workflow_execution_id is not None: + d[SystemVariableKey.WORKFLOW_EXECUTION_ID] = self.workflow_execution_id + if self.query is not None: + d[SystemVariableKey.QUERY] = self.query + if self.conversation_id is not None: + d[SystemVariableKey.CONVERSATION_ID] = self.conversation_id + if self.dialogue_count is not None: + d[SystemVariableKey.DIALOGUE_COUNT] = self.dialogue_count + return d diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index 0aab2426af..50ff733979 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -26,6 +26,7 @@ from core.workflow.entities.workflow_node_execution import ( from core.workflow.enums import SystemVariableKey from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry from libs.datetime_utils import naive_utc_now @@ -43,7 +44,7 @@ class WorkflowCycleManager: self, *, application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], - workflow_system_variables: dict[SystemVariableKey, Any], + workflow_system_variables: SystemVariable, workflow_info: CycleManagerWorkflowInfo, workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, @@ -56,17 +57,22 @@ class WorkflowCycleManager: def handle_workflow_run_start(self) -> WorkflowExecution: inputs = {**self._application_generate_entity.inputs} - for key, value in (self._workflow_system_variables or {}).items(): - if key.value == "conversation": - continue - inputs[f"sys.{key.value}"] = value + + # Iterate over SystemVariable fields using Pydantic's model_fields + if self._workflow_system_variables: + for field_name, value in self._workflow_system_variables.to_dict().items(): + if field_name == SystemVariableKey.CONVERSATION_ID: + continue + inputs[f"sys.{field_name}"] = value # handle special values inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) # init workflow run # TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this - execution_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_EXECUTION_ID) or uuid4()) + execution_id = str( + self._workflow_system_variables.workflow_execution_id if self._workflow_system_variables else None + ) or str(uuid4()) execution = WorkflowExecution.new( id_=execution_id, workflow_id=self._workflow_info.workflow_id, diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 2868dcb7de..1399efcdb1 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -21,6 +21,7 @@ from core.workflow.nodes import NodeType from core.workflow.nodes.base import BaseNode from core.workflow.nodes.event import NodeEvent from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from factories import file_factory from models.enums import UserFrom @@ -254,7 +255,7 @@ class WorkflowEntry: # init variable pool variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, environment_variables=[], ) diff --git a/api/core/workflow/workflow_type_encoder.py b/api/core/workflow/workflow_type_encoder.py index 0123fdac18..2c634d25ec 100644 --- a/api/core/workflow/workflow_type_encoder.py +++ b/api/core/workflow/workflow_type_encoder.py @@ -1,4 +1,3 @@ -import json from collections.abc import Mapping from typing import Any @@ -8,18 +7,6 @@ from core.file.models import File from core.variables import Segment -class WorkflowRuntimeTypeEncoder(json.JSONEncoder): - def default(self, o: Any): - if isinstance(o, Segment): - return o.value - elif isinstance(o, File): - return o.to_dict() - elif isinstance(o, BaseModel): - return o.model_dump(mode="json") - else: - return super().default(o) - - class WorkflowRuntimeTypeConverter: def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None: result = self._to_json_encodable_recursive(value) diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py index b62b0b60d6..b027a165f9 100644 --- a/api/extensions/ext_otel.py +++ b/api/extensions/ext_otel.py @@ -193,13 +193,22 @@ def init_app(app: DifyApp): insecure=True, ) else: + headers = {"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"} if dify_config.OTLP_API_KEY else None + + trace_endpoint = dify_config.OTLP_TRACE_ENDPOINT + if not trace_endpoint: + trace_endpoint = dify_config.OTLP_BASE_ENDPOINT + "/v1/traces" exporter = HTTPSpanExporter( - endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/traces", - headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"}, + endpoint=trace_endpoint, + headers=headers, ) + + metric_endpoint = dify_config.OTLP_METRIC_ENDPOINT + if not metric_endpoint: + metric_endpoint = dify_config.OTLP_BASE_ENDPOINT + "/v1/metrics" metric_exporter = HTTPMetricExporter( - endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/metrics", - headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"}, + endpoint=metric_endpoint, + headers=headers, ) else: exporter = ConsoleSpanExporter() diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 8dd4ec2e4a..33d7ad29b4 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -100,9 +100,13 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen result = StringVariable.model_validate(mapping) case SegmentType.SECRET: result = SecretVariable.model_validate(mapping) - case SegmentType.NUMBER if isinstance(value, int): + case SegmentType.NUMBER | SegmentType.INTEGER if isinstance(value, int): + mapping = dict(mapping) + mapping["value_type"] = SegmentType.INTEGER result = IntegerVariable.model_validate(mapping) - case SegmentType.NUMBER if isinstance(value, float): + case SegmentType.NUMBER | SegmentType.FLOAT if isinstance(value, float): + mapping = dict(mapping) + mapping["value_type"] = SegmentType.FLOAT result = FloatVariable.model_validate(mapping) case SegmentType.NUMBER if not isinstance(value, float | int): raise VariableError(f"invalid number value {value}") @@ -128,6 +132,8 @@ def infer_segment_type_from_value(value: Any, /) -> SegmentType: def build_segment(value: Any, /) -> Segment: + # NOTE: If you have runtime type information available, consider using the `build_segment_with_type` + # below if value is None: return NoneSegment() if isinstance(value, str): @@ -143,12 +149,17 @@ def build_segment(value: Any, /) -> Segment: if isinstance(value, list): items = [build_segment(item) for item in value] types = {item.value_type for item in items} - if len(types) != 1 or all(isinstance(item, ArraySegment) for item in items): + if all(isinstance(item, ArraySegment) for item in items): return ArrayAnySegment(value=value) + elif len(types) != 1: + if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}): + return ArrayNumberSegment(value=value) + return ArrayAnySegment(value=value) + match types.pop(): case SegmentType.STRING: return ArrayStringSegment(value=value) - case SegmentType.NUMBER: + case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: return ArrayNumberSegment(value=value) case SegmentType.OBJECT: return ArrayObjectSegment(value=value) @@ -162,6 +173,22 @@ def build_segment(value: Any, /) -> Segment: raise ValueError(f"not supported value {value}") +_segment_factory: Mapping[SegmentType, type[Segment]] = { + SegmentType.NONE: NoneSegment, + SegmentType.STRING: StringSegment, + SegmentType.INTEGER: IntegerSegment, + SegmentType.FLOAT: FloatSegment, + SegmentType.FILE: FileSegment, + SegmentType.OBJECT: ObjectSegment, + # Array types + SegmentType.ARRAY_ANY: ArrayAnySegment, + SegmentType.ARRAY_STRING: ArrayStringSegment, + SegmentType.ARRAY_NUMBER: ArrayNumberSegment, + SegmentType.ARRAY_OBJECT: ArrayObjectSegment, + SegmentType.ARRAY_FILE: ArrayFileSegment, +} + + def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: """ Build a segment with explicit type checking. @@ -199,7 +226,7 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: if segment_type == SegmentType.NONE: return NoneSegment() else: - raise TypeMismatchError(f"Expected {segment_type}, but got None") + raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None") # Handle empty list special case for array types if isinstance(value, list) and len(value) == 0: @@ -214,21 +241,25 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: elif segment_type == SegmentType.ARRAY_FILE: return ArrayFileSegment(value=value) else: - raise TypeMismatchError(f"Expected {segment_type}, but got empty list") - - # Build segment using existing logic to infer actual type - inferred_segment = build_segment(value) - inferred_type = inferred_segment.value_type + raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list") + inferred_type = SegmentType.infer_segment_type(value) # Type compatibility checking + if inferred_type is None: + raise TypeMismatchError( + f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}" + ) if inferred_type == segment_type: - return inferred_segment - - # Type mismatch - raise error with descriptive message - raise TypeMismatchError( - f"Type mismatch: expected {segment_type}, but value '{value}' " - f"(type: {type(value).__name__}) corresponds to {inferred_type}" - ) + segment_class = _segment_factory[segment_type] + return segment_class(value_type=segment_type, value=value) + elif segment_type == SegmentType.NUMBER and inferred_type in ( + SegmentType.INTEGER, + SegmentType.FLOAT, + ): + segment_class = _segment_factory[inferred_type] + return segment_class(value_type=inferred_type, value=value) + else: + raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}") def segment_to_variable( @@ -256,6 +287,6 @@ def segment_to_variable( name=name, description=description, value=segment.value, - selector=selector, + selector=list(selector), ), ) diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py new file mode 100644 index 0000000000..8288bd54a3 --- /dev/null +++ b/api/fields/_value_type_serializer.py @@ -0,0 +1,15 @@ +from typing import TypedDict + +from core.variables.segments import Segment +from core.variables.types import SegmentType + + +class _VarTypedDict(TypedDict, total=False): + value_type: SegmentType + + +def serialize_value_type(v: _VarTypedDict | Segment) -> str: + if isinstance(v, Segment): + return v.value_type.exposed_type().value + else: + return v["value_type"].exposed_type().value diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index 73c224542a..b6d85e0e24 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -188,6 +188,7 @@ app_detail_fields_with_site = { "site": fields.Nested(site_fields), "api_base_url": fields.String, "use_icon_as_answer_icon": fields.Boolean, + "max_active_requests": fields.Integer, "created_by": fields.String, "created_at": TimestampField, "updated_by": fields.String, diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py index 71785e7d67..c5a0c9a49d 100644 --- a/api/fields/conversation_variable_fields.py +++ b/api/fields/conversation_variable_fields.py @@ -2,10 +2,12 @@ from flask_restful import fields from libs.helper import TimestampField +from ._value_type_serializer import serialize_value_type + conversation_variable_fields = { "id": fields.String, "name": fields.String, - "value_type": fields.String(attribute="value_type.value"), + "value_type": fields.String(attribute=serialize_value_type), "value": fields.String, "description": fields.String, "created_at": TimestampField, diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 57758bf86c..9207ad7ab7 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -5,6 +5,8 @@ from core.variables import SecretVariable, SegmentType, Variable from fields.member_fields import simple_account_fields from libs.helper import TimestampField +from ._value_type_serializer import serialize_value_type + ENVIRONMENT_VARIABLE_SUPPORTED_TYPES = (SegmentType.STRING, SegmentType.NUMBER, SegmentType.SECRET) @@ -24,11 +26,16 @@ class EnvironmentVariableField(fields.Raw): "id": value.id, "name": value.name, "value": value.value, - "value_type": value.value_type.value, + "value_type": value.value_type.exposed_type().value, "description": value.description, } if isinstance(value, dict): - value_type = value.get("value_type") + value_type_str = value.get("value_type") + if not isinstance(value_type_str, str): + raise TypeError( + f"unexpected type for value_type field, value={value_type_str}, type={type(value_type_str)}" + ) + value_type = SegmentType(value_type_str).exposed_type() if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES: raise ValueError(f"Unsupported environment variable value type: {value_type}") return value @@ -37,7 +44,7 @@ class EnvironmentVariableField(fields.Raw): conversation_variable_fields = { "id": fields.String, "name": fields.String, - "value_type": fields.String(attribute="value_type.value"), + "value_type": fields.String(attribute=serialize_value_type), "value": fields.Raw, "description": fields.String, } diff --git a/api/libs/helper.py b/api/libs/helper.py index 48126461a3..00772d530a 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -148,25 +148,6 @@ class StrLen: return value -class FloatRange: - """Restrict input to an float in a range (inclusive)""" - - def __init__(self, low, high, argument="argument"): - self.low = low - self.high = high - self.argument = argument - - def __call__(self, value): - value = _get_float(value) - if value < self.low or value > self.high: - error = "Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}".format( - arg=self.argument, val=value, lo=self.low, hi=self.high - ) - raise ValueError(error) - - return value - - class DatetimeString: def __init__(self, format, argument="argument"): self.format = format diff --git a/api/libs/jsonutil.py b/api/libs/jsonutil.py deleted file mode 100644 index fa29671034..0000000000 --- a/api/libs/jsonutil.py +++ /dev/null @@ -1,11 +0,0 @@ -import json - -from pydantic import BaseModel - - -class PydanticModelEncoder(json.JSONEncoder): - def default(self, o): - if isinstance(o, BaseModel): - return o.model_dump() - else: - super().default(o) diff --git a/api/libs/uuid_utils.py b/api/libs/uuid_utils.py new file mode 100644 index 0000000000..a8190011ed --- /dev/null +++ b/api/libs/uuid_utils.py @@ -0,0 +1,164 @@ +import secrets +import struct +import time +import uuid + +# Reference for UUIDv7 specification: +# RFC 9562, Section 5.7 - https://www.rfc-editor.org/rfc/rfc9562.html#section-5.7 + +# Define the format for packing the timestamp as an unsigned 64-bit integer (big-endian). +# +# For details on the `struct.pack` format, refer to: +# https://docs.python.org/3/library/struct.html#byte-order-size-and-alignment +_PACK_TIMESTAMP = ">Q" + +# Define the format for packing the 12-bit random data A (as specified in RFC 9562 Section 5.7) +# into an unsigned 16-bit integer (big-endian). +_PACK_RAND_A = ">H" + + +def _create_uuidv7_bytes(timestamp_ms: int, random_bytes: bytes) -> bytes: + """Create UUIDv7 byte structure with given timestamp and random bytes. + + This is a private helper function that handles the common logic for creating + UUIDv7 byte structure according to RFC 9562 specification. + + UUIDv7 Structure: + - 48 bits: timestamp (milliseconds since Unix epoch) + - 12 bits: random data A (with version bits) + - 62 bits: random data B (with variant bits) + + The function performs the following operations: + 1. Creates a 128-bit (16-byte) UUID structure + 2. Packs the timestamp into the first 48 bits (6 bytes) + 3. Sets the version bits to 7 (0111) in the correct position + 4. Sets the variant bits to 10 (binary) in the correct position + 5. Fills the remaining bits with the provided random bytes + + Args: + timestamp_ms: The timestamp in milliseconds since Unix epoch (48 bits). + random_bytes: Random bytes to use for the random portions (must be 10 bytes). + First 2 bytes are used for random data A (12 bits after version). + Last 8 bytes are used for random data B (62 bits after variant). + + Returns: + A 16-byte bytes object representing the complete UUIDv7 structure. + + Note: + This function assumes the random_bytes parameter is exactly 10 bytes. + The caller is responsible for providing appropriate random data. + """ + # Create the 128-bit UUID structure + uuid_bytes = bytearray(16) + + # Pack timestamp (48 bits) into first 6 bytes + uuid_bytes[0:6] = struct.pack(_PACK_TIMESTAMP, timestamp_ms)[2:8] # Take last 6 bytes of 8-byte big-endian + + # Next 16 bits: random data A (12 bits) + version (4 bits) + # Take first 2 random bytes and set version to 7 + rand_a = struct.unpack(_PACK_RAND_A, random_bytes[0:2])[0] + # Clear the highest 4 bits to make room for the version field + # by performing a bitwise AND with 0x0FFF (binary: 0b0000_1111_1111_1111). + rand_a = rand_a & 0x0FFF + # Set the version field to 7 (binary: 0111) by performing a bitwise OR with 0x7000 (binary: 0b0111_0000_0000_0000). + rand_a = rand_a | 0x7000 + uuid_bytes[6:8] = struct.pack(_PACK_RAND_A, rand_a) + + # Last 64 bits: random data B (62 bits) + variant (2 bits) + # Use remaining 8 random bytes and set variant to 10 (binary) + uuid_bytes[8:16] = random_bytes[2:10] + + # Set variant bits (first 2 bits of byte 8 should be '10') + uuid_bytes[8] = (uuid_bytes[8] & 0x3F) | 0x80 # Set variant to 10xxxxxx + + return bytes(uuid_bytes) + + +def uuidv7(timestamp_ms: int | None = None) -> uuid.UUID: + """Generate a UUID version 7 according to RFC 9562 specification. + + UUIDv7 features a time-ordered value field derived from the widely + implemented and well known Unix Epoch timestamp source, the number of + milliseconds since midnight 1 Jan 1970 UTC, leap seconds excluded. + + Structure: + - 48 bits: timestamp (milliseconds since Unix epoch) + - 12 bits: random data A (with version bits) + - 62 bits: random data B (with variant bits) + + Args: + timestamp_ms: The timestamp used when generating UUID, use the current time if unspecified. + Should be an integer representing milliseconds since Unix epoch. + + Returns: + A UUID object representing a UUIDv7. + + Example: + >>> import time + >>> # Generate UUIDv7 with current time + >>> uuid_current = uuidv7() + >>> # Generate UUIDv7 with specific timestamp + >>> uuid_specific = uuidv7(int(time.time() * 1000)) + """ + if timestamp_ms is None: + timestamp_ms = int(time.time() * 1000) + + # Generate 10 random bytes for the random portions + random_bytes = secrets.token_bytes(10) + + # Create UUIDv7 bytes using the helper function + uuid_bytes = _create_uuidv7_bytes(timestamp_ms, random_bytes) + + return uuid.UUID(bytes=uuid_bytes) + + +def uuidv7_timestamp(id_: uuid.UUID) -> int: + """Extract the timestamp from a UUIDv7. + + UUIDv7 contains a 48-bit timestamp field representing milliseconds since + the Unix epoch (1970-01-01 00:00:00 UTC). This function extracts and + returns that timestamp as an integer representing milliseconds since the epoch. + + Args: + id_: A UUID object that should be a UUIDv7 (version 7). + + Returns: + The timestamp as an integer representing milliseconds since Unix epoch. + + Raises: + ValueError: If the provided UUID is not version 7. + + Example: + >>> uuid_v7 = uuidv7() + >>> timestamp = uuidv7_timestamp(uuid_v7) + >>> print(f"UUID was created at: {timestamp} ms") + """ + # Verify this is a UUIDv7 + if id_.version != 7: + raise ValueError(f"Expected UUIDv7 (version 7), got version {id_.version}") + + # Extract the UUID bytes + uuid_bytes = id_.bytes + + # Extract the first 48 bits (6 bytes) as the timestamp in milliseconds + # Pad with 2 zero bytes at the beginning to make it 8 bytes for unpacking as Q (unsigned long long) + timestamp_bytes = b"\x00\x00" + uuid_bytes[0:6] + ts_in_ms = struct.unpack(_PACK_TIMESTAMP, timestamp_bytes)[0] + + # Return timestamp directly in milliseconds as integer + assert isinstance(ts_in_ms, int) + return ts_in_ms + + +def uuidv7_boundary(timestamp_ms: int) -> uuid.UUID: + """Generate a non-random uuidv7 with the given timestamp (first 48 bits) and + all random bits to 0. As the smallest possible uuidv7 for that timestamp, + it may be used as a boundary for partitions. + """ + # Use zero bytes for all random portions + zero_random_bytes = b"\x00" * 10 + + # Create UUIDv7 bytes using the helper function + uuid_bytes = _create_uuidv7_bytes(timestamp_ms, zero_random_bytes) + + return uuid.UUID(bytes=uuid_bytes) diff --git a/api/migrations/versions/2025_07_02_2332-1c9ba48be8e4_add_uuidv7_function_in_sql.py b/api/migrations/versions/2025_07_02_2332-1c9ba48be8e4_add_uuidv7_function_in_sql.py new file mode 100644 index 0000000000..2bbbb3d28e --- /dev/null +++ b/api/migrations/versions/2025_07_02_2332-1c9ba48be8e4_add_uuidv7_function_in_sql.py @@ -0,0 +1,86 @@ +"""add uuidv7 function in SQL + +Revision ID: 1c9ba48be8e4 +Revises: 58eb7bdb93fe +Create Date: 2025-07-02 23:32:38.484499 + +""" + +""" +The functions in this files comes from https://github.com/dverite/postgres-uuidv7-sql/, with minor modifications. + +LICENSE: + +# Copyright and License + +Copyright (c) 2024, Daniel Vérité + +Permission to use, copy, modify, and distribute this software and its documentation for any purpose, without fee, and without a written agreement is hereby granted, provided that the above copyright notice and this paragraph and the following two paragraphs appear in all copies. + +In no event shall Daniel Vérité be liable to any party for direct, indirect, special, incidental, or consequential damages, including lost profits, arising out of the use of this software and its documentation, even if Daniel Vérité has been advised of the possibility of such damage. + +Daniel Vérité specifically disclaims any warranties, including, but not limited to, the implied warranties of merchantability and fitness for a particular purpose. The software provided hereunder is on an "AS IS" basis, and Daniel Vérité has no obligations to provide maintenance, support, updates, enhancements, or modifications. +""" + +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '1c9ba48be8e4' +down_revision = '58eb7bdb93fe' +branch_labels: None = None +depends_on: None = None + + +def upgrade(): + # This implementation differs slightly from the original uuidv7 function in + # https://github.com/dverite/postgres-uuidv7-sql/. + # The ability to specify source timestamp has been removed because its type signature is incompatible with + # PostgreSQL 18's `uuidv7` function. This capability is rarely needed in practice, as IDs can be + # generated and controlled within the application layer. + op.execute(sa.text(r""" +/* Main function to generate a uuidv7 value with millisecond precision */ +CREATE FUNCTION uuidv7() RETURNS uuid +AS +$$ + -- Replace the first 48 bits of a uuidv4 with the current + -- number of milliseconds since 1970-01-01 UTC + -- and set the "ver" field to 7 by setting additional bits +SELECT encode( + set_bit( + set_bit( + overlay(uuid_send(gen_random_uuid()) placing + substring(int8send((extract(epoch from clock_timestamp()) * 1000)::bigint) from + 3) + from 1 for 6), + 52, 1), + 53, 1), 'hex')::uuid; +$$ LANGUAGE SQL VOLATILE PARALLEL SAFE; + +COMMENT ON FUNCTION uuidv7 IS + 'Generate a uuid-v7 value with a 48-bit timestamp (millisecond precision) and 74 bits of randomness'; +""")) + + op.execute(sa.text(r""" +CREATE FUNCTION uuidv7_boundary(timestamptz) RETURNS uuid +AS +$$ + /* uuid fields: version=0b0111, variant=0b10 */ +SELECT encode( + overlay('\x00000000000070008000000000000000'::bytea + placing substring(int8send(floor(extract(epoch from $1) * 1000)::bigint) from 3) + from 1 for 6), + 'hex')::uuid; +$$ LANGUAGE SQL STABLE STRICT PARALLEL SAFE; + +COMMENT ON FUNCTION uuidv7_boundary(timestamptz) IS + 'Generate a non-random uuidv7 with the given timestamp (first 48 bits) and all random bits to 0. As the smallest possible uuidv7 for that timestamp, it may be used as a boundary for partitions.'; +""" +)) + + +def downgrade(): + op.execute(sa.text("DROP FUNCTION uuidv7")) + op.execute(sa.text("DROP FUNCTION uuidv7_boundary")) diff --git a/api/models/model.py b/api/models/model.py index b30871f6ef..96b4b6fcac 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -612,14 +612,6 @@ class InstalledApp(Base): return tenant -class ConversationSource(StrEnum): - """This enumeration is designed for use with `Conversation.from_source`.""" - - # NOTE(QuantumGhost): The enumeration members may not cover all possible cases. - API = "api" - CONSOLE = "console" - - class Conversation(Base): __tablename__ = "conversations" __table_args__ = ( @@ -935,7 +927,7 @@ class Message(Base): created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - workflow_run_id = db.Column(StringUUID) + workflow_run_id: Mapped[str] = db.Column(StringUUID) @property def inputs(self): diff --git a/api/models/workflow.py b/api/models/workflow.py index bc596cf4e2..e36b0f4ecf 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -12,6 +12,7 @@ from sqlalchemy import orm from core.file.constants import maybe_file_object from core.file.models import File from core.variables import utils as variable_utils +from core.variables.variables import FloatVariable, IntegerVariable, StringVariable from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.nodes.enums import NodeType from factories.variable_factory import TypeMismatchError, build_segment_with_type @@ -359,7 +360,7 @@ class Workflow(Base): ) @property - def environment_variables(self) -> Sequence[Variable]: + def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]: # TODO: find some way to init `self._environment_variables` when instance created. if self._environment_variables is None: self._environment_variables = "{}" @@ -379,11 +380,15 @@ class Workflow(Base): def decrypt_func(var): if isinstance(var, SecretVariable): return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) - else: + elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)): return var + else: + raise AssertionError("this statement should be unreachable.") - results = list(map(decrypt_func, results)) - return results + decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = list( + map(decrypt_func, results) + ) + return decrypted_results @environment_variables.setter def environment_variables(self, value: Sequence[Variable]): diff --git a/api/repositories/__init__.py b/api/repositories/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py new file mode 100644 index 0000000000..00a2d1f87d --- /dev/null +++ b/api/repositories/api_workflow_node_execution_repository.py @@ -0,0 +1,197 @@ +""" +Service-layer repository protocol for WorkflowNodeExecutionModel operations. + +This module provides a protocol interface for service-layer operations on WorkflowNodeExecutionModel +that abstracts database queries currently done directly in service classes. This repository is +specifically designed for service-layer needs and is separate from the core domain repository. + +The service repository handles operations that require access to database-specific fields like +tenant_id, app_id, triggered_from, etc., which are not part of the core domain model. +""" + +from collections.abc import Sequence +from datetime import datetime +from typing import Optional, Protocol + +from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from models.workflow import WorkflowNodeExecutionModel + + +class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Protocol): + """ + Protocol for service-layer operations on WorkflowNodeExecutionModel. + + This repository provides database access patterns specifically needed by service classes, + handling queries that involve database-specific fields and multi-tenancy concerns. + + Key responsibilities: + - Manages database operations for workflow node executions + - Handles multi-tenant data isolation + - Provides batch processing capabilities + - Supports execution lifecycle management + + Implementation notes: + - Returns database models directly (WorkflowNodeExecutionModel) + - Handles tenant/app filtering automatically + - Provides service-specific query patterns + - Focuses on database operations without domain logic + - Supports cleanup and maintenance operations + """ + + def get_node_last_execution( + self, + tenant_id: str, + app_id: str, + workflow_id: str, + node_id: str, + ) -> Optional[WorkflowNodeExecutionModel]: + """ + Get the most recent execution for a specific node. + + This method finds the latest execution of a specific node within a workflow, + ordered by creation time. Used primarily for debugging and inspection purposes. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + workflow_id: The workflow identifier + node_id: The node identifier + + Returns: + The most recent WorkflowNodeExecutionModel for the node, or None if not found + """ + ... + + def get_executions_by_workflow_run( + self, + tenant_id: str, + app_id: str, + workflow_run_id: str, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Get all node executions for a specific workflow run. + + This method retrieves all node executions that belong to a specific workflow run, + ordered by index in descending order for proper trace visualization. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + workflow_run_id: The workflow run identifier + + Returns: + A sequence of WorkflowNodeExecutionModel instances ordered by index (desc) + """ + ... + + def get_execution_by_id( + self, + execution_id: str, + tenant_id: Optional[str] = None, + ) -> Optional[WorkflowNodeExecutionModel]: + """ + Get a workflow node execution by its ID. + + This method retrieves a specific execution by its unique identifier. + Tenant filtering is optional for cases where the execution ID is globally unique. + + When `tenant_id` is None, it's the caller's responsibility to ensure proper data isolation between tenants. + If the `execution_id` comes from untrusted sources (e.g., retrieved from an API request), the caller should + set `tenant_id` to prevent horizontal privilege escalation. + + Args: + execution_id: The execution identifier + tenant_id: Optional tenant identifier for additional filtering + + Returns: + The WorkflowNodeExecutionModel if found, or None if not found + """ + ... + + def delete_expired_executions( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> int: + """ + Delete workflow node executions that are older than the specified date. + + This method is used for cleanup operations to remove expired executions + in batches to avoid overwhelming the database. + + Args: + tenant_id: The tenant identifier + before_date: Delete executions created before this date + batch_size: Maximum number of executions to delete in one batch + + Returns: + The number of executions deleted + """ + ... + + def delete_executions_by_app( + self, + tenant_id: str, + app_id: str, + batch_size: int = 1000, + ) -> int: + """ + Delete all workflow node executions for a specific app. + + This method is used when removing an app and all its related data. + Executions are deleted in batches to avoid overwhelming the database. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + batch_size: Maximum number of executions to delete in one batch + + Returns: + The total number of executions deleted + """ + ... + + def get_expired_executions_batch( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Get a batch of expired workflow node executions for backup purposes. + + This method retrieves expired executions without deleting them, + allowing the caller to backup the data before deletion. + + Args: + tenant_id: The tenant identifier + before_date: Get executions created before this date + batch_size: Maximum number of executions to retrieve + + Returns: + A sequence of WorkflowNodeExecutionModel instances + """ + ... + + def delete_executions_by_ids( + self, + execution_ids: Sequence[str], + ) -> int: + """ + Delete workflow node executions by their IDs. + + This method deletes specific executions by their IDs, + typically used after backing up the data. + + This method does not perform tenant isolation checks. The caller is responsible for ensuring proper + data isolation between tenants. When execution IDs come from untrusted sources (e.g., API requests), + additional tenant validation should be implemented to prevent unauthorized access. + + Args: + execution_ids: List of execution IDs to delete + + Returns: + The number of executions deleted + """ + ... diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py new file mode 100644 index 0000000000..59e7baeb79 --- /dev/null +++ b/api/repositories/api_workflow_run_repository.py @@ -0,0 +1,181 @@ +""" +API WorkflowRun Repository Protocol + +This module defines the protocol for service-layer WorkflowRun operations. +The repository provides an abstraction layer for WorkflowRun database operations +used by service classes, separating service-layer concerns from core domain logic. + +Key Features: +- Paginated workflow run queries with filtering +- Bulk deletion operations with OSS backup support +- Multi-tenant data isolation +- Expired record cleanup with data retention +- Service-layer specific query patterns + +Usage: + This protocol should be used by service classes that need to perform + WorkflowRun database operations. It provides a clean interface that + hides implementation details and supports dependency injection. + +Example: + ```python + from repositories.dify_api_repository_factory import DifyAPIRepositoryFactory + + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + + # Get paginated workflow runs + runs = repo.get_paginated_workflow_runs( + tenant_id="tenant-123", + app_id="app-456", + triggered_from="debugging", + limit=20 + ) + ``` +""" + +from collections.abc import Sequence +from datetime import datetime +from typing import Optional, Protocol + +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models.workflow import WorkflowRun + + +class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): + """ + Protocol for service-layer WorkflowRun repository operations. + + This protocol defines the interface for WorkflowRun database operations + that are specific to service-layer needs, including pagination, filtering, + and bulk operations with data backup support. + """ + + def get_paginated_workflow_runs( + self, + tenant_id: str, + app_id: str, + triggered_from: str, + limit: int = 20, + last_id: Optional[str] = None, + ) -> InfiniteScrollPagination: + """ + Get paginated workflow runs with filtering. + + Retrieves workflow runs for a specific app and trigger source with + cursor-based pagination support. Used primarily for debugging and + workflow run listing in the UI. + + Args: + tenant_id: Tenant identifier for multi-tenant isolation + app_id: Application identifier + triggered_from: Filter by trigger source (e.g., "debugging", "app-run") + limit: Maximum number of records to return (default: 20) + last_id: Cursor for pagination - ID of the last record from previous page + + Returns: + InfiniteScrollPagination object containing: + - data: List of WorkflowRun objects + - limit: Applied limit + - has_more: Boolean indicating if more records exist + + Raises: + ValueError: If last_id is provided but the corresponding record doesn't exist + """ + ... + + def get_workflow_run_by_id( + self, + tenant_id: str, + app_id: str, + run_id: str, + ) -> Optional[WorkflowRun]: + """ + Get a specific workflow run by ID. + + Retrieves a single workflow run with tenant and app isolation. + Used for workflow run detail views and execution tracking. + + Args: + tenant_id: Tenant identifier for multi-tenant isolation + app_id: Application identifier + run_id: Workflow run identifier + + Returns: + WorkflowRun object if found, None otherwise + """ + ... + + def get_expired_runs_batch( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> Sequence[WorkflowRun]: + """ + Get a batch of expired workflow runs for cleanup. + + Retrieves workflow runs created before the specified date for + cleanup operations. Used by scheduled tasks to remove old data + while maintaining data retention policies. + + Args: + tenant_id: Tenant identifier for multi-tenant isolation + before_date: Only return runs created before this date + batch_size: Maximum number of records to return + + Returns: + Sequence of WorkflowRun objects to be processed for cleanup + """ + ... + + def delete_runs_by_ids( + self, + run_ids: Sequence[str], + ) -> int: + """ + Delete workflow runs by their IDs. + + Performs bulk deletion of workflow runs by ID. This method should + be used after backing up the data to OSS storage for retention. + + Args: + run_ids: Sequence of workflow run IDs to delete + + Returns: + Number of records actually deleted + + Note: + This method performs hard deletion. Ensure data is backed up + to OSS storage before calling this method for compliance with + data retention policies. + """ + ... + + def delete_runs_by_app( + self, + tenant_id: str, + app_id: str, + batch_size: int = 1000, + ) -> int: + """ + Delete all workflow runs for a specific app. + + Performs bulk deletion of all workflow runs associated with an app. + Used during app cleanup operations. Processes records in batches + to avoid memory issues and long-running transactions. + + Args: + tenant_id: Tenant identifier for multi-tenant isolation + app_id: Application identifier + batch_size: Number of records to process in each batch + + Returns: + Total number of records deleted across all batches + + Note: + This method performs hard deletion without backup. Use with caution + and ensure proper data retention policies are followed. + """ + ... diff --git a/api/repositories/factory.py b/api/repositories/factory.py new file mode 100644 index 0000000000..0a0adbf2c2 --- /dev/null +++ b/api/repositories/factory.py @@ -0,0 +1,103 @@ +""" +DifyAPI Repository Factory for creating repository instances. + +This factory is specifically designed for DifyAPI repositories that handle +service-layer operations with dependency injection patterns. +""" + +import logging + +from sqlalchemy.orm import sessionmaker + +from configs import dify_config +from core.repositories import DifyCoreRepositoryFactory, RepositoryImportError +from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository +from repositories.api_workflow_run_repository import APIWorkflowRunRepository + +logger = logging.getLogger(__name__) + + +class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory): + """ + Factory for creating DifyAPI repository instances based on configuration. + + This factory handles the creation of repositories that are specifically designed + for service-layer operations and use dependency injection with sessionmaker + for better testability and separation of concerns. + """ + + @classmethod + def create_api_workflow_node_execution_repository( + cls, session_maker: sessionmaker + ) -> DifyAPIWorkflowNodeExecutionRepository: + """ + Create a DifyAPIWorkflowNodeExecutionRepository instance based on configuration. + + This repository is designed for service-layer operations and uses dependency injection + with a sessionmaker for better testability and separation of concerns. It provides + database access patterns specifically needed by service classes, handling queries + that involve database-specific fields and multi-tenancy concerns. + + Args: + session_maker: SQLAlchemy sessionmaker to inject for database session management. + + Returns: + Configured DifyAPIWorkflowNodeExecutionRepository instance + + Raises: + RepositoryImportError: If the configured repository cannot be imported or instantiated + """ + class_path = dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY + logger.debug(f"Creating DifyAPIWorkflowNodeExecutionRepository from: {class_path}") + + try: + repository_class = cls._import_class(class_path) + cls._validate_repository_interface(repository_class, DifyAPIWorkflowNodeExecutionRepository) + # Service repository requires session_maker parameter + cls._validate_constructor_signature(repository_class, ["session_maker"]) + + return repository_class(session_maker=session_maker) # type: ignore[no-any-return] + except RepositoryImportError: + # Re-raise our custom errors as-is + raise + except Exception as e: + logger.exception("Failed to create DifyAPIWorkflowNodeExecutionRepository") + raise RepositoryImportError( + f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}" + ) from e + + @classmethod + def create_api_workflow_run_repository(cls, session_maker: sessionmaker) -> APIWorkflowRunRepository: + """ + Create an APIWorkflowRunRepository instance based on configuration. + + This repository is designed for service-layer WorkflowRun operations and uses dependency + injection with a sessionmaker for better testability and separation of concerns. It provides + database access patterns specifically needed by service classes for workflow run management, + including pagination, filtering, and bulk operations. + + Args: + session_maker: SQLAlchemy sessionmaker to inject for database session management. + + Returns: + Configured APIWorkflowRunRepository instance + + Raises: + RepositoryImportError: If the configured repository cannot be imported or instantiated + """ + class_path = dify_config.API_WORKFLOW_RUN_REPOSITORY + logger.debug(f"Creating APIWorkflowRunRepository from: {class_path}") + + try: + repository_class = cls._import_class(class_path) + cls._validate_repository_interface(repository_class, APIWorkflowRunRepository) + # Service repository requires session_maker parameter + cls._validate_constructor_signature(repository_class, ["session_maker"]) + + return repository_class(session_maker=session_maker) # type: ignore[no-any-return] + except RepositoryImportError: + # Re-raise our custom errors as-is + raise + except Exception as e: + logger.exception("Failed to create APIWorkflowRunRepository") + raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py new file mode 100644 index 0000000000..e6a23ddf9f --- /dev/null +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -0,0 +1,290 @@ +""" +SQLAlchemy implementation of WorkflowNodeExecutionServiceRepository. + +This module provides a concrete implementation of the service repository protocol +using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations. +""" + +from collections.abc import Sequence +from datetime import datetime +from typing import Optional + +from sqlalchemy import delete, desc, select +from sqlalchemy.orm import Session, sessionmaker + +from models.workflow import WorkflowNodeExecutionModel +from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository + + +class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository): + """ + SQLAlchemy implementation of DifyAPIWorkflowNodeExecutionRepository. + + This repository provides service-layer database operations for WorkflowNodeExecutionModel + using SQLAlchemy 2.0 style queries. It implements the DifyAPIWorkflowNodeExecutionRepository + protocol with the following features: + + - Multi-tenancy data isolation through tenant_id filtering + - Direct database model operations without domain conversion + - Batch processing for efficient large-scale operations + - Optimized query patterns for common access patterns + - Dependency injection for better testability and maintainability + - Session management and transaction handling with proper cleanup + - Maintenance operations for data lifecycle management + - Thread-safe database operations using session-per-request pattern + """ + + def __init__(self, session_maker: sessionmaker[Session]): + """ + Initialize the repository with a sessionmaker. + + Args: + session_maker: SQLAlchemy sessionmaker for creating database sessions + """ + self._session_maker = session_maker + + def get_node_last_execution( + self, + tenant_id: str, + app_id: str, + workflow_id: str, + node_id: str, + ) -> Optional[WorkflowNodeExecutionModel]: + """ + Get the most recent execution for a specific node. + + This method replicates the query pattern from WorkflowService.get_node_last_run() + using SQLAlchemy 2.0 style syntax. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + workflow_id: The workflow identifier + node_id: The node identifier + + Returns: + The most recent WorkflowNodeExecutionModel for the node, or None if not found + """ + stmt = ( + select(WorkflowNodeExecutionModel) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.app_id == app_id, + WorkflowNodeExecutionModel.workflow_id == workflow_id, + WorkflowNodeExecutionModel.node_id == node_id, + ) + .order_by(desc(WorkflowNodeExecutionModel.created_at)) + .limit(1) + ) + + with self._session_maker() as session: + return session.scalar(stmt) + + def get_executions_by_workflow_run( + self, + tenant_id: str, + app_id: str, + workflow_run_id: str, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Get all node executions for a specific workflow run. + + This method replicates the query pattern from WorkflowRunService.get_workflow_run_node_executions() + using SQLAlchemy 2.0 style syntax. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + workflow_run_id: The workflow run identifier + + Returns: + A sequence of WorkflowNodeExecutionModel instances ordered by index (desc) + """ + stmt = ( + select(WorkflowNodeExecutionModel) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.app_id == app_id, + WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, + ) + .order_by(desc(WorkflowNodeExecutionModel.index)) + ) + + with self._session_maker() as session: + return session.execute(stmt).scalars().all() + + def get_execution_by_id( + self, + execution_id: str, + tenant_id: Optional[str] = None, + ) -> Optional[WorkflowNodeExecutionModel]: + """ + Get a workflow node execution by its ID. + + This method replicates the query pattern from WorkflowDraftVariableService + and WorkflowService.single_step_run_workflow_node() using SQLAlchemy 2.0 style syntax. + + When `tenant_id` is None, it's the caller's responsibility to ensure proper data isolation between tenants. + If the `execution_id` comes from untrusted sources (e.g., retrieved from an API request), the caller should + set `tenant_id` to prevent horizontal privilege escalation. + + Args: + execution_id: The execution identifier + tenant_id: Optional tenant identifier for additional filtering + + Returns: + The WorkflowNodeExecutionModel if found, or None if not found + """ + stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution_id) + + # Add tenant filtering if provided + if tenant_id is not None: + stmt = stmt.where(WorkflowNodeExecutionModel.tenant_id == tenant_id) + + with self._session_maker() as session: + return session.scalar(stmt) + + def delete_expired_executions( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> int: + """ + Delete workflow node executions that are older than the specified date. + + Args: + tenant_id: The tenant identifier + before_date: Delete executions created before this date + batch_size: Maximum number of executions to delete in one batch + + Returns: + The number of executions deleted + """ + total_deleted = 0 + + while True: + with self._session_maker() as session: + # Find executions to delete in batches + stmt = ( + select(WorkflowNodeExecutionModel.id) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.created_at < before_date, + ) + .limit(batch_size) + ) + + execution_ids = session.execute(stmt).scalars().all() + if not execution_ids: + break + + # Delete the batch + delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids)) + result = session.execute(delete_stmt) + session.commit() + total_deleted += result.rowcount + + # If we deleted fewer than the batch size, we're done + if len(execution_ids) < batch_size: + break + + return total_deleted + + def delete_executions_by_app( + self, + tenant_id: str, + app_id: str, + batch_size: int = 1000, + ) -> int: + """ + Delete all workflow node executions for a specific app. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + batch_size: Maximum number of executions to delete in one batch + + Returns: + The total number of executions deleted + """ + total_deleted = 0 + + while True: + with self._session_maker() as session: + # Find executions to delete in batches + stmt = ( + select(WorkflowNodeExecutionModel.id) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.app_id == app_id, + ) + .limit(batch_size) + ) + + execution_ids = session.execute(stmt).scalars().all() + if not execution_ids: + break + + # Delete the batch + delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids)) + result = session.execute(delete_stmt) + session.commit() + total_deleted += result.rowcount + + # If we deleted fewer than the batch size, we're done + if len(execution_ids) < batch_size: + break + + return total_deleted + + def get_expired_executions_batch( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Get a batch of expired workflow node executions for backup purposes. + + Args: + tenant_id: The tenant identifier + before_date: Get executions created before this date + batch_size: Maximum number of executions to retrieve + + Returns: + A sequence of WorkflowNodeExecutionModel instances + """ + stmt = ( + select(WorkflowNodeExecutionModel) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.created_at < before_date, + ) + .limit(batch_size) + ) + + with self._session_maker() as session: + return session.execute(stmt).scalars().all() + + def delete_executions_by_ids( + self, + execution_ids: Sequence[str], + ) -> int: + """ + Delete workflow node executions by their IDs. + + Args: + execution_ids: List of execution IDs to delete + + Returns: + The number of executions deleted + """ + if not execution_ids: + return 0 + + with self._session_maker() as session: + stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids)) + result = session.execute(stmt) + session.commit() + return result.rowcount diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py new file mode 100644 index 0000000000..ebd1d74b20 --- /dev/null +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -0,0 +1,203 @@ +""" +SQLAlchemy API WorkflowRun Repository Implementation + +This module provides the SQLAlchemy-based implementation of the APIWorkflowRunRepository +protocol. It handles service-layer WorkflowRun database operations using SQLAlchemy 2.0 +style queries with proper session management and multi-tenant data isolation. + +Key Features: +- SQLAlchemy 2.0 style queries for modern database operations +- Cursor-based pagination for efficient large dataset handling +- Bulk operations with batch processing for performance +- Multi-tenant data isolation and security +- Proper session management with dependency injection + +Implementation Notes: +- Uses sessionmaker for consistent session management +- Implements cursor-based pagination using created_at timestamps +- Provides efficient bulk deletion with batch processing +- Maintains data consistency with proper transaction handling +""" + +import logging +from collections.abc import Sequence +from datetime import datetime +from typing import Optional, cast + +from sqlalchemy import delete, select +from sqlalchemy.orm import Session, sessionmaker + +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models.workflow import WorkflowRun +from repositories.api_workflow_run_repository import APIWorkflowRunRepository + +logger = logging.getLogger(__name__) + + +class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): + """ + SQLAlchemy implementation of APIWorkflowRunRepository. + + Provides service-layer WorkflowRun database operations using SQLAlchemy 2.0 + style queries. Supports dependency injection through sessionmaker and + maintains proper multi-tenant data isolation. + + Args: + session_maker: SQLAlchemy sessionmaker instance for database connections + """ + + def __init__(self, session_maker: sessionmaker[Session]) -> None: + """ + Initialize the repository with a sessionmaker. + + Args: + session_maker: SQLAlchemy sessionmaker for database connections + """ + self._session_maker = session_maker + + def get_paginated_workflow_runs( + self, + tenant_id: str, + app_id: str, + triggered_from: str, + limit: int = 20, + last_id: Optional[str] = None, + ) -> InfiniteScrollPagination: + """ + Get paginated workflow runs with filtering. + + Implements cursor-based pagination using created_at timestamps for + efficient handling of large datasets. Filters by tenant, app, and + trigger source for proper data isolation. + """ + with self._session_maker() as session: + # Build base query with filters + base_stmt = select(WorkflowRun).where( + WorkflowRun.tenant_id == tenant_id, + WorkflowRun.app_id == app_id, + WorkflowRun.triggered_from == triggered_from, + ) + + if last_id: + # Get the last workflow run for cursor-based pagination + last_run_stmt = base_stmt.where(WorkflowRun.id == last_id) + last_workflow_run = session.scalar(last_run_stmt) + + if not last_workflow_run: + raise ValueError("Last workflow run not exists") + + # Get records created before the last run's timestamp + base_stmt = base_stmt.where( + WorkflowRun.created_at < last_workflow_run.created_at, + WorkflowRun.id != last_workflow_run.id, + ) + + # First page - get most recent records + workflow_runs = session.scalars(base_stmt.order_by(WorkflowRun.created_at.desc()).limit(limit + 1)).all() + + # Check if there are more records for pagination + has_more = len(workflow_runs) > limit + if has_more: + workflow_runs = workflow_runs[:-1] + + return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) + + def get_workflow_run_by_id( + self, + tenant_id: str, + app_id: str, + run_id: str, + ) -> Optional[WorkflowRun]: + """ + Get a specific workflow run by ID with tenant and app isolation. + """ + with self._session_maker() as session: + stmt = select(WorkflowRun).where( + WorkflowRun.tenant_id == tenant_id, + WorkflowRun.app_id == app_id, + WorkflowRun.id == run_id, + ) + return cast(Optional[WorkflowRun], session.scalar(stmt)) + + def get_expired_runs_batch( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> Sequence[WorkflowRun]: + """ + Get a batch of expired workflow runs for cleanup operations. + """ + with self._session_maker() as session: + stmt = ( + select(WorkflowRun) + .where( + WorkflowRun.tenant_id == tenant_id, + WorkflowRun.created_at < before_date, + ) + .limit(batch_size) + ) + return cast(Sequence[WorkflowRun], session.scalars(stmt).all()) + + def delete_runs_by_ids( + self, + run_ids: Sequence[str], + ) -> int: + """ + Delete workflow runs by their IDs using bulk deletion. + """ + if not run_ids: + return 0 + + with self._session_maker() as session: + stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids)) + result = session.execute(stmt) + session.commit() + + deleted_count = cast(int, result.rowcount) + logger.info(f"Deleted {deleted_count} workflow runs by IDs") + return deleted_count + + def delete_runs_by_app( + self, + tenant_id: str, + app_id: str, + batch_size: int = 1000, + ) -> int: + """ + Delete all workflow runs for a specific app in batches. + """ + total_deleted = 0 + + while True: + with self._session_maker() as session: + # Get a batch of run IDs to delete + stmt = ( + select(WorkflowRun.id) + .where( + WorkflowRun.tenant_id == tenant_id, + WorkflowRun.app_id == app_id, + ) + .limit(batch_size) + ) + run_ids = session.scalars(stmt).all() + + if not run_ids: + break + + # Delete the batch + delete_stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids)) + result = session.execute(delete_stmt) + session.commit() + + batch_deleted = result.rowcount + total_deleted += batch_deleted + + logger.info(f"Deleted batch of {batch_deleted} workflow runs for app {app_id}") + + # If we deleted fewer records than the batch size, we're done + if batch_deleted < batch_size: + break + + logger.info(f"Total deleted {total_deleted} workflow runs for app {app_id}") + return total_deleted diff --git a/api/services/account_service.py b/api/services/account_service.py index 2ba6f4345b..4d5366f47f 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -52,8 +52,14 @@ from services.errors.workspace import WorkSpaceNotAllowedCreateError, Workspaces from services.feature_service import FeatureService from tasks.delete_account_task import delete_account_task from tasks.mail_account_deletion_task import send_account_deletion_verification_code +from tasks.mail_change_mail_task import send_change_mail_task from tasks.mail_email_code_login import send_email_code_login_mail_task from tasks.mail_invite_member_task import send_invite_member_mail_task +from tasks.mail_owner_transfer_task import ( + send_new_owner_transfer_notify_email_task, + send_old_owner_transfer_notify_email_task, + send_owner_transfer_confirm_task, +) from tasks.mail_reset_password_task import send_reset_password_mail_task @@ -75,8 +81,13 @@ class AccountService: email_code_account_deletion_rate_limiter = RateLimiter( prefix="email_code_account_deletion_rate_limit", max_attempts=1, time_window=60 * 1 ) + change_email_rate_limiter = RateLimiter(prefix="change_email_rate_limit", max_attempts=1, time_window=60 * 1) + owner_transfer_rate_limiter = RateLimiter(prefix="owner_transfer_rate_limit", max_attempts=1, time_window=60 * 1) + LOGIN_MAX_ERROR_LIMITS = 5 FORGOT_PASSWORD_MAX_ERROR_LIMITS = 5 + CHANGE_EMAIL_MAX_ERROR_LIMITS = 5 + OWNER_TRANSFER_MAX_ERROR_LIMITS = 5 @staticmethod def _get_refresh_token_key(refresh_token: str) -> str: @@ -419,6 +430,101 @@ class AccountService: cls.reset_password_rate_limiter.increment_rate_limit(account_email) return token + @classmethod + def send_change_email_email( + cls, + account: Optional[Account] = None, + email: Optional[str] = None, + old_email: Optional[str] = None, + language: Optional[str] = "en-US", + phase: Optional[str] = None, + ): + account_email = account.email if account else email + if account_email is None: + raise ValueError("Email must be provided.") + + if cls.change_email_rate_limiter.is_rate_limited(account_email): + from controllers.console.auth.error import EmailChangeRateLimitExceededError + + raise EmailChangeRateLimitExceededError() + + code, token = cls.generate_change_email_token(account_email, account, old_email=old_email) + + send_change_mail_task.delay( + language=language, + to=account_email, + code=code, + phase=phase, + ) + cls.change_email_rate_limiter.increment_rate_limit(account_email) + return token + + @classmethod + def send_owner_transfer_email( + cls, + account: Optional[Account] = None, + email: Optional[str] = None, + language: Optional[str] = "en-US", + workspace_name: Optional[str] = "", + ): + account_email = account.email if account else email + if account_email is None: + raise ValueError("Email must be provided.") + + if cls.owner_transfer_rate_limiter.is_rate_limited(account_email): + from controllers.console.auth.error import OwnerTransferRateLimitExceededError + + raise OwnerTransferRateLimitExceededError() + + code, token = cls.generate_owner_transfer_token(account_email, account) + + send_owner_transfer_confirm_task.delay( + language=language, + to=account_email, + code=code, + workspace=workspace_name, + ) + cls.owner_transfer_rate_limiter.increment_rate_limit(account_email) + return token + + @classmethod + def send_old_owner_transfer_notify_email( + cls, + account: Optional[Account] = None, + email: Optional[str] = None, + language: Optional[str] = "en-US", + workspace_name: Optional[str] = "", + new_owner_email: Optional[str] = "", + ): + account_email = account.email if account else email + if account_email is None: + raise ValueError("Email must be provided.") + + send_old_owner_transfer_notify_email_task.delay( + language=language, + to=account_email, + workspace=workspace_name, + new_owner_email=new_owner_email, + ) + + @classmethod + def send_new_owner_transfer_notify_email( + cls, + account: Optional[Account] = None, + email: Optional[str] = None, + language: Optional[str] = "en-US", + workspace_name: Optional[str] = "", + ): + account_email = account.email if account else email + if account_email is None: + raise ValueError("Email must be provided.") + + send_new_owner_transfer_notify_email_task.delay( + language=language, + to=account_email, + workspace=workspace_name, + ) + @classmethod def generate_reset_password_token( cls, @@ -435,14 +541,64 @@ class AccountService: ) return code, token + @classmethod + def generate_change_email_token( + cls, + email: str, + account: Optional[Account] = None, + code: Optional[str] = None, + old_email: Optional[str] = None, + additional_data: dict[str, Any] = {}, + ): + if not code: + code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)]) + additional_data["code"] = code + additional_data["old_email"] = old_email + token = TokenManager.generate_token( + account=account, email=email, token_type="change_email", additional_data=additional_data + ) + return code, token + + @classmethod + def generate_owner_transfer_token( + cls, + email: str, + account: Optional[Account] = None, + code: Optional[str] = None, + additional_data: dict[str, Any] = {}, + ): + if not code: + code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)]) + additional_data["code"] = code + token = TokenManager.generate_token( + account=account, email=email, token_type="owner_transfer", additional_data=additional_data + ) + return code, token + @classmethod def revoke_reset_password_token(cls, token: str): TokenManager.revoke_token(token, "reset_password") + @classmethod + def revoke_change_email_token(cls, token: str): + TokenManager.revoke_token(token, "change_email") + + @classmethod + def revoke_owner_transfer_token(cls, token: str): + TokenManager.revoke_token(token, "owner_transfer") + @classmethod def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]: return TokenManager.get_token_data(token, "reset_password") + @classmethod + def get_change_email_data(cls, token: str) -> Optional[dict[str, Any]]: + return TokenManager.get_token_data(token, "change_email") + + @classmethod + def get_owner_transfer_data(cls, token: str) -> Optional[dict[str, Any]]: + return TokenManager.get_token_data(token, "owner_transfer") + @classmethod def send_email_code_login_email( cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US" @@ -552,6 +708,62 @@ class AccountService: key = f"forgot_password_error_rate_limit:{email}" redis_client.delete(key) + @staticmethod + @redis_fallback(default_return=None) + def add_change_email_error_rate_limit(email: str) -> None: + key = f"change_email_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + count = 0 + count = int(count) + 1 + redis_client.setex(key, dify_config.CHANGE_EMAIL_LOCKOUT_DURATION, count) + + @staticmethod + @redis_fallback(default_return=False) + def is_change_email_error_rate_limit(email: str) -> bool: + key = f"change_email_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + return False + count = int(count) + if count > AccountService.CHANGE_EMAIL_MAX_ERROR_LIMITS: + return True + return False + + @staticmethod + @redis_fallback(default_return=None) + def reset_change_email_error_rate_limit(email: str): + key = f"change_email_error_rate_limit:{email}" + redis_client.delete(key) + + @staticmethod + @redis_fallback(default_return=None) + def add_owner_transfer_error_rate_limit(email: str) -> None: + key = f"owner_transfer_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + count = 0 + count = int(count) + 1 + redis_client.setex(key, dify_config.OWNER_TRANSFER_LOCKOUT_DURATION, count) + + @staticmethod + @redis_fallback(default_return=False) + def is_owner_transfer_error_rate_limit(email: str) -> bool: + key = f"owner_transfer_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + return False + count = int(count) + if count > AccountService.OWNER_TRANSFER_MAX_ERROR_LIMITS: + return True + return False + + @staticmethod + @redis_fallback(default_return=None) + def reset_owner_transfer_error_rate_limit(email: str): + key = f"owner_transfer_error_rate_limit:{email}" + redis_client.delete(key) + @staticmethod @redis_fallback(default_return=False) def is_email_send_ip_limit(ip_address: str): @@ -593,6 +805,10 @@ class AccountService: return False + @staticmethod + def check_email_unique(email: str) -> bool: + return db.session.query(Account).filter_by(email=email).first() is None + class TenantService: @staticmethod @@ -865,6 +1081,15 @@ class TenantService: return cast(dict, tenant.custom_config_dict) + @staticmethod + def is_owner(account: Account, tenant: Tenant) -> bool: + return TenantService.get_user_role(account, tenant) == TenantAccountRole.OWNER + + @staticmethod + def is_member(account: Account, tenant: Tenant) -> bool: + """Check if the account is a member of the tenant""" + return TenantService.get_user_role(account, tenant) is not None + class RegisterService: @classmethod diff --git a/api/services/app_service.py b/api/services/app_service.py index d08462d001..0a08f345df 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -47,8 +47,6 @@ class AppService: filters.append(App.mode == AppMode.ADVANCED_CHAT.value) elif args["mode"] == "agent-chat": filters.append(App.mode == AppMode.AGENT_CHAT.value) - elif args["mode"] == "channel": - filters.append(App.mode == AppMode.CHANNEL.value) if args.get("is_created_by_me", False): filters.append(App.created_by == user_id) @@ -235,6 +233,7 @@ class AppService: app.icon = args.get("icon") app.icon_background = args.get("icon_background") app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False) + app.max_active_requests = args.get("max_active_requests") app.updated_by = current_user.id app.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index 1fd560d581..ddd16b2e0c 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -6,7 +6,7 @@ from concurrent.futures import ThreadPoolExecutor import click from flask import Flask, current_app -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.model_runtime.utils.encoders import jsonable_encoder @@ -14,7 +14,7 @@ from extensions.ext_database import db from extensions.ext_storage import storage from models.account import Tenant from models.model import App, Conversation, Message -from models.workflow import WorkflowNodeExecutionModel, WorkflowRun +from repositories.factory import DifyAPIRepositoryFactory from services.billing_service import BillingService logger = logging.getLogger(__name__) @@ -105,84 +105,99 @@ class ClearFreePlanTenantExpiredLogs: ) ) - while True: - with Session(db.engine).no_autoflush as session: - workflow_node_executions = ( - session.query(WorkflowNodeExecutionModel) - .filter( - WorkflowNodeExecutionModel.tenant_id == tenant_id, - WorkflowNodeExecutionModel.created_at - < datetime.datetime.now() - datetime.timedelta(days=days), - ) - .limit(batch) - .all() - ) - - if len(workflow_node_executions) == 0: - break - - # save workflow node executions - storage.save( - f"free_plan_tenant_expired_logs/" - f"{tenant_id}/workflow_node_executions/{datetime.datetime.now().strftime('%Y-%m-%d')}" - f"-{time.time()}.json", - json.dumps( - jsonable_encoder(workflow_node_executions), - ).encode("utf-8"), - ) - - workflow_node_execution_ids = [ - workflow_node_execution.id for workflow_node_execution in workflow_node_executions - ] - - # delete workflow node executions - session.query(WorkflowNodeExecutionModel).filter( - WorkflowNodeExecutionModel.id.in_(workflow_node_execution_ids), - ).delete(synchronize_session=False) - session.commit() - - click.echo( - click.style( - f"[{datetime.datetime.now()}] Processed {len(workflow_node_execution_ids)}" - f" workflow node executions for tenant {tenant_id}" - ) - ) + # Process expired workflow node executions with backup + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker) + before_date = datetime.datetime.now() - datetime.timedelta(days=days) + total_deleted = 0 while True: - with Session(db.engine).no_autoflush as session: - workflow_runs = ( - session.query(WorkflowRun) - .filter( - WorkflowRun.tenant_id == tenant_id, - WorkflowRun.created_at < datetime.datetime.now() - datetime.timedelta(days=days), - ) - .limit(batch) - .all() + # Get a batch of expired executions for backup + workflow_node_executions = node_execution_repo.get_expired_executions_batch( + tenant_id=tenant_id, + before_date=before_date, + batch_size=batch, + ) + + if len(workflow_node_executions) == 0: + break + + # Save workflow node executions to storage + storage.save( + f"free_plan_tenant_expired_logs/" + f"{tenant_id}/workflow_node_executions/{datetime.datetime.now().strftime('%Y-%m-%d')}" + f"-{time.time()}.json", + json.dumps( + jsonable_encoder(workflow_node_executions), + ).encode("utf-8"), + ) + + # Extract IDs for deletion + workflow_node_execution_ids = [ + workflow_node_execution.id for workflow_node_execution in workflow_node_executions + ] + + # Delete the backed up executions + deleted_count = node_execution_repo.delete_executions_by_ids(workflow_node_execution_ids) + total_deleted += deleted_count + + click.echo( + click.style( + f"[{datetime.datetime.now()}] Processed {len(workflow_node_execution_ids)}" + f" workflow node executions for tenant {tenant_id}" ) + ) - if len(workflow_runs) == 0: - break + # If we got fewer than the batch size, we're done + if len(workflow_node_executions) < batch: + break - # save workflow runs + # Process expired workflow runs with backup + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + before_date = datetime.datetime.now() - datetime.timedelta(days=days) + total_deleted = 0 - storage.save( - f"free_plan_tenant_expired_logs/" - f"{tenant_id}/workflow_runs/{datetime.datetime.now().strftime('%Y-%m-%d')}" - f"-{time.time()}.json", - json.dumps( - jsonable_encoder( - [workflow_run.to_dict() for workflow_run in workflow_runs], - ), - ).encode("utf-8"), + while True: + # Get a batch of expired workflow runs for backup + workflow_runs = workflow_run_repo.get_expired_runs_batch( + tenant_id=tenant_id, + before_date=before_date, + batch_size=batch, + ) + + if len(workflow_runs) == 0: + break + + # Save workflow runs to storage + storage.save( + f"free_plan_tenant_expired_logs/" + f"{tenant_id}/workflow_runs/{datetime.datetime.now().strftime('%Y-%m-%d')}" + f"-{time.time()}.json", + json.dumps( + jsonable_encoder( + [workflow_run.to_dict() for workflow_run in workflow_runs], + ), + ).encode("utf-8"), + ) + + # Extract IDs for deletion + workflow_run_ids = [workflow_run.id for workflow_run in workflow_runs] + + # Delete the backed up workflow runs + deleted_count = workflow_run_repo.delete_runs_by_ids(workflow_run_ids) + total_deleted += deleted_count + + click.echo( + click.style( + f"[{datetime.datetime.now()}] Processed {len(workflow_run_ids)}" + f" workflow runs for tenant {tenant_id}" ) + ) - workflow_run_ids = [workflow_run.id for workflow_run in workflow_runs] - - # delete workflow runs - session.query(WorkflowRun).filter( - WorkflowRun.id.in_(workflow_run_ids), - ).delete(synchronize_session=False) - session.commit() + # If we got fewer than the batch size, we're done + if len(workflow_runs) < batch: + break @classmethod def process(cls, days: int, batch: int, tenant_ids: list[str]): diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 88d4224e97..344c67885e 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -4,13 +4,6 @@ from typing import Literal, Optional from pydantic import BaseModel -class SegmentUpdateEntity(BaseModel): - content: str - answer: Optional[str] = None - keywords: Optional[list[str]] = None - enabled: Optional[bool] = None - - class ParentMode(StrEnum): FULL_DOC = "full-doc" PARAGRAPH = "paragraph" @@ -153,10 +146,6 @@ class MetadataUpdateArgs(BaseModel): value: Optional[str | int | float] = None -class MetadataValueUpdateArgs(BaseModel): - fields: list[MetadataUpdateArgs] - - class MetadataDetail(BaseModel): id: str name: str diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 188caf3505..1441e6ce16 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -123,7 +123,7 @@ class FeatureModel(BaseModel): dataset_operator_enabled: bool = False webapp_copyright_enabled: bool = False workspace_members: LicenseLimitationModel = LicenseLimitationModel(enabled=False, size=0, limit=0) - + is_allow_transfer_workspace: bool = True # pydantic configs model_config = ConfigDict(protected_namespaces=()) @@ -149,6 +149,7 @@ class SystemFeatureModel(BaseModel): branding: BrandingModel = BrandingModel() webapp_auth: WebAppAuthModel = WebAppAuthModel() plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel() + enable_change_email: bool = True class FeatureService: @@ -186,6 +187,7 @@ class FeatureService: if dify_config.ENTERPRISE_ENABLED: system_features.branding.enabled = True system_features.webapp_auth.enabled = True + system_features.enable_change_email = False cls._fulfill_params_from_enterprise(system_features) if dify_config.MARKETPLACE_ENABLED: @@ -228,6 +230,8 @@ class FeatureService: if features.billing.subscription.plan != "sandbox": features.webapp_copyright_enabled = True + else: + features.is_allow_transfer_workspace = False if "members" in billing_info: features.members.size = billing_info["members"]["size"] diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 44fd72b5e4..f306e1f062 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -5,9 +5,9 @@ from collections.abc import Mapping, Sequence from enum import StrEnum from typing import Any, ClassVar -from sqlalchemy import Engine, orm, select +from sqlalchemy import Engine, orm from sqlalchemy.dialects.postgresql import insert -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.sql.expression import and_, or_ from core.app.entities.app_invoke_entities import InvokeFrom @@ -25,7 +25,8 @@ from factories.file_factory import StorageKeyLoader from factories.variable_factory import build_segment, segment_to_variable from models import App, Conversation from models.enums import DraftVariableType -from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable +from models.workflow import Workflow, WorkflowDraftVariable, is_system_variable_editable +from repositories.factory import DifyAPIRepositoryFactory _logger = logging.getLogger(__name__) @@ -117,7 +118,24 @@ class WorkflowDraftVariableService: _session: Session def __init__(self, session: Session) -> None: + """ + Initialize the WorkflowDraftVariableService with a SQLAlchemy session. + + Args: + session (Session): The SQLAlchemy session used to execute database queries. + The provided session must be bound to an `Engine` object, not a specific `Connection`. + + Raises: + AssertionError: If the provided session is not bound to an `Engine` object. + """ self._session = session + engine = session.get_bind() + # Ensure the session is bound to a engine. + assert isinstance(engine, Engine) + session_maker = sessionmaker(bind=engine, expire_on_commit=False) + self._api_node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + session_maker + ) def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None: return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first() @@ -248,8 +266,7 @@ class WorkflowDraftVariableService: _logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name) return None - query = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == variable.node_execution_id) - node_exec = self._session.scalars(query).first() + node_exec = self._api_node_execution_repo.get_execution_by_id(variable.node_execution_id) if node_exec is None: _logger.warning( "Node exectution not found for draft variable, id=%s, name=%s, node_execution_id=%s", @@ -298,6 +315,8 @@ class WorkflowDraftVariableService: def reset_variable(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None: variable_type = variable.get_variable_type() + if variable_type == DraftVariableType.SYS and not is_system_variable_editable(variable.name): + raise VariableResetError(f"cannot reset system variable, variable_id={variable.id}") if variable_type == DraftVariableType.CONVERSATION: return self._reset_conv_var(workflow, variable) else: diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 483c0d3086..e43999a8c9 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -2,9 +2,9 @@ import threading from collections.abc import Sequence from typing import Optional +from sqlalchemy.orm import sessionmaker + import contexts -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import OrderConfig from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import ( @@ -15,10 +15,18 @@ from models import ( WorkflowRun, WorkflowRunTriggeredFrom, ) -from models.workflow import WorkflowNodeExecutionTriggeredFrom +from repositories.factory import DifyAPIRepositoryFactory class WorkflowRunService: + def __init__(self): + """Initialize WorkflowRunService with repository dependencies.""" + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + session_maker + ) + self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + def get_paginate_advanced_chat_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination: """ Get advanced chat app workflow run list @@ -62,45 +70,16 @@ class WorkflowRunService: :param args: request args """ limit = int(args.get("limit", 20)) + last_id = args.get("last_id") - base_query = db.session.query(WorkflowRun).filter( - WorkflowRun.tenant_id == app_model.tenant_id, - WorkflowRun.app_id == app_model.id, - WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value, + return self._workflow_run_repo.get_paginated_workflow_runs( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING.value, + limit=limit, + last_id=last_id, ) - if args.get("last_id"): - last_workflow_run = base_query.filter( - WorkflowRun.id == args.get("last_id"), - ).first() - - if not last_workflow_run: - raise ValueError("Last workflow run not exists") - - workflow_runs = ( - base_query.filter( - WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id - ) - .order_by(WorkflowRun.created_at.desc()) - .limit(limit) - .all() - ) - else: - workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all() - - has_more = False - if len(workflow_runs) == limit: - current_page_first_workflow_run = workflow_runs[-1] - rest_count = base_query.filter( - WorkflowRun.created_at < current_page_first_workflow_run.created_at, - WorkflowRun.id != current_page_first_workflow_run.id, - ).count() - - if rest_count > 0: - has_more = True - - return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) - def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]: """ Get workflow run detail @@ -108,18 +87,12 @@ class WorkflowRunService: :param app_model: app model :param run_id: workflow run id """ - workflow_run = ( - db.session.query(WorkflowRun) - .filter( - WorkflowRun.tenant_id == app_model.tenant_id, - WorkflowRun.app_id == app_model.id, - WorkflowRun.id == run_id, - ) - .first() + return self._workflow_run_repo.get_workflow_run_by_id( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + run_id=run_id, ) - return workflow_run - def get_workflow_run_node_executions( self, app_model: App, @@ -137,17 +110,13 @@ class WorkflowRunService: if not workflow_run: return [] - repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=db.engine, - user=user, + # Get tenant_id from user + tenant_id = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id + if tenant_id is None: + raise ValueError("User tenant_id cannot be None") + + return self._node_execution_service_repo.get_executions_by_workflow_run( + tenant_id=tenant_id, app_id=app_model.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + workflow_run_id=run_id, ) - - # Use the repository to get the database models directly - order_config = OrderConfig(order_by=["index"], order_direction="desc") - workflow_node_executions = repository.get_db_models_by_workflow_run( - workflow_run_id=run_id, order_config=order_config - ) - - return workflow_node_executions diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 712a38d6f6..c16aba7f85 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -3,22 +3,22 @@ import time import uuid from collections.abc import Callable, Generator, Mapping, Sequence from datetime import UTC, datetime -from typing import Any, Optional +from typing import Any, Optional, cast from uuid import uuid4 from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from core.app.app_config.entities import VariableEntityType from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.file import File -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.variables import Variable +from core.variables.variables import VariableUnion from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes import NodeType @@ -28,6 +28,7 @@ from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.event.types import NodeEvent from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db @@ -41,6 +42,7 @@ from models.workflow import ( WorkflowNodeExecutionTriggeredFrom, WorkflowType, ) +from repositories.factory import DifyAPIRepositoryFactory from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError from services.workflow.workflow_converter import WorkflowConverter @@ -57,21 +59,32 @@ class WorkflowService: Workflow Service """ + def __init__(self, session_maker: sessionmaker | None = None): + """Initialize WorkflowService with repository dependencies.""" + if session_maker is None: + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + session_maker + ) + def get_node_last_run(self, app_model: App, workflow: Workflow, node_id: str) -> WorkflowNodeExecutionModel | None: - # TODO(QuantumGhost): This query is not fully covered by index. - criteria = ( - WorkflowNodeExecutionModel.tenant_id == app_model.tenant_id, - WorkflowNodeExecutionModel.app_id == app_model.id, - WorkflowNodeExecutionModel.workflow_id == workflow.id, - WorkflowNodeExecutionModel.node_id == node_id, + """ + Get the most recent execution for a specific node. + + Args: + app_model: The application model + workflow: The workflow model + node_id: The node identifier + + Returns: + The most recent WorkflowNodeExecutionModel for the node, or None if not found + """ + return self._node_execution_service_repo.get_node_last_execution( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + workflow_id=workflow.id, + node_id=node_id, ) - node_exec = ( - db.session.query(WorkflowNodeExecutionModel) - .filter(*criteria) - .order_by(WorkflowNodeExecutionModel.created_at.desc()) - .first() - ) - return node_exec def is_workflow_exist(self, app_model: App) -> bool: return ( @@ -358,7 +371,7 @@ class WorkflowService: else: variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs=user_inputs, environment_variables=draft_workflow.environment_variables, conversation_variables=[], @@ -397,7 +410,7 @@ class WorkflowService: node_execution.workflow_id = draft_workflow.id # Create repository and save the node execution - repository = SQLAlchemyWorkflowNodeExecutionRepository( + repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=db.engine, user=account, app_id=app_model.id, @@ -405,8 +418,9 @@ class WorkflowService: ) repository.save(node_execution) - # Convert node_execution to WorkflowNodeExecution after save - workflow_node_execution = repository.to_db_model(node_execution) + workflow_node_execution = self._node_execution_service_repo.get_execution_by_id(node_execution.id) + if workflow_node_execution is None: + raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving") with Session(bind=db.engine) as session, session.begin(): draft_var_saver = DraftVariableSaver( @@ -419,6 +433,7 @@ class WorkflowService: ) draft_var_saver.save(process_data=node_execution.process_data, outputs=node_execution.outputs) session.commit() + return workflow_node_execution def run_free_workflow_node( @@ -430,7 +445,7 @@ class WorkflowService: # run draft workflow node start_at = time.perf_counter() - workflow_node_execution = self._handle_node_run_result( + node_execution = self._handle_node_run_result( invoke_node_fn=lambda: WorkflowEntry.run_free_node( node_id=node_id, node_data=node_data, @@ -442,7 +457,7 @@ class WorkflowService: node_id=node_id, ) - return workflow_node_execution + return node_execution def _handle_node_run_result( self, @@ -672,36 +687,30 @@ def _setup_variable_pool( ): # Only inject system variables for START node type. if node_type == NodeType.START: - # Create a variable pool. - system_inputs: dict[SystemVariableKey, Any] = { - # From inputs: - SystemVariableKey.FILES: files, - SystemVariableKey.USER_ID: user_id, - # From workflow model - SystemVariableKey.APP_ID: workflow.app_id, - SystemVariableKey.WORKFLOW_ID: workflow.id, - # Randomly generated. - SystemVariableKey.WORKFLOW_EXECUTION_ID: str(uuid.uuid4()), - } + system_variable = SystemVariable( + user_id=user_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + files=files or [], + workflow_execution_id=str(uuid.uuid4()), + ) # Only add chatflow-specific variables for non-workflow types if workflow.type != WorkflowType.WORKFLOW.value: - system_inputs.update( - { - SystemVariableKey.QUERY: query, - SystemVariableKey.CONVERSATION_ID: conversation_id, - SystemVariableKey.DIALOGUE_COUNT: 0, - } - ) + system_variable.query = query + system_variable.conversation_id = conversation_id + system_variable.dialogue_count = 0 else: - system_inputs = {} + system_variable = SystemVariable.empty() # init variable pool variable_pool = VariablePool( - system_variables=system_inputs, + system_variables=system_variable, user_inputs=user_inputs, environment_variables=workflow.environment_variables, - conversation_variables=conversation_variables, + # Based on the definition of `VariableUnion`, + # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. + conversation_variables=cast(list[VariableUnion], conversation_variables), # ) return variable_pool diff --git a/api/tasks/mail_change_mail_task.py b/api/tasks/mail_change_mail_task.py new file mode 100644 index 0000000000..da44040b7d --- /dev/null +++ b/api/tasks/mail_change_mail_task.py @@ -0,0 +1,78 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore +from flask import render_template + +from extensions.ext_mail import mail +from services.feature_service import FeatureService + + +@shared_task(queue="mail") +def send_change_mail_task(language: str, to: str, code: str, phase: str): + """ + Async Send change email mail + :param language: Language in which the email should be sent (e.g., 'en', 'zh') + :param to: Recipient email address + :param code: Change email code + :param phase: Change email phase (new_email, old_email) + """ + if not mail.is_inited(): + return + + logging.info(click.style("Start change email mail to {}".format(to), fg="green")) + start_at = time.perf_counter() + + email_config = { + "zh-Hans": { + "old_email": { + "subject": "检测您现在的邮箱", + "template_with_brand": "change_mail_confirm_old_template_zh-CN.html", + "template_without_brand": "without-brand/change_mail_confirm_old_template_zh-CN.html", + }, + "new_email": { + "subject": "确认您的邮箱地址变更", + "template_with_brand": "change_mail_confirm_new_template_zh-CN.html", + "template_without_brand": "without-brand/change_mail_confirm_new_template_zh-CN.html", + }, + }, + "en": { + "old_email": { + "subject": "Check your current email", + "template_with_brand": "change_mail_confirm_old_template_en-US.html", + "template_without_brand": "without-brand/change_mail_confirm_old_template_en-US.html", + }, + "new_email": { + "subject": "Confirm your new email address", + "template_with_brand": "change_mail_confirm_new_template_en-US.html", + "template_without_brand": "without-brand/change_mail_confirm_new_template_en-US.html", + }, + }, + } + + # send change email mail using different languages + try: + system_features = FeatureService.get_system_features() + lang_key = "zh-Hans" if language == "zh-Hans" else "en" + + if phase not in ["old_email", "new_email"]: + raise ValueError("Invalid phase") + + config = email_config[lang_key][phase] + subject = config["subject"] + + if system_features.branding.enabled: + template = config["template_without_brand"] + else: + template = config["template_with_brand"] + + html_content = render_template(template, to=to, code=code) + mail.send(to=to, subject=subject, html=html_content) + + end_at = time.perf_counter() + logging.info( + click.style("Send change email mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green") + ) + except Exception: + logging.exception("Send change email mail to {} failed".format(to)) diff --git a/api/tasks/mail_owner_transfer_task.py b/api/tasks/mail_owner_transfer_task.py new file mode 100644 index 0000000000..8d05c6dc0f --- /dev/null +++ b/api/tasks/mail_owner_transfer_task.py @@ -0,0 +1,152 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore +from flask import render_template + +from extensions.ext_mail import mail +from services.feature_service import FeatureService + + +@shared_task(queue="mail") +def send_owner_transfer_confirm_task(language: str, to: str, code: str, workspace: str): + """ + Async Send owner transfer confirm mail + :param language: Language in which the email should be sent (e.g., 'en', 'zh') + :param to: Recipient email address + :param workspace: Workspace name + """ + if not mail.is_inited(): + return + + logging.info(click.style("Start change email mail to {}".format(to), fg="green")) + start_at = time.perf_counter() + # send change email mail using different languages + try: + if language == "zh-Hans": + template = "transfer_workspace_owner_confirm_template_zh-CN.html" + system_features = FeatureService.get_system_features() + if system_features.branding.enabled: + template = "without-brand/transfer_workspace_owner_confirm_template_zh-CN.html" + html_content = render_template(template, to=to, code=code, WorkspaceName=workspace) + mail.send(to=to, subject="验证您转移工作空间所有权的请求", html=html_content) + else: + html_content = render_template(template, to=to, code=code, WorkspaceName=workspace) + mail.send(to=to, subject="验证您转移工作空间所有权的请求", html=html_content) + else: + template = "transfer_workspace_owner_confirm_template_en-US.html" + system_features = FeatureService.get_system_features() + if system_features.branding.enabled: + template = "without-brand/transfer_workspace_owner_confirm_template_en-US.html" + html_content = render_template(template, to=to, code=code, WorkspaceName=workspace) + mail.send(to=to, subject="Verify Your Request to Transfer Workspace Ownership", html=html_content) + else: + html_content = render_template(template, to=to, code=code, WorkspaceName=workspace) + mail.send(to=to, subject="Verify Your Request to Transfer Workspace Ownership", html=html_content) + + end_at = time.perf_counter() + logging.info( + click.style( + "Send owner transfer confirm mail to {} succeeded: latency: {}".format(to, end_at - start_at), + fg="green", + ) + ) + except Exception: + logging.exception("owner transfer confirm email mail to {} failed".format(to)) + + +@shared_task(queue="mail") +def send_old_owner_transfer_notify_email_task(language: str, to: str, workspace: str, new_owner_email: str): + """ + Async Send owner transfer confirm mail + :param language: Language in which the email should be sent (e.g., 'en', 'zh') + :param to: Recipient email address + :param workspace: Workspace name + :param new_owner_email: New owner email + """ + if not mail.is_inited(): + return + + logging.info(click.style("Start change email mail to {}".format(to), fg="green")) + start_at = time.perf_counter() + # send change email mail using different languages + try: + if language == "zh-Hans": + template = "transfer_workspace_old_owner_notify_template_zh-CN.html" + system_features = FeatureService.get_system_features() + if system_features.branding.enabled: + template = "without-brand/transfer_workspace_old_owner_notify_template_zh-CN.html" + html_content = render_template(template, to=to, WorkspaceName=workspace, NewOwnerEmail=new_owner_email) + mail.send(to=to, subject="工作区所有权已转移", html=html_content) + else: + html_content = render_template(template, to=to, WorkspaceName=workspace, NewOwnerEmail=new_owner_email) + mail.send(to=to, subject="工作区所有权已转移", html=html_content) + else: + template = "transfer_workspace_old_owner_notify_template_en-US.html" + system_features = FeatureService.get_system_features() + if system_features.branding.enabled: + template = "without-brand/transfer_workspace_old_owner_notify_template_en-US.html" + html_content = render_template(template, to=to, WorkspaceName=workspace, NewOwnerEmail=new_owner_email) + mail.send(to=to, subject="Workspace ownership has been transferred", html=html_content) + else: + html_content = render_template(template, to=to, WorkspaceName=workspace, NewOwnerEmail=new_owner_email) + mail.send(to=to, subject="Workspace ownership has been transferred", html=html_content) + + end_at = time.perf_counter() + logging.info( + click.style( + "Send owner transfer confirm mail to {} succeeded: latency: {}".format(to, end_at - start_at), + fg="green", + ) + ) + except Exception: + logging.exception("owner transfer confirm email mail to {} failed".format(to)) + + +@shared_task(queue="mail") +def send_new_owner_transfer_notify_email_task(language: str, to: str, workspace: str): + """ + Async Send owner transfer confirm mail + :param language: Language in which the email should be sent (e.g., 'en', 'zh') + :param to: Recipient email address + :param code: Change email code + :param workspace: Workspace name + """ + if not mail.is_inited(): + return + + logging.info(click.style("Start change email mail to {}".format(to), fg="green")) + start_at = time.perf_counter() + # send change email mail using different languages + try: + if language == "zh-Hans": + template = "transfer_workspace_new_owner_notify_template_zh-CN.html" + system_features = FeatureService.get_system_features() + if system_features.branding.enabled: + template = "without-brand/transfer_workspace_new_owner_notify_template_zh-CN.html" + html_content = render_template(template, to=to, WorkspaceName=workspace) + mail.send(to=to, subject=f"您现在是 {workspace} 的所有者", html=html_content) + else: + html_content = render_template(template, to=to, WorkspaceName=workspace) + mail.send(to=to, subject=f"您现在是 {workspace} 的所有者", html=html_content) + else: + template = "transfer_workspace_new_owner_notify_template_en-US.html" + system_features = FeatureService.get_system_features() + if system_features.branding.enabled: + template = "without-brand/transfer_workspace_new_owner_notify_template_en-US.html" + html_content = render_template(template, to=to, WorkspaceName=workspace) + mail.send(to=to, subject=f"You are now the owner of {workspace}", html=html_content) + else: + html_content = render_template(template, to=to, WorkspaceName=workspace) + mail.send(to=to, subject=f"You are now the owner of {workspace}", html=html_content) + + end_at = time.perf_counter() + logging.info( + click.style( + "Send owner transfer confirm mail to {} succeeded: latency: {}".format(to, end_at - start_at), + fg="green", + ) + ) + except Exception: + logging.exception("owner transfer confirm email mail to {} failed".format(to)) diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 4a62cb74b4..179adcbd6e 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -6,6 +6,7 @@ import click from celery import shared_task # type: ignore from sqlalchemy import delete from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import sessionmaker from extensions.ext_database import db from models import ( @@ -31,7 +32,8 @@ from models import ( ) from models.tools import WorkflowToolProvider from models.web import PinnedConversation, SavedMessage -from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecutionModel, WorkflowRun +from models.workflow import ConversationVariable, Workflow, WorkflowAppLog +from repositories.factory import DifyAPIRepositoryFactory @shared_task(queue="app_deletion", bind=True, max_retries=3) @@ -189,30 +191,32 @@ def _delete_app_workflows(tenant_id: str, app_id: str): def _delete_app_workflow_runs(tenant_id: str, app_id: str): - def del_workflow_run(workflow_run_id: str): - db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).delete(synchronize_session=False) + """Delete all workflow runs for an app using the service repository.""" + session_maker = sessionmaker(bind=db.engine) + workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) - _delete_records( - """select id from workflow_runs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", - {"tenant_id": tenant_id, "app_id": app_id}, - del_workflow_run, - "workflow run", + deleted_count = workflow_run_repo.delete_runs_by_app( + tenant_id=tenant_id, + app_id=app_id, + batch_size=1000, ) + logging.info(f"Deleted {deleted_count} workflow runs for app {app_id}") + def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): - def del_workflow_node_execution(workflow_node_execution_id: str): - db.session.query(WorkflowNodeExecutionModel).filter( - WorkflowNodeExecutionModel.id == workflow_node_execution_id - ).delete(synchronize_session=False) + """Delete all workflow node executions for an app using the service repository.""" + session_maker = sessionmaker(bind=db.engine) + node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker) - _delete_records( - """select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""", - {"tenant_id": tenant_id, "app_id": app_id}, - del_workflow_node_execution, - "workflow node execution", + deleted_count = node_execution_repo.delete_executions_by_app( + tenant_id=tenant_id, + app_id=app_id, + batch_size=1000, ) + logging.info(f"Deleted {deleted_count} workflow node executions for app {app_id}") + def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def del_workflow_app_log(workflow_app_log_id: str): diff --git a/api/templates/change_mail_confirm_new_template_en-US.html b/api/templates/change_mail_confirm_new_template_en-US.html new file mode 100644 index 0000000000..88721e787c --- /dev/null +++ b/api/templates/change_mail_confirm_new_template_en-US.html @@ -0,0 +1,125 @@ + + + + + + + + +
+
+ + Dify Logo +
+

Confirm Your New Email Address

+
+

You’re updating the email address linked to your Dify account.

+

To confirm this action, please use the verification code below.

+

This code will only be valid for the next 5 minutes:

+
+
+ {{code}} +
+

If you didn’t make this request, please ignore this email or contact support immediately.

+
+ + + + diff --git a/api/templates/change_mail_confirm_new_template_zh-CN.html b/api/templates/change_mail_confirm_new_template_zh-CN.html new file mode 100644 index 0000000000..25336ea1a1 --- /dev/null +++ b/api/templates/change_mail_confirm_new_template_zh-CN.html @@ -0,0 +1,125 @@ + + + + + + + + +
+
+ + Dify Logo +
+

确认您的邮箱地址变更

+
+

您正在更新与您的 Dify 账户关联的邮箱地址。

+

为了确认此操作,请使用以下验证码。

+

此验证码仅在接下来的5分钟内有效:

+
+
+ {{code}} +
+

如果您没有请求变更邮箱地址,请忽略此邮件或立即联系支持。

+
+ + + + diff --git a/api/templates/change_mail_confirm_old_template_en-US.html b/api/templates/change_mail_confirm_old_template_en-US.html new file mode 100644 index 0000000000..b20306aa87 --- /dev/null +++ b/api/templates/change_mail_confirm_old_template_en-US.html @@ -0,0 +1,125 @@ + + + + + + + + +
+
+ + Dify Logo +
+

Verify Your Request to Change Email

+
+

We received a request to change the email address associated with your Dify account.

+

To confirm this action, please use the verification code below.

+

This code will only be valid for the next 5 minutes:

+
+
+ {{code}} +
+

If you didn’t make this request, please ignore this email or contact support immediately.

+
+ + + + diff --git a/api/templates/change_mail_confirm_old_template_zh-CN.html b/api/templates/change_mail_confirm_old_template_zh-CN.html new file mode 100644 index 0000000000..23c9e46652 --- /dev/null +++ b/api/templates/change_mail_confirm_old_template_zh-CN.html @@ -0,0 +1,124 @@ + + + + + + + + +
+
+ + Dify Logo +
+

验证您的邮箱变更请求

+
+

我们收到了一个变更您 Dify 账户关联邮箱地址的请求。

+

此验证码仅在接下来的5分钟内有效:

+
+
+ {{code}} +
+

如果您没有请求变更邮箱地址,请忽略此邮件或立即联系支持。

+
+ + + + diff --git a/api/templates/clean_document_job_mail_template-US.html b/api/templates/clean_document_job_mail_template-US.html index 2d8f78b46a..b26e494f80 100644 --- a/api/templates/clean_document_job_mail_template-US.html +++ b/api/templates/clean_document_job_mail_template-US.html @@ -6,94 +6,136 @@ Documents Disabled Notification -